Skip to content

Commit

Permalink
[mlir] make DiagnosedSilenceableError(LogicalResult) ctor private
Browse files Browse the repository at this point in the history
Now we have more convenient functions to construct silenceable errors
while emitting diagnostics, and the constructor is ambiguous as it
doesn't tell whether the logical error is silencebale or definite.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D137257
  • Loading branch information
ftynse committed Dec 12, 2022
1 parent 843be73 commit 7d5bef7
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 51 deletions.
Expand Up @@ -34,7 +34,6 @@ namespace mlir {
/// failures as their diagnostics have been already reported to the user.
class [[nodiscard]] DiagnosedSilenceableFailure {
public:
explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
DiagnosedSilenceableFailure(const DiagnosedSilenceableFailure &) = delete;
DiagnosedSilenceableFailure &
operator=(const DiagnosedSilenceableFailure &) = delete;
Expand Down Expand Up @@ -156,6 +155,7 @@ class [[nodiscard]] DiagnosedSilenceableFailure {
}

private:
explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic)
: result(failure()) {
diagnostics.emplace_back(std::move(diagnostic));
Expand Down
24 changes: 12 additions & 12 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
Expand Up @@ -51,23 +51,12 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
];

let extraSharedClassDeclaration = [{
/// Emits a generic transform error for the current transform operation
/// targeting the given Payload IR operation and returns failure. Should
/// be only used as a last resort when the transformation itself provides
/// no further indication as to the reason of the failure.
::mlir::LogicalResult reportUnknownTransformError(
::mlir::Operation *target) {
::mlir::InFlightDiagnostic diag = $_op->emitError() << "failed to apply";
diag.attachNote(target->getLoc()) << "attempted to apply to this op";
return diag;
}

/// Creates the silenceable failure object with a diagnostic located at the
/// current operation. Silenceable failure must be suppressed or reported
/// explicitly at some later time.
DiagnosedSilenceableFailure
emitSilenceableError(const ::llvm::Twine &message = {}) {
return ::mlir::emitSilenceableFailure($_op);
return ::mlir::emitSilenceableFailure($_op, message);
}

/// Creates the definite failure object with a diagnostic located at the
Expand All @@ -78,6 +67,17 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
return ::mlir::emitDefiniteFailure($_op, message);
}

/// Emits a generic definite failure for the current transform operation
/// targeting the given Payload IR operation and returns failure. Should
/// be only used as a last resort when the transformation itself provides
/// no further indication as to the reason of the failure.
DiagnosedDefiniteFailure emitDefaultDefiniteFailure(
::mlir::Operation *target) {
auto diag = ::mlir::emitDefiniteFailure($_op, "failed to apply");
diag.attachNote(target->getLoc()) << "attempted to apply to this op";
return diag;
}

/// Creates the default silenceable failure for a transform op that failed
/// to properly apply to a target.
DiagnosedSilenceableFailure emitDefaultSilenceableFailure(
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
Expand Up @@ -119,7 +119,7 @@ createGpuLaunch(RewriterBase &rewriter, Location loc,
blkSizeX, blkSizeY, blkSizeZ);
rewriter.setInsertionPointToEnd(&launchOp.getBody().front());
rewriter.create<TerminatorOp>(loc);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

/// Alter kernel configuration of the given kernel.
Expand Down
58 changes: 29 additions & 29 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Expand Up @@ -79,20 +79,20 @@ transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
Conv1DNwcWcfOp>>(target);
if (succeeded(windowedNhwc)) {
results.push_back(*windowedNhwc);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}
FailureOr<LinalgOp> windowedNchw =
tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
Conv1DNcwFcwOp>>(target);
if (succeeded(windowedNchw)) {
results.push_back(*windowedNchw);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}
FailureOr<LinalgOp> depthwise =
tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
if (succeeded(depthwise)) {
results.push_back(*depthwise);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
Expand Down Expand Up @@ -206,7 +206,8 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
rewriter, tilingInterfaceOp, tileAndFuseOptions);
});
return DiagnosedSilenceableFailure(result);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}

ParseResult transform::FuseOp::parse(OpAsmParser &parser,
Expand Down Expand Up @@ -568,12 +569,12 @@ transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
// Exit early if no transformation is needed.
if (isa<GenericOp>(target)) {
results.push_back(target);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}
FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
if (succeeded(generic)) {
results.push_back(generic->getOperation());
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
Expand All @@ -592,15 +593,15 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
// Exit early if no transformation is needed.
if (interchangeVector.empty()) {
results.push_back(target);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}
TrivialPatternRewriter rewriter(target->getContext());
FailureOr<GenericOp> res =
interchangeGenericOp(rewriter, target, interchangeVector);
if (failed(res))
return DiagnosedSilenceableFailure::definiteFailure();
results.push_back(res->getOperation());
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

LogicalResult transform::InterchangeOp::verify() {
Expand Down Expand Up @@ -639,8 +640,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
if (payloadOps.size() != 1) {
results.set(getResult().cast<OpResult>(), {});
return DiagnosedSilenceableFailure(
this->emitOpError("requires exactly one target handle"));
return emitDefiniteFailure("requires exactly one target handle");
}

SmallVector<Operation *> res;
Expand Down Expand Up @@ -687,7 +687,7 @@ transform::MatchOp::apply(transform::TransformResults &results,

payloadOps.front()->walk(matchFun);
results.set(getResult().cast<OpResult>(), res);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

//===---------------------------------------------------------------------===//
Expand Down Expand Up @@ -792,7 +792,7 @@ transform::PadOp::applyToOne(linalg::LinalgOp target,
tryApply<LinalgPaddingPattern>(target, paddingOptions);
if (succeeded(result)) {
results.push_back(result->getOperation());
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

results.assign(1, nullptr);
Expand Down Expand Up @@ -866,15 +866,15 @@ transform::PromoteOp::applyToOne(linalg::LinalgOp target,
promotionOptions = promotionOptions.setAlignment(*getAlignment());

if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
return emitDefaultDefiniteFailure(target);

TrivialPatternRewriter rewriter(target->getContext());
rewriter.setInsertionPoint(target);
FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
if (failed(res))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
return emitDefaultDefiniteFailure(target);
results.push_back(target);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -909,7 +909,7 @@ transform::ReplaceOp::apply(TransformResults &transformResults,
replacements.push_back(replacement);
}
transformResults.set(getReplacement().cast<OpResult>(), replacements);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

void transform::ReplaceOp::getEffects(
Expand Down Expand Up @@ -972,10 +972,10 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
if (failed(maybeTilingResult))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
return emitDefaultDefiniteFailure(target);

results.append(maybeTilingResult->tiledOps);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1171,13 +1171,13 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
: splitReduction(rewriter, target, splitFn, getUseAlloc());
if (failed(splitResult))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
return emitDefaultDefiniteFailure(target);

results.push_back(splitResult->initOrAlloc);
results.push_back(splitResult->fillOp);
results.push_back(splitResult->splitLinalgOp);
results.push_back(splitResult->resultCombiningLinalgOp);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
Expand All @@ -1200,12 +1200,12 @@ DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
sizes);

if (failed(result))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
return emitDefaultSilenceableFailure(target);
results.push_back(result->loops.front());
results.push_back(result->initialOp);
results.push_back(result->parallelTiledOp);
results.push_back(result->mergeOp);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1235,7 +1235,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
results.push_back(result->initialOp);
results.push_back(result->parallelTiledOp);
results.push_back(result->mergeOp);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1523,7 +1523,7 @@ static DiagnosedSilenceableFailure unpackPDLOperations(
}
}

return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
Expand All @@ -1533,7 +1533,7 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> mapping,
SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
if (targets.empty())
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();

// getMixedNumThreads are OpFoldResults[index attributes or PDL operation].
// Convert to OpFoldResults[index attributes or payload op].
Expand Down Expand Up @@ -1577,7 +1577,7 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
tileOps.push_back(tilingResult->tileOp);
tiledOps.push_back(tilingResult->tiledOp);
}
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
Expand All @@ -1604,7 +1604,7 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);

return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

void transform::TileToForeachThreadOp::getEffects(
Expand Down Expand Up @@ -1852,10 +1852,10 @@ transform::VectorizeOp::applyToOne(Operation *target,
linalg::populatePadOpVectorizationPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
return emitDefaultDefiniteFailure(target);

results.push_back(target);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
Expand Down
Expand Up @@ -33,7 +33,7 @@ transform::MemRefMultiBufferOp::applyToOne(memref::AllocOp target,
}

results.push_back(newBuffer.value());
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
Expand Down
12 changes: 5 additions & 7 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Expand Up @@ -103,10 +103,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
rewriter, location, exec.getRegion(), getFuncName(), &call);

if (failed(outlined)) {
(void)reportUnknownTransformError(target);
return DiagnosedSilenceableFailure::definiteFailure();
}
if (failed(outlined))
return emitDefaultDefiniteFailure(target);

if (symbolTableOp) {
SymbolTable &symbolTable =
Expand Down Expand Up @@ -139,7 +137,7 @@ transform::LoopPeelOp::applyToOne(scf::ForOp target,
scf::peelAndCanonicalizeForLoop(rewriter, target, result);
// TODO: Return both the peeled loop and the remainder loop.
results.push_back(failed(status) ? target : result);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -200,7 +198,7 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target,
pattern.returningMatchAndRewrite(target, rewriter);
if (succeeded(patternResult)) {
results.push_back(*patternResult);
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
Expand All @@ -225,7 +223,7 @@ transform::LoopUnrollOp::applyToOne(Operation *op,
diag << "Op failed to unroll";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
return DiagnosedSilenceableFailure(success());
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 7d5bef7

Please sign in to comment.