-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Add rewriter API for moving operations #78988
[mlir][IR] Add rewriter API for moving operations #78988
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir-tensor Author: Matthias Springer (matthias-springer) ChangesThe pattern rewriter documentation states that "all IR mutations [...] are required to be performed via the After an operation was moved, the This commit narrows the discrepancy between the kind of IR modification that can be performed and the kind of IR modifications that can be listened to. Full diff: https://github.com/llvm/llvm-project/pull/78988.diff 13 Files Affected:
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53b..948c8a045e341e4 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc->getBlock();
- alloc->moveBefore(&parentBlock->front());
+ rewriter.moveOpBefore(alloc, &parentBlock->front());
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
- dealloc->moveBefore(&parentBlock->back());
+ rewriter.moveOpBefore(dealloc, &parentBlock->back());
return alloc;
}
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53b..948c8a045e341e4 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc->getBlock();
- alloc->moveBefore(&parentBlock->front());
+ rewriter.moveOpBefore(alloc, &parentBlock->front());
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
- dealloc->moveBefore(&parentBlock->back());
+ rewriter.moveOpBefore(dealloc, &parentBlock->back());
return alloc;
}
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53b..948c8a045e341e4 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc->getBlock();
- alloc->moveBefore(&parentBlock->front());
+ rewriter.moveOpBefore(alloc, &parentBlock->front());
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
- dealloc->moveBefore(&parentBlock->back());
+ rewriter.moveOpBefore(dealloc, &parentBlock->back());
return alloc;
}
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 13fbc3fb928c399..7b9e40e245c713a 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -285,12 +285,18 @@ class OpBuilder : public Builder {
virtual ~Listener() = default;
- /// Notification handler for when an operation is inserted into the builder.
- /// `op` is the operation that was inserted.
+ /// Notify the listener that the specified operation was inserted.
+ ///
+ /// Note: Creating an (unlinked) op does not trigger this notification.
+ /// Only when the op is inserted, this notification is triggered. This
+ /// notification is also triggered when moving an operation to a different
+ /// location.
+ // TODO: If needed, the previous location of the operation could be passed
+ // as a parameter. This would also allow listeners to distinguish between
+ // "newly created op was inserted" and "existing op was moved".
virtual void notifyOperationInserted(Operation *op) {}
- /// Notification handler for when a block is created using the builder.
- /// `block` is the block that was created.
+ /// Notify the listener that the specified block was inserted.
virtual void notifyBlockCreated(Block *block) {}
protected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 815340c91850935..db95f7243e178c1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -428,6 +428,8 @@ class RewriterBase : public OpBuilder {
/// Notify the listener that the specified operation is about to be erased.
/// At this point, the operation has zero uses.
+ ///
+ /// Note: This notification is not triggered when unlinking an operation.
virtual void notifyOperationRemoved(Operation *op) {}
/// Notify the listener that the pattern failed to match the given
@@ -591,6 +593,26 @@ class RewriterBase : public OpBuilder {
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before);
+ /// Unlink this operation from its current block and insert it right before
+ /// `existingOp` which may be in the same or another block in the same
+ /// function.
+ void moveOpBefore(Operation *op, Operation *existingOp);
+
+ /// Unlink this operation from its current block and insert it right before
+ /// `iterator` in the specified block.
+ virtual void moveOpBefore(Operation *op, Block *block,
+ Block::iterator iterator);
+
+ /// Unlink this operation from its current block and insert it right after
+ /// `existingOp` which may be in the same or another block in the same
+ /// function.
+ void moveOpAfter(Operation *op, Operation *existingOp);
+
+ /// Unlink this operation from its current block and insert it right after
+ /// `iterator` in the specified block.
+ virtual void moveOpAfter(Operation *op, Block *block,
+ Block::iterator iterator);
+
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeOpModification` or
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9568540789df3f6..7dc07e5b05e61a1 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -761,9 +761,15 @@ class ConversionPatternRewriter final : public PatternRewriter,
detail::ConversionPatternRewriterImpl &getImpl();
private:
+ // Hide unsupported pattern rewriter API.
using OpBuilder::getListener;
using OpBuilder::setListener;
+ void moveOpBefore(Operation *op, Block *block,
+ Block::iterator iterator) override;
+ void moveOpAfter(Operation *op, Block *block,
+ Block::iterator iterator) override;
+
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index cda561b1d1054d9..9f8189ae15e6de2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
// Inline for-loop body operations into 'after' region.
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
- arg.moveBefore(afterBlock, afterBlock->end());
+ rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
// Add incremented IV to yield operations
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 2cd57e7324b4dc5..678b7c099fa3692 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface
for (Operation *user : srcBuffer->getUsers()) {
if (hasEffect<MemoryEffects::Free>(user)) {
if (user->getBlock() == parallelCombiningParent->getBlock())
- user->moveBefore(user->getBlock()->getTerminator());
+ rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
break;
}
}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index ba0516e0539b6ca..2acc1629ddac0a6 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -366,3 +366,25 @@ void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent,
void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
+
+void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
+ moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
+}
+
+void RewriterBase::moveOpBefore(Operation *op, Block *block,
+ Block::iterator iterator) {
+ op->moveBefore(block, iterator);
+ if (listener)
+ listener->notifyOperationInserted(op);
+}
+
+void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
+ moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
+}
+
+void RewriterBase::moveOpAfter(Operation *op, Block *block,
+ Block::iterator iterator) {
+ op->moveAfter(block, iterator);
+ if (listener)
+ listener->notifyOperationInserted(op);
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ef6a49455d18605..0187436f700fd73 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1651,6 +1651,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure(
return impl->notifyMatchFailure(loc, reasonCallback);
}
+void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
+ Block::iterator iterator) {
+ llvm_unreachable(
+ "moving single ops is not supported in a dialect conversion");
+}
+
+void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
+ Block::iterator iterator) {
+ llvm_unreachable(
+ "moving single ops is not supported in a dialect conversion");
+}
+
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 8f97fd3d9ddf84e..66ce6067963f838 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
OpResult newLoopResult = loopLike.getLoopResults()->back();
- extractionOp->moveBefore(loopLike);
- insertionOp->moveAfter(loopLike);
+ rewriter.moveOpBefore(extractionOp, loopLike);
+ rewriter.moveOpAfter(insertionOp, loopLike);
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
insertionOp.getDestinationOperand().get());
extractionOp.getSourceOperand().set(
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index a8a808424b690f6..8a92d840ad13026 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -159,8 +159,8 @@ struct TestSCFPipeliningPass
auto ifOp =
rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
// True branch.
- op->moveBefore(&ifOp.getThenRegion().front(),
- ifOp.getThenRegion().front().begin());
+ rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
+ ifOp.getThenRegion().front().begin());
rewriter.setInsertionPointAfter(op);
if (op->getNumResults() > 0)
rewriter.create<scf::YieldOp>(loc, op->getResults());
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d1ac5e81e75a695..89b9d1ce78a52b6 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
return failure();
if (!toBeHoisted->hasAttr("eligible"))
return failure();
- // Hoisting means removing an op from the enclosing op. I.e., the enclosing
- // op is modified.
- rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
+ rewriter.moveOpBefore(toBeHoisted, op);
return success();
}
};
|
@@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { | |||
return failure(); | |||
if (!toBeHoisted->hasAttr("eligible")) | |||
return failure(); | |||
// Hoisting means removing an op from the enclosing op. I.e., the enclosing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a workaround around the fact that there were no notifications for moveBefore
. This used to trigger a failed "expensive pattern check" (IR changed but rewriter was not notified).
8a2015b
to
5eae529
Compare
/// is empty. | ||
/// | ||
/// Note: Creating an (unlinked) op does not trigger this notification. | ||
virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This is similar to what was proposed here (IR listeners), but the callback is fired after the IR was modified.
// This method is called when an operation is inserted into a block. The oldBlock is nullptr is the operation wasn't previously in a block.
virtual void notifyOpInserted(Operation *op, Block *oldBlock,
Block *newBlock) {}
5eae529
to
c53cdd7
Compare
op->moveBefore(block, iterator); | ||
if (listener) | ||
listener->notifyOperationInserted( | ||
op, /*previous=*/InsertPoint(currentBlock, currentIterator)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me that this is enough to implement a "moveOpBefore" safely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you referring to the fact that "moving ops" is not supported by the dialect conversion? The notifyOperationInserted
callback is triggered after moving the op. The current location can be queried from the op itself. The previous location is passed as a parameter. That should be enough information to implement the rollback mechanism in the future.
"Moving an op" is a form of "inserting an op". Until now, we used notifyOperationInserted
only for newly created ops. But in both cases we are inserting an op, only the previous location is different (moving an op: has a previous location, inserting a newly created op: was previously unlinked). (The callback is called notifyOperationInserted
, not notifyOperationCreated
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me what operations should be considered "changed" when we move an op.
For example sinking:
A;
for (...) {
if (...) {
...
}
}
to
for (...) {
if (...) {
A
...
}
}
Is the "if" modified? The "for"?
Is notifyOperationInserted
really meant to handle arbitrary moves and the handler must handle all this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually have the same issue with other kinds of IR changes:
- When erasing an operation, should we fire
notifyOperationModified
for the parent op? And for the parent's parent op? Etc. - When inserting a newly created operation, should we fire
notifyOperationModified
for the parent op into which we are inserting? And for the parent's parent op? Etc. - Same for erasing a block.
My conclusion was that we should not trigger a notifyOperationModified
for things that have a separate notification. Otherwise, to be consistent, we would have to call notifyOperationModified
for pretty much every IR change. That would potentially make it more difficult for listeners, because they would get duplicate notifications about the same IR change. (E.g., when inserting an op there would be two notifications.)
I think we should trigger notifyOperationModified
only for "attribute changed", "property changed", "result type changed", "operand changed". Maybe also for "region entry block argument changed". And have separate callbacks for everything else.
Another thing that we could consider is giving notifyOperationInsertion
, notificationOperationRemoved
, notifyBlockCreated
, etc. a default implementation that calls notifyOperationModified
. Listeners could then decide what kind of granularity of notifications they would like to receive. (We already do something similar for notifyOperationReplaced(Operation *, Operation *)
.)
Is
notifyOperationInserted
really meant to handle arbitrary moves and the handler must handle all this?
I'd say that notifyOperationInserted
should be called for all op insertions. Whether the inserted op is an already existing op or a newly created op is irrelevant. (Note the function is called notifyOperationInserted
not notifyOperationCreated
.)
But it raises the question whether notifyOperationRemoved
should also be triggered. In my current implementation it is not, because the documentation of the callback says that it is triggered for "op erasure", not "op removal from a block". (I think the callback should be renamed to notifyOperationErased
.)
/// Notify the listener that the specified operation is about to be erased.
/// At this point, the operation has zero uses.
virtual void notifyOperationRemoved(Operation *op) {}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK!
Please check the pre-merge! |
c53cdd7
to
6161d33
Compare
I did not have any immediate plans to look at the dialect conversion. It looks like a non-trivial change to implement the rollback logic etc. and I'm not that familiar with the dialect conversion overall. Maybe as a next side project when I'm happy with the greedy rewriter driver, listeners, etc... |
6161d33
to
553ea29
Compare
The pattern rewriter documentation states that "*all* IR mutations [...] are required to be performed via the `PatternRewriter`." This commit adds two functions that were missing from the rewriter API: `moveOpBefore` and `moveOpAfter`. After an operation was moved, the `notifyOperationInserted` callback is triggered. This may cause listeners such as the greedy pattern rewrite driver to put the op back on the worklist.
553ea29
to
f2bf7ed
Compare
This change makes the callback consistent with `notifyOperationInserted`: both now notify about IR insertion, not IR creation. See also llvm#78988. This change also simplifies the dialect conversion: it is no longer necessary to override the `inlineRegionBefore` method. All information that is necessary for rollback is provided with the `notifyBlockInserted` callback.
This change makes the callback consistent with `notifyOperationInserted`: both now notify about IR insertion, not IR creation. See also llvm#78988. This change also simplifies the dialect conversion: it is no longer necessary to override the `inlineRegionBefore` method. All information that is necessary for rollback is provided with the `notifyBlockInserted` callback.
This change makes the callback consistent with `notifyOperationInserted`: both now notify about IR insertion, not IR creation. See also #78988. This change also simplifies the dialect conversion: it is no longer necessary to override the `inlineRegionBefore` method. All information that is necessary for rollback is provided with the `notifyBlockInserted` callback.
The pattern rewriter documentation states that "all IR mutations [...] are required to be performed via the
PatternRewriter
." This commit adds two functions that were missing from the rewriter API:moveOpBefore
andmoveOpAfter
.After an operation was moved, the
notifyOperationInserted
callback is triggered. This allows listeners such as the greedy pattern rewrite driver to react to IR changes.This commit narrows the discrepancy between the kind of IR modification that can be performed and the kind of IR modifications that can be listened to.