diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h index b5b2c99810b15..30f20ccfa1543 100644 --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -490,7 +490,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener { LLVM_DUMP_METHOD void dumpFunc(); /// FirOpBuilder hook for creating new operation. - void notifyOperationInserted(mlir::Operation *op) override { + void notifyOperationInserted(mlir::Operation *op, + mlir::OpBuilder::InsertPoint previous) override { + // We only care about newly created operations. + if (previous.isSet()) + return; setCommonAttributes(op); } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp index 641854bd201f0..5fe78b7408026 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -730,9 +730,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener { HLFIRListener(fir::FirOpBuilder &builder, mlir::ConversionPatternRewriter &rewriter) : builder{builder}, rewriter{rewriter} {} - void notifyOperationInserted(mlir::Operation *op) override { - builder.notifyOperationInserted(op); - rewriter.notifyOperationInserted(op); + void notifyOperationInserted(mlir::Operation *op, + mlir::OpBuilder::InsertPoint previous) override { + builder.notifyOperationInserted(op, previous); + rewriter.notifyOperationInserted(op, previous); } virtual void notifyBlockCreated(mlir::Block *block) override { builder.notifyBlockCreated(block); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 13fbc3fb928c3..6b95be7c6d372 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -205,6 +205,7 @@ class Builder { /// automatically inserted at an insertion point. The builder is copyable. class OpBuilder : public Builder { public: + class InsertPoint; struct Listener; /// Create a builder with the given context. @@ -285,12 +286,17 @@ class OpBuilder : public Builder { virtual ~Listener() = default; - /// Notification handler for when an operation is inserted into the builder. - /// `op` is the operation that was inserted. - virtual void notifyOperationInserted(Operation *op) {} - - /// Notification handler for when a block is created using the builder. - /// `block` is the block that was created. + /// Notify the listener that the specified operation was inserted. + /// + /// * If the operation was moved, then `previous` is the previous location + /// of the op. + /// * If the operation was unlinked before it was inserted, then `previous` + /// is empty. + /// + /// Note: Creating an (unlinked) op does not trigger this notification. + virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {} + + /// Notify the listener that the specified block was inserted. virtual void notifyBlockCreated(Block *block) {} protected: @@ -517,7 +523,7 @@ class OpBuilder : public Builder { if (succeeded(tryFold(op, results))) op->erase(); else if (listener) - listener->notifyOperationInserted(op); + listener->notifyOperationInserted(op, /*previous=*/{}); } /// Overload to create or fold a single result operation. diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 815340c918509..7f233cd3f4d4b 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -428,6 +428,8 @@ class RewriterBase : public OpBuilder { /// Notify the listener that the specified operation is about to be erased. /// At this point, the operation has zero uses. + /// + /// Note: This notification is not triggered when unlinking an operation. virtual void notifyOperationRemoved(Operation *op) {} /// Notify the listener that the pattern failed to match the given @@ -450,8 +452,8 @@ class RewriterBase : public OpBuilder { struct ForwardingListener : public RewriterBase::Listener { ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {} - void notifyOperationInserted(Operation *op) override { - listener->notifyOperationInserted(op); + void notifyOperationInserted(Operation *op, InsertPoint previous) override { + listener->notifyOperationInserted(op, previous); } void notifyBlockCreated(Block *block) override { listener->notifyBlockCreated(block); @@ -591,6 +593,26 @@ class RewriterBase : public OpBuilder { /// block into a new block, and return it. virtual Block *splitBlock(Block *block, Block::iterator before); + /// Unlink this operation from its current block and insert it right before + /// `existingOp` which may be in the same or another block in the same + /// function. + void moveOpBefore(Operation *op, Operation *existingOp); + + /// Unlink this operation from its current block and insert it right before + /// `iterator` in the specified block. + virtual void moveOpBefore(Operation *op, Block *block, + Block::iterator iterator); + + /// Unlink this operation from its current block and insert it right after + /// `existingOp` which may be in the same or another block in the same + /// function. + void moveOpAfter(Operation *op, Operation *existingOp); + + /// Unlink this operation from its current block and insert it right after + /// `iterator` in the specified block. + virtual void moveOpAfter(Operation *op, Block *block, + Block::iterator iterator); + /// This method is used to notify the rewriter that an in-place operation /// modification is about to happen. A call to this function *must* be /// followed by a call to either `finalizeOpModification` or diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 9568540789df3..32c5937d014e9 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -737,7 +737,7 @@ class ConversionPatternRewriter final : public PatternRewriter, using PatternRewriter::cloneRegionBefore; /// PatternRewriter hook for inserting a new operation. - void notifyOperationInserted(Operation *op) override; + void notifyOperationInserted(Operation *op, InsertPoint previous) override; /// PatternRewriter hook for updating the given operation in-place. /// Note: These methods only track updates to the given operation itself, @@ -761,9 +761,15 @@ class ConversionPatternRewriter final : public PatternRewriter, detail::ConversionPatternRewriterImpl &getImpl(); private: + // Hide unsupported pattern rewriter API. using OpBuilder::getListener; using OpBuilder::setListener; + void moveOpBefore(Operation *op, Block *block, + Block::iterator iterator) override; + void moveOpAfter(Operation *op, Block *block, + Block::iterator iterator) override; + std::unique_ptr impl; }; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index c260e68d509e9..b802ae33edacc 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1206,7 +1206,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, if (failed(applyOp->fold(constOperands, foldResults)) || foldResults.empty()) { if (OpBuilder::Listener *listener = b.getListener()) - listener->notifyOperationInserted(applyOp); + listener->notifyOperationInserted(applyOp, /*previous=*/{}); return applyOp.getResult(); } @@ -1274,7 +1274,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, if (failed(minMaxOp->fold(constOperands, foldResults)) || foldResults.empty()) { if (OpBuilder::Listener *listener = b.getListener()) - listener->notifyOperationInserted(minMaxOp); + listener->notifyOperationInserted(minMaxOp, /*previous=*/{}); return minMaxOp.getResult(); } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index 428a3c945581b..8c3e25355f608 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -273,7 +273,7 @@ static ParallelComputeFunction createParallelComputeFunction( // Insert function into the module symbol table and assign it unique name. SymbolTable symbolTable(module); symbolTable.insert(func); - rewriter.getListener()->notifyOperationInserted(func); + rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{}); // Create function entry block. Block *block = @@ -489,7 +489,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, // Insert function into the module symbol table and assign it unique name. SymbolTable symbolTable(module); symbolTable.insert(func); - rewriter.getListener()->notifyOperationInserted(func); + rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{}); // Create function entry block. Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(), diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 3f1626a6af34d..2758d554712b9 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -371,7 +371,11 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { toMemrefOps.erase(op); } - void notifyOperationInserted(Operation *op) override { + void notifyOperationInserted(Operation *op, InsertPoint previous) override { + // We only care about newly created ops. + if (previous.isSet()) + return; + erasedOps.erase(op); // Gather statistics about allocs. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 140bdd1f2db36..803c5691a0403 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -214,8 +214,12 @@ class NewOpsListener : public RewriterBase::ForwardingListener { } private: - void notifyOperationInserted(Operation *op) override { - ForwardingListener::notifyOperationInserted(op); + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint previous) override { + ForwardingListener::notifyOperationInserted(op, previous); + // We only care about newly created ops. + if (previous.isSet()) + return; auto inserted = newOps.insert(op); (void)inserted; assert(inserted.second && "expected newly created op"); diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index cda561b1d1054..9f8189ae15e6d 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern { // Inline for-loop body operations into 'after' region. for (auto &arg : llvm::make_early_inc_range(*forOp.getBody())) - arg.moveBefore(afterBlock, afterBlock->end()); + rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end()); // Add incremented IV to yield operations for (auto yieldOp : afterBlock->getOps()) { diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 2cd57e7324b4d..678b7c099fa36 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface for (Operation *user : srcBuffer->getUsers()) { if (hasEffect(user)) { if (user->getBlock() == parallelCombiningParent->getBlock()) - user->moveBefore(user->getBlock()->getTerminator()); + rewriter.moveOpBefore(user, user->getBlock()->getTerminator()); break; } } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index d156504765877..a319afcdc6a9a 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -412,7 +412,7 @@ Operation *OpBuilder::insert(Operation *op) { block->getOperations().insert(insertPoint, op); if (listener) - listener->notifyOperationInserted(op); + listener->notifyOperationInserted(op, /*previous=*/{}); return op; } @@ -530,7 +530,7 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) { // about any ops that got inserted inside those regions as part of cloning. if (listener) { auto walkFn = [&](Operation *walkedOp) { - listener->notifyOperationInserted(walkedOp); + listener->notifyOperationInserted(walkedOp, /*previous=*/{}); }; for (Region ®ion : newOp->getRegions()) region.walk(walkFn); diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index ba0516e0539b6..affb8898fa075 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -366,3 +366,31 @@ void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent, void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) { cloneRegionBefore(region, *before->getParent(), before->getIterator()); } + +void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) { + moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator()); +} + +void RewriterBase::moveOpBefore(Operation *op, Block *block, + Block::iterator iterator) { + Block *currentBlock = op->getBlock(); + Block::iterator currentIterator = op->getIterator(); + op->moveBefore(block, iterator); + if (listener) + listener->notifyOperationInserted( + op, /*previous=*/InsertPoint(currentBlock, currentIterator)); +} + +void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) { + moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator()); +} + +void RewriterBase::moveOpAfter(Operation *op, Block *block, + Block::iterator iterator) { + Block *currentBlock = op->getBlock(); + Block::iterator currentIterator = op->getIterator(); + op->moveAfter(block, iterator); + if (listener) + listener->notifyOperationInserted( + op, /*previous=*/InsertPoint(currentBlock, currentIterator)); +} diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index ef6a49455d186..f5bede2b94f9c 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1602,11 +1602,13 @@ void ConversionPatternRewriter::cloneRegionBefore(Region ®ion, Block *cloned = mapping.lookup(&b); impl->notifyCreatedBlock(cloned); cloned->walk>( - [&](Operation *op) { notifyOperationInserted(op); }); + [&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); }); } } -void ConversionPatternRewriter::notifyOperationInserted(Operation *op) { +void ConversionPatternRewriter::notifyOperationInserted(Operation *op, + InsertPoint previous) { + assert(!previous.isSet() && "expected newly created op"); LLVM_DEBUG({ impl->logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; @@ -1651,6 +1653,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure( return impl->notifyMatchFailure(loc, reasonCallback); } +void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block, + Block::iterator iterator) { + llvm_unreachable( + "moving single ops is not supported in a dialect conversion"); +} + +void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block, + Block::iterator iterator) { + llvm_unreachable( + "moving single ops is not supported in a dialect conversion"); +} + detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { return *impl; } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index ac73e82bfe92a..c27fee7a738eb 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -133,9 +133,16 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener { } } - void notifyOperationInserted(Operation *op) override { - RewriterBase::ForwardingListener::notifyOperationInserted(op); + void notifyOperationInserted(Operation *op, InsertPoint previous) override { + RewriterBase::ForwardingListener::notifyOperationInserted(op, previous); + // Invalidate the finger print of the op that owns the block into which the + // op was inserted into. invalidateFingerPrint(op->getParentOp()); + + // Also invalidate the finger print of the op that owns the block from which + // the op was moved from. (Only applicable if the op was moved.) + if (previous.isSet()) + invalidateFingerPrint(previous.getBlock()->getParentOp()); } void notifyOperationModified(Operation *op) override { @@ -331,7 +338,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter, /// Notify the driver that the specified operation was inserted. Update the /// worklist as needed: The operation is enqueued depending on scope and /// strict mode. - void notifyOperationInserted(Operation *op) override; + void notifyOperationInserted(Operation *op, InsertPoint previous) override; /// Notify the driver that the specified operation was removed. Update the /// worklist as needed: The operation and its children are removed from the @@ -641,13 +648,14 @@ void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) { config.listener->notifyBlockRemoved(block); } -void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) { +void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op, + InsertPoint previous) { LLVM_DEBUG({ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); if (config.listener) - config.listener->notifyOperationInserted(op); + config.listener->notifyOperationInserted(op, previous); if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps) strictModeFilteredOps.insert(op); addToWorklist(op); diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp index 8f97fd3d9ddf8..66ce6067963f8 100644 --- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp @@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, iterArg = loopLike.getRegionIterArgs()[iterArgIdx]; OpResult loopResult = loopLike.getTiedLoopResult(iterArg); OpResult newLoopResult = loopLike.getLoopResults()->back(); - extractionOp->moveBefore(loopLike); - insertionOp->moveAfter(loopLike); + rewriter.moveOpBefore(extractionOp, loopLike); + rewriter.moveOpAfter(insertionOp, loopLike); rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(), insertionOp.getDestinationOperand().get()); extractionOp.getSourceOperand().set( diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index a8a808424b690..8a92d840ad130 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -159,8 +159,8 @@ struct TestSCFPipeliningPass auto ifOp = rewriter.create(loc, op->getResultTypes(), pred, true); // True branch. - op->moveBefore(&ifOp.getThenRegion().front(), - ifOp.getThenRegion().front().begin()); + rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(), + ifOp.getThenRegion().front().begin()); rewriter.setInsertionPointAfter(op); if (op->getNumResults() > 0) rewriter.create(loc, op->getResults()); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index d1ac5e81e75a6..89b9d1ce78a52 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern { return failure(); if (!toBeHoisted->hasAttr("eligible")) return failure(); - // Hoisting means removing an op from the enclosing op. I.e., the enclosing - // op is modified. - rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); }); + rewriter.moveOpBefore(toBeHoisted, op); return success(); } }; diff --git a/mlir/test/lib/IR/TestClone.cpp b/mlir/test/lib/IR/TestClone.cpp index 13a0cfeb402a9..7b18f219b915f 100644 --- a/mlir/test/lib/IR/TestClone.cpp +++ b/mlir/test/lib/IR/TestClone.cpp @@ -15,7 +15,8 @@ using namespace mlir; namespace { struct DumpNotifications : public OpBuilder::Listener { - void notifyOperationInserted(Operation *op) override { + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint previous) override { llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n"; } }; diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp index aa67e0a78d43b..81f634d8d3ef7 100644 --- a/mlir/test/lib/Transforms/TestConstantFold.cpp +++ b/mlir/test/lib/Transforms/TestConstantFold.cpp @@ -27,7 +27,8 @@ struct TestConstantFold : public PassWrapper>, void foldOperation(Operation *op, OperationFolder &helper); void runOnOperation() override; - void notifyOperationInserted(Operation *op) override { + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint previous) override { existingConstants.push_back(op); } void notifyOperationRemoved(Operation *op) override {