diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 32724ff4b98e8..5db1a2c28fd41 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -1026,7 +1026,7 @@ class TrackingListener : public RewriterBase::Listener, /// Return the transform op in which this TrackingListener is used. TransformOpInterface getTransformOp() const { return transformOp; } -private: +protected: friend class TransformRewriter; void notifyOperationErased(Operation *op) override; @@ -1034,6 +1034,7 @@ class TrackingListener : public RewriterBase::Listener, void notifyOperationReplaced(Operation *op, ValueRange newValues) override; using Listener::notifyOperationReplaced; +private: /// The transform op in which this TrackingListener is used. TransformOpInterface transformOp; @@ -1047,23 +1048,48 @@ class TrackingListener : public RewriterBase::Listener, /// A specialized listener that keeps track of cases in which no replacement /// payload could be found. The error state of this listener must be checked /// before the end of its lifetime. -class ErrorCheckingTrackingListener : public TrackingListener { +template +class ErrorCheckingTrackingListener : public TrackingListenerTy { public: - using transform::TrackingListener::TrackingListener; + using TrackingListenerTy::TrackingListenerTy; - ~ErrorCheckingTrackingListener() override; + ~ErrorCheckingTrackingListener() override { + // The state of the ErrorCheckingTrackingListener must be checked and reset + // if there was an error. This is to prevent errors from accidentally being + // missed. + assert(status.succeeded() && "listener state was not checked"); + } /// Check and return the current error state of this listener. Afterwards, /// resets the error state to "success". - DiagnosedSilenceableFailure checkAndResetError(); + DiagnosedSilenceableFailure checkAndResetError() { + DiagnosedSilenceableFailure s = std::move(status); + status = DiagnosedSilenceableFailure::success(); + errorCounter = 0; + return s; + } /// Return "true" if this tracking listener had a failure. - bool failed() const; + bool failed() const { return !status.succeeded(); } protected: - void - notifyPayloadReplacementNotFound(Operation *op, ValueRange values, - DiagnosedSilenceableFailure &&diag) override; + void notifyPayloadReplacementNotFound( + Operation *op, ValueRange values, + DiagnosedSilenceableFailure &&diag) override { + // Merge potentially existing diags and store the result in the listener. + SmallVector diags; + diag.takeDiagnostics(diags); + if (!status.succeeded()) + status.takeDiagnostics(diags); + status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags)); + + // Report more details. + status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op"; + for (auto &&[index, value] : llvm::enumerate(values)) + status.attachNote(value.getLoc()) + << "[" << errorCounter << "] replacement value " << index; + ++errorCounter; + } private: /// The error state of this listener. "Success" indicates that no error @@ -1082,8 +1108,9 @@ class TransformRewriter : public RewriterBase { friend class TransformState; /// Create a new TransformRewriter. - explicit TransformRewriter(MLIRContext *ctx, - ErrorCheckingTrackingListener *listener); + explicit TransformRewriter( + MLIRContext *ctx, + ErrorCheckingTrackingListener *listener); public: /// Return "true" if the tracking listener had failures. @@ -1106,7 +1133,7 @@ class TransformRewriter : public RewriterBase { Operation *replacement); private: - ErrorCheckingTrackingListener *const listener; + ErrorCheckingTrackingListener *const listener; }; /// This trait is supposed to be attached to Transform dialect operations that diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 1766e4bb875f3..686a51bf7f9d3 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -203,6 +203,16 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns", lower ops to different ops (from a different dialect). More details can be found at the documentation site of `TrackingListener`. + The way op handles are updated can be customized with `find_replacements`. + If `find_replacements` is set, replacement ops are *not* deduced from the + replacement SSA values. The `find_replacements` dictionary attribute + specifies the kind of op that should be considered as a replacement for a + replaced tracked op. E.g., "arith.mulf => llvm.fmul" specifies that the + replacement op for a tracked "arith.mulf" must be an "llvm.fmul" op that was + created in the same pattern that replaced the "arith.mulf" op. If there is + no such op or if there are multiple such ops, a tracking listener failure + is produced. + This transform produces a silenceable failure if the dialect conversion was unsuccessful or the tracking listener failed to find a replacement op. }]; @@ -212,6 +222,7 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns", OptionalAttr:$illegal_ops, OptionalAttr:$legal_dialects, OptionalAttr:$illegal_dialects, + OptionalAttr:$find_replacements, UnitAttr:$partial_conversion, UnitAttr:$preserve_handles); let results = (outs); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 71a9d61198e3f..92f59c47018f6 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -935,8 +935,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { } return true; }; - transform::ErrorCheckingTrackingListener trackingListener(*this, transform, - config); + transform::ErrorCheckingTrackingListener + trackingListener(*this, transform, config); transform::TransformRewriter rewriter(transform->getContext(), &trackingListener); @@ -1214,11 +1214,10 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp( Operation *&result, Operation *op, ValueRange newValues) const { assert(op->getNumResults() == newValues.size() && "invalid number of replacement values"); - SmallVector values(newValues.begin(), newValues.end()); - DiagnosedSilenceableFailure diag = emitSilenceableFailure( getTransformOp(), "tracking listener failed to find replacement op " "during application of this transform op"); + SmallVector values(newValues.begin(), newValues.end()); do { // If the replacement values belong to different ops, drop the mapping. @@ -1278,14 +1277,11 @@ void transform::TrackingListener::notifyMatchFailure( } void transform::TrackingListener::notifyOperationErased(Operation *op) { - // TODO: Walk can be removed when D144193 has landed. - op->walk([&](Operation *op) { - // Remove mappings for result values. - for (OpResult value : op->getResults()) - (void)replacePayloadValue(value, nullptr); - // Remove mapping for op. - (void)replacePayloadOp(op, nullptr); - }); + // Remove mappings for result values. + for (OpResult value : op->getResults()) + (void)replacePayloadValue(value, nullptr); + // Remove mapping for op. + (void)replacePayloadOp(op, nullptr); } void transform::TrackingListener::notifyOperationReplaced( @@ -1352,49 +1348,12 @@ void transform::TrackingListener::notifyOperationReplaced( (void)replacePayloadOp(op, replacement); } -transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() { - // The state of the ErrorCheckingTrackingListener must be checked and reset - // if there was an error. This is to prevent errors from accidentally being - // missed. - assert(status.succeeded() && "listener state was not checked"); -} - -DiagnosedSilenceableFailure -transform::ErrorCheckingTrackingListener::checkAndResetError() { - DiagnosedSilenceableFailure s = std::move(status); - status = DiagnosedSilenceableFailure::success(); - errorCounter = 0; - return s; -} - -bool transform::ErrorCheckingTrackingListener::failed() const { - return !status.succeeded(); -} - -void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( - Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) { - - // Merge potentially existing diags and store the result in the listener. - SmallVector diags; - diag.takeDiagnostics(diags); - if (!status.succeeded()) - status.takeDiagnostics(diags); - status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags)); - - // Report more details. - status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op"; - for (auto &&[index, value] : llvm::enumerate(values)) - status.attachNote(value.getLoc()) - << "[" << errorCounter << "] replacement value " << index; - ++errorCounter; -} - //===----------------------------------------------------------------------===// // TransformRewriter //===----------------------------------------------------------------------===// transform::TransformRewriter::TransformRewriter( - MLIRContext *ctx, ErrorCheckingTrackingListener *listener) + MLIRContext *ctx, ErrorCheckingTrackingListener *listener) : RewriterBase(ctx), listener(listener) { setListener(listener); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index ca80899ab0734..b73fceee7aba5 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -493,6 +493,125 @@ void transform::ApplyCanonicalizationPatternsOp::populatePatterns( // ApplyConversionPatternsOp //===----------------------------------------------------------------------===// +namespace { +/// A specialized tracking listener for dialect conversions. It can be +/// configured with a "replacement mapping" that specifies how replacement ops +/// for replaced tracked operations should be determined. +class ConversionTrackingListener : public transform::TrackingListener { +public: + ConversionTrackingListener( + transform::TransformState &state, transform::TransformOpInterface op, + transform::TrackingListenerConfig config, + const DenseMap *replacementMapping) + : transform::TrackingListener(state, op, config), + replacementMapping(replacementMapping) {} + + /// Instead of deducing the replacement op from the replacement values, the + /// replacement op is chosen among all ops that were created during the + /// current pattern application. E.g., a mapping of "arith.mulsi_extended -> + /// llvm.mul" indicates that tracked arith.mulsi_extended ops should be + /// updated to llvm.mul ops, assuming that an llvm.mul op was created in the + /// same pattern that replaced the arith.mulsi_extended op. If no such op or + /// multiple such ops were created, "nullptr" replacement op is returned. + /// + /// If no replacement mapping is set, fall back to the original mechanism of + /// `TrackingListener`. + DiagnosedSilenceableFailure + findReplacementOp(Operation *&result, Operation *op, + ValueRange newValues) const override; + +protected: + void notifyOperationErased(Operation *op) override; + + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint previous) override; + + void notifyPatternBegin(const Pattern &pattern, Operation *op) override; + + void notifyPatternEnd(const Pattern &pattern, LogicalResult status) override; + + /// The root op of the pattern that is currently being applied or "nullptr" if + /// no pattern application is running. + Operation *rootOp = nullptr; + + /// All ops that have been created during the current pattern application. + /// This set is maintained only if "config.replacementMapping" is set. + SmallVector createdOps; + + /// A mapping that specifies how replacement ops should be + /// determined when a mapped op is replaced. If set to "nullptr", the default + /// lookup mechanism (i.e., op deduced from the replacement values) is used. + const DenseMap *replacementMapping = nullptr; +}; +} // namespace + +void ConversionTrackingListener::notifyOperationErased(Operation *op) { + TrackingListener::notifyOperationErased(op); + + // Remove from created ops. + auto it = llvm::find(createdOps, op); + if (it != createdOps.end()) + createdOps.erase(it); +} + +void ConversionTrackingListener::notifyOperationInserted( + Operation *op, OpBuilder::InsertPoint previous) { + if (replacementMapping) + createdOps.push_back(op); +} + +void ConversionTrackingListener::notifyPatternBegin(const Pattern &pattern, + Operation *op) { + assert(!rootOp && "expected that no other pattern is in progress"); + rootOp = op; +} + +void ConversionTrackingListener::notifyPatternEnd(const Pattern &pattern, + LogicalResult status) { + rootOp = nullptr; + createdOps.clear(); +} + +DiagnosedSilenceableFailure +ConversionTrackingListener::findReplacementOp(Operation *&result, Operation *op, + ValueRange newValues) const { + if (!replacementMapping) + return TrackingListener::findReplacementOp(result, op, newValues); + + DiagnosedSilenceableFailure diag = emitSilenceableFailure( + getTransformOp(), + "conversion tracking listener failed to find replacement op during " + "application of this transform op"); + + auto it = replacementMapping->find(op->getName().getStringRef()); + if (it == replacementMapping->end()) { + diag.attachNote(op->getLoc()) + << "no mapping specified for '" << op->getName().getStringRef() << "'"; + return diag; + } + StringRef replacementOpName = it->second; + Operation *replacementOp = nullptr; + for (Operation *op : createdOps) { + if (op->getName().getStringRef() == replacementOpName) { + if (replacementOp) { + diag.attachNote(op->getLoc()) << "multiple '" << replacementOpName + << "' replacement candidates found for '" + << op->getName().getStringRef() << "'"; + return diag; + } + replacementOp = op; + } + } + if (!replacementOp) { + diag.attachNote(op->getLoc()) + << "no replacement found for '" << op->getName().getStringRef() + << "', expected '" << replacementOpName << "'"; + return diag; + } + result = replacementOp; + return DiagnosedSilenceableFailure::success(); +} + DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { @@ -523,6 +642,15 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( for (Attribute attr : cast(*getIllegalDialects())) conversionTarget.addIllegalDialect(cast(attr).getValue()); + // Extract op replacement rules from attribute. + DenseMap replacementMapping; + if (getFindReplacements()) { + DictionaryAttr mappingAttr = cast(*getFindReplacements()); + for (auto it : mappingAttr) + replacementMapping[it.getName()] = + cast(it.getValue()).getValue(); + } + // Gather all specified patterns. RewritePatternSet patterns(ctx); // Need to keep the converters alive until after pattern application because @@ -569,7 +697,9 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( // name. TrackingListenerConfig trackingConfig; trackingConfig.requireMatchingReplacementOpName = false; - ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig); + ErrorCheckingTrackingListener trackingListener( + state, *this, trackingConfig, + replacementMapping.empty() ? nullptr : &replacementMapping); ConversionConfig conversionConfig; if (getPreserveHandles()) conversionConfig.listener = &trackingListener; @@ -658,6 +788,16 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() { } } } + if (getFindReplacements()) { + if (!getPreserveHandles()) + return emitOpError() << "find_replacements requires preserve_handles"; + auto mapping = cast(*getFindReplacements()); + for (auto it : mapping) { + if (!isa(it.getValue())) + return emitOpError() << "expected find_replacements to contain only " + "StringAttr values"; + } + } return success(); } diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir index 73a5f36af9295..729645aca2f91 100644 --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -771,3 +771,25 @@ module attributes { transform.with_named_sequence } { transform.yield %arg0 : !transform.any_op } } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected find_replacements to contain only StringAttr values}} + transform.apply_conversion_patterns to %arg0 { + } {legal_dialects = ["func", "llvm"], preserve_handles, + find_replacements = {"arith.muli" = 3}} : !transform.any_op + transform.yield +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{find_replacements requires preserve_handles}} + transform.apply_conversion_patterns to %arg0 { + } {legal_dialects = ["func", "llvm"], + find_replacements = {"arith.muli" = 3}} : !transform.any_op + transform.yield +} diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir index fa8a555af9218..7ac2838bb95cc 100644 --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -447,3 +447,42 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// "arith.mulsi_extended" is tracked and replaced with "llvm.mul" (and other +// ops) during a dialect conversion. Make sure that the handle is updated +// accordingly. + +// CHECK-LABEL: func @dialect_conversion_find_replacements( +// CHECK-SAME: %[[arg0:.*]]: vector<4xi32>, %[[arg1:.*]]: vector<4xi32>) +// CHECK: %[[VAL0:.*]] = llvm.sext %[[arg0]] : vector<4xi32> to vector<4xi64> +// CHECK: %[[VAL1:.*]] = llvm.sext %[[arg1]] : vector<4xi32> to vector<4xi64> +// CHECK: %[[VAL2:.*]] = llvm.mul %[[VAL0]], %[[VAL1]] {annotated} : vector<4xi64> +// CHECK: %[[VAL3:.*]] = llvm.trunc %[[VAL2]] : vector<4xi64> to vector<4xi32> +// CHECK: %[[VAL4:.*]] = llvm.mlir.constant(dense<32> : vector<4xi64>) : vector<4xi64> +// CHECK: %[[VAL5:.*]] = llvm.lshr %[[VAL2]], %[[VAL4]] : vector<4xi64> +// CHECK: %[[VAL6:.*]] = llvm.trunc %[[VAL5]] : vector<4xi64> to vector<4xi32> +// CHECK: return %[[VAL3]], %[[VAL6]] : vector<4xi32>, vector<4xi32> +func.func @dialect_conversion_find_replacements(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { + %c:2 = arith.mulsi_extended %arg0, %arg1 : vector<4xi32> + return %c#0, %c#1 : vector<4xi32>, vector<4xi32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["arith.mulsi_extended"]} in %0 : (!transform.any_op) -> !transform.any_op + // arith.mulsi_extended handles are updated to llvm.mul. + transform.apply_conversion_patterns to %0 { + transform.apply_conversion_patterns.dialect_to_llvm "arith" + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + } {legal_dialects = ["func", "llvm"], preserve_handles, + find_replacements = {"arith.mulsi_extended" = "llvm.mul"}} + : !transform.any_op + // Add an attribute to %1, which is now mapped to a new op. + transform.annotate %1 "annotated" : !transform.any_op + transform.yield + } +}