New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] introduce transform.collect_matching #76724
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
#include "mlir/IR/Verifier.h" | ||
#include "mlir/Interfaces/ControlFlowInterfaces.h" | ||
#include "mlir/Interfaces/FunctionImplementation.h" | ||
#include "mlir/Interfaces/FunctionInterfaces.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Pass/PassManager.h" | ||
#include "mlir/Pass/PassRegistry.h" | ||
|
@@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { | |
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ForeachMatchOp | ||
// CollectMatchingOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
/// Applies matcher operations from the given `block` assigning `op` as the | ||
|
@@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state, | |
return DiagnosedSilenceableFailure::success(); | ||
} | ||
|
||
/// Returns `true` if both types implement one of the interfaces provided as | ||
/// template parameters. | ||
template <typename... Tys> | ||
static bool implementSameInterface(Type t1, Type t2) { | ||
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false); | ||
} | ||
|
||
/// Returns `true` if both types implement one of the transform dialect | ||
/// interfaces. | ||
static bool implementSameTransformInterface(Type t1, Type t2) { | ||
return implementSameInterface<transform::TransformHandleTypeInterface, | ||
transform::TransformParamTypeInterface, | ||
transform::TransformValueHandleTypeInterface>( | ||
t1, t2); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// CollectMatchingOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
DiagnosedSilenceableFailure | ||
transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, | ||
transform::TransformResults &results, | ||
transform::TransformState &state) { | ||
auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>( | ||
getOperation(), getMatcher()); | ||
if (matcher.isExternal()) { | ||
return emitDefiniteFailure() | ||
<< "unresolved external symbol " << getMatcher(); | ||
} | ||
|
||
SmallVector<SmallVector<MappedValue>, 2> rawResults; | ||
rawResults.resize(getOperation()->getNumResults()); | ||
std::optional<DiagnosedSilenceableFailure> maybeFailure; | ||
for (Operation *root : state.getPayloadOps(getRoot())) { | ||
WalkResult walkResult = root->walk([&](Operation *op) { | ||
DEBUG_MATCHER({ | ||
DBGS_MATCHER() << "matching "; | ||
op->print(llvm::dbgs(), | ||
OpPrintingFlags().assumeVerified().skipRegions()); | ||
llvm::dbgs() << " @" << op << "\n"; | ||
}); | ||
|
||
// Try matching. | ||
SmallVector<SmallVector<MappedValue>> mappings; | ||
DiagnosedSilenceableFailure diag = | ||
matchBlock(matcher.getFunctionBody().front(), op, state, mappings); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FMI: what is the use case for matchers returning definite failures? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normally, match operations should not. But we don't internally differentiate match ops from other transform ops internally right now. |
||
if (diag.isDefiniteFailure()) | ||
return WalkResult::interrupt(); | ||
if (diag.isSilenceableFailure()) { | ||
DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() | ||
<< " failed: " << diag.getMessage()); | ||
return WalkResult::advance(); | ||
} | ||
|
||
// If succeeded, collect results. | ||
for (auto &&[i, mapping] : llvm::enumerate(mappings)) { | ||
if (mapping.size() != 1) { | ||
maybeFailure.emplace(emitSilenceableError() | ||
<< "result #" << i << ", associated with " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sounds like incorrect usage of the transform op. Why is this not a definite error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For me, definite errors are a graceful equivalent of |
||
<< mapping.size() | ||
<< " payload objects, expected 1"); | ||
return WalkResult::interrupt(); | ||
} | ||
rawResults[i].push_back(mapping[0]); | ||
} | ||
return WalkResult::advance(); | ||
}); | ||
if (walkResult.wasInterrupted()) | ||
return std::move(*maybeFailure); | ||
assert(!maybeFailure && "failure set but the walk was not interrupted"); | ||
|
||
for (auto &&[opResult, rawResult] : | ||
llvm::zip_equal(getOperation()->getResults(), rawResults)) { | ||
results.setMappedValues(opResult, rawResult); | ||
} | ||
} | ||
return DiagnosedSilenceableFailure::success(); | ||
} | ||
|
||
void transform::CollectMatchingOp::getEffects( | ||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | ||
onlyReadsHandle(getRoot(), effects); | ||
producesHandle(getResults(), effects); | ||
onlyReadsPayload(effects); | ||
} | ||
|
||
LogicalResult transform::CollectMatchingOp::verifySymbolUses( | ||
SymbolTableCollection &symbolTable) { | ||
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( | ||
symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher())); | ||
if (!matcherSymbol || | ||
!isa<TransformOpInterface>(matcherSymbol.getOperation())) | ||
return emitError() << "unresolved matcher symbol " << getMatcher(); | ||
|
||
ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes(); | ||
if (argumentTypes.size() != 1 || | ||
!isa<TransformHandleTypeInterface>(argumentTypes[0])) { | ||
return emitError() | ||
<< "expected the matcher to take one operation handle argument"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. value handles are also allowed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. |
||
} | ||
if (!matcherSymbol.getArgAttr( | ||
0, transform::TransformDialect::kArgReadOnlyAttrName)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there an easy way to check that matcher does not modify the payload? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The presence of this attribute triggers the verifier on named sequence to check that ops nested in the named sequence don't have a write effect on the payload IR. The effect specification may be wrong, but there is little we can do in that case, other than cloning the payload before and comparing that indeed nothing has changed under an -enable-even-more-expensive checks flag. |
||
return emitError() << "expected the matcher argument to be marked readonly"; | ||
} | ||
|
||
ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes(); | ||
if (resultTypes.size() != getOperation()->getNumResults()) { | ||
return emitError() | ||
<< "expected the matcher to yield as many values as op has results (" | ||
<< getOperation()->getNumResults() << "), got " | ||
<< resultTypes.size(); | ||
} | ||
|
||
for (auto &&[i, matcherType, resultType] : | ||
llvm::enumerate(resultTypes, getOperation()->getResultTypes())) { | ||
if (implementSameTransformInterface(matcherType, resultType)) | ||
continue; | ||
|
||
return emitError() | ||
<< "mismatching type interfaces for matcher result and op result #" | ||
<< i; | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ForeachMatchOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
DiagnosedSilenceableFailure | ||
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, | ||
transform::TransformResults &results, | ||
|
@@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() { | |
return success(); | ||
} | ||
|
||
/// Returns `true` if both types implement one of the interfaces provided as | ||
/// template parameters. | ||
template <typename... Tys> | ||
static bool implementSameInterface(Type t1, Type t2) { | ||
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false); | ||
} | ||
|
||
/// Returns `true` if both types implement one of the transform dialect | ||
/// interfaces. | ||
static bool implementSameTransformInterface(Type t1, Type t2) { | ||
return implementSameInterface<transform::TransformHandleTypeInterface, | ||
transform::TransformParamTypeInterface, | ||
transform::TransformValueHandleTypeInterface>( | ||
t1, t2); | ||
} | ||
|
||
/// Checks that the attributes of the function-like operation have correct | ||
/// consumption effect annotations. If `alsoVerifyInternal`, checks for | ||
/// annotations being present even if they can be inferred from the body. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be verified in the op verifier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, at verification time we allow external symbols. They may be "resolved" later by merging modules or by providing a library module to the interpreter.