diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index ea158914e445b..90857358437a1 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -182,8 +182,15 @@ LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOpts &options); +/// Options for dropping equivalent memref buffer results. +struct DropBufferResultsOpts { + /// If true, signatures of public functions are modified. + bool modifyPublicFunctions = false; +}; + /// Drop all memref function results that are equivalent to a function argument. -LogicalResult dropEquivalentBufferResults(ModuleOp module); +LogicalResult dropEquivalentBufferResults( + ModuleOp module, DropBufferResultsOpts options = DropBufferResultsOpts()); /// Creates a pass that promotes heap-based allocations to stack-based ones. /// Only buffers smaller with `isSmallAlloc(alloc) == true` are promoted. diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td index 1eb692586bcfc..cd28bd6cf73a5 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -276,6 +276,11 @@ def DropEquivalentBufferResultsPass Note: If a bbArg buffer is not returned directly but casted to beforehand, the buffer is still considered equivalent. }]; + let options = [ + Option<"modifyPublicFunctions", "modify-public-functions", "bool", + /*default=*/"false", "Modify function signatures of public " + "functions.">, + ]; let dependentDialects = ["memref::MemRefDialect"]; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index bc1799099de31..e724af312652d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -74,8 +74,8 @@ static bool operandsEqualFuncArgument(ArrayRef operands, return true; } -LogicalResult -mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { +LogicalResult mlir::bufferization::dropEquivalentBufferResults( + ModuleOp module, DropBufferResultsOpts options) { IRRewriter rewriter(module.getContext()); DenseMap> callerMap; @@ -83,13 +83,18 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { module.walk([&](func::CallOp callOp) { if (func::FuncOp calledFunc = dyn_cast_or_null(callOp.resolveCallable())) { - if (!calledFunc.isPublic() && !calledFunc.isExternal()) + if (calledFunc.isPublic() && !options.modifyPublicFunctions) + return WalkResult::advance(); + if (!calledFunc.isExternal()) callerMap[calledFunc].insert(callOp); } + return WalkResult::advance(); }); for (auto funcOp : module.getOps()) { - if (funcOp.isExternal() || funcOp.isPublic()) + if (funcOp.isPublic() && !options.modifyPublicFunctions) + continue; + if (funcOp.isExternal()) continue; SmallVector returnOps = getReturnOps(funcOp); if (returnOps.empty()) @@ -166,9 +171,18 @@ namespace { struct DropEquivalentBufferResultsPass : bufferization::impl::DropEquivalentBufferResultsPassBase< DropEquivalentBufferResultsPass> { + using Base::Base; + void runOnOperation() override { - if (failed(bufferization::dropEquivalentBufferResults(getOperation()))) + // Convert pass options. + options.modifyPublicFunctions = modifyPublicFunctions; + + if (failed(bufferization::dropEquivalentBufferResults(getOperation(), + options))) return signalPassFailure(); } + +private: + bufferization::DropBufferResultsOpts options; }; } // namespace diff --git a/mlir/test/Dialect/Bufferization/Transforms/drop-equivalent-buffer-results.mlir b/mlir/test/Dialect/Bufferization/Transforms/drop-equivalent-buffer-results.mlir new file mode 100644 index 0000000000000..b20188af43bf5 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/drop-equivalent-buffer-results.mlir @@ -0,0 +1,106 @@ +// RUN: mlir-opt -drop-equivalent-buffer-results -split-input-file %s | FileCheck %s +// RUN: mlir-opt -drop-equivalent-buffer-results=modify-public-functions=1 -split-input-file %s | \ +// RUN: FileCheck %s --check-prefix=MODIFY-PUBLIC + + +// CHECK-LABEL: func private @single_buffer_return({{.*}}) { +// CHECK: return + +!type = memref> +func.func private @single_buffer_return(%buf: !type, %val: f32, %idx: index) -> !type { + memref.store %val, %buf[%idx] : !type + return %buf : !type +} + +// CHECK-LABEL: func @caller( +// CHECK-SAME: %[[BUF:.+]]: memref>, +// CHECK: call @single_buffer_return(%[[BUF]]{{.*}}-> () +// CHECK: %[[LOADED:.+]] = memref.load %[[BUF]] +// CHECK: return %[[LOADED]] + +func.func @caller(%buf: !type, %val: f32, %idx: index) -> f32 { + %0 = call @single_buffer_return(%buf, %val, %idx) : (!type, f32, index) -> (!type) + %1 = memref.load %0[%idx] : !type + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: func private @multiple_buffer_returns({{.*}}) { +// CHECK: return + +!type = memref> +!type1 = memref +func.func private @multiple_buffer_returns( + %buf: !type, %buf1: !type1, %val: f32, %idx: index) -> (!type1, !type) { + memref.store %val, %buf[%idx] : !type + memref.store %val, %buf1[%idx, %idx] : !type1 + return %buf1, %buf : !type1, !type +} + +// ----- + +// CHECK-LABEL: func private @multiple_mixed_returns({{.*}}) -> i32 { +// CHECK: %[[CST:.+]] = arith.constant 1 : i32 +// CHECK: return %[[CST]] : i32 + +!type = memref> +!type1 = memref +func.func private @multiple_mixed_returns( + %buf: !type, %buf1: !type1, %val: f32, %idx: index) -> (!type1, i32, !type) { + memref.store %val, %buf[%idx] : !type + memref.store %val, %buf1[%idx, %idx] : !type1 + %cst = arith.constant 1 : i32 + return %buf1, %cst, %buf : !type1, i32, !type +} + +// ----- + +// Ensure public functions remain unchanged by default. +// CHECK-LABEL: func @public_function( +// CHECK-SAME: %[[BUF:.+]]: memref>, +// CHECK-SAME: ) -> memref> { +// CHECK: return %[[BUF]] + +// When explicitly requested, public functions can be modified. +// MODIFY-PUBLIC-LABEL: func @public_function( +// MODIFY-PUBLIC-SAME: %[[BUF:.+]]: memref>, +// MODIFY-PUBLIC-SAME: ) { +// MODIFY-PUBLIC: return + +!type = memref> +func.func @public_function( + %buf: !type, %val: f32, %idx: index) -> !type { + memref.store %val, %buf[%idx] : !type + return %buf : !type +} + +// CHECK-LABEL: func @caller( +// CHECK-SAME: %[[IN_BUF:.+]]: memref>, +// CHECK: %[[RET_VAL:.+]] = call @public_function(%[[IN_BUF]]{{.*}}-> memref +// CHECK: %[[LOADED:.+]] = memref.load %[[RET_VAL]] +// CHECK: return %[[LOADED]] + +// MODIFY-PUBLIC-LABEL: func @caller( +// MODIFY-PUBLIC-SAME: %[[IN_BUF:.+]]: memref>, +// MODIFY-PUBLIC: call @public_function(%[[IN_BUF]]{{.*}}-> () +// MODIFY-PUBLIC: %[[LOADED:.*]] = memref.load %[[IN_BUF]] +// MODIFY-PUBLIC: return %[[LOADED]] + +func.func @caller(%buf: !type, %val: f32, %idx: index) -> f32 { + %0 = call @public_function(%buf, %val, %idx) : (!type, f32, index) -> (!type) + %1 = memref.load %0[%idx] : !type + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: func private @negative_external_function( +// CHECK-SAME: -> memref> + +// Ensure external function remains unchanged. +// MODIFY-PUBLIC-LABEL: func private @negative_external_function( +// MODIFY-PUBLIC-SAME: -> memref> + +!type = memref> +func.func private @negative_external_function(%arg0: !type) -> !type