Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][transform] Improve error when merging of modules fails. #69331

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ TransformOpInterface findTransformEntryPoint(
// function to clone (or move) `other` in order to improve efficiency.
// This might primarily make sense if we can also prune the symbols that
// are merged to a subset (such as those that are actually used).
LogicalResult mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other);
InFlightDiagnostic mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other);
} // namespace detail

/// Standalone util to apply the named sequence `entryPoint` to the payload.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,14 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
diag.attachNote(target->getLoc()) << "pass anchor op";
return diag;
}
if (failed(detail::mergeSymbolsInto(
SymbolTable::getNearestSymbolTable(transformRoot),
transformLibraryModule->get()->clone())))
return emitError(transformRoot->getLoc(),
"failed to merge library symbols into transform root");
InFlightDiagnostic diag = detail::mergeSymbolsInto(
SymbolTable::getNearestSymbolTable(transformRoot),
transformLibraryModule->get()->clone());
if (failed(diag)) {
diag.attachNote(transformRoot->getLoc())
<< "failed to merge library symbols into transform root";
return diag;
}
}

// Step 4
Expand Down
38 changes: 21 additions & 17 deletions mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ LogicalResult transform::detail::assembleTransformLibraryFromPaths(
for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
if (failed(transform::detail::mergeSymbolsInto(
mergedParsedLibraries.get(), std::move(parsedLibrary))))
return mergedParsedLibraries->emitError()
<< "failed to verify merged transform module";
return parsedLibrary->emitError()
<< "failed to merge symbols into shared library module";
}
}

Expand All @@ -196,8 +196,8 @@ static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
/// Merge `func1` into `func2`. The two ops must be inside the same parent op
/// and mergable according to `canMergeInto`. The function erases `func1` such
/// that only `func2` exists when the function returns.
static LogicalResult mergeInto(FunctionOpInterface func1,
FunctionOpInterface func2) {
static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
FunctionOpInterface func2) {
assert(canMergeInto(func1, func2));
assert(func1->getParentOp() == func2->getParentOp() &&
"expected func1 and func2 to be in the same parent op");
Expand Down Expand Up @@ -240,10 +240,10 @@ static LogicalResult mergeInto(FunctionOpInterface func1,
assert(func1.isExternal());
func1->erase();

return success();
return InFlightDiagnostic();
}

LogicalResult
InFlightDiagnostic
transform::detail::mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other) {
assert(target->hasTrait<OpTrait::SymbolTable>() &&
Expand Down Expand Up @@ -300,7 +300,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
auto renameToUnique =
[&](SymbolOpInterface op, SymbolOpInterface otherOp,
SymbolTable &symbolTable,
SymbolTable &otherSymbolTable) -> LogicalResult {
SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
FailureOr<StringAttr> maybeNewName =
symbolTable.renameToUnique(op, {&otherSymbolTable});
Expand All @@ -312,19 +312,21 @@ transform::detail::mergeSymbolsInto(Operation *target,
}
LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
<< "\n");
return success();
return InFlightDiagnostic();
};

if (symbolOp.isPrivate()) {
if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
*otherSymbolTable)))
return failure();
InFlightDiagnostic diag = renameToUnique(
symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
if (failed(diag))
return diag;
continue;
}
if (collidingOp.isPrivate()) {
if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
*symbolTable)))
return failure();
InFlightDiagnostic diag = renameToUnique(
collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
if (failed(diag))
return diag;
continue;
}
LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
Expand Down Expand Up @@ -393,8 +395,10 @@ transform::detail::mergeSymbolsInto(Operation *target,
assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);

// Do the actual merging.
if (failed(mergeInto(funcOp, collidingFuncOp))) {
return failure();
{
InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
if (failed(diag))
return diag;
}
}
}
Expand All @@ -404,7 +408,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
<< "failed to verify target op after merging symbols";

LLVM_DEBUG(DBGS() << "done merging ops\n");
return success();
return InFlightDiagnostic();
}

LogicalResult transform::applyTransformNamedSequence(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has a mismatching signature}}
transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})

// expected-error @below {{failed to merge library symbols into transform root}}
// expected-note @below {{failed to merge library symbols into transform root}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.op<"builtin.module">):
include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
Expand All @@ -33,7 +33,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})

// expected-error @below {{failed to merge library symbols into transform root}}
// expected-note @below {{failed to merge library symbols into transform root}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
Expand All @@ -49,7 +49,7 @@ module attributes {transform.with_named_sequence} {
transform.yield
}

// expected-error @below {{failed to merge library symbols into transform root}}
// expected-note @below {{failed to merge library symbols into transform root}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
Expand Down