Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] introduce transform.collect_matching #76724

Merged
merged 4 commits into from Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 34 additions & 1 deletion mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Expand Up @@ -460,6 +460,39 @@ def NumAssociationsOp : TransformDialectOp<"num_associations",
let hasVerifier = 1;
}

def CollectMatchingOp : TransformDialectOp<"collect_matching", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Collects all payload ops that match the given named matcher";
let description = [{
Collects operations or other payload IR objects nested under `root`
(inclusive) that match the given matcher expressed as a named sequence. The
matcher sequence must accept exactly one argument that it is not allowed to
modify. It must yield as many values as this op has results. Each of the
yielded values must be associated with exactly one payload object. If any
operation in the matcher sequence produces a silenceable failure, the
matcher advances to the next payload operation in the walk order without
finishing the sequence.

The i-th result of this operation is constructed by concatenating the i-th
yielded payload IR objects of all successful matcher sequence applications.
All results are guaranteed to be mapped to the same number of payload IR
objects.

The operation succeeds unless the matcher sequence produced a definite
failure for any invocation.
}];

let arguments = (ins TransformHandleTypeInterface:$root,
SymbolRefAttr:$matcher);
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);

let assemblyFormat = [{
$matcher `in` $root attr-dict `:` functional-type($root, $results)
}];
}

def ForeachMatchOp : TransformDialectOp<"foreach_match", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
Expand Down Expand Up @@ -674,7 +707,7 @@ def GetParentOp : TransformDialectOp<"get_parent_op",

def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
let summary = "Get handle to the producer of this operation's operand number";
let description = [{
The handle defined by this Transform op corresponds to operation that
Expand Down
150 changes: 133 additions & 17 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Expand Up @@ -22,6 +22,7 @@
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
Expand Down Expand Up @@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}

//===----------------------------------------------------------------------===//
// ForeachMatchOp
// CollectMatchingOp
//===----------------------------------------------------------------------===//

/// Applies matcher operations from the given `block` assigning `op` as the
Expand Down Expand Up @@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
return DiagnosedSilenceableFailure::success();
}

/// Returns `true` if both types implement one of the interfaces provided as
/// template parameters.
template <typename... Tys>
static bool implementSameInterface(Type t1, Type t2) {
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
}

/// Returns `true` if both types implement one of the transform dialect
/// interfaces.
static bool implementSameTransformInterface(Type t1, Type t2) {
return implementSameInterface<transform::TransformHandleTypeInterface,
transform::TransformParamTypeInterface,
transform::TransformValueHandleTypeInterface>(
t1, t2);
}

//===----------------------------------------------------------------------===//
// CollectMatchingOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
getOperation(), getMatcher());
if (matcher.isExternal()) {
return emitDefiniteFailure()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be verified in the op verifier?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, at verification time we allow external symbols. They may be "resolved" later by merging modules or by providing a library module to the interpreter.

<< "unresolved external symbol " << getMatcher();
}

SmallVector<SmallVector<MappedValue>, 2> rawResults;
rawResults.resize(getOperation()->getNumResults());
std::optional<DiagnosedSilenceableFailure> maybeFailure;
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
DEBUG_MATCHER({
DBGS_MATCHER() << "matching ";
op->print(llvm::dbgs(),
OpPrintingFlags().assumeVerified().skipRegions());
llvm::dbgs() << " @" << op << "\n";
});

// Try matching.
SmallVector<SmallVector<MappedValue>> mappings;
DiagnosedSilenceableFailure diag =
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FMI: what is the use case for matchers returning definite failures?
Can they modify IR and leave it in an undefined state?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally, match operations should not. But we don't internally differentiate match ops from other transform ops internally right now.

if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
<< " failed: " << diag.getMessage());
return WalkResult::advance();
}

// If succeeded, collect results.
for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
if (mapping.size() != 1) {
maybeFailure.emplace(emitSilenceableError()
<< "result #" << i << ", associated with "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like incorrect usage of the transform op. Why is this not a definite error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me, definite errors are a graceful equivalent of abort(). They mean that the payload IR was modified in an irrecoverable way and we should immediately stop the interpreter, and the parent pass. This is fully recoverable. The surrounding op (usually sequence) may choose to suppress the error or report it.

<< mapping.size()
<< " payload objects, expected 1");
return WalkResult::interrupt();
}
rawResults[i].push_back(mapping[0]);
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return std::move(*maybeFailure);
assert(!maybeFailure && "failure set but the walk was not interrupted");

for (auto &&[opResult, rawResult] :
llvm::zip_equal(getOperation()->getResults(), rawResults)) {
results.setMappedValues(opResult, rawResult);
}
}
return DiagnosedSilenceableFailure::success();
}

void transform::CollectMatchingOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getRoot(), effects);
producesHandle(getResults(), effects);
onlyReadsPayload(effects);
}

LogicalResult transform::CollectMatchingOp::verifySymbolUses(
SymbolTableCollection &symbolTable) {
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
if (!matcherSymbol ||
!isa<TransformOpInterface>(matcherSymbol.getOperation()))
return emitError() << "unresolved matcher symbol " << getMatcher();

ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
if (argumentTypes.size() != 1 ||
!isa<TransformHandleTypeInterface>(argumentTypes[0])) {
return emitError()
<< "expected the matcher to take one operation handle argument";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value handles are also allowed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. TransformHandleTypeInterface has always referred to operation handles. Maybe we should rename it eventually.

}
if (!matcherSymbol.getArgAttr(
0, transform::TransformDialect::kArgReadOnlyAttrName)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an easy way to check that matcher does not modify the payload?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The presence of this attribute triggers the verifier on named sequence to check that ops nested in the named sequence don't have a write effect on the payload IR. The effect specification may be wrong, but there is little we can do in that case, other than cloning the payload before and comparing that indeed nothing has changed under an -enable-even-more-expensive checks flag.

return emitError() << "expected the matcher argument to be marked readonly";
}

ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
if (resultTypes.size() != getOperation()->getNumResults()) {
return emitError()
<< "expected the matcher to yield as many values as op has results ("
<< getOperation()->getNumResults() << "), got "
<< resultTypes.size();
}

for (auto &&[i, matcherType, resultType] :
llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
if (implementSameTransformInterface(matcherType, resultType))
continue;

return emitError()
<< "mismatching type interfaces for matcher result and op result #"
<< i;
}

return success();
}

//===----------------------------------------------------------------------===//
// ForeachMatchOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
Expand Down Expand Up @@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() {
return success();
}

/// Returns `true` if both types implement one of the interfaces provided as
/// template parameters.
template <typename... Tys>
static bool implementSameInterface(Type t1, Type t2) {
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
}

/// Returns `true` if both types implement one of the transform dialect
/// interfaces.
static bool implementSameTransformInterface(Type t1, Type t2) {
return implementSameInterface<transform::TransformHandleTypeInterface,
transform::TransformParamTypeInterface,
transform::TransformValueHandleTypeInterface>(
t1, t2);
}

/// Checks that the attributes of the function-like operation have correct
/// consumption effect annotations. If `alsoVerifyInternal`, checks for
/// annotations being present even if they can be inferred from the body.
Expand Down
68 changes: 68 additions & 0 deletions mlir/test/Dialect/Transform/ops-invalid.mlir
Expand Up @@ -704,3 +704,71 @@ transform.sequence failures(propagate) {
// expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}}
transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32>
}

// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{unresolved matcher symbol @missing_symbol}}
transform.collect_matching @missing_symbol in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{expected the matcher to take one operation handle argument}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}

transform.named_sequence @matcher() {
transform.yield
}
}

// -----


module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{expected the matcher argument to be marked readonly}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}

transform.named_sequence @matcher(%arg0: !transform.any_op) {
transform.yield
}
}


// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{expected the matcher to yield as many values as op has results (1), got 0}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}

transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) {
transform.yield
}
}

// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{mismatching type interfaces for matcher result and op result #0}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_value
transform.yield
}

transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
transform.yield %arg0 : !transform.any_op
}
}
44 changes: 44 additions & 0 deletions mlir/test/Dialect/Transform/test-interpreter.mlir
Expand Up @@ -2380,3 +2380,47 @@ module @named_inclusion attributes { transform.with_named_sequence } {
transform.yield
}
}

// -----
ftynse marked this conversation as resolved.
Show resolved Hide resolved

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{result #0, associated with 2 payload objects, expected 1}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}

transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
%0 = transform.merge_handles %arg0, %arg0 : !transform.any_op
transform.yield %0 : !transform.any_op
}
}

// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{unresolved external symbol @matcher}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}

transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
}

// -----

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-remark @below {{matched}}
%0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
// expected-remark @below {{matched}}
transform.test_print_remark_at_operand %0, "matched" : !transform.any_op
transform.yield
}

transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
transform.match.operation_name %arg0 ["transform.test_print_remark_at_operand", "transform.collect_matching"] : !transform.any_op
transform.yield %arg0 : !transform.any_op
}
}