diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index fcdb21d21503a..fe2c28f45aea0 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -460,6 +460,39 @@ def NumAssociationsOp : TransformDialectOp<"num_associations", let hasVerifier = 1; } +def CollectMatchingOp : TransformDialectOp<"collect_matching", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Collects all payload ops that match the given named matcher"; + let description = [{ + Collects operations or other payload IR objects nested under `root` + (inclusive) that match the given matcher expressed as a named sequence. The + matcher sequence must accept exactly one argument that it is not allowed to + modify. It must yield as many values as this op has results. Each of the + yielded values must be associated with exactly one payload object. If any + operation in the matcher sequence produces a silenceable failure, the + matcher advances to the next payload operation in the walk order without + finishing the sequence. + + The i-th result of this operation is constructed by concatenating the i-th + yielded payload IR objects of all successful matcher sequence applications. + All results are guaranteed to be mapped to the same number of payload IR + objects. + + The operation succeeds unless the matcher sequence produced a definite + failure for any invocation. + }]; + + let arguments = (ins TransformHandleTypeInterface:$root, + SymbolRefAttr:$matcher); + let results = (outs Variadic:$results); + + let assemblyFormat = [{ + $matcher `in` $root attr-dict `:` functional-type($root, $results) + }]; +} + def ForeachMatchOp : TransformDialectOp<"foreach_match", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -674,7 +707,7 @@ def GetParentOp : TransformDialectOp<"get_parent_op", def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand", [DeclareOpInterfaceMethods, - NavigationTransformOpTrait, MemoryEffectsOpInterface]> { + NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> { let summary = "Get handle to the producer of this operation's operand number"; let description = [{ The handle defined by this Transform op corresponds to operation that diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index aa4694c88d3b2..b80fc09751d2a 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" @@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { } //===----------------------------------------------------------------------===// -// ForeachMatchOp +// CollectMatchingOp //===----------------------------------------------------------------------===// /// Applies matcher operations from the given `block` assigning `op` as the @@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state, return DiagnosedSilenceableFailure::success(); } +/// Returns `true` if both types implement one of the interfaces provided as +/// template parameters. +template +static bool implementSameInterface(Type t1, Type t2) { + return ((isa(t1) && isa(t2)) || ... || false); +} + +/// Returns `true` if both types implement one of the transform dialect +/// interfaces. +static bool implementSameTransformInterface(Type t1, Type t2) { + return implementSameInterface( + t1, t2); +} + +//===----------------------------------------------------------------------===// +// CollectMatchingOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto matcher = SymbolTable::lookupNearestSymbolFrom( + getOperation(), getMatcher()); + if (matcher.isExternal()) { + return emitDefiniteFailure() + << "unresolved external symbol " << getMatcher(); + } + + SmallVector, 2> rawResults; + rawResults.resize(getOperation()->getNumResults()); + std::optional maybeFailure; + for (Operation *root : state.getPayloadOps(getRoot())) { + WalkResult walkResult = root->walk([&](Operation *op) { + DEBUG_MATCHER({ + DBGS_MATCHER() << "matching "; + op->print(llvm::dbgs(), + OpPrintingFlags().assumeVerified().skipRegions()); + llvm::dbgs() << " @" << op << "\n"; + }); + + // Try matching. + SmallVector> mappings; + DiagnosedSilenceableFailure diag = + matchBlock(matcher.getFunctionBody().front(), op, state, mappings); + if (diag.isDefiniteFailure()) + return WalkResult::interrupt(); + if (diag.isSilenceableFailure()) { + DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() + << " failed: " << diag.getMessage()); + return WalkResult::advance(); + } + + // If succeeded, collect results. + for (auto &&[i, mapping] : llvm::enumerate(mappings)) { + if (mapping.size() != 1) { + maybeFailure.emplace(emitSilenceableError() + << "result #" << i << ", associated with " + << mapping.size() + << " payload objects, expected 1"); + return WalkResult::interrupt(); + } + rawResults[i].push_back(mapping[0]); + } + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return std::move(*maybeFailure); + assert(!maybeFailure && "failure set but the walk was not interrupted"); + + for (auto &&[opResult, rawResult] : + llvm::zip_equal(getOperation()->getResults(), rawResults)) { + results.setMappedValues(opResult, rawResult); + } + } + return DiagnosedSilenceableFailure::success(); +} + +void transform::CollectMatchingOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getRoot(), effects); + producesHandle(getResults(), effects); + onlyReadsPayload(effects); +} + +LogicalResult transform::CollectMatchingOp::verifySymbolUses( + SymbolTableCollection &symbolTable) { + auto matcherSymbol = dyn_cast_or_null( + symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher())); + if (!matcherSymbol || + !isa(matcherSymbol.getOperation())) + return emitError() << "unresolved matcher symbol " << getMatcher(); + + ArrayRef argumentTypes = matcherSymbol.getArgumentTypes(); + if (argumentTypes.size() != 1 || + !isa(argumentTypes[0])) { + return emitError() + << "expected the matcher to take one operation handle argument"; + } + if (!matcherSymbol.getArgAttr( + 0, transform::TransformDialect::kArgReadOnlyAttrName)) { + return emitError() << "expected the matcher argument to be marked readonly"; + } + + ArrayRef resultTypes = matcherSymbol.getResultTypes(); + if (resultTypes.size() != getOperation()->getNumResults()) { + return emitError() + << "expected the matcher to yield as many values as op has results (" + << getOperation()->getNumResults() << "), got " + << resultTypes.size(); + } + + for (auto &&[i, matcherType, resultType] : + llvm::enumerate(resultTypes, getOperation()->getResultTypes())) { + if (implementSameTransformInterface(matcherType, resultType)) + continue; + + return emitError() + << "mismatching type interfaces for matcher result and op result #" + << i; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ForeachMatchOp +//===----------------------------------------------------------------------===// + DiagnosedSilenceableFailure transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() { return success(); } -/// Returns `true` if both types implement one of the interfaces provided as -/// template parameters. -template -static bool implementSameInterface(Type t1, Type t2) { - return ((isa(t1) && isa(t2)) || ... || false); -} - -/// Returns `true` if both types implement one of the transform dialect -/// interfaces. -static bool implementSameTransformInterface(Type t1, Type t2) { - return implementSameInterface( - t1, t2); -} - /// Checks that the attributes of the function-like operation have correct /// consumption effect annotations. If `alsoVerifyInternal`, checks for /// annotations being present even if they can be inferred from the body. diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir index 5123958b02bfb..233dbbcb6804c 100644 --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -704,3 +704,71 @@ transform.sequence failures(propagate) { // expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}} transform.num_associations %arg0 : (!transform.any_op) -> !transform.param } + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{unresolved matcher symbol @missing_symbol}} + transform.collect_matching @missing_symbol in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{expected the matcher to take one operation handle argument}} + transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } + + transform.named_sequence @matcher() { + transform.yield + } +} + +// ----- + + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{expected the matcher argument to be marked readonly}} + transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } + + transform.named_sequence @matcher(%arg0: !transform.any_op) { + transform.yield + } +} + + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{expected the matcher to yield as many values as op has results (1), got 0}} + transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } + + transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) { + transform.yield + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{mismatching type interfaces for matcher result and op result #0}} + transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_value + transform.yield + } + + transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.yield %arg0 : !transform.any_op + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index 3bbf875ef309e..4ecd731ce4178 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -2380,3 +2380,47 @@ module @named_inclusion attributes { transform.with_named_sequence } { transform.yield } } + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{result #0, associated with 2 payload objects, expected 1}} + transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } + + transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + %0 = transform.merge_handles %arg0, %arg0 : !transform.any_op + transform.yield %0 : !transform.any_op + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-error @below {{unresolved external symbol @matcher}} + transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield + } + + transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + // expected-remark @below {{matched}} + %0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-remark @below {{matched}} + transform.test_print_remark_at_operand %0, "matched" : !transform.any_op + transform.yield + } + + transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.operation_name %arg0 ["transform.test_print_remark_at_operand", "transform.collect_matching"] : !transform.any_op + transform.yield %arg0 : !transform.any_op + } +}