Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transform] apply_conversion_patterns: Update handles #83950

Conversation

matthias-springer
Copy link
Member

Until now, transform.apply_conversion_patterns consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar to transform.apply_patterns, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion.

This new functionality is hidden behind a preserve_handles attribute for now.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 5, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Until now, transform.apply_conversion_patterns consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar to transform.apply_patterns, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion.

This new functionality is hidden behind a preserve_handles attribute for now.


Full diff: https://github.com/llvm/llvm-project/pull/83950.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+24-8)
  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+14-4)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+21-18)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+40-5)
  • (modified) mlir/test/Dialect/Transform/test-pattern-application.mlir (+30)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 313cdc27f780a7..32724ff4b98e8e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -921,20 +921,36 @@ TransformState::RegionScope TransformState::make_region_scope(Region &region) {
   return RegionScope(*this, region);
 }
 
+/// A configuration object for customizing a `TrackingListener`.
+struct TrackingListenerConfig {
+  using SkipHandleFn = std::function<bool(Value)>;
+
+  /// An optional function that returns "true" for handles that do not have to
+  /// be updated. These are typically dead or consumed handles.
+  SkipHandleFn skipHandleFn = nullptr;
+
+  /// If set to "true", the name of a replacement op must match the name of the
+  /// original op. If set to "false", the names of the payload ops tracked in a
+  /// handle may change as the tracking listener updates the transform state.
+  bool requireMatchingReplacementOpName = true;
+
+  /// If set to "true", cast ops (that implement the CastOpInterface) are
+  /// skipped and the replacement op search continues with the operands of the
+  /// cast op.
+  bool skipCastOps = true;
+};
+
 /// A listener that updates a TransformState based on IR modifications. This
 /// listener can be used during a greedy pattern rewrite to keep the transform
 /// state up-to-date.
 class TrackingListener : public RewriterBase::Listener,
                          public TransformState::Extension {
 public:
-  /// A function that returns "true" for handles that do not have to be updated.
-  using SkipHandleFn = std::function<bool(Value)>;
-
   /// Create a new TrackingListener for usage in the specified transform op.
   /// Optionally, a function can be specified to identify handles that should
   /// do not have to be updated.
   TrackingListener(TransformState &state, TransformOpInterface op,
-                   SkipHandleFn skipHandleFn = nullptr);
+                   TrackingListenerConfig config = TrackingListenerConfig());
 
 protected:
   /// Return a replacement payload op for the given op, which is going to be
@@ -959,7 +975,8 @@ class TrackingListener : public RewriterBase::Listener,
   /// same computation; e.g., there may be tiled "linalg.generic" inside the
   /// loop body that represents the original computation. Therefore, the
   /// TrackingListener is conservative by default: it drops the mapping and
-  /// triggers the "payload replacement not found" notification.
+  /// triggers the "payload replacement not found" notification. This default
+  /// behavior can be customized in `TrackingListenerConfig`.
   ///
   /// If no replacement op could be found according to the rules mentioned
   /// above, this function tries to skip over cast-like ops that implement
@@ -1023,9 +1040,8 @@ class TrackingListener : public RewriterBase::Listener,
   /// The handles that are consumed by the transform op.
   DenseSet<Value> consumedHandles;
 
-  /// Handles for which this function evaluates to "true" do not have to be
-  /// updated. These are typically dead or consumed handles.
-  SkipHandleFn skipHandleFn;
+  /// Tracking listener configuration.
+  TrackingListenerConfig config;
 };
 
 /// A specialized listener that keeps track of cases in which no replacement
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 9f513822ed0a4e..0e42d12a69a400 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -190,11 +190,20 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
     The `legal_ops`, `illegal_ops`, `legal_dialects`, `illegal_dialects`
     attributes specify the conversion target.
 
-    This transform consumes the `target` handle and modifies the payload. It
-    does not produce any handles.
+    This transform modifies the payload. By default, it consumes the `target`
+    handle. It does not produce any handles.
+
+    If the `preserve_handles` attribute is set, this transform does not consume
+    the `target` handle and instead updates handles based on notifications from
+    a tracking listener that is attached to the dialect conversion, similar to
+    `transform.apply_patterns`. Only replacements via `RewriterBase::replaceOp`
+    or `replaceOpWithNewOp` are considered "payload op replacements". In
+    contrast to `transform.apply_patterns`, we allow replacement ops even if the
+    op name has changed. More details can be found at the documentation site of
+    `TrackingListener`.
 
     This transform produces a silenceable failure if the dialect conversion was
-    unsuccessful.
+    unsuccessful or the tracking listener failed to find a replacement op.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
@@ -202,7 +211,8 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
                        OptionalAttr<StrArrayAttr>:$illegal_ops,
                        OptionalAttr<StrArrayAttr>:$legal_dialects,
                        OptionalAttr<StrArrayAttr>:$illegal_dialects,
-                       UnitAttr:$partial_conversion);
+                       UnitAttr:$partial_conversion,
+                       UnitAttr:$preserve_handles);
   let results = (outs);
   let regions = (region
       MaxSizedRegion<1>:$patterns,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index bb9f6fec452986..71a9d61198e3fb 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -918,7 +918,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   }
 
   // Prepare rewriter and listener.
-  TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
+  TrackingListenerConfig config;
+  config.skipHandleFn = [&](Value handle) {
     // Skip handle if it is dead.
     auto scopeIt =
         llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
@@ -935,7 +936,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
     return true;
   };
   transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
-                                                            skipHandleFn);
+                                                            config);
   transform::TransformRewriter rewriter(transform->getContext(),
                                         &trackingListener);
 
@@ -1184,9 +1185,8 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
 
 transform::TrackingListener::TrackingListener(TransformState &state,
                                               TransformOpInterface op,
-                                              SkipHandleFn skipHandleFn)
-    : TransformState::Extension(state), transformOp(op),
-      skipHandleFn(skipHandleFn) {
+                                              TrackingListenerConfig config)
+    : TransformState::Extension(state), transformOp(op), config(config) {
   if (op) {
     for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
       consumedHandles.insert(opOperand->get());
@@ -1228,8 +1228,19 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
       return diag;
     }
 
-    // If the defining op has the same type, we take it as a replacement.
-    if (op->getName() == defOp->getName()) {
+    // Skip through ops that implement CastOpInterface.
+    if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
+      values.clear();
+      values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
+      diag.attachNote(defOp->getLoc())
+          << "using output of 'CastOpInterface' op";
+      continue;
+    }
+
+    // If the defining op has the same name or we do not care about the name of
+    // op replacements at all, we take it as a replacement.
+    if (!config.requireMatchingReplacementOpName ||
+        op->getName() == defOp->getName()) {
       result = defOp;
       return DiagnosedSilenceableFailure::success();
     }
@@ -1251,14 +1262,6 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
                                           "'FindPayloadReplacementOpInterface'";
       continue;
     }
-
-    // Skip through ops that implement CastOpInterface.
-    if (isa<CastOpInterface>(defOp)) {
-      values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
-      diag.attachNote(defOp->getLoc())
-          << "using output of 'CastOpInterface' op";
-      continue;
-    }
   } while (!values.empty());
 
   diag.attachNote() << "ran out of suitable replacement values";
@@ -1318,9 +1321,9 @@ void transform::TrackingListener::notifyOperationReplaced(
 
   // Check if there are any handles that must be updated.
   Value aliveHandle;
-  if (skipHandleFn) {
-    auto it =
-        llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
+  if (config.skipHandleFn) {
+    auto it = llvm::find_if(opHandles,
+                            [&](Value v) { return !config.skipHandleFn(v); });
     if (it != opHandles.end())
       aliveHandle = *it;
   } else if (!opHandles.empty()) {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 180d11c30e65de..ca80899ab07341 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -563,6 +563,17 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
     }
   }
 
+  // Attach a tracking listener if handles should be preserved. We configure the
+  // listener to allow op replacements with different names, as conversion
+  // patterns typically replace ops with replacement ops that have a different
+  // name.
+  TrackingListenerConfig trackingConfig;
+  trackingConfig.requireMatchingReplacementOpName = false;
+  ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
+  ConversionConfig conversionConfig;
+  if (getPreserveHandles())
+    conversionConfig.listener = &trackingListener;
+
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
   for (Operation *target : state.getPayloadOps(getTarget())) {
     // Make sure that this transform is not applied to itself. Modifying the
@@ -574,16 +585,36 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
 
     LogicalResult status = failure();
     if (getPartialConversion()) {
-      status = applyPartialConversion(target, conversionTarget, frozenPatterns);
+      status = applyPartialConversion(target, conversionTarget, frozenPatterns,
+                                      conversionConfig);
     } else {
-      status = applyFullConversion(target, conversionTarget, frozenPatterns);
+      status = applyFullConversion(target, conversionTarget, frozenPatterns,
+                                   conversionConfig);
     }
 
+    // Check dialect conversion state.
+    DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
     if (failed(status)) {
-      auto diag = emitSilenceableError() << "dialect conversion failed";
+      diag = emitSilenceableError() << "dialect conversion failed";
       diag.attachNote(target->getLoc()) << "target op";
-      return diag;
     }
+
+    // Check tracking listener error state.
+    DiagnosedSilenceableFailure trackingFailure =
+        trackingListener.checkAndResetError();
+    if (!trackingFailure.succeeded()) {
+      if (diag.succeeded()) {
+        // Tracking failure is the only failure.
+        return trackingFailure;
+      } else {
+        diag.attachNote() << "tracking listener also failed: "
+                          << trackingFailure.getMessage();
+        (void)trackingFailure.silence();
+      }
+    }
+
+    if (!diag.succeeded())
+      return diag;
   }
 
   return DiagnosedSilenceableFailure::success();
@@ -632,7 +663,11 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() {
 
 void transform::ApplyConversionPatternsOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::consumesHandle(getTarget(), effects);
+  if (!getPreserveHandles()) {
+    transform::consumesHandle(getTarget(), effects);
+  } else {
+    transform::onlyReadsHandle(getTarget(), effects);
+  }
   transform::modifiesPayload(effects);
 }
 
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 0c41e81b17b522..fa8a555af92188 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -417,3 +417,33 @@ module attributes { transform.with_named_sequence } {
     transform.yield
   }
 }
+
+// -----
+
+// "test.foo" is tracked and replaced with "test.new_op" during a dialect
+// conversion. Make sure that the handle is updated accordingly.
+
+// CHECK-LABEL: func @dialect_conversion_tracking
+//  CHECK-NEXT:   %[[m:.*]] = "test.new_op"() {annotated} : () -> memref<5xf32>
+//  CHECK-NEXT:   %[[cast:.*]] = builtin.unrealized_conversion_cast %0 : memref<5xf32> to tensor<5xf32>
+//  CHECK-NEXT:   return %[[cast]]
+func.func @dialect_conversion_tracking() -> tensor<5xf32> {
+  %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>)
+  return %0 : tensor<5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["test.foo"]} in %0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_conversion_patterns to %0 {
+      transform.apply_conversion_patterns.transform.test_conversion_patterns
+    } with type_converter {
+      transform.apply_conversion_patterns.transform.test_type_converter
+    } {legal_ops = ["func.func", "func.return", "test.new_op"], preserve_handles}
+        : !transform.any_op
+    // Add an attribute to %1, which is now mapped to a new op.
+    transform.annotate %1 "annotated" : !transform.any_op
+    transform.yield
+  }
+}

Comment on lines +666 to +670
if (!getPreserveHandles()) {
transform::consumesHandle(getTarget(), effects);
} else {
transform::onlyReadsHandle(getTarget(), effects);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I don't recall if dialect conversion could rewrite the top-level op or not. If it can, it may need to still consume the handle...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it actually matter? If the top-level op is rewritten, it will be done through the rewriter, so the transform dialect state will be updated accordingly. (Same as for non-top-level ops.)

@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_listener branch from 9b4aa86 to d91a0fb Compare March 8, 2024 01:13
Base automatically changed from users/matthias-springer/dialect_conversion_listener to main March 8, 2024 01:34
Until now, `transform.apply_conversion_patterns` consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar to `transform.apply_patterns`, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion.

This new functionality is hidden behind a `preserve_handles` attribute for now.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/apply_conversion_patterns_listener branch from 270ade8 to c698363 Compare March 8, 2024 01:59
@matthias-springer matthias-springer merged commit c1029b6 into main Mar 10, 2024
4 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/apply_conversion_patterns_listener branch March 10, 2024 03:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants