-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Trigger notifyOperationReplaced
on replaceAllOpUsesWith
#84721
[mlir][IR] Trigger notifyOperationReplaced
on replaceAllOpUsesWith
#84721
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesBefore this change: Until now, every Note: It is still possible to write semantically equivalent code that does not trigger a Full diff: https://github.com/llvm/llvm-project/pull/84721.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8d84ab6100007e..c1408c3f90a53b 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -409,9 +409,9 @@ class RewriterBase : public OpBuilder {
/// Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationModified(Operation *op) {}
- /// Notify the listener that the specified operation is about to be replaced
- /// with another operation. This is called before the uses of the old
- /// operation have been changed.
+ /// Notify the listener that all uses of the specified operation's results
+ /// are about to be replaced with the results of another operation. This is
+ /// called before the uses of the old operation have been changed.
///
/// By default, this function calls the "operation replaced with values"
/// notification.
@@ -420,9 +420,10 @@ class RewriterBase : public OpBuilder {
notifyOperationReplaced(op, replacement->getResults());
}
- /// Notify the listener that the specified operation is about to be replaced
- /// with the a range of values, potentially produced by other operations.
- /// This is called before the uses of the operation have been changed.
+ /// Notify the listener that all uses of the specified operation's results
+ /// are about to be replaced with the a range of values, potentially
+ /// produced by other operations. This is called before the uses of the
+ /// operation have been changed.
virtual void notifyOperationReplaced(Operation *op,
ValueRange replacement) {}
@@ -628,12 +629,16 @@ class RewriterBase : public OpBuilder {
for (auto it : llvm::zip(from, to))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
}
- // Note: This function cannot be called `replaceAllUsesWith` because the
- // overload resolution, when called with an op that can be implicitly
- // converted to a Value, would be ambiguous.
- void replaceAllOpUsesWith(Operation *from, ValueRange to) {
- replaceAllUsesWith(from->getResults(), to);
- }
+
+ /// Find uses of `from` and replace them with `to`. Also notify the listener
+ /// about every in-place op modification (for every use that was replaced)
+ /// and that the `from` operation is about to be replaced.
+ ///
+ /// Note: This function cannot be called `replaceAllUsesWith` because the
+ /// overload resolution, when called with an op that can be implicitly
+ /// converted to a Value, would be ambiguous.
+ void replaceAllOpUsesWith(Operation *from, ValueRange to);
+ void replaceAllOpUsesWith(Operation *from, Operation *to);
/// Find uses of `from` and replace them with `to` if the `functor` returns
/// true. Also notify the listener about every in-place op modification (for
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 4079ccc7567256..5944a0ea46a143 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -110,6 +110,22 @@ RewriterBase::~RewriterBase() {
// Out of line to provide a vtable anchor for the class.
}
+void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
+ // Notify the listener that we're about to replace this op.
+ if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+ rewriteListener->notifyOperationReplaced(from, to);
+
+ replaceAllUsesWith(from->getResults(), to);
+}
+
+void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
+ // Notify the listener that we're about to replace this op.
+ if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+ rewriteListener->notifyOperationReplaced(from, to);
+
+ replaceAllUsesWith(from->getResults(), to->getResults());
+}
+
/// This method replaces the results of the operation with the specified list of
/// values. The number of provided values must match the number of results of
/// the operation. The replaced op is erased.
@@ -117,10 +133,6 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");
- // Notify the listener that we're about to replace this op.
- if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
- rewriteListener->notifyOperationReplaced(op, newValues);
-
// Replace all result uses. Also notifies the listener of modifications.
replaceAllOpUsesWith(op, newValues);
@@ -136,10 +148,6 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
assert(op->getNumResults() == newOp->getNumResults() &&
"ops have different number of results");
- // Notify the listener that we're about to replace this op.
- if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
- rewriteListener->notifyOperationReplaced(op, newOp);
-
// Replace all result uses. Also notifies the listener of modifications.
replaceAllOpUsesWith(op, newOp->getResults());
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 2da184bc3d85ba..76dc825fe44515 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -489,7 +489,10 @@ struct TestStrictPatternDriver
OperationName("test.new_op", op->getContext()).getIdentifier(),
op->getOperands(), op->getResultTypes());
}
- rewriter.replaceOp(op, newOp->getResults());
+ // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
+ // A "notifyOperationReplaced" callback is triggered in either case.
+ rewriter.replaceAllOpUsesWith(op, newOp->getResults());
+ rewriter.eraseOp(op);
return success();
}
};
|
Before this change: `notifyOperationReplaced` was triggered when calling `RewriteBase::replaceOp`. After this change: `notifyOperationReplaced` is triggered when `RewriterBase::replaceAllOpUsesWith` or `RewriterBase::replaceOp` is called. Until now, every `notifyOperationReplaced` was always sent together with a `notifyOperationErased`, which made that `notifyOperationErased` callback irrelevant. More importantly, when a user called `RewriterBase::replaceAllOpUsesWith`+`RewriterBase::eraseOp` instead of `RewriterBase::replaceOp`, no `notifyOperationReplaced` callback was sent, even though the two notations are semantically equivalent. As an example, this can be a problem when applying patterns with the transform dialect because the `TrackingListener` will only see the `notifyOperationErased` callback and the payload op is dropped from the mappings. Note: It is still possible to write semantically equivalent code that does not trigger a `notifyOperationReplaced` (e.g., when op results are replaced one-by-one), but this commit already improves the situation a lot.
7b93ec3
to
9d461b5
Compare
Before this change:
notifyOperationReplaced
was triggered when callingRewriteBase::replaceOp
.After this change:
notifyOperationReplaced
is triggered whenRewriterBase::replaceAllOpUsesWith
orRewriterBase::replaceOp
is called.Until now, every
notifyOperationReplaced
was always sent together with anotifyOperationErased
, which made thatnotifyOperationErased
callback irrelevant. More importantly, when a user calledRewriterBase::replaceAllOpUsesWith
+RewriterBase::eraseOp
instead ofRewriterBase::replaceOp
, nonotifyOperationReplaced
callback was sent, even though the two notations are semantically equivalent. As an example, this can be a problem when applying patterns with the transform dialect because theTrackingListener
will only see thenotifyOperationErased
callback and the payload op is dropped from the mappings.Note: It is still possible to write semantically equivalent code that does not trigger a
notifyOperationReplaced
(e.g., when op results are replaced one-by-one), but this commit already improves the situation a lot.