Skip to content

Commit

Permalink
[mlir][scf] support 1:N type conversion for scf.for.
Browse files Browse the repository at this point in the history
scf.for used to only support 1:1 type conversion, this patch add support for 1:N type conversion.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136314
  • Loading branch information
PeimingLiu committed Oct 21, 2022
1 parent 4153f98 commit d3f5f33
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 37 deletions.
122 changes: 85 additions & 37 deletions mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
Expand Up @@ -15,58 +15,102 @@ using namespace mlir;
using namespace mlir::scf;

namespace {

// Unpacks the single unrealized_conversion_cast using the list of inputs
// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
static void unpackUnrealizedConversionCast(Value v,
SmallVectorImpl<Value> &unpacked) {
if (auto cast =
dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
if (cast.getInputs().size() != 1) {
// 1 : N type conversion.
unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
return;
}
}
// 1 : 1 type conversion.
unpacked.push_back(v);
}

class ConvertForOpTypes : public OpConversionPattern<ForOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ForOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type, 6> newResultTypes;
for (auto type : op.getResultTypes()) {
Type newType = typeConverter->convertType(type);
if (!newType)
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
newResultTypes.push_back(newType);
SmallVector<Type> newResultTypes;
SmallVector<unsigned> offsets;
offsets.push_back(0);
// Do the type conversion and record the offsets.
for (Type type : op.getResultTypes()) {
if (failed(typeConverter->convertTypes(type, newResultTypes)))
return rewriter.notifyMatchFailure(op, "could not convert result");
offsets.push_back(newResultTypes.size());
}

// Clone the op without the regions and inline the regions from the old op.
// Create a empty new op and inline the regions from the old op.
//
// This is a little bit tricky. We have two concerns here:
//
// 1. We cannot update the op in place because the dialect conversion
// framework does not track type changes for ops updated in place, so it
// won't insert appropriate materializations on the changed result types.
// PR47938 tracks this issue, but it seems hard to fix. Instead, we need to
// clone the op.
// PR47938 tracks this issue, but it seems hard to fix. Instead, we need
// to clone the op.
//
// 2. We cannot simply call `op.clone()` to get the cloned op. Besides being
// inefficient to recursively clone the regions, there is a correctness
// issue: if we clone with the regions, then the dialect conversion
// framework thinks that we just inserted all the cloned child ops. But what
// we want is to "take" the child regions and let the dialect conversion
// framework continue recursively into ops inside those regions (which are
// already in its worklist; inlining them into the new op's regions doesn't
// remove the child ops from the worklist).
ForOp newOp = cast<ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
// Take the region from the old op and put it in the new op.
// 2. We need to resue the original region instead of cloning it, otherwise
// the dialect conversion framework thinks that we just inserted all the
// cloned child ops. But what we want is to "take" the child regions and let
// the dialect conversion framework continue recursively into ops inside
// those regions (which are already in its worklist; inlining them into the
// new op's regions doesn't remove the child ops from the worklist).

// convertRegionTypes already takes care of 1:N conversion.
if (failed(rewriter.convertRegionTypes(&op.getLoopBody(), *typeConverter)))
return failure();

// Unpacked the iteration arguments.
SmallVector<Value> flatArgs;
for (Value arg : adaptor.getInitArgs())
unpackUnrealizedConversionCast(arg, flatArgs);

// We can not do clone as the number of result types after conversion might
// be different.
ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
adaptor.getUpperBound(),
adaptor.getStep(), flatArgs);

// Reserve whatever attributes in the original op.
newOp->setAttrs(op->getAttrs());

// We do not need the empty block created by rewriter.
rewriter.eraseBlock(newOp.getBody(0));
// Inline the type converted region from the original operation.
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
newOp.getLoopBody().end());

// Now, update all the types.

// Convert the type of the entry block of the ForOp's body.
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
*getTypeConverter()))) {
return rewriter.notifyMatchFailure(op, "could not convert body types");
// Pack the return value.
SmallVector<Value, 6> packedRets;
for (unsigned i = 1, e = offsets.size(); i < e; i++) {
unsigned start = offsets[i - 1], end = offsets[i];
unsigned len = end - start;
ValueRange mappedValue = newOp.getResults().slice(start, len);
if (len != 1) {
// 1 : N type conversion.
Type origType = op.getResultTypes()[i - 1];
Value mat = typeConverter->materializeSourceConversion(
rewriter, op.getLoc(), origType, mappedValue);
if (!mat)
return rewriter.notifyMatchFailure(
op, "Failed to materialize 1:N type conversion");
packedRets.push_back(mat);
} else {
// 1 : 1 type conversion.
packedRets.push_back(mappedValue.front());
}
}
// Change the clone to use the updated operands. We could have cloned with
// a BlockAndValueMapping, but this seems a bit more direct.
newOp->setOperands(adaptor.getOperands());
// Update the result types to the new converted types.
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
std::get<0>(t).setType(std::get<1>(t));

rewriter.replaceOp(op, newOp.getResults());
rewriter.replaceOp(op, packedRets);
return success();
}
};
Expand All @@ -81,12 +125,12 @@ class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
ConversionPatternRewriter &rewriter) const override {
// TODO: Generalize this to any type conversion, not just 1:1.
//
// We need to implement something more sophisticated here that tracks which
// types convert to which other types and does the appropriate
// We need to implement something more sophisticated here that tracks
// which types convert to which other types and does the appropriate
// materialization logic.
// For example, it's possible that one result type converts to 0 types and
// another to 2 types, so newResultTypes would at least be the right size to
// not crash in the llvm::zip call below, but then we would set the the
// another to 2 types, so newResultTypes would at least be the right size
// to not crash in the llvm::zip call below, but then we would set the the
// wrong type on the SSA values! These edge cases are also why we cannot
// safely use the TypeConverter::convertTypes helper here.
SmallVector<Type, 6> newResultTypes;
Expand Down Expand Up @@ -125,7 +169,11 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
LogicalResult
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
SmallVector<Value> unpackedYield;
for (Value operand : adaptor.getOperands())
unpackUnrealizedConversionCast(operand, unpackedYield);

rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
return success();
}
};
Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
@@ -0,0 +1,29 @@
// RUN: mlir-opt %s -sparse-tensor-codegen -cse | FileCheck %s

#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
// CHECK-LABEL: func @for(
// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>,
// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>,
// CHECK-SAME: %[[POINTER:.*2]]: memref<?xindex>,
// CHECK-SAME: %[[INDICES:.*3]]: memref<?xindex>,
// CHECK-SAME: %[[VALUE:.*4]]: memref<?xf32>,
// CHECK-SAME: %[[TMP_arg5:.*5]]: index,
// CHECK-SAME: %[[TMP_arg6:.*6]]: index,
// CHECK-SAME: %[[TMP_arg7:.*7]]: index
// CHECK: %[[TMP_0:.*]]:5 = scf.for %[[TMP_arg8:.*]] = %[[TMP_arg5]] to %[[TMP_arg6]] step %[[TMP_arg7]] iter_args(
// CHECK-SAME: %[[TMP_arg9:.*]] = %[[DIM_SIZE]],
// CHECK-SAME: %[[TMP_arg10:.*]] = %[[MEM_SIZE]],
// CHECK-SAME: %[[TMP_arg11:.*]] = %[[POINTER]],
// CHECK-SAME: %[[TMP_arg12:.*]] = %[[INDICES]],
// CHECK-SAME: %[[TMP_arg13:.*]] = %[[VALUE]])
// CHECK: scf.yield %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]], %[[TMP_arg13]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
// CHECK: }
// CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2, %[[TMP_0]]#3, %[[TMP_0]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
func.func @for(%in: tensor<1024xf32, #SparseVector>,
%lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> {
%1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in)
-> tensor<1024xf32, #SparseVector> {
scf.yield %vin : tensor<1024xf32, #SparseVector>
}
return %1 : tensor<1024xf32, #SparseVector>
}

0 comments on commit d3f5f33

Please sign in to comment.