Expand Up
@@ -16,20 +16,16 @@ using namespace mlir::scf;
namespace {
// Unpacks the single unrealized_conversion_cast using the list of inputs
// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
static void unpackUnrealizedConversionCast (Value v,
SmallVectorImpl<Value> &unpacked) {
if (auto cast =
dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp ())) {
if (cast.getInputs ().size () != 1 ) {
// 1 : N type conversion.
unpacked.append (cast.getInputs ().begin (), cast.getInputs ().end ());
return ;
}
}
// 1 : 1 type conversion.
unpacked.push_back (v);
static SmallVector<Value> flattenValues (ArrayRef<ArrayRef<Value>> values) {
SmallVector<Value> result;
for (ArrayRef<Value> v : values)
llvm::append_range (result, v);
return result;
}
static Value getSingleValue (ArrayRef<Value> values) {
assert (values.size () == 1 && " expected single value" );
return values.front ();
}
// CRTP
Expand All
@@ -40,19 +36,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
public:
using OpConversionPattern<SourceOp>::typeConverter;
using OpConversionPattern<SourceOp>::OpConversionPattern;
using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
using OneToNOpAdaptor =
typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
//
// Derived classes should provide the following method which performs the
// actual conversion. It should return std::nullopt upon conversion failure
// and return the converted operation upon success.
//
// std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
// ConversionPatternRewriter &rewriter,
// TypeRange dstTypes) const;
// std::optional<SourceOp> convertSourceOp(
// SourceOp op, OneToNOpAdaptor adaptor,
// ConversionPatternRewriter &rewriter,
// TypeRange dstTypes) const;
LogicalResult
matchAndRewrite (SourceOp op, OpAdaptor adaptor,
matchAndRewrite (SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> dstTypes;
SmallVector<unsigned > offsets;
Expand All
@@ -73,28 +71,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
return rewriter.notifyMatchFailure (op, " could not convert operation" );
// Packs the return value.
SmallVector<Value > packedRets;
SmallVector<ValueRange > packedRets;
for (unsigned i = 1 , e = offsets.size (); i < e; i++) {
unsigned start = offsets[i - 1 ], end = offsets[i];
unsigned len = end - start;
ValueRange mappedValue = newOp->getResults ().slice (start, len);
if (len != 1 ) {
// 1 : N type conversion.
Type origType = op.getResultTypes ()[i - 1 ];
Value mat = typeConverter->materializeSourceConversion (
rewriter, op.getLoc (), origType, mappedValue);
if (!mat) {
return rewriter.notifyMatchFailure (
op, " Failed to materialize 1:N type conversion" );
}
packedRets.push_back (mat);
} else {
// 1 : 1 type conversion.
packedRets.push_back (mappedValue.front ());
}
packedRets.push_back (mappedValue);
}
rewriter.replaceOp (op, packedRets);
rewriter.replaceOpWithMultiple (op, packedRets);
return success ();
}
};
Expand All
@@ -105,7 +90,7 @@ class ConvertForOpTypes
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
// The callback required by CRTP.
std::optional<ForOp> convertSourceOp (ForOp op, OpAdaptor adaptor,
std::optional<ForOp> convertSourceOp (ForOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
// Create a empty new op and inline the regions from the old op.
Expand All
@@ -129,16 +114,13 @@ class ConvertForOpTypes
if (failed (rewriter.convertRegionTypes (&op.getRegion (), *typeConverter)))
return std::nullopt;
// Unpacked the iteration arguments.
SmallVector<Value> flatArgs;
for (Value arg : adaptor.getInitArgs ())
unpackUnrealizedConversionCast (arg, flatArgs);
// We can not do clone as the number of result types after conversion
// might be different.
ForOp newOp = rewriter.create <ForOp>(op.getLoc (), adaptor.getLowerBound (),
adaptor.getUpperBound (),
adaptor.getStep (), flatArgs);
ForOp newOp = rewriter.create <ForOp>(
op.getLoc (), getSingleValue (adaptor.getLowerBound ()),
getSingleValue (adaptor.getUpperBound ()),
getSingleValue (adaptor.getStep ()),
flattenValues (adaptor.getInitArgs ()));
// Reserve whatever attributes in the original op.
newOp->setAttrs (op->getAttrs ());
Expand All
@@ -160,12 +142,12 @@ class ConvertIfOpTypes
public:
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
std::optional<IfOp> convertSourceOp (IfOp op, OpAdaptor adaptor,
std::optional<IfOp> convertSourceOp (IfOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
IfOp newOp = rewriter.create <IfOp>(op. getLoc (), dstTypes,
adaptor.getCondition (), true );
IfOp newOp = rewriter.create <IfOp>(
op. getLoc (), dstTypes, getSingleValue ( adaptor.getCondition () ), true );
newOp->setAttrs (op->getAttrs ());
// We do not need the empty blocks created by rewriter.
Expand All
@@ -189,15 +171,11 @@ class ConvertWhileOpTypes
public:
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
std::optional<WhileOp> convertSourceOp (WhileOp op, OpAdaptor adaptor,
std::optional<WhileOp> convertSourceOp (WhileOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
// Unpacked the iteration arguments.
SmallVector<Value> flatArgs;
for (Value arg : adaptor.getOperands ())
unpackUnrealizedConversionCast (arg, flatArgs);
auto newOp = rewriter.create <WhileOp>(op.getLoc (), dstTypes, flatArgs);
auto newOp = rewriter.create <WhileOp>(op.getLoc (), dstTypes,
flattenValues (adaptor.getOperands ()));
for (auto i : {0u , 1u }) {
if (failed (rewriter.convertRegionTypes (&op.getRegion (i), *typeConverter)))
Expand All
@@ -218,13 +196,10 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite (scf::YieldOp op, OpAdaptor adaptor,
matchAndRewrite (scf::YieldOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> unpackedYield;
for (Value operand : adaptor.getOperands ())
unpackUnrealizedConversionCast (operand, unpackedYield);
rewriter.replaceOpWithNewOp <scf::YieldOp>(op, unpackedYield);
rewriter.replaceOpWithNewOp <scf::YieldOp>(
op, flattenValues (adaptor.getOperands ()));
return success ();
}
};
Expand All
@@ -235,13 +210,10 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
public:
using OpConversionPattern<ConditionOp>::OpConversionPattern;
LogicalResult
matchAndRewrite (ConditionOp op, OpAdaptor adaptor,
matchAndRewrite (ConditionOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> unpackedYield;
for (Value operand : adaptor.getOperands ())
unpackUnrealizedConversionCast (operand, unpackedYield);
rewriter.modifyOpInPlace (op, [&]() { op->setOperands (unpackedYield); });
rewriter.modifyOpInPlace (
op, [&]() { op->setOperands (flattenValues (adaptor.getOperands ())); });
return success ();
}
};
Expand Down