Skip to content

Commit

Permalink
[mlir][sparse] add conversion rules for storage_get/set/callOp
Browse files Browse the repository at this point in the history
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133175
  • Loading branch information
PeimingLiu committed Sep 2, 2022
1 parent 46b293c commit 928b5b0
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 16 deletions.
Expand Up @@ -217,10 +217,12 @@ struct SparseTensorStorageExpansionPass
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
// We generate UnrealizedConversionCastOp to intermix tuples and a
// list of types.
target.addLegalOp<UnrealizedConversionCastOp>();
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateSparseTensorStorageExpansionPatterns(converter, patterns);
Expand Down
Expand Up @@ -41,10 +41,69 @@ convertSparseTensorStorageTuple(Type t, SmallVectorImpl<Type> &result) {
return llvm::None;
}

/// Flatten a list of operands that may contain tuples.
static void flattenOperands(ValueRange operands,
SmallVectorImpl<Value> &flattened) {
// In case of
// tuple<a, b>, c, tuple<d, e>
// ==>
// a, b, c, d, e
for (auto operand : operands) {
if (auto cast =
dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
cast && cast->getResultTypes()[0].isa<TupleType>())
// An unrealized_conversion_cast will be inserted by type converter to
// inter-mix the gap between 1:N conversion between tuple and types.
// In this case, take the operands in the cast and replace the tuple
// output with the flattened type array.
flattened.append(cast.getOperands().begin(), cast.getOperands().end());
else
flattened.push_back(operand);
}
}
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//

/// Sparse tensor storage conversion rule for sparse_tensor::storage_get.
class SparseStorageGetConverter : public OpConversionPattern<StorageGetOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(StorageGetOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto castOp =
cast<UnrealizedConversionCastOp>(adaptor.getStorage().getDefiningOp());
uint64_t idx = op.getIdx().getZExtValue();
assert(idx < castOp.getOperands().size());

rewriter.replaceOp(op, castOp.getOperand(idx));
return success();
}
};

/// Sparse tensor storage conversion rule for sparse_tensor::storage_set.
class SparseStorageSetConverter : public OpConversionPattern<StorageSetOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(StorageSetOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto castOp =
cast<UnrealizedConversionCastOp>(adaptor.getStorage().getDefiningOp());
uint64_t idx = op.getIdx().getZExtValue();

SmallVector<Value, 8> values(castOp.getOperands());
assert(idx < values.size());

// Updates the corresponding element.
values[idx] = adaptor.getValue();
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, TypeRange{op.getType()}, values);
return success();
}
};

/// Sparse tensor storage conversion rule for returns.
class SparseStorageReturnConverter
: public OpConversionPattern<func::ReturnOp> {
Expand All @@ -54,24 +113,69 @@ class SparseStorageReturnConverter
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value, 8> flattened;
for (auto operand : adaptor.getOperands()) {
if (auto cast =
dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
cast && cast->getResultTypes()[0].isa<TupleType>())
// An unrealized_conversion_cast will be inserted by type converter to
// inter-mix the gap between 1:N conversion between tuple and types.
// In this case, take the operands in the cast and replace the tuple
// output with the flattened type array.
flattened.append(cast.getOperands().begin(), cast.getOperands().end());
else
flattened.push_back(operand);
}
flattenOperands(adaptor.getOperands(), flattened);
// Create a return with the flattened value extracted from tuple.
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
return success();
}
};

/// Sparse tensor storage conversion rule for calls.
class SparseStorageCallConverter : public OpConversionPattern<func::CallOp> {
public:
// The default CallOp converter can not handle 1:N type conversion properly
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// In case of:
// tuple(a, b), f, tuple(c, d) = call @foo(...)
// ==>
// a, b, f, c, d = call @foo(...)
// cast(a, b)->tuple, f, cast(c,d)->tuple
SmallVector<Type, 8> finalRetTy;
if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
return failure();

// (1) Genereates new call with flattened return value.
SmallVector<Value, 8> flattened;
flattenOperands(adaptor.getOperands(), flattened);
auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
finalRetTy, flattened);

// (2) Create cast operation for tuple returns.
SmallVector<Value, 4> castedRet;
// Tracks the offset of current return value (of the orignal call)
// relative to the new call (after tuple flattening);
unsigned retOffset = 0;
for (auto ret : op.getResults()) {
assert(retOffset < newCall.getNumResults());
auto tupleRet = ret.getType().dyn_cast<TupleType>();
if (tupleRet) {
auto tupleSize = tupleRet.size();
// NOTE: The range is computed under the assumption of non-recursive
// tuple type.
ValueRange tupleElem(iterator_range<ResultRange::iterator>(
newCall.result_begin() + retOffset,
newCall.result_begin() + retOffset + tupleSize));
auto castOp = rewriter.create<UnrealizedConversionCastOp>(
loc, TypeRange({tupleRet}), tupleElem);
castedRet.push_back(castOp.getResult(0));
retOffset += tupleSize;
} else {
// If this not a tuple, simply add it into returned values.
castedRet.push_back(ret);
retOffset++;
}
}

assert(castedRet.size() == op.getNumResults());
rewriter.replaceOp(op, castedRet);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -91,6 +195,7 @@ mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() {
/// to expand compounded sparse tensor tuples.
void mlir::populateSparseTensorStorageExpansionPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<SparseStorageReturnConverter>(typeConverter,
patterns.getContext());
patterns.add<SparseStorageGetConverter, SparseStorageSetConverter,
SparseStorageReturnConverter, SparseStorageCallConverter>(
typeConverter, patterns.getContext());
}
40 changes: 39 additions & 1 deletion mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -sparse-tensor-storage-expansion | FileCheck %s
// RUN: mlir-opt %s -sparse-tensor-storage-expansion -cse | FileCheck %s

// CHECK-LABEL: func @sparse_storage_expand(
// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
Expand All @@ -9,3 +9,41 @@ func.func @sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>
-> tuple<memref<?xf64>, memref<?xf64>, f64> {
return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
}

// CHECK-LABEL: func @call_sparse_storage_expand(
// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
// CHECK-SAME: %[[TMP_arg2:.*]]: f64)
// CHECK: %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]])
// CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2 : memref<?xf64>, memref<?xf64>, f64
func.func @call_sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
-> tuple<memref<?xf64>, memref<?xf64>, f64> {
%1 = call @sparse_storage_expand(%arg0) : (tuple<memref<?xf64>, memref<?xf64>, f64>) ->
tuple<memref<?xf64>, memref<?xf64>, f64>
return %1 : tuple<memref<?xf64>, memref<?xf64>, f64>
}

// CHECK-LABEL: func @sparse_storage_get(
// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
// CHECK-SAME: %[[TMP_arg2:.*]]: f64)
// CHECK: return %[[TMP_arg0]] : memref<?xf64>
func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
%0 = sparse_tensor.storage_get %arg0[0]
: tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
return %0 : memref<?xf64>
}

// CHECK-LABEL: func @sparse_storage_set(
// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
// CHECK-SAME: %[[TMP_arg2:.*]]: f64,
// CHECK-SAME: %[[TMP_arg3:.*]]: memref<?xf64>)
// CHECK: return %[[TMP_arg3]], %[[TMP_arg1]], %[[TMP_arg2]] : memref<?xf64>, memref<?xf64>, f64
func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>,
%arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
%0 = sparse_tensor.storage_set %arg0[0], %arg1
: tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
tuple<memref<?xf64>, memref<?xf64>, f64>
return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
}

0 comments on commit 928b5b0

Please sign in to comment.