Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,28 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();

/// Attempt to legalize the given operation. This can be used within
/// conversion patterns to change the default pre-order legalization order.
/// Returns "success" if the operation was legalized, "failure" otherwise.
///
/// Note: In a partial conversion, this function returns "success" even if
/// the operation could not be legalized, as long as it was not explicitly
/// marked as illegal in the conversion target.
LogicalResult legalize(Operation *op);

/// Attempt to legalize the given region. This can be used within
/// conversion patterns to change the default pre-order legalization order.
/// Returns "success" if the region was legalized, "failure" otherwise.
///
/// If the current pattern runs with a type converter, the entry block
/// signature will be converted before legalizing the operations in the
/// region.
///
/// Note: In a partial conversion, this function returns "success" even if
/// an operation could not be legalized, as long as it was not explicitly
/// marked as illegal in the conversion target.
LogicalResult legalize(Region *r);

private:
// Allow OperationConverter to construct new rewriters.
friend struct OperationConverter;
Expand All @@ -989,7 +1011,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// conversions. They apply some IR rewrites in a delayed fashion and could
/// bring the IR into an inconsistent state when used standalone.
explicit ConversionPatternRewriter(MLIRContext *ctx,
const ConversionConfig &config);
const ConversionConfig &config,
OperationConverter &converter);

// Hide unsupported pattern rewriter API.
using OpBuilder::setListener;
Expand Down
120 changes: 85 additions & 35 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,22 @@ static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
return pt;
}

namespace {
enum OpConversionMode {
/// In this mode, the conversion will ignore failed conversions to allow
/// illegal operations to co-exist in the IR.
Partial,

/// In this mode, all operations must be legal for the given target for the
/// conversion to succeed.
Full,

/// In this mode, operations are analyzed for legality. No actual rewrites are
/// applied to the operations on success.
Analysis,
};
} // namespace

//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -866,8 +882,9 @@ namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
const ConversionConfig &config)
: rewriter(rewriter), config(config),
const ConversionConfig &config,
OperationConverter &opConverter)
: rewriter(rewriter), config(config), opConverter(opConverter),
notifyingRewriter(rewriter.getContext(), config.listener) {}

//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -1124,6 +1141,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Dialect conversion configuration.
const ConversionConfig &config;

/// The operation converter to use for recursive legalization.
OperationConverter &opConverter;

/// A set of erased operations. This set is utilized only if
/// `allowPatternRollback` is set to "false". Conceptually, this set is
/// similar to `replacedOps` (which is maintained when the flag is set to
Expand Down Expand Up @@ -2084,9 +2104,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
//===----------------------------------------------------------------------===//

ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
MLIRContext *ctx, const ConversionConfig &config,
OperationConverter &opConverter)
: PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl(
*this, config, opConverter)) {
setListener(impl.get());
}

Expand Down Expand Up @@ -2207,6 +2228,37 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
return success();
}

LogicalResult ConversionPatternRewriter::legalize(Region *r) {
// Fast path: If the region is empty, there is nothing to legalize.
if (r->empty())
return success();

// Gather a list of all operations to legalize. This is done before
// converting the entry block signature because unrealized_conversion_cast
// ops should not be included.
SmallVector<Operation *> ops;
for (Block &b : *r)
for (Operation &op : b)
ops.push_back(&op);

// If the current pattern runs with a type converter, convert the entry block
// signature.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this case tested right now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, by this test case.

if (const TypeConverter *converter = impl->currentTypeConverter) {
std::optional<TypeConverter::SignatureConversion> conversion =
converter->convertBlockSignature(&r->front());
if (!conversion)
return failure();
applySignatureConversion(&r->front(), *conversion, converter);
}

// Legalize all operations in the region.
for (Operation *op : ops)
if (failed(legalize(op)))
return failure();

return success();
}

void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
Block::iterator before,
ValueRange argValues) {
Expand Down Expand Up @@ -3192,22 +3244,6 @@ static void reconcileUnrealizedCasts(
// OperationConverter
//===----------------------------------------------------------------------===//

namespace {
enum OpConversionMode {
/// In this mode, the conversion will ignore failed conversions to allow
/// illegal operations to co-exist in the IR.
Partial,

/// In this mode, all operations must be legal for the given target for the
/// conversion to succeed.
Full,

/// In this mode, operations are analyzed for legality. No actual rewrites are
/// applied to the operations on success.
Analysis,
};
} // namespace

namespace mlir {
// This class converts operations to a given conversion target via a set of
// rewrite patterns. The conversion behaves differently depending on the
Expand All @@ -3217,16 +3253,20 @@ struct OperationConverter {
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
: rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
: rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
mode(mode) {}

/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);

private:
/// Converts an operation with the given rewriter.
LogicalResult convert(Operation *op);
/// Converts a single operation. If `isRecursiveLegalization` is "true", the
/// conversion is a recursive legalization request, triggered from within a
/// pattern. In that case, do not emit errors because there will be another
/// attempt at legalizing the operation later (via the regular pre-order
/// legalization mechanism).
LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);

private:
/// The rewriter to use when converting operations.
ConversionPatternRewriter rewriter;

Expand All @@ -3238,32 +3278,42 @@ struct OperationConverter {
};
} // namespace mlir

LogicalResult OperationConverter::convert(Operation *op) {
LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
}

LogicalResult OperationConverter::convert(Operation *op,
bool isRecursiveLegalization) {
const ConversionConfig &config = rewriter.getConfig();

// Legalize the given operation.
if (failed(opLegalizer.legalize(op))) {
// Handle the case of a failed conversion for each of the different modes.
// Full conversions expect all operations to be converted.
if (mode == OpConversionMode::Full)
return op->emitError()
<< "failed to legalize operation '" << op->getName() << "'";
if (mode == OpConversionMode::Full) {
if (!isRecursiveLegalization)
op->emitError() << "failed to legalize operation '" << op->getName()
<< "'";
return failure();
}
// Partial conversions allow conversions to fail iff the operation was not
// explicitly marked as illegal. If the user provided a `unlegalizedOps`
// set, non-legalizable ops are added to that set.
if (mode == OpConversionMode::Partial) {
if (opLegalizer.isIllegal(op))
return op->emitError()
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
if (config.unlegalizedOps)
if (opLegalizer.isIllegal(op)) {
if (!isRecursiveLegalization)
op->emitError() << "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
return failure();
}
if (config.unlegalizedOps && !isRecursiveLegalization)
config.unlegalizedOps->insert(op);
}
} else if (mode == OpConversionMode::Analysis) {
// Analysis conversions don't fail if any operations fail to legalize,
// they are only interested in the operations that were successfully
// legalized.
if (config.legalizableOps)
if (config.legalizableOps && !isRecursiveLegalization)
config.legalizableOps->insert(op);
}
return success();
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Transforms/test-legalizer-full.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,21 @@ builtin.module {
}

}

// -----

// The region of "test.post_order_legalization" is converted before the op.

// expected-remark@+1 {{applyFullConversion failed}}
builtin.module {
func.func @test_preorder_legalization() {
// expected-error@+1 {{failed to legalize operation 'test.post_order_legalization'}}
"test.post_order_legalization"() ({
^bb0(%arg0: i64):
// Not-explicitly-legal ops are not allowed to survive.
"test.remaining_consumer"(%arg0) : (i64) -> ()
"test.invalid"(%arg0) : (i64) -> ()
}) : () -> ()
return
}
}
19 changes: 19 additions & 0 deletions mlir/test/Transforms/test-legalizer-rollback.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,22 @@ func.func @create_unregistered_op_in_pattern() -> i32 {
"test.return"(%0) : (i32) -> ()
}
}

// -----

// CHECK-LABEL: func @test_failed_preorder_legalization
// CHECK: "test.post_order_legalization"() ({
// CHECK: %[[r:.*]] = "test.illegal_op_g"() : () -> i32
// CHECK: "test.return"(%[[r]]) : (i32) -> ()
// CHECK: }) : () -> ()
// expected-remark @+1 {{applyPartialConversion failed}}
module {
func.func @test_failed_preorder_legalization() {
// expected-error @+1 {{failed to legalize operation 'test.post_order_legalization' that was explicitly marked illegal}}
"test.post_order_legalization"() ({
%0 = "test.illegal_op_g"() : () -> (i32)
"test.return"(%0) : (i32) -> ()
}) : () -> ()
return
}
}
32 changes: 32 additions & 0 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,35 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
"test.type_consumer"(%arg0) : (f16) -> ()
"test.return"() : () -> ()
}

// -----

// The region of "test.post_order_legalization" is converted before the op.

// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
// CHECK: notifyOperationInserted: test.invalid
// CHECK: notifyBlockErased
// CHECK: notifyOperationInserted: test.valid, was unlinked
// CHECK: notifyOperationReplaced: test.invalid
// CHECK: notifyOperationErased: test.invalid
// CHECK: notifyOperationModified: test.post_order_legalization

// CHECK-LABEL: func @test_preorder_legalization
// CHECK: "test.post_order_legalization"() ({
// CHECK: ^{{.*}}(%[[arg0:.*]]: f64):
// Note: The survival of a not-explicitly-invalid operation does *not* cause
// a conversion failure in when applying a partial conversion.
// CHECK: %[[cast:.*]] = "test.cast"(%[[arg0]]) : (f64) -> i64
// CHECK: "test.remaining_consumer"(%[[cast]]) : (i64) -> ()
// CHECK: "test.valid"(%[[arg0]]) : (f64) -> ()
// CHECK: }) {is_legal} : () -> ()
func.func @test_preorder_legalization() {
"test.post_order_legalization"() ({
^bb0(%arg0: i64):
// expected-remark @+1 {{'test.remaining_consumer' is not legalizable}}
"test.remaining_consumer"(%arg0) : (i64) -> ()
"test.invalid"(%arg0) : (i64) -> ()
}) : () -> ()
// expected-remark @+1 {{'func.return' is not legalizable}}
return
}
22 changes: 21 additions & 1 deletion mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,22 @@ class TestTypeConsumerOpPattern
}
};

class TestPostOrderLegalization : public ConversionPattern {
public:
TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter)
: ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
for (Region &r : op->getRegions())
if (failed(rewriter.legalize(&r)))
return failure();
rewriter.modifyOpInPlace(
op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); });
return success();
}
};

/// Test unambiguous overload resolution of replaceOpWithMultiple. This
/// function is just to trigger compiler errors. It is never executed.
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
Expand Down Expand Up @@ -1532,7 +1548,8 @@ struct TestLegalizePatternDriver
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
TestValueReplace, TestReplaceWithValidConsumer,
TestTypeConsumerOpPattern>(&getContext(), converter);
TestTypeConsumerOpPattern, TestPostOrderLegalization>(
&getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
Expand Down Expand Up @@ -1560,6 +1577,9 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp(
OperationName("test.value_replace", &getContext()),
[](Operation *op) { return op->hasAttr("is_legal"); });
target.addDynamicallyLegalOp(
OperationName("test.post_order_legalization", &getContext()),
[](Operation *op) { return op->hasAttr("is_legal"); });

// TestCreateUnregisteredOp creates `arith.constant` operation,
// which was not added to target intentionally to test
Expand Down
Loading