104 changes: 38 additions & 66 deletions mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,16 @@ using namespace mlir::scf;

namespace {

// Unpacks the single unrealized_conversion_cast using the list of inputs
// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
static void unpackUnrealizedConversionCast(Value v,
SmallVectorImpl<Value> &unpacked) {
if (auto cast =
dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
if (cast.getInputs().size() != 1) {
// 1 : N type conversion.
unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
return;
}
}
// 1 : 1 type conversion.
unpacked.push_back(v);
static SmallVector<Value> flattenValues(ArrayRef<ArrayRef<Value>> values) {
SmallVector<Value> result;
for (ArrayRef<Value> v : values)
llvm::append_range(result, v);
return result;
}

static Value getSingleValue(ArrayRef<Value> values) {
assert(values.size() == 1 && "expected single value");
return values.front();
}

// CRTP
Expand All @@ -40,19 +36,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
public:
using OpConversionPattern<SourceOp>::typeConverter;
using OpConversionPattern<SourceOp>::OpConversionPattern;
using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
using OneToNOpAdaptor =
typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;

//
// Derived classes should provide the following method which performs the
// actual conversion. It should return std::nullopt upon conversion failure
// and return the converted operation upon success.
//
// std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
// ConversionPatternRewriter &rewriter,
// TypeRange dstTypes) const;
// std::optional<SourceOp> convertSourceOp(
// SourceOp op, OneToNOpAdaptor adaptor,
// ConversionPatternRewriter &rewriter,
// TypeRange dstTypes) const;

LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> dstTypes;
SmallVector<unsigned> offsets;
Expand All @@ -73,28 +71,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
return rewriter.notifyMatchFailure(op, "could not convert operation");

// Packs the return value.
SmallVector<Value> packedRets;
SmallVector<ValueRange> packedRets;
for (unsigned i = 1, e = offsets.size(); i < e; i++) {
unsigned start = offsets[i - 1], end = offsets[i];
unsigned len = end - start;
ValueRange mappedValue = newOp->getResults().slice(start, len);
if (len != 1) {
// 1 : N type conversion.
Type origType = op.getResultTypes()[i - 1];
Value mat = typeConverter->materializeSourceConversion(
rewriter, op.getLoc(), origType, mappedValue);
if (!mat) {
return rewriter.notifyMatchFailure(
op, "Failed to materialize 1:N type conversion");
}
packedRets.push_back(mat);
} else {
// 1 : 1 type conversion.
packedRets.push_back(mappedValue.front());
}
packedRets.push_back(mappedValue);
}

rewriter.replaceOp(op, packedRets);
rewriter.replaceOpWithMultiple(op, packedRets);
return success();
}
};
Expand All @@ -105,7 +90,7 @@ class ConvertForOpTypes
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;

// The callback required by CRTP.
std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
// Create a empty new op and inline the regions from the old op.
Expand All @@ -129,16 +114,13 @@ class ConvertForOpTypes
if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
return std::nullopt;

// Unpacked the iteration arguments.
SmallVector<Value> flatArgs;
for (Value arg : adaptor.getInitArgs())
unpackUnrealizedConversionCast(arg, flatArgs);

// We can not do clone as the number of result types after conversion
// might be different.
ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
adaptor.getUpperBound(),
adaptor.getStep(), flatArgs);
ForOp newOp = rewriter.create<ForOp>(
op.getLoc(), getSingleValue(adaptor.getLowerBound()),
getSingleValue(adaptor.getUpperBound()),
getSingleValue(adaptor.getStep()),
flattenValues(adaptor.getInitArgs()));

// Reserve whatever attributes in the original op.
newOp->setAttrs(op->getAttrs());
Expand All @@ -160,12 +142,12 @@ class ConvertIfOpTypes
public:
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;

std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {

IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
adaptor.getCondition(), true);
IfOp newOp = rewriter.create<IfOp>(
op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
newOp->setAttrs(op->getAttrs());

// We do not need the empty blocks created by rewriter.
Expand All @@ -189,15 +171,11 @@ class ConvertWhileOpTypes
public:
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;

std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
// Unpacked the iteration arguments.
SmallVector<Value> flatArgs;
for (Value arg : adaptor.getOperands())
unpackUnrealizedConversionCast(arg, flatArgs);

auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
flattenValues(adaptor.getOperands()));

for (auto i : {0u, 1u}) {
if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
Expand All @@ -218,13 +196,10 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> unpackedYield;
for (Value operand : adaptor.getOperands())
unpackUnrealizedConversionCast(operand, unpackedYield);

rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
rewriter.replaceOpWithNewOp<scf::YieldOp>(
op, flattenValues(adaptor.getOperands()));
return success();
}
};
Expand All @@ -235,13 +210,10 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
public:
using OpConversionPattern<ConditionOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> unpackedYield;
for (Value operand : adaptor.getOperands())
unpackUnrealizedConversionCast(operand, unpackedYield);

rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
rewriter.modifyOpInPlace(
op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
return success();
}
};
Expand Down
236 changes: 133 additions & 103 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Large diffs are not rendered by default.

502 changes: 329 additions & 173 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ struct TestUpdateConsumerType : public ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
llvm::errs() << "TestUpdateConsumerType operand: " << operands.front() << "\n";
// Verify that the incoming operand has been successfully remapped to F64.
if (!operands[0].getType().isF64())
return failure();
Expand Down
2 changes: 1 addition & 1 deletion mlir/unittests/ExecutionEngine/Invoke.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ static struct LLVMInitializer {
/// dialects lowering to LLVM Dialect.
static LogicalResult lowerToLLVMDialect(ModuleOp module) {
PassManager pm(module->getName());
pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass());
pm.addNestedPass<func::FuncOp>(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createConvertFuncToLLVMPass());
pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass());
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
return pm.run(module);
}
Expand Down