-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transform] apply_conversion_patterns
: Update handles
#83950
[mlir][Transform] apply_conversion_patterns
: Update handles
#83950
Conversation
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesUntil now, This new functionality is hidden behind a Full diff: https://github.com/llvm/llvm-project/pull/83950.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 313cdc27f780a7..32724ff4b98e8e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -921,20 +921,36 @@ TransformState::RegionScope TransformState::make_region_scope(Region ®ion) {
return RegionScope(*this, region);
}
+/// A configuration object for customizing a `TrackingListener`.
+struct TrackingListenerConfig {
+ using SkipHandleFn = std::function<bool(Value)>;
+
+ /// An optional function that returns "true" for handles that do not have to
+ /// be updated. These are typically dead or consumed handles.
+ SkipHandleFn skipHandleFn = nullptr;
+
+ /// If set to "true", the name of a replacement op must match the name of the
+ /// original op. If set to "false", the names of the payload ops tracked in a
+ /// handle may change as the tracking listener updates the transform state.
+ bool requireMatchingReplacementOpName = true;
+
+ /// If set to "true", cast ops (that implement the CastOpInterface) are
+ /// skipped and the replacement op search continues with the operands of the
+ /// cast op.
+ bool skipCastOps = true;
+};
+
/// A listener that updates a TransformState based on IR modifications. This
/// listener can be used during a greedy pattern rewrite to keep the transform
/// state up-to-date.
class TrackingListener : public RewriterBase::Listener,
public TransformState::Extension {
public:
- /// A function that returns "true" for handles that do not have to be updated.
- using SkipHandleFn = std::function<bool(Value)>;
-
/// Create a new TrackingListener for usage in the specified transform op.
/// Optionally, a function can be specified to identify handles that should
/// do not have to be updated.
TrackingListener(TransformState &state, TransformOpInterface op,
- SkipHandleFn skipHandleFn = nullptr);
+ TrackingListenerConfig config = TrackingListenerConfig());
protected:
/// Return a replacement payload op for the given op, which is going to be
@@ -959,7 +975,8 @@ class TrackingListener : public RewriterBase::Listener,
/// same computation; e.g., there may be tiled "linalg.generic" inside the
/// loop body that represents the original computation. Therefore, the
/// TrackingListener is conservative by default: it drops the mapping and
- /// triggers the "payload replacement not found" notification.
+ /// triggers the "payload replacement not found" notification. This default
+ /// behavior can be customized in `TrackingListenerConfig`.
///
/// If no replacement op could be found according to the rules mentioned
/// above, this function tries to skip over cast-like ops that implement
@@ -1023,9 +1040,8 @@ class TrackingListener : public RewriterBase::Listener,
/// The handles that are consumed by the transform op.
DenseSet<Value> consumedHandles;
- /// Handles for which this function evaluates to "true" do not have to be
- /// updated. These are typically dead or consumed handles.
- SkipHandleFn skipHandleFn;
+ /// Tracking listener configuration.
+ TrackingListenerConfig config;
};
/// A specialized listener that keeps track of cases in which no replacement
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 9f513822ed0a4e..0e42d12a69a400 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -190,11 +190,20 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
The `legal_ops`, `illegal_ops`, `legal_dialects`, `illegal_dialects`
attributes specify the conversion target.
- This transform consumes the `target` handle and modifies the payload. It
- does not produce any handles.
+ This transform modifies the payload. By default, it consumes the `target`
+ handle. It does not produce any handles.
+
+ If the `preserve_handles` attribute is set, this transform does not consume
+ the `target` handle and instead updates handles based on notifications from
+ a tracking listener that is attached to the dialect conversion, similar to
+ `transform.apply_patterns`. Only replacements via `RewriterBase::replaceOp`
+ or `replaceOpWithNewOp` are considered "payload op replacements". In
+ contrast to `transform.apply_patterns`, we allow replacement ops even if the
+ op name has changed. More details can be found at the documentation site of
+ `TrackingListener`.
This transform produces a silenceable failure if the dialect conversion was
- unsuccessful.
+ unsuccessful or the tracking listener failed to find a replacement op.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
@@ -202,7 +211,8 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
OptionalAttr<StrArrayAttr>:$illegal_ops,
OptionalAttr<StrArrayAttr>:$legal_dialects,
OptionalAttr<StrArrayAttr>:$illegal_dialects,
- UnitAttr:$partial_conversion);
+ UnitAttr:$partial_conversion,
+ UnitAttr:$preserve_handles);
let results = (outs);
let regions = (region
MaxSizedRegion<1>:$patterns,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index bb9f6fec452986..71a9d61198e3fb 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -918,7 +918,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}
// Prepare rewriter and listener.
- TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
+ TrackingListenerConfig config;
+ config.skipHandleFn = [&](Value handle) {
// Skip handle if it is dead.
auto scopeIt =
llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
@@ -935,7 +936,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
return true;
};
transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
- skipHandleFn);
+ config);
transform::TransformRewriter rewriter(transform->getContext(),
&trackingListener);
@@ -1184,9 +1185,8 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
transform::TrackingListener::TrackingListener(TransformState &state,
TransformOpInterface op,
- SkipHandleFn skipHandleFn)
- : TransformState::Extension(state), transformOp(op),
- skipHandleFn(skipHandleFn) {
+ TrackingListenerConfig config)
+ : TransformState::Extension(state), transformOp(op), config(config) {
if (op) {
for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
consumedHandles.insert(opOperand->get());
@@ -1228,8 +1228,19 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
return diag;
}
- // If the defining op has the same type, we take it as a replacement.
- if (op->getName() == defOp->getName()) {
+ // Skip through ops that implement CastOpInterface.
+ if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
+ values.clear();
+ values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
+ diag.attachNote(defOp->getLoc())
+ << "using output of 'CastOpInterface' op";
+ continue;
+ }
+
+ // If the defining op has the same name or we do not care about the name of
+ // op replacements at all, we take it as a replacement.
+ if (!config.requireMatchingReplacementOpName ||
+ op->getName() == defOp->getName()) {
result = defOp;
return DiagnosedSilenceableFailure::success();
}
@@ -1251,14 +1262,6 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
"'FindPayloadReplacementOpInterface'";
continue;
}
-
- // Skip through ops that implement CastOpInterface.
- if (isa<CastOpInterface>(defOp)) {
- values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
- diag.attachNote(defOp->getLoc())
- << "using output of 'CastOpInterface' op";
- continue;
- }
} while (!values.empty());
diag.attachNote() << "ran out of suitable replacement values";
@@ -1318,9 +1321,9 @@ void transform::TrackingListener::notifyOperationReplaced(
// Check if there are any handles that must be updated.
Value aliveHandle;
- if (skipHandleFn) {
- auto it =
- llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
+ if (config.skipHandleFn) {
+ auto it = llvm::find_if(opHandles,
+ [&](Value v) { return !config.skipHandleFn(v); });
if (it != opHandles.end())
aliveHandle = *it;
} else if (!opHandles.empty()) {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 180d11c30e65de..ca80899ab07341 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -563,6 +563,17 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
}
}
+ // Attach a tracking listener if handles should be preserved. We configure the
+ // listener to allow op replacements with different names, as conversion
+ // patterns typically replace ops with replacement ops that have a different
+ // name.
+ TrackingListenerConfig trackingConfig;
+ trackingConfig.requireMatchingReplacementOpName = false;
+ ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
+ ConversionConfig conversionConfig;
+ if (getPreserveHandles())
+ conversionConfig.listener = &trackingListener;
+
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (Operation *target : state.getPayloadOps(getTarget())) {
// Make sure that this transform is not applied to itself. Modifying the
@@ -574,16 +585,36 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
LogicalResult status = failure();
if (getPartialConversion()) {
- status = applyPartialConversion(target, conversionTarget, frozenPatterns);
+ status = applyPartialConversion(target, conversionTarget, frozenPatterns,
+ conversionConfig);
} else {
- status = applyFullConversion(target, conversionTarget, frozenPatterns);
+ status = applyFullConversion(target, conversionTarget, frozenPatterns,
+ conversionConfig);
}
+ // Check dialect conversion state.
+ DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
if (failed(status)) {
- auto diag = emitSilenceableError() << "dialect conversion failed";
+ diag = emitSilenceableError() << "dialect conversion failed";
diag.attachNote(target->getLoc()) << "target op";
- return diag;
}
+
+ // Check tracking listener error state.
+ DiagnosedSilenceableFailure trackingFailure =
+ trackingListener.checkAndResetError();
+ if (!trackingFailure.succeeded()) {
+ if (diag.succeeded()) {
+ // Tracking failure is the only failure.
+ return trackingFailure;
+ } else {
+ diag.attachNote() << "tracking listener also failed: "
+ << trackingFailure.getMessage();
+ (void)trackingFailure.silence();
+ }
+ }
+
+ if (!diag.succeeded())
+ return diag;
}
return DiagnosedSilenceableFailure::success();
@@ -632,7 +663,11 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() {
void transform::ApplyConversionPatternsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::consumesHandle(getTarget(), effects);
+ if (!getPreserveHandles()) {
+ transform::consumesHandle(getTarget(), effects);
+ } else {
+ transform::onlyReadsHandle(getTarget(), effects);
+ }
transform::modifiesPayload(effects);
}
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 0c41e81b17b522..fa8a555af92188 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -417,3 +417,33 @@ module attributes { transform.with_named_sequence } {
transform.yield
}
}
+
+// -----
+
+// "test.foo" is tracked and replaced with "test.new_op" during a dialect
+// conversion. Make sure that the handle is updated accordingly.
+
+// CHECK-LABEL: func @dialect_conversion_tracking
+// CHECK-NEXT: %[[m:.*]] = "test.new_op"() {annotated} : () -> memref<5xf32>
+// CHECK-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast %0 : memref<5xf32> to tensor<5xf32>
+// CHECK-NEXT: return %[[cast]]
+func.func @dialect_conversion_tracking() -> tensor<5xf32> {
+ %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>)
+ return %0 : tensor<5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["test.foo"]} in %0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_conversion_patterns to %0 {
+ transform.apply_conversion_patterns.transform.test_conversion_patterns
+ } with type_converter {
+ transform.apply_conversion_patterns.transform.test_type_converter
+ } {legal_ops = ["func.func", "func.return", "test.new_op"], preserve_handles}
+ : !transform.any_op
+ // Add an attribute to %1, which is now mapped to a new op.
+ transform.annotate %1 "annotated" : !transform.any_op
+ transform.yield
+ }
+}
|
if (!getPreserveHandles()) { | ||
transform::consumesHandle(getTarget(), effects); | ||
} else { | ||
transform::onlyReadsHandle(getTarget(), effects); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I don't recall if dialect conversion could rewrite the top-level op or not. If it can, it may need to still consume the handle...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it actually matter? If the top-level op is rewritten, it will be done through the rewriter, so the transform dialect state will be updated accordingly. (Same as for non-top-level ops.)
9b4aa86
to
d91a0fb
Compare
Until now, `transform.apply_conversion_patterns` consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar to `transform.apply_patterns`, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion. This new functionality is hidden behind a `preserve_handles` attribute for now.
270ade8
to
c698363
Compare
Until now,
transform.apply_conversion_patterns
consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar totransform.apply_patterns
, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion.This new functionality is hidden behind a
preserve_handles
attribute for now.