diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index ed7e2a08ebfd9..5ac9e26e8636d 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -981,6 +981,28 @@ class ConversionPatternRewriter final : public PatternRewriter { /// Return a reference to the internal implementation. detail::ConversionPatternRewriterImpl &getImpl(); + /// Attempt to legalize the given operation. This can be used within + /// conversion patterns to change the default pre-order legalization order. + /// Returns "success" if the operation was legalized, "failure" otherwise. + /// + /// Note: In a partial conversion, this function returns "success" even if + /// the operation could not be legalized, as long as it was not explicitly + /// marked as illegal in the conversion target. + LogicalResult legalize(Operation *op); + + /// Attempt to legalize the given region. This can be used within + /// conversion patterns to change the default pre-order legalization order. + /// Returns "success" if the region was legalized, "failure" otherwise. + /// + /// If the current pattern runs with a type converter, the entry block + /// signature will be converted before legalizing the operations in the + /// region. + /// + /// Note: In a partial conversion, this function returns "success" even if + /// an operation could not be legalized, as long as it was not explicitly + /// marked as illegal in the conversion target. + LogicalResult legalize(Region *r); + private: // Allow OperationConverter to construct new rewriters. friend struct OperationConverter; @@ -989,7 +1011,8 @@ class ConversionPatternRewriter final : public PatternRewriter { /// conversions. They apply some IR rewrites in a delayed fashion and could /// bring the IR into an inconsistent state when used standalone. explicit ConversionPatternRewriter(MLIRContext *ctx, - const ConversionConfig &config); + const ConversionConfig &config, + OperationConverter &converter); // Hide unsupported pattern rewriter API. using OpBuilder::setListener; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 2fe06970eb568..f8c38fadbd229 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -92,6 +92,22 @@ static OpBuilder::InsertPoint computeInsertPoint(ArrayRef vals) { return pt; } +namespace { +enum OpConversionMode { + /// In this mode, the conversion will ignore failed conversions to allow + /// illegal operations to co-exist in the IR. + Partial, + + /// In this mode, all operations must be legal for the given target for the + /// conversion to succeed. + Full, + + /// In this mode, operations are analyzed for legality. No actual rewrites are + /// applied to the operations on success. + Analysis, +}; +} // namespace + //===----------------------------------------------------------------------===// // ConversionValueMapping //===----------------------------------------------------------------------===// @@ -866,8 +882,9 @@ namespace mlir { namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, - const ConversionConfig &config) - : rewriter(rewriter), config(config), + const ConversionConfig &config, + OperationConverter &opConverter) + : rewriter(rewriter), config(config), opConverter(opConverter), notifyingRewriter(rewriter.getContext(), config.listener) {} //===--------------------------------------------------------------------===// @@ -1124,6 +1141,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Dialect conversion configuration. const ConversionConfig &config; + /// The operation converter to use for recursive legalization. + OperationConverter &opConverter; + /// A set of erased operations. This set is utilized only if /// `allowPatternRollback` is set to "false". Conceptually, this set is /// similar to `replacedOps` (which is maintained when the flag is set to @@ -2084,9 +2104,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure( //===----------------------------------------------------------------------===// ConversionPatternRewriter::ConversionPatternRewriter( - MLIRContext *ctx, const ConversionConfig &config) - : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(*this, config)) { + MLIRContext *ctx, const ConversionConfig &config, + OperationConverter &opConverter) + : PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl( + *this, config, opConverter)) { setListener(impl.get()); } @@ -2207,6 +2228,37 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, return success(); } +LogicalResult ConversionPatternRewriter::legalize(Region *r) { + // Fast path: If the region is empty, there is nothing to legalize. + if (r->empty()) + return success(); + + // Gather a list of all operations to legalize. This is done before + // converting the entry block signature because unrealized_conversion_cast + // ops should not be included. + SmallVector ops; + for (Block &b : *r) + for (Operation &op : b) + ops.push_back(&op); + + // If the current pattern runs with a type converter, convert the entry block + // signature. + if (const TypeConverter *converter = impl->currentTypeConverter) { + std::optional conversion = + converter->convertBlockSignature(&r->front()); + if (!conversion) + return failure(); + applySignatureConversion(&r->front(), *conversion, converter); + } + + // Legalize all operations in the region. + for (Operation *op : ops) + if (failed(legalize(op))) + return failure(); + + return success(); +} + void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues) { @@ -3192,22 +3244,6 @@ static void reconcileUnrealizedCasts( // OperationConverter //===----------------------------------------------------------------------===// -namespace { -enum OpConversionMode { - /// In this mode, the conversion will ignore failed conversions to allow - /// illegal operations to co-exist in the IR. - Partial, - - /// In this mode, all operations must be legal for the given target for the - /// conversion to succeed. - Full, - - /// In this mode, operations are analyzed for legality. No actual rewrites are - /// applied to the operations on success. - Analysis, -}; -} // namespace - namespace mlir { // This class converts operations to a given conversion target via a set of // rewrite patterns. The conversion behaves differently depending on the @@ -3217,16 +3253,20 @@ struct OperationConverter { const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode) - : rewriter(ctx, config), opLegalizer(rewriter, target, patterns), + : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns), mode(mode) {} /// Converts the given operations to the conversion target. LogicalResult convertOperations(ArrayRef ops); -private: - /// Converts an operation with the given rewriter. - LogicalResult convert(Operation *op); + /// Converts a single operation. If `isRecursiveLegalization` is "true", the + /// conversion is a recursive legalization request, triggered from within a + /// pattern. In that case, do not emit errors because there will be another + /// attempt at legalizing the operation later (via the regular pre-order + /// legalization mechanism). + LogicalResult convert(Operation *op, bool isRecursiveLegalization = false); +private: /// The rewriter to use when converting operations. ConversionPatternRewriter rewriter; @@ -3238,32 +3278,42 @@ struct OperationConverter { }; } // namespace mlir -LogicalResult OperationConverter::convert(Operation *op) { +LogicalResult ConversionPatternRewriter::legalize(Operation *op) { + return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true); +} + +LogicalResult OperationConverter::convert(Operation *op, + bool isRecursiveLegalization) { const ConversionConfig &config = rewriter.getConfig(); // Legalize the given operation. if (failed(opLegalizer.legalize(op))) { // Handle the case of a failed conversion for each of the different modes. // Full conversions expect all operations to be converted. - if (mode == OpConversionMode::Full) - return op->emitError() - << "failed to legalize operation '" << op->getName() << "'"; + if (mode == OpConversionMode::Full) { + if (!isRecursiveLegalization) + op->emitError() << "failed to legalize operation '" << op->getName() + << "'"; + return failure(); + } // Partial conversions allow conversions to fail iff the operation was not // explicitly marked as illegal. If the user provided a `unlegalizedOps` // set, non-legalizable ops are added to that set. if (mode == OpConversionMode::Partial) { - if (opLegalizer.isIllegal(op)) - return op->emitError() - << "failed to legalize operation '" << op->getName() - << "' that was explicitly marked illegal"; - if (config.unlegalizedOps) + if (opLegalizer.isIllegal(op)) { + if (!isRecursiveLegalization) + op->emitError() << "failed to legalize operation '" << op->getName() + << "' that was explicitly marked illegal"; + return failure(); + } + if (config.unlegalizedOps && !isRecursiveLegalization) config.unlegalizedOps->insert(op); } } else if (mode == OpConversionMode::Analysis) { // Analysis conversions don't fail if any operations fail to legalize, // they are only interested in the operations that were successfully // legalized. - if (config.legalizableOps) + if (config.legalizableOps && !isRecursiveLegalization) config.legalizableOps->insert(op); } return success(); diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir index 42cec68b9fbbb..8da9109a32762 100644 --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -72,3 +72,21 @@ builtin.module { } } + +// ----- + +// The region of "test.post_order_legalization" is converted before the op. + +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { +func.func @test_preorder_legalization() { + // expected-error@+1 {{failed to legalize operation 'test.post_order_legalization'}} + "test.post_order_legalization"() ({ + ^bb0(%arg0: i64): + // Not-explicitly-legal ops are not allowed to survive. + "test.remaining_consumer"(%arg0) : (i64) -> () + "test.invalid"(%arg0) : (i64) -> () + }) : () -> () + return +} +} diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir index 71e11782e14b0..4bcca6b7e5228 100644 --- a/mlir/test/Transforms/test-legalizer-rollback.mlir +++ b/mlir/test/Transforms/test-legalizer-rollback.mlir @@ -163,3 +163,22 @@ func.func @create_unregistered_op_in_pattern() -> i32 { "test.return"(%0) : (i32) -> () } } + +// ----- + +// CHECK-LABEL: func @test_failed_preorder_legalization +// CHECK: "test.post_order_legalization"() ({ +// CHECK: %[[r:.*]] = "test.illegal_op_g"() : () -> i32 +// CHECK: "test.return"(%[[r]]) : (i32) -> () +// CHECK: }) : () -> () +// expected-remark @+1 {{applyPartialConversion failed}} +module { +func.func @test_failed_preorder_legalization() { + // expected-error @+1 {{failed to legalize operation 'test.post_order_legalization' that was explicitly marked illegal}} + "test.post_order_legalization"() ({ + %0 = "test.illegal_op_g"() : () -> (i32) + "test.return"(%0) : (i32) -> () + }) : () -> () + return +} +} diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 7c43bb7bface0..88a71cc26ab0c 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -448,3 +448,35 @@ func.func @test_working_1to1_pattern(%arg0: f16) { "test.type_consumer"(%arg0) : (f16) -> () "test.return"() : () -> () } + +// ----- + +// The region of "test.post_order_legalization" is converted before the op. + +// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked +// CHECK: notifyOperationInserted: test.invalid +// CHECK: notifyBlockErased +// CHECK: notifyOperationInserted: test.valid, was unlinked +// CHECK: notifyOperationReplaced: test.invalid +// CHECK: notifyOperationErased: test.invalid +// CHECK: notifyOperationModified: test.post_order_legalization + +// CHECK-LABEL: func @test_preorder_legalization +// CHECK: "test.post_order_legalization"() ({ +// CHECK: ^{{.*}}(%[[arg0:.*]]: f64): +// Note: The survival of a not-explicitly-invalid operation does *not* cause +// a conversion failure in when applying a partial conversion. +// CHECK: %[[cast:.*]] = "test.cast"(%[[arg0]]) : (f64) -> i64 +// CHECK: "test.remaining_consumer"(%[[cast]]) : (i64) -> () +// CHECK: "test.valid"(%[[arg0]]) : (f64) -> () +// CHECK: }) {is_legal} : () -> () +func.func @test_preorder_legalization() { + "test.post_order_legalization"() ({ + ^bb0(%arg0: i64): + // expected-remark @+1 {{'test.remaining_consumer' is not legalizable}} + "test.remaining_consumer"(%arg0) : (i64) -> () + "test.invalid"(%arg0) : (i64) -> () + }) : () -> () + // expected-remark @+1 {{'func.return' is not legalizable}} + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 12edecc113495..9b64bc691588d 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1418,6 +1418,22 @@ class TestTypeConsumerOpPattern } }; +class TestPostOrderLegalization : public ConversionPattern { +public: + TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter) + : ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + for (Region &r : op->getRegions()) + if (failed(rewriter.legalize(&r))) + return failure(); + rewriter.modifyOpInPlace( + op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); }); + return success(); + } +}; + /// Test unambiguous overload resolution of replaceOpWithMultiple. This /// function is just to trigger compiler errors. It is never executed. [[maybe_unused]] void testReplaceOpWithMultipleOverloads( @@ -1532,7 +1548,8 @@ struct TestLegalizePatternDriver patterns.add(&getContext(), converter); + TestTypeConsumerOpPattern, TestPostOrderLegalization>( + &getContext(), converter); patterns.add(converter, &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1560,6 +1577,9 @@ struct TestLegalizePatternDriver target.addDynamicallyLegalOp( OperationName("test.value_replace", &getContext()), [](Operation *op) { return op->hasAttr("is_legal"); }); + target.addDynamicallyLegalOp( + OperationName("test.post_order_legalization", &getContext()), + [](Operation *op) { return op->hasAttr("is_legal"); }); // TestCreateUnregisteredOp creates `arith.constant` operation, // which was not added to target intentionally to test