Skip to content

Commit

Permalink
[mlir][transform] Improve error when merging of modules fails. (#69331)
Browse files Browse the repository at this point in the history
This resolved #69112.
  • Loading branch information
ingomueller-net committed Oct 24, 2023
1 parent 1a7061c commit f07718b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ TransformOpInterface findTransformEntryPoint(
// function to clone (or move) `other` in order to improve efficiency.
// This might primarily make sense if we can also prune the symbols that
// are merged to a subset (such as those that are actually used).
LogicalResult mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other);
InFlightDiagnostic mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other);
} // namespace detail

/// Standalone util to apply the named sequence `transformRoot` to `payload` IR.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,14 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
diag.attachNote(target->getLoc()) << "pass anchor op";
return diag;
}
if (failed(detail::mergeSymbolsInto(
SymbolTable::getNearestSymbolTable(transformRoot),
transformLibraryModule->get()->clone())))
return emitError(transformRoot->getLoc(),
"failed to merge library symbols into transform root");
InFlightDiagnostic diag = detail::mergeSymbolsInto(
SymbolTable::getNearestSymbolTable(transformRoot),
transformLibraryModule->get()->clone());
if (failed(diag)) {
diag.attachNote(transformRoot->getLoc())
<< "failed to merge library symbols into transform root";
return diag;
}
}

// Step 4
Expand Down
38 changes: 21 additions & 17 deletions mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ LogicalResult transform::detail::assembleTransformLibraryFromPaths(
for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
if (failed(transform::detail::mergeSymbolsInto(
mergedParsedLibraries.get(), std::move(parsedLibrary))))
return mergedParsedLibraries->emitError()
<< "failed to verify merged transform module";
return parsedLibrary->emitError()
<< "failed to merge symbols into shared library module";
}
}

Expand All @@ -197,8 +197,8 @@ static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
/// Merge `func1` into `func2`. The two ops must be inside the same parent op
/// and mergable according to `canMergeInto`. The function erases `func1` such
/// that only `func2` exists when the function returns.
static LogicalResult mergeInto(FunctionOpInterface func1,
FunctionOpInterface func2) {
static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
FunctionOpInterface func2) {
assert(canMergeInto(func1, func2));
assert(func1->getParentOp() == func2->getParentOp() &&
"expected func1 and func2 to be in the same parent op");
Expand Down Expand Up @@ -241,10 +241,10 @@ static LogicalResult mergeInto(FunctionOpInterface func1,
assert(func1.isExternal());
func1->erase();

return success();
return InFlightDiagnostic();
}

LogicalResult
InFlightDiagnostic
transform::detail::mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other) {
assert(target->hasTrait<OpTrait::SymbolTable>() &&
Expand Down Expand Up @@ -301,7 +301,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
auto renameToUnique =
[&](SymbolOpInterface op, SymbolOpInterface otherOp,
SymbolTable &symbolTable,
SymbolTable &otherSymbolTable) -> LogicalResult {
SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
FailureOr<StringAttr> maybeNewName =
symbolTable.renameToUnique(op, {&otherSymbolTable});
Expand All @@ -313,19 +313,21 @@ transform::detail::mergeSymbolsInto(Operation *target,
}
LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
<< "\n");
return success();
return InFlightDiagnostic();
};

if (symbolOp.isPrivate()) {
if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
*otherSymbolTable)))
return failure();
InFlightDiagnostic diag = renameToUnique(
symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
if (failed(diag))
return diag;
continue;
}
if (collidingOp.isPrivate()) {
if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
*symbolTable)))
return failure();
InFlightDiagnostic diag = renameToUnique(
collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
if (failed(diag))
return diag;
continue;
}
LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
Expand Down Expand Up @@ -394,8 +396,10 @@ transform::detail::mergeSymbolsInto(Operation *target,
assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);

// Do the actual merging.
if (failed(mergeInto(funcOp, collidingFuncOp))) {
return failure();
{
InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
if (failed(diag))
return diag;
}
}
}
Expand All @@ -405,7 +409,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
<< "failed to verify target op after merging symbols";

LLVM_DEBUG(DBGS() << "done merging ops\n");
return success();
return InFlightDiagnostic();
}

LogicalResult transform::applyTransformNamedSequence(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has a mismatching signature}}
transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})

// expected-error @below {{failed to merge library symbols into transform root}}
// expected-note @below {{failed to merge library symbols into transform root}}
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.op<"builtin.module">):
include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
Expand All @@ -33,7 +33,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})

// expected-error @below {{failed to merge library symbols into transform root}}
// expected-note @below {{failed to merge library symbols into transform root}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
Expand All @@ -49,7 +49,7 @@ module attributes {transform.with_named_sequence} {
transform.yield
}

// expected-error @below {{failed to merge library symbols into transform root}}
// expected-note @below {{failed to merge library symbols into transform root}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
Expand Down

0 comments on commit f07718b

Please sign in to comment.