Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][IR] Trigger notifyOperationReplaced on replaceAllOpUsesWith #84721

Conversation

matthias-springer
Copy link
Member

Before this change: notifyOperationReplaced was triggered when calling RewriteBase::replaceOp.
After this change: notifyOperationReplaced is triggered when RewriterBase::replaceAllOpUsesWith or RewriterBase::replaceOp is called.

Until now, every notifyOperationReplaced was always sent together with a notifyOperationErased, which made that notifyOperationErased callback irrelevant. More importantly, when a user called RewriterBase::replaceAllOpUsesWith+RewriterBase::eraseOp instead of RewriterBase::replaceOp, no notifyOperationReplaced callback was sent, even though the two notations are semantically equivalent. As an example, this can be a problem when applying patterns with the transform dialect because the TrackingListener will only see the notifyOperationErased callback and the payload op is dropped from the mappings.

Note: It is still possible to write semantically equivalent code that does not trigger a notifyOperationReplaced (e.g., when op results are replaced one-by-one), but this commit already improves the situation a lot.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Mar 11, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 11, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Before this change: notifyOperationReplaced was triggered when calling RewriteBase::replaceOp.
After this change: notifyOperationReplaced is triggered when RewriterBase::replaceAllOpUsesWith or RewriterBase::replaceOp is called.

Until now, every notifyOperationReplaced was always sent together with a notifyOperationErased, which made that notifyOperationErased callback irrelevant. More importantly, when a user called RewriterBase::replaceAllOpUsesWith+RewriterBase::eraseOp instead of RewriterBase::replaceOp, no notifyOperationReplaced callback was sent, even though the two notations are semantically equivalent. As an example, this can be a problem when applying patterns with the transform dialect because the TrackingListener will only see the notifyOperationErased callback and the payload op is dropped from the mappings.

Note: It is still possible to write semantically equivalent code that does not trigger a notifyOperationReplaced (e.g., when op results are replaced one-by-one), but this commit already improves the situation a lot.


Full diff: https://github.com/llvm/llvm-project/pull/84721.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/PatternMatch.h (+17-12)
  • (modified) mlir/lib/IR/PatternMatch.cpp (+16-8)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+4-1)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8d84ab6100007e..c1408c3f90a53b 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -409,9 +409,9 @@ class RewriterBase : public OpBuilder {
     /// Notify the listener that the specified operation was modified in-place.
     virtual void notifyOperationModified(Operation *op) {}
 
-    /// Notify the listener that the specified operation is about to be replaced
-    /// with another operation. This is called before the uses of the old
-    /// operation have been changed.
+    /// Notify the listener that all uses of the specified operation's results
+    /// are about to be replaced with the results of another operation. This is
+    /// called before the uses of the old operation have been changed.
     ///
     /// By default, this function calls the "operation replaced with values"
     /// notification.
@@ -420,9 +420,10 @@ class RewriterBase : public OpBuilder {
       notifyOperationReplaced(op, replacement->getResults());
     }
 
-    /// Notify the listener that the specified operation is about to be replaced
-    /// with the a range of values, potentially produced by other operations.
-    /// This is called before the uses of the operation have been changed.
+    /// Notify the listener that all uses of the specified operation's results
+    /// are about to be replaced with the a range of values, potentially
+    /// produced by other operations. This is called before the uses of the
+    /// operation have been changed.
     virtual void notifyOperationReplaced(Operation *op,
                                          ValueRange replacement) {}
 
@@ -628,12 +629,16 @@ class RewriterBase : public OpBuilder {
     for (auto it : llvm::zip(from, to))
       replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
   }
-  // Note: This function cannot be called `replaceAllUsesWith` because the
-  // overload resolution, when called with an op that can be implicitly
-  // converted to a Value, would be ambiguous.
-  void replaceAllOpUsesWith(Operation *from, ValueRange to) {
-    replaceAllUsesWith(from->getResults(), to);
-  }
+
+  /// Find uses of `from` and replace them with `to`. Also notify the listener
+  /// about every in-place op modification (for every use that was replaced)
+  /// and that the `from` operation is about to be replaced.
+  ///
+  /// Note: This function cannot be called `replaceAllUsesWith` because the
+  /// overload resolution, when called with an op that can be implicitly
+  /// converted to a Value, would be ambiguous.
+  void replaceAllOpUsesWith(Operation *from, ValueRange to);
+  void replaceAllOpUsesWith(Operation *from, Operation *to);
 
   /// Find uses of `from` and replace them with `to` if the `functor` returns
   /// true. Also notify the listener about every in-place op modification (for
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 4079ccc7567256..5944a0ea46a143 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -110,6 +110,22 @@ RewriterBase::~RewriterBase() {
   // Out of line to provide a vtable anchor for the class.
 }
 
+void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
+  // Notify the listener that we're about to replace this op.
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyOperationReplaced(from, to);
+
+  replaceAllUsesWith(from->getResults(), to);
+}
+
+void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
+  // Notify the listener that we're about to replace this op.
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyOperationReplaced(from, to);
+
+  replaceAllUsesWith(from->getResults(), to->getResults());
+}
+
 /// This method replaces the results of the operation with the specified list of
 /// values. The number of provided values must match the number of results of
 /// the operation. The replaced op is erased.
@@ -117,10 +133,6 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
   assert(op->getNumResults() == newValues.size() &&
          "incorrect # of replacement values");
 
-  // Notify the listener that we're about to replace this op.
-  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-    rewriteListener->notifyOperationReplaced(op, newValues);
-
   // Replace all result uses. Also notifies the listener of modifications.
   replaceAllOpUsesWith(op, newValues);
 
@@ -136,10 +148,6 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
   assert(op->getNumResults() == newOp->getNumResults() &&
          "ops have different number of results");
 
-  // Notify the listener that we're about to replace this op.
-  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-    rewriteListener->notifyOperationReplaced(op, newOp);
-
   // Replace all result uses. Also notifies the listener of modifications.
   replaceAllOpUsesWith(op, newOp->getResults());
 
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 2da184bc3d85ba..76dc825fe44515 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -489,7 +489,10 @@ struct TestStrictPatternDriver
             OperationName("test.new_op", op->getContext()).getIdentifier(),
             op->getOperands(), op->getResultTypes());
       }
-      rewriter.replaceOp(op, newOp->getResults());
+      // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
+      // A "notifyOperationReplaced" callback is triggered in either case.
+      rewriter.replaceAllOpUsesWith(op, newOp->getResults());
+      rewriter.eraseOp(op);
       return success();
     }
   };

Base automatically changed from users/matthias-springer/msvc_overloads to main March 11, 2024 08:36
Before this change: `notifyOperationReplaced` was triggered when calling `RewriteBase::replaceOp`.
After this change: `notifyOperationReplaced` is triggered when `RewriterBase::replaceAllOpUsesWith` or `RewriterBase::replaceOp` is called.

Until now, every `notifyOperationReplaced` was always sent together with a `notifyOperationErased`, which made that `notifyOperationErased` callback irrelevant. More importantly, when a user called `RewriterBase::replaceAllOpUsesWith`+`RewriterBase::eraseOp` instead of `RewriterBase::replaceOp`, no `notifyOperationReplaced` callback was sent, even though the two notations are semantically equivalent. As an example, this can be a problem when applying patterns with the transform dialect because the `TrackingListener` will only see the `notifyOperationErased` callback and the payload op is dropped from the mappings.

Note: It is still possible to write semantically equivalent code that does not trigger a `notifyOperationReplaced` (e.g., when op results are replaced one-by-one), but this commit already improves the situation a lot.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/replace_all_op_uses_notification branch from 7b93ec3 to 9d461b5 Compare April 2, 2024 01:45
@matthias-springer matthias-springer merged commit 38113a0 into main Apr 2, 2024
3 of 4 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/replace_all_op_uses_notification branch April 2, 2024 01:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants