Skip to content

Commit

Permalink
[mlir][bufferization] Update empty_tensor_elimination transform op (#…
Browse files Browse the repository at this point in the history
…68497)

The empty tensor elimination pass semantics have changed recently: when
applied to a module, the One-Shot Module Analysis is run. Otherwise, the
regular One-Shot Analysis is run. The latter one is slightly different
because it ignores function boundaries and treats function block
arguments as "read-only".

This commit updates the transform dialect op to behave in the same way.
  • Loading branch information
matthias-springer committed Oct 8, 2023
1 parent 32f7197 commit 8763343
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ struct OneShotBufferizationOptions;
/// In the above example, the subset op is "tensor.insert_slice". When tracing
/// back the reverse use-def chain of a the source, we end up at a
/// "tensor.empty" op.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op);

/// Try to eliminate "tensor.empty" ops inside `op`.
///
/// This function overload accepts an existing `OneShotAnalysisState`, which
/// contains in-place bufferization decisions. This overload is useful if an
/// existing analysis should be reused for empty tensor elimination.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
OneShotAnalysisState &state);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,10 @@ void transform::EliminateEmptyTensorsOp::getEffects(
DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
transform::TransformRewriter &rewriter, TransformResults &transformResults,
TransformState &state) {
OneShotBufferizationOptions options;
options.allowReturnAllocsFromLoops = true;

for (Operation *target : state.getPayloadOps(getTarget())) {
OneShotAnalysisState state(target, options);
if (failed(analyzeOp(target, state)))
return mlir::emitSilenceableFailure(target->getLoc())
<< "failed to analyze op";
if (failed(bufferization::eliminateEmptyTensors(rewriter, target, state)))
if (failed(bufferization::eliminateEmptyTensors(rewriter, target)))
return mlir::emitSilenceableFailure(target->getLoc())
<< "failed to eliminate insert_slice anchored tensor.empty ops";
<< "empty tensor elimination failed";
}
return DiagnosedSilenceableFailure::success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ struct EmptyTensorElimination
};
} // namespace

void EmptyTensorElimination::runOnOperation() {
Operation *op = getOperation();
LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
Operation *op) {
auto moduleOp = dyn_cast<ModuleOp>(op);
OneShotBufferizationOptions options;
options.allowReturnAllocsFromLoops = true;
Expand All @@ -193,21 +193,21 @@ void EmptyTensorElimination::runOnOperation() {
OneShotAnalysisState state(op, options);
if (moduleOp) {
// Module analysis takes into account function boundaries.
if (failed(analyzeModuleOp(moduleOp, state))) {
signalPassFailure();
return;
}
if (failed(analyzeModuleOp(moduleOp, state)))
return failure();
} else {
// Regular One-Shot Bufferize ignores func.func block arguments, func.call,
// func.return.
if (failed(analyzeOp(op, state))) {
signalPassFailure();
return;
}
if (failed(analyzeOp(op, state)))
return failure();
}

IRRewriter rewriter(op->getContext());
if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state)))
return bufferization::eliminateEmptyTensors(rewriter, op, state);
}

void EmptyTensorElimination::runOnOperation() {
IRRewriter rewriter(getOperation()->getContext());
if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
signalPassFailure();
}

Expand Down

0 comments on commit 8763343

Please sign in to comment.