From 876334321f842edadcc0cd4241c76b59bb888b9e Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 8 Oct 2023 08:46:43 -0700 Subject: [PATCH] [mlir][bufferization] Update empty_tensor_elimination transform op (#68497) The empty tensor elimination pass semantics have changed recently: when applied to a module, the One-Shot Module Analysis is run. Otherwise, the regular One-Shot Analysis is run. The latter one is slightly different because it ignores function boundaries and treats function block arguments as "read-only". This commit updates the transform dialect op to behave in the same way. --- .../Bufferization/Transforms/Transforms.h | 7 ++++++ .../BufferizationTransformOps.cpp | 11 ++------- .../Transforms/EmptyTensorElimination.cpp | 24 +++++++++---------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h index df866daf1ab1f..892675954493b 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h @@ -32,6 +32,13 @@ struct OneShotBufferizationOptions; /// In the above example, the subset op is "tensor.insert_slice". When tracing /// back the reverse use-def chain of a the source, we end up at a /// "tensor.empty" op. +LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op); + +/// Try to eliminate "tensor.empty" ops inside `op`. +/// +/// This function overload accepts an existing `OneShotAnalysisState`, which +/// contains in-place bufferization decisions. This overload is useful if an +/// existing analysis should be reused for empty tensor elimination. LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state); diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index b7db4917a4138..cbb36639a3383 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -117,17 +117,10 @@ void transform::EliminateEmptyTensorsOp::getEffects( DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( transform::TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { - OneShotBufferizationOptions options; - options.allowReturnAllocsFromLoops = true; - for (Operation *target : state.getPayloadOps(getTarget())) { - OneShotAnalysisState state(target, options); - if (failed(analyzeOp(target, state))) - return mlir::emitSilenceableFailure(target->getLoc()) - << "failed to analyze op"; - if (failed(bufferization::eliminateEmptyTensors(rewriter, target, state))) + if (failed(bufferization::eliminateEmptyTensors(rewriter, target))) return mlir::emitSilenceableFailure(target->getLoc()) - << "failed to eliminate insert_slice anchored tensor.empty ops"; + << "empty tensor elimination failed"; } return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index 2ddc51357448a..1a5a65bfac132 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -183,8 +183,8 @@ struct EmptyTensorElimination }; } // namespace -void EmptyTensorElimination::runOnOperation() { - Operation *op = getOperation(); +LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter, + Operation *op) { auto moduleOp = dyn_cast(op); OneShotBufferizationOptions options; options.allowReturnAllocsFromLoops = true; @@ -193,21 +193,21 @@ void EmptyTensorElimination::runOnOperation() { OneShotAnalysisState state(op, options); if (moduleOp) { // Module analysis takes into account function boundaries. - if (failed(analyzeModuleOp(moduleOp, state))) { - signalPassFailure(); - return; - } + if (failed(analyzeModuleOp(moduleOp, state))) + return failure(); } else { // Regular One-Shot Bufferize ignores func.func block arguments, func.call, // func.return. - if (failed(analyzeOp(op, state))) { - signalPassFailure(); - return; - } + if (failed(analyzeOp(op, state))) + return failure(); } - IRRewriter rewriter(op->getContext()); - if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state))) + return bufferization::eliminateEmptyTensors(rewriter, op, state); +} + +void EmptyTensorElimination::runOnOperation() { + IRRewriter rewriter(getOperation()->getContext()); + if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation()))) signalPassFailure(); }