Skip to content

Commit

Permalink
[mlir] introduce transform.collect_matching (#76724)
Browse files Browse the repository at this point in the history
Introduce a new match combinator into the transform dialect. This
operation collects all operations that are yielded by a satisfactory
match into its results. This is a simpler version of `foreach_match`
that can be inserted directly into existing transform scripts.
  • Loading branch information
ftynse committed Jan 9, 2024
1 parent 4f7c402 commit 633d918
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 18 deletions.
35 changes: 34 additions & 1 deletion mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Expand Up @@ -460,6 +460,39 @@ def NumAssociationsOp : TransformDialectOp<"num_associations",
let hasVerifier = 1;
}

def CollectMatchingOp : TransformDialectOp<"collect_matching", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Collects all payload ops that match the given named matcher";
let description = [{
Collects operations or other payload IR objects nested under `root`
(inclusive) that match the given matcher expressed as a named sequence. The
matcher sequence must accept exactly one argument that it is not allowed to
modify. It must yield as many values as this op has results. Each of the
yielded values must be associated with exactly one payload object. If any
operation in the matcher sequence produces a silenceable failure, the
matcher advances to the next payload operation in the walk order without
finishing the sequence.

The i-th result of this operation is constructed by concatenating the i-th
yielded payload IR objects of all successful matcher sequence applications.
All results are guaranteed to be mapped to the same number of payload IR
objects.

The operation succeeds unless the matcher sequence produced a definite
failure for any invocation.
}];

let arguments = (ins TransformHandleTypeInterface:$root,
SymbolRefAttr:$matcher);
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);

let assemblyFormat = [{
$matcher `in` $root attr-dict `:` functional-type($root, $results)
}];
}

def ForeachMatchOp : TransformDialectOp<"foreach_match", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
Expand Down Expand Up @@ -674,7 +707,7 @@ def GetParentOp : TransformDialectOp<"get_parent_op",

def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
let summary = "Get handle to the producer of this operation's operand number";
let description = [{
The handle defined by this Transform op corresponds to operation that
Expand Down
150 changes: 133 additions & 17 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Expand Up @@ -22,6 +22,7 @@
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
Expand Down Expand Up @@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}

//===----------------------------------------------------------------------===//
// ForeachMatchOp
// CollectMatchingOp
//===----------------------------------------------------------------------===//

/// Applies matcher operations from the given `block` assigning `op` as the
Expand Down Expand Up @@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
return DiagnosedSilenceableFailure::success();
}

/// Returns `true` if both types implement one of the interfaces provided as
/// template parameters.
template <typename... Tys>
static bool implementSameInterface(Type t1, Type t2) {
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
}

/// Returns `true` if both types implement one of the transform dialect
/// interfaces.
static bool implementSameTransformInterface(Type t1, Type t2) {
return implementSameInterface<transform::TransformHandleTypeInterface,
transform::TransformParamTypeInterface,
transform::TransformValueHandleTypeInterface>(
t1, t2);
}

//===----------------------------------------------------------------------===//
// CollectMatchingOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
getOperation(), getMatcher());
if (matcher.isExternal()) {
return emitDefiniteFailure()
<< "unresolved external symbol " << getMatcher();
}

SmallVector<SmallVector<MappedValue>, 2> rawResults;
rawResults.resize(getOperation()->getNumResults());
std::optional<DiagnosedSilenceableFailure> maybeFailure;
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
DEBUG_MATCHER({
DBGS_MATCHER() << "matching ";
op->print(llvm::dbgs(),
OpPrintingFlags().assumeVerified().skipRegions());
llvm::dbgs() << " @" << op << "\n";
});

// Try matching.
SmallVector<SmallVector<MappedValue>> mappings;
DiagnosedSilenceableFailure diag =
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
<< " failed: " << diag.getMessage());
return WalkResult::advance();
}

// If succeeded, collect results.
for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
if (mapping.size() != 1) {
maybeFailure.emplace(emitSilenceableError()
<< "result #" << i << ", associated with "
<< mapping.size()
<< " payload objects, expected 1");
return WalkResult::interrupt();
}
rawResults[i].push_back(mapping[0]);
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return std::move(*maybeFailure);
assert(!maybeFailure && "failure set but the walk was not interrupted");

for (auto &&[opResult, rawResult] :
llvm::zip_equal(getOperation()->getResults(), rawResults)) {
results.setMappedValues(opResult, rawResult);
}
}
return DiagnosedSilenceableFailure::success();
}

void transform::CollectMatchingOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getRoot(), effects);
producesHandle(getResults(), effects);
onlyReadsPayload(effects);
}

LogicalResult transform::CollectMatchingOp::verifySymbolUses(
SymbolTableCollection &symbolTable) {
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
if (!matcherSymbol ||
!isa<TransformOpInterface>(matcherSymbol.getOperation()))
return emitError() << "unresolved matcher symbol " << getMatcher();

ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
if (argumentTypes.size() != 1 ||
!isa<TransformHandleTypeInterface>(argumentTypes[0])) {
return emitError()
<< "expected the matcher to take one operation handle argument";
}
if (!matcherSymbol.getArgAttr(
0, transform::TransformDialect::kArgReadOnlyAttrName)) {
return emitError() << "expected the matcher argument to be marked readonly";
}

ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
if (resultTypes.size() != getOperation()->getNumResults()) {
return emitError()
<< "expected the matcher to yield as many values as op has results ("
<< getOperation()->getNumResults() << "), got "
<< resultTypes.size();
}

for (auto &&[i, matcherType, resultType] :
llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
if (implementSameTransformInterface(matcherType, resultType))
continue;

return emitError()
<< "mismatching type interfaces for matcher result and op result #"
<< i;
}

return success();
}

//===----------------------------------------------------------------------===//
// ForeachMatchOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
Expand Down Expand Up @@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() {
return success();
}

/// Returns `true` if both types implement one of the interfaces provided as
/// template parameters.
template <typename... Tys>
static bool implementSameInterface(Type t1, Type t2) {
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
}

/// Returns `true` if both types implement one of the transform dialect
/// interfaces.
static bool implementSameTransformInterface(Type t1, Type t2) {
return implementSameInterface<transform::TransformHandleTypeInterface,
transform::TransformParamTypeInterface,
transform::TransformValueHandleTypeInterface>(
t1, t2);
}

/// Checks that the attributes of the function-like operation have correct
/// consumption effect annotations. If `alsoVerifyInternal`, checks for
/// annotations being present even if they can be inferred from the body.
Expand Down
68 changes: 68 additions & 0 deletions mlir/test/Dialect/Transform/ops-invalid.mlir
Expand Up @@ -704,3 +704,71 @@ transform.sequence failures(propagate) {
// expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}}
transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32>
}

// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{unresolved matcher symbol @missing_symbol}}
transform.collect_matching @missing_symbol in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{expected the matcher to take one operation handle argument}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}

transform.named_sequence @matcher() {
transform.yield
}
}

// -----


module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{expected the matcher argument to be marked readonly}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}

transform.named_sequence @matcher(%arg0: !transform.any_op) {
transform.yield
}
}


// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{expected the matcher to yield as many values as op has results (1), got 0}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}

transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) {
transform.yield
}
}

// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{mismatching type interfaces for matcher result and op result #0}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_value
transform.yield
}

transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
transform.yield %arg0 : !transform.any_op
}
}
44 changes: 44 additions & 0 deletions mlir/test/Dialect/Transform/test-interpreter.mlir
Expand Up @@ -2380,3 +2380,47 @@ module @named_inclusion attributes { transform.with_named_sequence } {
transform.yield
}
}

// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{result #0, associated with 2 payload objects, expected 1}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}

transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
%0 = transform.merge_handles %arg0, %arg0 : !transform.any_op
transform.yield %0 : !transform.any_op
}
}

// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{unresolved external symbol @matcher}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}

transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
}

// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-remark @below {{matched}}
%0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
// expected-remark @below {{matched}}
transform.test_print_remark_at_operand %0, "matched" : !transform.any_op
transform.yield
}

transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
transform.match.operation_name %arg0 ["transform.test_print_remark_at_operand", "transform.collect_matching"] : !transform.any_op
transform.yield %arg0 : !transform.any_op
}
}

0 comments on commit 633d918

Please sign in to comment.