diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index bbcfabe754aa3..99e12a1067dd8 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -34,7 +34,6 @@ namespace mlir { /// failures as their diagnostics have been already reported to the user. class [[nodiscard]] DiagnosedSilenceableFailure { public: - explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {} DiagnosedSilenceableFailure(const DiagnosedSilenceableFailure &) = delete; DiagnosedSilenceableFailure & operator=(const DiagnosedSilenceableFailure &) = delete; @@ -156,6 +155,7 @@ class [[nodiscard]] DiagnosedSilenceableFailure { } private: + explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {} explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic) : result(failure()) { diagnostics.emplace_back(std::move(diagnostic)); diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td index fe29f303a630a..33781536239e9 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -51,23 +51,12 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> { ]; let extraSharedClassDeclaration = [{ - /// Emits a generic transform error for the current transform operation - /// targeting the given Payload IR operation and returns failure. Should - /// be only used as a last resort when the transformation itself provides - /// no further indication as to the reason of the failure. - ::mlir::LogicalResult reportUnknownTransformError( - ::mlir::Operation *target) { - ::mlir::InFlightDiagnostic diag = $_op->emitError() << "failed to apply"; - diag.attachNote(target->getLoc()) << "attempted to apply to this op"; - return diag; - } - /// Creates the silenceable failure object with a diagnostic located at the /// current operation. Silenceable failure must be suppressed or reported /// explicitly at some later time. DiagnosedSilenceableFailure emitSilenceableError(const ::llvm::Twine &message = {}) { - return ::mlir::emitSilenceableFailure($_op); + return ::mlir::emitSilenceableFailure($_op, message); } /// Creates the definite failure object with a diagnostic located at the @@ -78,6 +67,17 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> { return ::mlir::emitDefiniteFailure($_op, message); } + /// Emits a generic definite failure for the current transform operation + /// targeting the given Payload IR operation and returns failure. Should + /// be only used as a last resort when the transformation itself provides + /// no further indication as to the reason of the failure. + DiagnosedDefiniteFailure emitDefaultDefiniteFailure( + ::mlir::Operation *target) { + auto diag = ::mlir::emitDefiniteFailure($_op, "failed to apply"); + diag.attachNote(target->getLoc()) << "attempted to apply to this op"; + return diag; + } + /// Creates the default silenceable failure for a transform op that failed /// to properly apply to a target. DiagnosedSilenceableFailure emitDefaultSilenceableFailure( diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index baf18dc4d329b..57cce1942803f 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -119,7 +119,7 @@ createGpuLaunch(RewriterBase &rewriter, Location loc, blkSizeX, blkSizeY, blkSizeZ); rewriter.setInsertionPointToEnd(&launchOp.getBody().front()); rewriter.create(loc); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } /// Alter kernel configuration of the given kernel. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 3ae4163b5cc6a..c8dd269029026 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -79,20 +79,20 @@ transform::DecomposeOp::applyToOne(linalg::LinalgOp target, Conv1DNwcWcfOp>>(target); if (succeeded(windowedNhwc)) { results.push_back(*windowedNhwc); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } FailureOr windowedNchw = tryApply>(target); if (succeeded(windowedNchw)) { results.push_back(*windowedNchw); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } FailureOr depthwise = tryApply(target); if (succeeded(depthwise)) { results.push_back(*depthwise); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); @@ -206,7 +206,8 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, return tileConsumerAndFuseProducerGreedilyUsingSCFForOp( rewriter, tilingInterfaceOp, tileAndFuseOptions); }); - return DiagnosedSilenceableFailure(result); + return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() + : DiagnosedSilenceableFailure::success(); } ParseResult transform::FuseOp::parse(OpAsmParser &parser, @@ -568,12 +569,12 @@ transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, // Exit early if no transformation is needed. if (isa(target)) { results.push_back(target); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } FailureOr generic = tryApply(target); if (succeeded(generic)) { results.push_back(generic->getOperation()); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); @@ -592,7 +593,7 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target, // Exit early if no transformation is needed. if (interchangeVector.empty()) { results.push_back(target); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } TrivialPatternRewriter rewriter(target->getContext()); FailureOr res = @@ -600,7 +601,7 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target, if (failed(res)) return DiagnosedSilenceableFailure::definiteFailure(); results.push_back(res->getOperation()); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } LogicalResult transform::InterchangeOp::verify() { @@ -639,8 +640,7 @@ transform::MatchOp::apply(transform::TransformResults &results, ArrayRef payloadOps = state.getPayloadOps(getTarget()); if (payloadOps.size() != 1) { results.set(getResult().cast(), {}); - return DiagnosedSilenceableFailure( - this->emitOpError("requires exactly one target handle")); + return emitDefiniteFailure("requires exactly one target handle"); } SmallVector res; @@ -687,7 +687,7 @@ transform::MatchOp::apply(transform::TransformResults &results, payloadOps.front()->walk(matchFun); results.set(getResult().cast(), res); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===---------------------------------------------------------------------===// @@ -792,7 +792,7 @@ transform::PadOp::applyToOne(linalg::LinalgOp target, tryApply(target, paddingOptions); if (succeeded(result)) { results.push_back(result->getOperation()); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } results.assign(1, nullptr); @@ -866,15 +866,15 @@ transform::PromoteOp::applyToOne(linalg::LinalgOp target, promotionOptions = promotionOptions.setAlignment(*getAlignment()); if (failed(promoteSubviewsPrecondition(target, promotionOptions))) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultDefiniteFailure(target); TrivialPatternRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultDefiniteFailure(target); results.push_back(target); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// @@ -909,7 +909,7 @@ transform::ReplaceOp::apply(TransformResults &transformResults, replacements.push_back(replacement); } transformResults.set(getReplacement().cast(), replacements); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } void transform::ReplaceOp::getEffects( @@ -972,10 +972,10 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(target.getOperation()), tilingOptions); if (failed(maybeTilingResult)) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultDefiniteFailure(target); results.append(maybeTilingResult->tiledOps); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// @@ -1171,13 +1171,13 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) : splitReduction(rewriter, target, splitFn, getUseAlloc()); if (failed(splitResult)) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultDefiniteFailure(target); results.push_back(splitResult->initOrAlloc); results.push_back(splitResult->fillOp); results.push_back(splitResult->splitLinalgOp); results.push_back(splitResult->resultCombiningLinalgOp); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// @@ -1200,12 +1200,12 @@ DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( sizes); if (failed(result)) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultSilenceableFailure(target); results.push_back(result->loops.front()); results.push_back(result->initialOp); results.push_back(result->parallelTiledOp); results.push_back(result->mergeOp); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// @@ -1235,7 +1235,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne( results.push_back(result->initialOp); results.push_back(result->parallelTiledOp); results.push_back(result->mergeOp); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// @@ -1523,7 +1523,7 @@ static DiagnosedSilenceableFailure unpackPDLOperations( } } - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( @@ -1533,7 +1533,7 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( ArrayRef mixedTileSizes, Optional mapping, SmallVector &tileOps, SmallVector &tiledOps) { if (targets.empty()) - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); // getMixedNumThreads are OpFoldResults[index attributes or PDL operation]. // Convert to OpFoldResults[index attributes or payload op]. @@ -1577,7 +1577,7 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( tileOps.push_back(tilingResult->tileOp); tiledOps.push_back(tilingResult->tiledOp); } - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply( @@ -1604,7 +1604,7 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply( transformResults.set(getForeachThreadOp().cast(), tileOps); transformResults.set(getTiledOp().cast(), tiledOps); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } void transform::TileToForeachThreadOp::getEffects( @@ -1852,10 +1852,10 @@ transform::VectorizeOp::applyToOne(Operation *target, linalg::populatePadOpVectorizationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultDefiniteFailure(target); results.push_back(target); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 1d7a8b74ebe56..391164c76a07f 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -33,7 +33,7 @@ transform::MemRefMultiBufferOp::applyToOne(memref::AllocOp target, } results.push_back(newBuffer.value()); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 8777662ec3902..21deab6bc2a06 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -103,10 +103,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results, FailureOr outlined = outlineSingleBlockRegion( rewriter, location, exec.getRegion(), getFuncName(), &call); - if (failed(outlined)) { - (void)reportUnknownTransformError(target); - return DiagnosedSilenceableFailure::definiteFailure(); - } + if (failed(outlined)) + return emitDefaultDefiniteFailure(target); if (symbolTableOp) { SymbolTable &symbolTable = @@ -139,7 +137,7 @@ transform::LoopPeelOp::applyToOne(scf::ForOp target, scf::peelAndCanonicalizeForLoop(rewriter, target, result); // TODO: Return both the peeled loop and the remainder loop. results.push_back(failed(status) ? target : result); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// @@ -200,7 +198,7 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target, pattern.returningMatchAndRewrite(target, rewriter); if (succeeded(patternResult)) { results.push_back(*patternResult); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); @@ -225,7 +223,7 @@ transform::LoopUnrollOp::applyToOne(Operation *op, diag << "Op failed to unroll"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===//