diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index c3dca148b7f94..4bb4ad61cf8da 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -578,8 +578,10 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), replacementOp.getRegion().begin()); // 5a. Replace `linalg.index` operations that refer to the dropped unit - // dimensions. - IRRewriter rewriter(b); + // dimensions. Use a fresh IRRewriter to avoid inheriting any listener + // from the builder (e.g., WalkPatternRewriter's erasure listener), + // since the ops being erased here are newly cloned, not the matched op. + IRRewriter rewriter(b.getContext()); replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter); return replacementOp; diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp index 1382550e0f7e6..40fcb351ee079 100644 --- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/Verifier.h" #include "mlir/IR/Visitors.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" @@ -61,16 +62,38 @@ struct WalkAndApplyPatternsAction final // ops/blocks. Because we use walk-based pattern application, erasing the // op/block from the *next* iteration (e.g., a user of the visited op) is not // valid. Note that this is only used with expensive pattern API checks. +// +// Ops and blocks that were *created* during the current pattern application are +// exempt: they were not in the walk schedule before the pattern ran, so erasing +// them cannot disrupt the walk. struct ErasedOpsListener final : RewriterBase::ForwardingListener { using RewriterBase::ForwardingListener::ForwardingListener; + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint previous) override { + if (visitedOp) + newlyCreatedOps.insert(op); + ForwardingListener::notifyOperationInserted(op, previous); + } + + void notifyBlockInserted(Block *block, Region *previous, + Region::iterator previousIt) override { + if (visitedOp) + newlyCreatedBlocks.insert(block); + ForwardingListener::notifyBlockInserted(block, previous, previousIt); + } + void notifyOperationErased(Operation *op) override { - checkErasure(op); + if (!newlyCreatedOps.contains(op)) + checkErasure(op); + newlyCreatedOps.erase(op); ForwardingListener::notifyOperationErased(op); } void notifyBlockErased(Block *block) override { - checkErasure(block->getParentOp()); + if (!newlyCreatedBlocks.contains(block)) + checkErasure(block->getParentOp()); + newlyCreatedBlocks.erase(block); ForwardingListener::notifyBlockErased(block); } @@ -86,6 +109,9 @@ struct ErasedOpsListener final : RewriterBase::ForwardingListener { } Operation *visitedOp = nullptr; + // Ops and blocks inserted since visitedOp was last set; may be freely erased. + DenseSet newlyCreatedOps; + DenseSet newlyCreatedBlocks; }; #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS } // namespace @@ -204,6 +230,8 @@ void walkAndApplyPatterns(Operation *op, << OpWithFlags(op, OpPrintingFlags().skipRegions()); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS erasedListener.visitedOp = op; + erasedListener.newlyCreatedOps.clear(); + erasedListener.newlyCreatedBlocks.clear(); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (succeeded(applicator.matchAndRewrite(op, rewriter))) LDBG() << "\tOp matched and rewritten"; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 6c44ace831e96..7e8ed94630b61 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -383,7 +383,8 @@ struct CloneRegionBeforeOp : public RewritePattern { return failure(); for (Region &r : op->getRegions()) rewriter.cloneRegionBefore(r, op->getBlock()); - op->setAttr("was_cloned", rewriter.getUnitAttr()); + rewriter.modifyOpInPlace( + op, [&]() { op->setAttr("was_cloned", rewriter.getUnitAttr()); }); return success(); } };