Skip to content

Commit

Permalink
[MLIR][Bufferization] BufferResultsToOutParams: Add an option to elim…
Browse files Browse the repository at this point in the history
…inate AllocOp and avoid Copy (llvm#90011)

Add an option hoist-static-allocs to remove the unnecessary memref.alloc
and memref.copy after this pass, when the memref in ReturnOp is
allocated by memref.alloc and is statically shaped. Instead, it replaces
the uses of the allocated memref with the memref in the out argument.
By default, BufferResultsToOutParams will result in a memcpy operation
to copy the originally returned memref to the output argument memref.
This is inefficient when the source of memcpy (the returned memref in
the original ReturnOp) is from a local AllocOp. The pass can use the
output argument memref to replace the locally allocated memref for
better performance.hoist-static-allocs avoids dynamic allocation and
memory movement.
This option will be critical for performance-sensivtive applications,
which require BufferResultsToOutParams pass for a caller-owned output
buffer calling convension.
  • Loading branch information
Menooker committed May 8, 2024
1 parent bb01b89 commit 0af448b
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 6 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
Expand Up @@ -166,6 +166,10 @@ struct BufferResultsToOutParamsOpts {
/// If true, the pass adds a "bufferize.result" attribute to each output
/// parameter.
bool addResultAttribute = false;

/// If true, the pass eliminates the memref.alloc and memcpy if the returned
/// memref is allocated in the current function.
bool hoistStaticAllocs = false;
};

/// Creates a pass that converts memref function results to out-params.
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
Expand Up @@ -315,11 +315,20 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
The main issue with this pass (and the out-param calling convention) is that
buffers for results need to be allocated in the caller. This currently only
works for static shaped memrefs.

If the hoist-static-allocs option is on, the pass tries to eliminate the
allocation for the returned memref and avoid the memory-copy if possible.
This optimization applies on the returned memref which has static shape and
is allocated by memref.alloc in the function. It will use the memref given
in function argument to replace the allocated memref.
}];
let options = [
Option<"addResultAttribute", "add-result-attr", "bool",
/*default=*/"false",
"Add the attribute 'bufferize.result' to all output parameters.">,
Option<"hoistStaticAllocs", "hoist-static-allocs",
"bool", /*default=*/"false",
"Hoist static allocations to call sites.">,
];
let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()";
let dependentDialects = ["memref::MemRefDialect"];
Expand Down
Expand Up @@ -107,7 +107,8 @@ updateFuncOp(func::FuncOp func,
// the given out-params.
static LogicalResult updateReturnOps(func::FuncOp func,
ArrayRef<BlockArgument> appendedEntryArgs,
MemCpyFn memCpyFn) {
MemCpyFn memCpyFn,
bool hoistStaticAllocs) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
Expand All @@ -118,10 +119,15 @@ static LogicalResult updateReturnOps(func::FuncOp func,
keepAsReturnOperands.push_back(operand);
}
OpBuilder builder(op);
for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (failed(
memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
return WalkResult::interrupt();
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (hoistStaticAllocs && isa<memref::AllocOp>(orig.getDefiningOp()) &&
orig.getType().cast<MemRefType>().hasStaticShape()) {
orig.replaceAllUsesWith(arg);
orig.getDefiningOp()->erase();
} else {
if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
return WalkResult::interrupt();
}
}
builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
op.erase();
Expand Down Expand Up @@ -212,7 +218,8 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return success();
};
if (failed(updateReturnOps(func, appendedEntryArgs,
options.memCpyFn.value_or(defaultMemCpyFn)))) {
options.memCpyFn.value_or(defaultMemCpyFn),
options.hoistStaticAllocs))) {
return failure();
}
}
Expand All @@ -233,6 +240,8 @@ struct BufferResultsToOutParamsPass
// Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
if (addResultAttribute)
options.addResultAttribute = true;
if (hoistStaticAllocs)
options.hoistStaticAllocs = true;

if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
Expand Down
37 changes: 37 additions & 0 deletions mlir/test/Transforms/buffer-results-to-out-params-elim.mlir
@@ -0,0 +1,37 @@
// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-static-allocs})' %s | FileCheck %s

// CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: memref<8x64xf32>) {
// CHECK-NOT: memref.alloc()
// CHECK: "test.source"(%[[ARG]]) : (memref<8x64xf32>) -> ()
// CHECK: return
// CHECK: }
func.func @basic() -> (memref<8x64xf32>) {
%b = memref.alloc() : memref<8x64xf32>
"test.source"(%b) : (memref<8x64xf32>) -> ()
return %b : memref<8x64xf32>
}

// CHECK-LABEL: func @basic_no_change(
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) {
// CHECK: %[[RESULT:.*]] = "test.source"() : () -> memref<f32>
// CHECK: memref.copy %[[RESULT]], %[[ARG]] : memref<f32> to memref<f32>
// CHECK: return
// CHECK: }
func.func @basic_no_change() -> (memref<f32>) {
%0 = "test.source"() : () -> (memref<f32>)
return %0 : memref<f32>
}

// CHECK-LABEL: func @basic_dynamic(
// CHECK-SAME: %[[D:.*]]: index, %[[ARG:.*]]: memref<?xf32>) {
// CHECK: %[[RESULT:.*]] = memref.alloc(%[[D]]) : memref<?xf32>
// CHECK: "test.source"(%[[RESULT]]) : (memref<?xf32>) -> ()
// CHECK: memref.copy %[[RESULT]], %[[ARG]]
// CHECK: return
// CHECK: }
func.func @basic_dynamic(%d: index) -> (memref<?xf32>) {
%b = memref.alloc(%d) : memref<?xf32>
"test.source"(%b) : (memref<?xf32>) -> ()
return %b : memref<?xf32>
}

0 comments on commit 0af448b

Please sign in to comment.