diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index 67ac487d8226d..ea158914e445b 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -171,6 +171,9 @@ struct BufferResultsToOutParamsOpts { /// If true, the pass eliminates the memref.alloc and memcpy if the returned /// memref is allocated in the current function and has dynamic shape. bool hoistDynamicAllocs = false; + + /// If true, the pass modifies the function signatures of public functions. + bool modifyPublicFunctions = false; }; /// Replace buffers that are returned from a function with an out parameter. diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td index cad44cb15f479..1eb692586bcfc 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -258,6 +258,9 @@ def BufferResultsToOutParamsPass /*default=*/"false", "Hoist static allocations to call sites.">, Option<"hoistDynamicAllocs", "hoist-dynamic-allocs", "bool", /*default=*/"false", "Hoist dynamic allocations to call sites.">, + Option<"modifyPublicFunctions", "modify-public-functions", "bool", + /*default=*/"false", "Modify function signatures of public " + "functions.">, ]; let dependentDialects = ["memref::MemRefDialect"]; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index b9ee0a4d401f3..d0742ec27ed60 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -217,7 +217,9 @@ updateCalls(ModuleOp module, const AllocDynamicSizesMap &map, } if (!options.filterFn(&callee)) return; - if (callee.isExternal() || callee.isPublic()) + if (callee.isPublic() && !options.modifyPublicFunctions) + return; + if (callee.isExternal()) return; SmallVector replaceWithNewCallResults; @@ -295,7 +297,9 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( // function. AllocDynamicSizesMap map; for (auto func : module.getOps()) { - if (func.isExternal() || func.isPublic()) + if (func.isPublic() && !options.modifyPublicFunctions) + continue; + if (func.isExternal()) continue; if (!options.filterFn(&func)) continue; @@ -326,6 +330,8 @@ struct BufferResultsToOutParamsPass options.hoistStaticAllocs = true; if (hoistDynamicAllocs) options.hoistDynamicAllocs = true; + if (modifyPublicFunctions) + options.modifyPublicFunctions = true; if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), options))) diff --git a/mlir/test/Transforms/buffer-results-to-out-params-modify-public-functions.mlir b/mlir/test/Transforms/buffer-results-to-out-params-modify-public-functions.mlir new file mode 100644 index 0000000000000..c99bde3f34986 --- /dev/null +++ b/mlir/test/Transforms/buffer-results-to-out-params-modify-public-functions.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{modify-public-functions})' %s | FileCheck %s + +// Test if `public` functions' return values are transformed into out parameters +// when `buffer-results-to-out-params` is invoked with `modifyPublicFunctions`. + +// CHECK-LABEL: func.func @basic( +// CHECK-SAME: %[[ARG0:.*]]: memref) { +// CHECK: %[[VAL_0:.*]] = "test.source"() : () -> memref +// CHECK: memref.copy %[[VAL_0]], %[[ARG0]] : memref to memref +// CHECK: return +// CHECK: } +func.func @basic() -> (memref) { + %0 = "test.source"() : () -> (memref) + return %0 : memref +} + +// CHECK-LABEL: func.func @presence_of_existing_arguments( +// CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) { +// CHECK: %[[VAL_0:.*]] = "test.source"() : () -> memref<2xf32> +// CHECK: memref.copy %[[VAL_0]], %[[ARG1]] : memref<2xf32> to memref<2xf32> +// CHECK: return +// CHECK: } +func.func @presence_of_existing_arguments(%arg0: memref<1xf32>) -> (memref<2xf32>) { + %0 = "test.source"() : () -> (memref<2xf32>) + return %0 : memref<2xf32> +} + +// CHECK-LABEL: func.func @multiple_results( +// CHECK-SAME: %[[ARG0:.*]]: memref<1xf32>, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>) { +// CHECK: %[[VAL_0:.*]]:2 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>) +// CHECK: memref.copy %[[VAL_0]]#0, %[[ARG0]] : memref<1xf32> to memref<1xf32> +// CHECK: memref.copy %[[VAL_0]]#1, %[[ARG1]] : memref<2xf32> to memref<2xf32> +// CHECK: return +// CHECK: } +func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) { + %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>) + return %0, %1 : memref<1xf32>, memref<2xf32> +}