diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 78dcfe7f6fc3d..b8aeea0d23475 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -588,8 +588,7 @@ class RewriterBase : public OpBuilder { /// Unlink this operation from its current block and insert it right before /// `iterator` in the specified block. - virtual void moveOpBefore(Operation *op, Block *block, - Block::iterator iterator); + void moveOpBefore(Operation *op, Block *block, Block::iterator iterator); /// Unlink this operation from its current block and insert it right after /// `existingOp` which may be in the same or another block in the same @@ -598,8 +597,7 @@ class RewriterBase : public OpBuilder { /// Unlink this operation from its current block and insert it right after /// `iterator` in the specified block. - virtual void moveOpAfter(Operation *op, Block *block, - Block::iterator iterator); + void moveOpAfter(Operation *op, Block *block, Block::iterator iterator); /// Unlink this block and insert it right before `existingBlock`. void moveBlockBefore(Block *block, Block *anotherBlock); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 851d639ae68a7..15fa39bde104b 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -744,8 +744,8 @@ class ConversionPatternRewriter final : public PatternRewriter { /// PatternRewriter hook for updating the given operation in-place. /// Note: These methods only track updates to the given operation itself, - /// and not nested regions. Updates to regions will still require notification - /// through other more specific hooks above. + /// and not nested regions. Updates to regions will still require + /// notification through other more specific hooks above. void startOpModification(Operation *op) override; /// PatternRewriter hook for updating the given operation in-place. @@ -761,11 +761,6 @@ class ConversionPatternRewriter final : public PatternRewriter { // Hide unsupported pattern rewriter API. using OpBuilder::setListener; - void moveOpBefore(Operation *op, Block *block, - Block::iterator iterator) override; - void moveOpAfter(Operation *op, Block *block, - Block::iterator iterator) override; - std::unique_ptr impl; }; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 9875f8668b65a..84597fb7986b0 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -760,7 +760,8 @@ class IRRewrite { InlineBlock, MoveBlock, SplitBlock, - BlockTypeConversion + BlockTypeConversion, + MoveOperation }; virtual ~IRRewrite() = default; @@ -982,6 +983,54 @@ class BlockTypeConversionRewrite : public BlockRewrite { // `ArgConverter::applyRewrites`. This should be done in the "commit" method. void rollback() override; }; + +/// An operation rewrite. +class OperationRewrite : public IRRewrite { +public: + /// Return the operation that this rewrite operates on. + Operation *getOperation() const { return op; } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() >= Kind::MoveOperation && + rewrite->getKind() <= Kind::MoveOperation; + } + +protected: + OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, + Operation *op) + : IRRewrite(kind, rewriterImpl), op(op) {} + + // The operation that this rewrite operates on. + Operation *op; +}; + +/// Moving of an operation. This rewrite is immediately reflected in the IR. +class MoveOperationRewrite : public OperationRewrite { +public: + MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Operation *op, Block *block, Operation *insertBeforeOp) + : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block), + insertBeforeOp(insertBeforeOp) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::MoveOperation; + } + + void rollback() override { + // Move the operation back to its original position. + Block::iterator before = + insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end(); + block->getOperations().splice(before, op->getBlock()->getOperations(), op); + } + +private: + // The block in which this operation was previously contained. + Block *block; + + // The original successor of this operation before it was moved. "nullptr" if + // this operation was the only operation in the region. + Operation *insertBeforeOp; +}; } // namespace //===----------------------------------------------------------------------===// @@ -1478,12 +1527,19 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( void ConversionPatternRewriterImpl::notifyOperationInserted( Operation *op, OpBuilder::InsertPoint previous) { - assert(!previous.isSet() && "expected newly created op"); LLVM_DEBUG({ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); - createdOps.push_back(op); + if (!previous.isSet()) { + // This is a newly created op. + createdOps.push_back(op); + return; + } + Operation *prevOp = previous.getPoint() == previous.getBlock()->end() + ? nullptr + : &*previous.getPoint(); + appendRewrite(op, previous.getBlock(), prevOp); } void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, @@ -1722,18 +1778,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) { rootUpdates.erase(rootUpdates.begin() + updateIdx); } -void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block, - Block::iterator iterator) { - llvm_unreachable( - "moving single ops is not supported in a dialect conversion"); -} - -void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block, - Block::iterator iterator) { - llvm_unreachable( - "moving single ops is not supported in a dialect conversion"); -} - detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { return *impl; } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index d8cf6e4719ced..84fcc18ab7d37 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -320,3 +320,17 @@ module { return } } + +// ----- + +// CHECK-LABEL: func @test_move_op_before_rollback() +func.func @test_move_op_before_rollback() { + // CHECK: "test.one_region_op"() + // CHECK: "test.hoist_me"() + "test.one_region_op"() ({ + // expected-remark @below{{'test.hoist_me' is not legalizable}} + %0 = "test.hoist_me"() : () -> (i32) + "test.valid"(%0) : (i32) -> () + }) : () -> () + "test.return"() : () -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index d7e5d6db50c1f..1c02232b8adbb 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -773,6 +773,22 @@ struct TestUndoBlockArgReplace : public ConversionPattern { } }; +/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. +/// This is to test the rollback logic. +struct TestUndoMoveOpBefore : public ConversionPattern { + TestUndoMoveOpBefore(MLIRContext *ctx) + : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.moveOpBefore(op, op->getParentOp()); + // Replace with an illegal op to ensure the conversion fails. + rewriter.replaceOpWithNewOp(op, rewriter.getF32Type()); + return success(); + } +}; + /// A rewrite pattern that tests the undo mechanism when erasing a block. struct TestUndoBlockErase : public ConversionPattern { TestUndoBlockErase(MLIRContext *ctx) @@ -1069,7 +1085,7 @@ struct TestLegalizePatternDriver TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, TestNonRootReplacement, TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, - TestCreateUnregisteredOp>(&getContext()); + TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext()); patterns.add(&getContext(), converter); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1079,7 +1095,7 @@ struct TestLegalizePatternDriver ConversionTarget target(getContext()); target.addLegalOp(); target.addLegalOp(); + TerminatorOp, OneRegionOp>(); target .addIllegalOp(); target.addDynamicallyLegalOp([](TestReturnOp op) {