Skip to content

Commit

Permalink
[mlir][scf] refactor scf structuralOpConversion to better support 1:N…
Browse files Browse the repository at this point in the history
… type conversion

This patch moves the 1:N type mapping into its own classes to allow better code reuse in D137100.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D137099
  • Loading branch information
PeimingLiu committed Nov 2, 2022
1 parent 85c2d92 commit f4cd367
Showing 1 changed file with 71 additions and 33 deletions.
104 changes: 71 additions & 33 deletions mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,82 @@ static void unpackUnrealizedConversionCast(Value v,
unpacked.push_back(v);
}

class ConvertForOpTypes : public OpConversionPattern<ForOp> {
// CRTP
// A base class that takes care of 1:N type conversion, which maps the converted
// op results (computed by the derived class) and materializes 1:N conversion.
template <typename SourceOp, typename ConcretePattern>
class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
public:
using OpConversionPattern::OpConversionPattern;
using OpConversionPattern<SourceOp>::typeConverter;
using OpConversionPattern<SourceOp>::OpConversionPattern;
using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;

//
// Derived classes should provide the following method which performs the
// actual conversion. It should return llvm::None upon conversion failure and
// return the converted operation upon success.
//
// Optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
// ConversionPatternRewriter &rewriter,
// TypeRange dstTypes) const;

LogicalResult
matchAndRewrite(ForOp op, OpAdaptor adaptor,
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> newResultTypes;
SmallVector<Type> dstTypes;
SmallVector<unsigned> offsets;
offsets.push_back(0);
// Do the type conversion and record the offsets.
for (Type type : op.getResultTypes()) {
if (failed(typeConverter->convertTypes(type, newResultTypes)))
return rewriter.notifyMatchFailure(op, "could not convert result");
offsets.push_back(newResultTypes.size());
if (failed(typeConverter->convertTypes(type, dstTypes)))
return rewriter.notifyMatchFailure(op, "could not convert result type");
offsets.push_back(dstTypes.size());
}

// Calls the actual converter implementation to convert the operation.
Optional<SourceOp> newOp =
static_cast<const ConcretePattern *>(this)->convertSourceOp(
op, adaptor, rewriter, dstTypes);

if (!newOp)
return rewriter.notifyMatchFailure(op, "could not convert operation");

// Packs the return value.
SmallVector<Value> packedRets;
for (unsigned i = 1, e = offsets.size(); i < e; i++) {
unsigned start = offsets[i - 1], end = offsets[i];
unsigned len = end - start;
ValueRange mappedValue = newOp->getResults().slice(start, len);
if (len != 1) {
// 1 : N type conversion.
Type origType = op.getResultTypes()[i - 1];
Value mat = typeConverter->materializeSourceConversion(
rewriter, op.getLoc(), origType, mappedValue);
if (!mat) {
return rewriter.notifyMatchFailure(
op, "Failed to materialize 1:N type conversion");
}
packedRets.push_back(mat);
} else {
// 1 : 1 type conversion.
packedRets.push_back(mappedValue.front());
}
}

rewriter.replaceOp(op, packedRets);
return success();
}
};

class ConvertForOpTypes
: public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> {
public:
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;

// The callback required by CRTP.
Optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
// Create a empty new op and inline the regions from the old op.
//
// This is a little bit tricky. We have two concerns here:
Expand All @@ -67,15 +127,15 @@ class ConvertForOpTypes : public OpConversionPattern<ForOp> {

// convertRegionTypes already takes care of 1:N conversion.
if (failed(rewriter.convertRegionTypes(&op.getLoopBody(), *typeConverter)))
return failure();
return llvm::None;

// Unpacked the iteration arguments.
SmallVector<Value> flatArgs;
for (Value arg : adaptor.getInitArgs())
unpackUnrealizedConversionCast(arg, flatArgs);

// We can not do clone as the number of result types after conversion might
// be different.
// We can not do clone as the number of result types after conversion
// might be different.
ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
adaptor.getUpperBound(),
adaptor.getStep(), flatArgs);
Expand All @@ -89,29 +149,7 @@ class ConvertForOpTypes : public OpConversionPattern<ForOp> {
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
newOp.getLoopBody().end());

// Pack the return value.
SmallVector<Value, 6> packedRets;
for (unsigned i = 1, e = offsets.size(); i < e; i++) {
unsigned start = offsets[i - 1], end = offsets[i];
unsigned len = end - start;
ValueRange mappedValue = newOp.getResults().slice(start, len);
if (len != 1) {
// 1 : N type conversion.
Type origType = op.getResultTypes()[i - 1];
Value mat = typeConverter->materializeSourceConversion(
rewriter, op.getLoc(), origType, mappedValue);
if (!mat)
return rewriter.notifyMatchFailure(
op, "Failed to materialize 1:N type conversion");
packedRets.push_back(mat);
} else {
// 1 : 1 type conversion.
packedRets.push_back(mappedValue.front());
}
}

rewriter.replaceOp(op, packedRets);
return success();
return newOp;
}
};
} // namespace
Expand Down

0 comments on commit f4cd367

Please sign in to comment.