From c1879c5ec1e78d24d5be3b4a3287b5f752c8c1af Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle <2020382038@qq.com> Date: Sat, 27 Sep 2025 06:06:33 +0000 Subject: [PATCH 1/7] Add hoist-static-allocs-option to buffer-results-to-out-params. --- .../Dialect/Bufferization/Transforms/Passes.h | 22 +++-- .../Bufferization/Transforms/Passes.td | 2 + .../Transforms/BufferResultsToOutParams.cpp | 92 +++++++++++++++++-- ...ts-to-out-params-hosit-dynamic-allocs.mlir | 79 ++++++++++++++++ ...ts-to-out-params-hosit-static-allocs.mlir} | 0 5 files changed, 181 insertions(+), 14 deletions(-) create mode 100644 mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir rename mlir/test/Transforms/{buffer-results-to-out-params-elim.mlir => buffer-results-to-out-params-hosit-static-allocs.mlir} (100%) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index a2409f2796b94..e413a5ede5d64 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -5,6 +5,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/MapVector.h" namespace mlir { class FunctionOpInterface; @@ -131,8 +132,8 @@ struct BufferResultsToOutParamsOpts { /// Allocator function: Generate a memref allocation with the given type. /// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped /// results, we don't allow passing a range of values for dynamic dims. - using AllocationFn = - std::function(OpBuilder &, Location, MemRefType)>; + using AllocationFn = std::function(OpBuilder &, Location, + MemRefType, ValueRange)>; /// Memcpy function: Generate a memcpy between two memrefs. using MemCpyFn = @@ -147,8 +148,9 @@ struct BufferResultsToOutParamsOpts { /// Allocation function; used to allocate a memref. /// Default memref.alloc is used AllocationFn allocationFn = [](OpBuilder &builder, Location loc, - MemRefType type) { - return memref::AllocOp::create(builder, loc, type).getResult(); + MemRefType type, ValueRange dynamicSizes) { + return memref::AllocOp::create(builder, loc, type, dynamicSizes) + .getResult(); }; /// Memcpy function; used to create a copy between two memrefs. @@ -164,15 +166,23 @@ struct BufferResultsToOutParamsOpts { bool addResultAttribute = false; /// If true, the pass eliminates the memref.alloc and memcpy if the returned - /// memref is allocated in the current function. + /// memref is static allocated in the current function. bool hoistStaticAllocs = false; + + /// If true, the pass eliminates the memref.alloc and memcpy if the returned + /// memref is dynamic allocated in the current function. + bool hoistDynamicAllocs = false; + + /// It maps the shape source of the dynamic shape memref returned by each + /// function. + llvm::DenseMap>> dynamicSizesMap; }; /// Replace buffers that are returned from a function with an out parameter. /// Also update all call sites. LogicalResult promoteBufferResultsToOutParams(ModuleOp module, - const BufferResultsToOutParamsOpts &options); + BufferResultsToOutParamsOpts &options); /// Drop all memref function results that are equivalent to a function argument. LogicalResult dropEquivalentBufferResults(ModuleOp module); diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td index a0d113c150c5e..cad44cb15f479 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -256,6 +256,8 @@ def BufferResultsToOutParamsPass "Add the attribute 'bufferize.result' to all output parameters.">, Option<"hoistStaticAllocs", "hoist-static-allocs", "bool", /*default=*/"false", "Hoist static allocations to call sites.">, + Option<"hoistDynamicAllocs", "hoist-dynamic-allocs", "bool", + /*default=*/"false", "Hoist dynamic allocations to call sites.">, ]; let dependentDialects = ["memref::MemRefDialect"]; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index e30e094c28467..ae68477f57a0d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -43,6 +43,52 @@ static bool hasStaticIdentityLayout(MemRefType type) { return type.getLayout().isIdentity(); } +/// Return the dynamic shapes of the `memref` based on the define op. If the +/// complete dynamic shape fails to be captured, return an empty value. +/// Currently, only function parameters are supported for capturing. +static ValueRange getDynamicSize(Value memref, func::FuncOp funcOp) { + auto *defOp = memref.getDefiningOp(); + if (!defOp) + return {}; + auto operands = defOp->getOperands(); + SmallVector dynamicSizes; + for (Value size : operands) { + BlockArgument sizeSrc = mlir::dyn_cast(size); + if (!sizeSrc) + return {}; + + bool finded = false; + for (BlockArgument argument : funcOp.getArguments()) { + if (argument == sizeSrc) { + dynamicSizes.push_back(argument); + finded = true; + break; + } + } + if (!finded) + return {}; + } + return dynamicSizes; +} + +/// Returns the dynamic sizes at the callee, through the call relationship +/// between the caller and callee. +static ValueRange mapDynamicSizeAtCaller(func::CallOp call, func::FuncOp callee, + ValueRange dynamicSizes) { + SmallVector mapedDynamicSizes; + for (Value size : dynamicSizes) { + auto callOperands = call.getOperands(); + for (size_t i = 0, e = callOperands.size(); i < e; ++i) { + Value src = callOperands[i]; + BlockArgument dst = callee.getArgument(i); + if (size != dst) + continue; + mapedDynamicSizes.push_back(src); + } + } + return mapedDynamicSizes; +} + // Updates the func op and entry block. // // Any args appended to the entry block are added to `appendedEntryArgs`. @@ -109,7 +155,7 @@ updateFuncOp(func::FuncOp func, // the given out-params. static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, - const bufferization::BufferResultsToOutParamsOpts &options) { + bufferization::BufferResultsToOutParamsOpts &options) { auto res = func.walk([&](func::ReturnOp op) { SmallVector copyIntoOutParams; SmallVector keepAsReturnOperands; @@ -120,12 +166,22 @@ updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, keepAsReturnOperands.push_back(operand); } OpBuilder builder(op); + SmallVector> dynamicSizes; for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) { - if (options.hoistStaticAllocs && + bool hoistStaticAllocs = + options.hoistStaticAllocs && + mlir::cast(orig.getType()).hasStaticShape(); + bool hoistDynamicAllocs = + options.hoistDynamicAllocs && + !mlir::cast(orig.getType()).hasStaticShape(); + if ((hoistStaticAllocs || hoistDynamicAllocs) && isa_and_nonnull( - orig.getDefiningOp()) && - mlir::cast(orig.getType()).hasStaticShape()) { + orig.getDefiningOp())) { orig.replaceAllUsesWith(arg); + if (hoistDynamicAllocs) { + SmallVector dynamicSize = getDynamicSize(orig, func); + dynamicSizes.push_back(dynamicSize); + } orig.getDefiningOp()->erase(); } else { if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg))) @@ -134,6 +190,10 @@ updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, } func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands); op.erase(); + auto dynamicSizePair = + std::pair>>(func, + dynamicSizes); + options.dynamicSizesMap.insert(dynamicSizePair); return WalkResult::advance(); }); return failure(res.wasInterrupted()); @@ -166,8 +226,16 @@ updateCalls(ModuleOp module, } SmallVector outParams; OpBuilder builder(op); + SmallVector> dynamicSizes = + options.dynamicSizesMap.lookup(callee); + size_t dynamicSizesIndex = 0; for (Value memref : replaceWithOutParams) { - if (!cast(memref.getType()).hasStaticShape()) { + ValueRange dynamicSize = dynamicSizes.size() > dynamicSizesIndex + ? dynamicSizes[dynamicSizesIndex] + : SmallVector(); + bool memrefStaticShape = + cast(memref.getType()).hasStaticShape(); + if (!memrefStaticShape && dynamicSize.empty()) { op.emitError() << "cannot create out param for dynamically shaped result"; didFail = true; @@ -177,8 +245,15 @@ updateCalls(ModuleOp module, auto allocType = MemRefType::get(memrefType.getShape(), memrefType.getElementType(), AffineMap(), memrefType.getMemorySpace()); + + if (memrefStaticShape) { + dynamicSize = {}; + } else { + ++dynamicSizesIndex; + dynamicSize = mapDynamicSizeAtCaller(op, callee, dynamicSize); + } auto maybeOutParam = - options.allocationFn(builder, op.getLoc(), allocType); + options.allocationFn(builder, op.getLoc(), allocType, dynamicSize); if (failed(maybeOutParam)) { op.emitError() << "failed to create allocation op"; didFail = true; @@ -211,8 +286,7 @@ updateCalls(ModuleOp module, } LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( - ModuleOp module, - const bufferization::BufferResultsToOutParamsOpts &options) { + ModuleOp module, bufferization::BufferResultsToOutParamsOpts &options) { for (auto func : module.getOps()) { if (!options.filterFn(&func)) continue; @@ -243,6 +317,8 @@ struct BufferResultsToOutParamsPass options.addResultAttribute = true; if (hoistStaticAllocs) options.hoistStaticAllocs = true; + if (hoistDynamicAllocs) + options.hoistDynamicAllocs = true; if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), options))) diff --git a/mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir b/mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir new file mode 100644 index 0000000000000..f33eb8e26fbce --- /dev/null +++ b/mlir/test/Transforms/buffer-results-to-out-params-hosit-dynamic-allocs.mlir @@ -0,0 +1,79 @@ +// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-dynamic-allocs})' %s -split-input-file | FileCheck %s + +func.func private @single_alloc(%size : index) -> (memref) { + %alloc = memref.alloc(%size) : memref + return %alloc : memref +} + +func.func @single_alloc_test(%size : index) { + %alloc = call @single_alloc(%size) : (index) -> (memref) + "test.sink"(%alloc) : (memref) -> () +} + +// CHECK-LABEL: func.func private @single_alloc( +// CHECK-SAME: %{{.*}}: index, +// CHECK-SAME: %{{.*}}: memref) { + +// CHECK-LABEL: func.func @single_alloc_test( +// CHECK-SAME: %[[size:.*]]: index) { +// CHECK: %[[alloc:.*]] = memref.alloc(%[[size]]) : memref +// CHECK: call @single_alloc(%[[size]], %[[alloc]]) : (index, memref) -> () +// CHECK: "test.sink"(%[[alloc]]) : (memref) -> () +// CHECK: } + +// ----- + +func.func private @mult_alloc(%size0 : index, %size1 : index) -> (memref, memref) { + %alloc0 = memref.alloc(%size0, %size1) : memref + %alloc1 = memref.alloc(%size1) : memref + return %alloc0, %alloc1 : memref, memref +} + +func.func @mult_alloc_test(%size0 : index, %size1: index) { + %alloc0, %alloc1 = call @mult_alloc(%size0, %size1) : (index, index) -> (memref, memref) + "test.sink"(%alloc0, %alloc1) : (memref, memref) -> () +} + +// CHECK-LABEL: func private @mult_alloc( +// CHECK-SAME: %{{.*}}: index, %{{.*}}: index, +// CHECK-SAME: %{{.*}}: memref, %{{.*}}: memref) { + +// CHECK-LABEL: func @mult_alloc_test( +// CHECK-SAME: %[[size0:.*]]: index, +// CHECK-SAME: %[[size1:.*]]: index) { +// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref +// CHECK: %[[alloc1:.*]] = memref.alloc(%[[size1]]) : memref +// CHECK: call @mult_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]]) : (index, index, memref, memref) -> () +// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]]) : (memref, memref) -> () +// CHECK: } + + +// ----- + +func.func private @complex_alloc(%size0 : index, %size1 : index) -> (memref, memref<4xf32>, memref) { + %alloc0 = memref.alloc(%size0, %size1) : memref + %alloc1 = memref.alloc() : memref<4xf32> + %alloc2 = memref.alloc(%size1) : memref + return %alloc0, %alloc1, %alloc2 : memref, memref<4xf32>, memref +} + +func.func @complex_alloc_test(%size0 : index, %size1: index) { + %alloc0, %alloc1, %alloc2 = call @complex_alloc(%size0, %size1) : (index, index) -> (memref, memref<4xf32>, memref) + "test.sink"(%alloc0, %alloc1, %alloc2) : (memref, memref<4xf32>, memref) -> () +} + +// CHECK-LABEL: func private @complex_alloc( +// CHECK-SAME: %{{.*}}: index, %{{.*}}: index, +// CHECK-SAME: %{{.*}}: memref, +// CHECK-SAME: %{{.*}}: memref<4xf32>, +// CHECK-SAME: %{{.*}}: memref) { + +// CHECK-LABEL: func @complex_alloc_test( +// CHECK-SAME: %[[size0:.*]]: index, +// CHECK-SAME: %[[size1:.*]]: index) { +// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref +// CHECK: %[[alloc1:.*]] = memref.alloc() : memref<4xf32> +// CHECK: %[[alloc2:.*]] = memref.alloc(%[[size1]]) : memref +// CHECK: call @complex_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]], %[[alloc2]]) : (index, index, memref, memref<4xf32>, memref) -> () +// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]], %[[alloc2]]) : (memref, memref<4xf32>, memref) -> () +// CHECK: } diff --git a/mlir/test/Transforms/buffer-results-to-out-params-elim.mlir b/mlir/test/Transforms/buffer-results-to-out-params-hosit-static-allocs.mlir similarity index 100% rename from mlir/test/Transforms/buffer-results-to-out-params-elim.mlir rename to mlir/test/Transforms/buffer-results-to-out-params-hosit-static-allocs.mlir From 1fe33cf8d4d3218cc7cb042005255fac17eb3d4e Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle <2020382038@qq.com> Date: Sat, 27 Sep 2025 06:15:19 +0000 Subject: [PATCH 2/7] clearup Passes.h --- mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index e413a5ede5d64..6ded148ce9d84 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -5,7 +5,6 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" -#include "llvm/ADT/MapVector.h" namespace mlir { class FunctionOpInterface; From 348496966fee2eb876fc1ef658cc7b2966c72cad Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle <2020382038@qq.com> Date: Sat, 27 Sep 2025 06:30:58 +0000 Subject: [PATCH 3/7] fix build problem --- .../Transforms/BufferResultsToOutParams.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index ae68477f57a0d..1160f4232172e 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -46,7 +46,7 @@ static bool hasStaticIdentityLayout(MemRefType type) { /// Return the dynamic shapes of the `memref` based on the define op. If the /// complete dynamic shape fails to be captured, return an empty value. /// Currently, only function parameters are supported for capturing. -static ValueRange getDynamicSize(Value memref, func::FuncOp funcOp) { +static SmallVector getDynamicSize(Value memref, func::FuncOp funcOp) { auto *defOp = memref.getDefiningOp(); if (!defOp) return {}; @@ -73,8 +73,9 @@ static ValueRange getDynamicSize(Value memref, func::FuncOp funcOp) { /// Returns the dynamic sizes at the callee, through the call relationship /// between the caller and callee. -static ValueRange mapDynamicSizeAtCaller(func::CallOp call, func::FuncOp callee, - ValueRange dynamicSizes) { +static SmallVector mapDynamicSizeAtCaller(func::CallOp call, + func::FuncOp callee, + ValueRange dynamicSizes) { SmallVector mapedDynamicSizes; for (Value size : dynamicSizes) { auto callOperands = call.getOperands(); @@ -230,9 +231,9 @@ updateCalls(ModuleOp module, options.dynamicSizesMap.lookup(callee); size_t dynamicSizesIndex = 0; for (Value memref : replaceWithOutParams) { - ValueRange dynamicSize = dynamicSizes.size() > dynamicSizesIndex - ? dynamicSizes[dynamicSizesIndex] - : SmallVector(); + SmallVector dynamicSize = dynamicSizes.size() > dynamicSizesIndex + ? dynamicSizes[dynamicSizesIndex] + : SmallVector(); bool memrefStaticShape = cast(memref.getType()).hasStaticShape(); if (!memrefStaticShape && dynamicSize.empty()) { From 100dfcc773b7b82014d27f78b84d4ef5827297b1 Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle <2020382038@qq.com> Date: Sat, 27 Sep 2025 11:41:39 +0000 Subject: [PATCH 4/7] update. --- .../Dialect/Bufferization/Transforms/Passes.h | 6 +-- .../Transforms/BufferResultsToOutParams.cpp | 54 ++++++++++--------- 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index 6ded148ce9d84..78bd33ff619ce 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -171,17 +171,13 @@ struct BufferResultsToOutParamsOpts { /// If true, the pass eliminates the memref.alloc and memcpy if the returned /// memref is dynamic allocated in the current function. bool hoistDynamicAllocs = false; - - /// It maps the shape source of the dynamic shape memref returned by each - /// function. - llvm::DenseMap>> dynamicSizesMap; }; /// Replace buffers that are returned from a function with an out parameter. /// Also update all call sites. LogicalResult promoteBufferResultsToOutParams(ModuleOp module, - BufferResultsToOutParamsOpts &options); + const BufferResultsToOutParamsOpts &options); /// Drop all memref function results that are equivalent to a function argument. LogicalResult dropEquivalentBufferResults(ModuleOp module); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index 1160f4232172e..a5a7b6222125d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -23,6 +23,8 @@ namespace bufferization { using namespace mlir; using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn; using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; +using AllocDynamicSizesMap = + llvm::DenseMap>>; /// Return `true` if the given MemRef type has a fully dynamic layout. static bool hasFullyDynamicLayoutMap(MemRefType type) { @@ -43,30 +45,24 @@ static bool hasStaticIdentityLayout(MemRefType type) { return type.getLayout().isIdentity(); } -/// Return the dynamic shapes of the `memref` based on the define op. If the +/// Return the dynamic shapes of the `memref` based on the defining op. If the /// complete dynamic shape fails to be captured, return an empty value. -/// Currently, only function parameters are supported for capturing. +/// Currently, only function block arguments are supported for capturing. static SmallVector getDynamicSize(Value memref, func::FuncOp funcOp) { - auto *defOp = memref.getDefiningOp(); + Operation *defOp = memref.getDefiningOp(); if (!defOp) return {}; auto operands = defOp->getOperands(); SmallVector dynamicSizes; for (Value size : operands) { - BlockArgument sizeSrc = mlir::dyn_cast(size); + BlockArgument sizeSrc = dyn_cast(size); if (!sizeSrc) return {}; - bool finded = false; - for (BlockArgument argument : funcOp.getArguments()) { - if (argument == sizeSrc) { - dynamicSizes.push_back(argument); - finded = true; - break; - } - } - if (!finded) + auto iter = llvm::find(funcOp.getArguments(), sizeSrc); + if (!iter) return {}; + dynamicSizes.push_back(*iter); } return dynamicSizes; } @@ -76,7 +72,7 @@ static SmallVector getDynamicSize(Value memref, func::FuncOp funcOp) { static SmallVector mapDynamicSizeAtCaller(func::CallOp call, func::FuncOp callee, ValueRange dynamicSizes) { - SmallVector mapedDynamicSizes; + SmallVector mappedDynamicSizes; for (Value size : dynamicSizes) { auto callOperands = call.getOperands(); for (size_t i = 0, e = callOperands.size(); i < e; ++i) { @@ -84,10 +80,12 @@ static SmallVector mapDynamicSizeAtCaller(func::CallOp call, BlockArgument dst = callee.getArgument(i); if (size != dst) continue; - mapedDynamicSizes.push_back(src); + mappedDynamicSizes.push_back(src); } } - return mapedDynamicSizes; + assert(mappedDynamicSizes.size() == dynamicSizes.size() && + "could not find all dynamic sizes"); + return mappedDynamicSizes; } // Updates the func op and entry block. @@ -156,7 +154,8 @@ updateFuncOp(func::FuncOp func, // the given out-params. static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, - bufferization::BufferResultsToOutParamsOpts &options) { + AllocDynamicSizesMap &map, + const bufferization::BufferResultsToOutParamsOpts &options) { auto res = func.walk([&](func::ReturnOp op) { SmallVector copyIntoOutParams; SmallVector keepAsReturnOperands; @@ -171,10 +170,10 @@ updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) { bool hoistStaticAllocs = options.hoistStaticAllocs && - mlir::cast(orig.getType()).hasStaticShape(); + cast(orig.getType()).hasStaticShape(); bool hoistDynamicAllocs = options.hoistDynamicAllocs && - !mlir::cast(orig.getType()).hasStaticShape(); + !cast(orig.getType()).hasStaticShape(); if ((hoistStaticAllocs || hoistDynamicAllocs) && isa_and_nonnull( orig.getDefiningOp())) { @@ -194,7 +193,7 @@ updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, auto dynamicSizePair = std::pair>>(func, dynamicSizes); - options.dynamicSizesMap.insert(dynamicSizePair); + map.insert(dynamicSizePair); return WalkResult::advance(); }); return failure(res.wasInterrupted()); @@ -203,7 +202,7 @@ updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, // Updates all CallOps in the scope of the given ModuleOp by allocating // temporary buffers for newly introduced out params. static LogicalResult -updateCalls(ModuleOp module, +updateCalls(ModuleOp module, AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options) { bool didFail = false; SymbolTable symtab(module); @@ -227,8 +226,7 @@ updateCalls(ModuleOp module, } SmallVector outParams; OpBuilder builder(op); - SmallVector> dynamicSizes = - options.dynamicSizesMap.lookup(callee); + SmallVector> dynamicSizes = map.lookup(callee); size_t dynamicSizesIndex = 0; for (Value memref : replaceWithOutParams) { SmallVector dynamicSize = dynamicSizes.size() > dynamicSizesIndex @@ -287,7 +285,11 @@ updateCalls(ModuleOp module, } LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( - ModuleOp module, bufferization::BufferResultsToOutParamsOpts &options) { + ModuleOp module, + const bufferization::BufferResultsToOutParamsOpts &options) { + /// It maps the shape source of the dynamic shape memref returned by each + /// function. + AllocDynamicSizesMap map; for (auto func : module.getOps()) { if (!options.filterFn(&func)) continue; @@ -297,11 +299,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( return failure(); if (func.isExternal()) continue; - if (failed(updateReturnOps(func, appendedEntryArgs, options))) { + if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) { return failure(); } } - if (failed(updateCalls(module, options))) + if (failed(updateCalls(module, map, options))) return failure(); return success(); } From c44b91ec4161d7a3a5f92193e8e2c5017ad26d45 Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle <2020382038@qq.com> Date: Sat, 27 Sep 2025 11:42:36 +0000 Subject: [PATCH 5/7] fix nit. --- .../Bufferization/Transforms/BufferResultsToOutParams.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index a5a7b6222125d..06f6acd0febc8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -287,8 +287,8 @@ updateCalls(ModuleOp module, AllocDynamicSizesMap &map, LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( ModuleOp module, const bufferization::BufferResultsToOutParamsOpts &options) { - /// It maps the shape source of the dynamic shape memref returned by each - /// function. + // It maps the shape source of the dynamic shape memref returned by each + // function. AllocDynamicSizesMap map; for (auto func : module.getOps()) { if (!options.filterFn(&func)) From bb63a6d0186daa2013a58bd7db6a8436939fef29 Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle <2020382038@qq.com> Date: Sun, 28 Sep 2025 02:50:25 +0000 Subject: [PATCH 6/7] fix nit. --- .../mlir/Dialect/Bufferization/Transforms/Passes.h | 4 ++-- .../Transforms/BufferResultsToOutParams.cpp | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index 78bd33ff619ce..67ac487d8226d 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -165,11 +165,11 @@ struct BufferResultsToOutParamsOpts { bool addResultAttribute = false; /// If true, the pass eliminates the memref.alloc and memcpy if the returned - /// memref is static allocated in the current function. + /// memref is allocated in the current function. bool hoistStaticAllocs = false; /// If true, the pass eliminates the memref.alloc and memcpy if the returned - /// memref is dynamic allocated in the current function. + /// memref is allocated in the current function and has dynamic shape. bool hoistDynamicAllocs = false; }; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index 06f6acd0febc8..aec54dfe7ceab 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -59,8 +59,9 @@ static SmallVector getDynamicSize(Value memref, func::FuncOp funcOp) { if (!sizeSrc) return {}; - auto iter = llvm::find(funcOp.getArguments(), sizeSrc); - if (!iter) + auto arguments = funcOp.getArguments(); + auto iter = llvm::find(arguments, sizeSrc); + if (iter == arguments.end()) return {}; dynamicSizes.push_back(*iter); } @@ -74,10 +75,8 @@ static SmallVector mapDynamicSizeAtCaller(func::CallOp call, ValueRange dynamicSizes) { SmallVector mappedDynamicSizes; for (Value size : dynamicSizes) { - auto callOperands = call.getOperands(); - for (size_t i = 0, e = callOperands.size(); i < e; ++i) { - Value src = callOperands[i]; - BlockArgument dst = callee.getArgument(i); + for (auto [src, dst] : + llvm::zip_first(call.getOperands(), callee.getArguments())) { if (size != dst) continue; mappedDynamicSizes.push_back(src); @@ -202,7 +201,7 @@ updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, // Updates all CallOps in the scope of the given ModuleOp by allocating // temporary buffers for newly introduced out params. static LogicalResult -updateCalls(ModuleOp module, AllocDynamicSizesMap &map, +updateCalls(ModuleOp module, const AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options) { bool didFail = false; SymbolTable symtab(module); From 6e4abeb2bf1020d88566ef8c12fe9f93a4e5250a Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle <2020382038@qq.com> Date: Sun, 28 Sep 2025 03:10:26 +0000 Subject: [PATCH 7/7] supoort realloc. --- .../Bufferization/Transforms/BufferResultsToOutParams.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index aec54dfe7ceab..25f941dc16516 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -55,10 +55,12 @@ static SmallVector getDynamicSize(Value memref, func::FuncOp funcOp) { auto operands = defOp->getOperands(); SmallVector dynamicSizes; for (Value size : operands) { + if (!isa(size.getType())) + continue; + BlockArgument sizeSrc = dyn_cast(size); if (!sizeSrc) return {}; - auto arguments = funcOp.getArguments(); auto iter = llvm::find(arguments, sizeSrc); if (iter == arguments.end())