-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][bufferize] Add hoist-dynamic-allocs-option to buffer-results-to-out-params #160985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c1879c5
1fe33cf
3484969
100dfcc
c44b91e
bb63a6d
6e4abeb
72ce790
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,8 @@ namespace bufferization { | |
using namespace mlir; | ||
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn; | ||
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; | ||
using AllocDynamicSizesMap = | ||
llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>>; | ||
|
||
/// Return `true` if the given MemRef type has a fully dynamic layout. | ||
static bool hasFullyDynamicLayoutMap(MemRefType type) { | ||
|
@@ -43,6 +45,50 @@ static bool hasStaticIdentityLayout(MemRefType type) { | |
return type.getLayout().isIdentity(); | ||
} | ||
|
||
/// Return the dynamic shapes of the `memref` based on the defining op. If the | ||
/// complete dynamic shape fails to be captured, return an empty value. | ||
/// Currently, only function block arguments are supported for capturing. | ||
static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) { | ||
Operation *defOp = memref.getDefiningOp(); | ||
if (!defOp) | ||
return {}; | ||
auto operands = defOp->getOperands(); | ||
SmallVector<Value> dynamicSizes; | ||
for (Value size : operands) { | ||
if (!isa<IndexType>(size.getType())) | ||
continue; | ||
|
||
BlockArgument sizeSrc = dyn_cast<BlockArgument>(size); | ||
if (!sizeSrc) | ||
return {}; | ||
auto arguments = funcOp.getArguments(); | ||
auto iter = llvm::find(arguments, sizeSrc); | ||
if (iter == arguments.end()) | ||
return {}; | ||
dynamicSizes.push_back(*iter); | ||
} | ||
return dynamicSizes; | ||
} | ||
|
||
/// Returns the dynamic sizes at the callee, through the call relationship | ||
/// between the caller and callee. | ||
static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call, | ||
func::FuncOp callee, | ||
ValueRange dynamicSizes) { | ||
SmallVector<Value> mappedDynamicSizes; | ||
for (Value size : dynamicSizes) { | ||
for (auto [src, dst] : | ||
llvm::zip_first(call.getOperands(), callee.getArguments())) { | ||
if (size != dst) | ||
continue; | ||
mappedDynamicSizes.push_back(src); | ||
} | ||
} | ||
assert(mappedDynamicSizes.size() == dynamicSizes.size() && | ||
"could not find all dynamic sizes"); | ||
return mappedDynamicSizes; | ||
} | ||
|
||
// Updates the func op and entry block. | ||
// | ||
// Any args appended to the entry block are added to `appendedEntryArgs`. | ||
|
@@ -109,6 +155,7 @@ updateFuncOp(func::FuncOp func, | |
// the given out-params. | ||
static LogicalResult | ||
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs, | ||
AllocDynamicSizesMap &map, | ||
const bufferization::BufferResultsToOutParamsOpts &options) { | ||
auto res = func.walk([&](func::ReturnOp op) { | ||
SmallVector<Value, 6> copyIntoOutParams; | ||
|
@@ -120,12 +167,22 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs, | |
keepAsReturnOperands.push_back(operand); | ||
} | ||
OpBuilder builder(op); | ||
SmallVector<SmallVector<Value>> dynamicSizes; | ||
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) { | ||
if (options.hoistStaticAllocs && | ||
bool hoistStaticAllocs = | ||
options.hoistStaticAllocs && | ||
cast<MemRefType>(orig.getType()).hasStaticShape(); | ||
bool hoistDynamicAllocs = | ||
options.hoistDynamicAllocs && | ||
!cast<MemRefType>(orig.getType()).hasStaticShape(); | ||
if ((hoistStaticAllocs || hoistDynamicAllocs) && | ||
isa_and_nonnull<bufferization::AllocationOpInterface>( | ||
orig.getDefiningOp()) && | ||
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) { | ||
orig.getDefiningOp())) { | ||
orig.replaceAllUsesWith(arg); | ||
if (hoistDynamicAllocs) { | ||
SmallVector<Value> dynamicSize = getDynamicSize(orig, func); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens when the sizes could not be captured? You already performed the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
can't find dynamic size case.
|
||
dynamicSizes.push_back(dynamicSize); | ||
} | ||
orig.getDefiningOp()->erase(); | ||
} else { | ||
if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg))) | ||
|
@@ -134,6 +191,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs, | |
} | ||
func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands); | ||
op.erase(); | ||
auto dynamicSizePair = | ||
std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func, | ||
dynamicSizes); | ||
map.insert(dynamicSizePair); | ||
return WalkResult::advance(); | ||
}); | ||
return failure(res.wasInterrupted()); | ||
|
@@ -142,7 +203,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs, | |
// Updates all CallOps in the scope of the given ModuleOp by allocating | ||
// temporary buffers for newly introduced out params. | ||
static LogicalResult | ||
updateCalls(ModuleOp module, | ||
updateCalls(ModuleOp module, const AllocDynamicSizesMap &map, | ||
const bufferization::BufferResultsToOutParamsOpts &options) { | ||
bool didFail = false; | ||
SymbolTable symtab(module); | ||
|
@@ -166,8 +227,15 @@ updateCalls(ModuleOp module, | |
} | ||
SmallVector<Value, 6> outParams; | ||
OpBuilder builder(op); | ||
SmallVector<SmallVector<Value>> dynamicSizes = map.lookup(callee); | ||
size_t dynamicSizesIndex = 0; | ||
for (Value memref : replaceWithOutParams) { | ||
if (!cast<MemRefType>(memref.getType()).hasStaticShape()) { | ||
SmallVector<Value> dynamicSize = dynamicSizes.size() > dynamicSizesIndex | ||
? dynamicSizes[dynamicSizesIndex] | ||
: SmallVector<Value>(); | ||
bool memrefStaticShape = | ||
cast<MemRefType>(memref.getType()).hasStaticShape(); | ||
if (!memrefStaticShape && dynamicSize.empty()) { | ||
op.emitError() | ||
<< "cannot create out param for dynamically shaped result"; | ||
didFail = true; | ||
|
@@ -177,8 +245,15 @@ updateCalls(ModuleOp module, | |
auto allocType = | ||
MemRefType::get(memrefType.getShape(), memrefType.getElementType(), | ||
AffineMap(), memrefType.getMemorySpace()); | ||
|
||
if (memrefStaticShape) { | ||
dynamicSize = {}; | ||
} else { | ||
++dynamicSizesIndex; | ||
dynamicSize = mapDynamicSizeAtCaller(op, callee, dynamicSize); | ||
} | ||
auto maybeOutParam = | ||
options.allocationFn(builder, op.getLoc(), allocType); | ||
options.allocationFn(builder, op.getLoc(), allocType, dynamicSize); | ||
if (failed(maybeOutParam)) { | ||
op.emitError() << "failed to create allocation op"; | ||
didFail = true; | ||
|
@@ -213,6 +288,9 @@ updateCalls(ModuleOp module, | |
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( | ||
ModuleOp module, | ||
const bufferization::BufferResultsToOutParamsOpts &options) { | ||
// It maps the shape source of the dynamic shape memref returned by each | ||
// function. | ||
AllocDynamicSizesMap map; | ||
for (auto func : module.getOps<func::FuncOp>()) { | ||
if (!options.filterFn(&func)) | ||
continue; | ||
|
@@ -222,11 +300,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( | |
return failure(); | ||
if (func.isExternal()) | ||
continue; | ||
if (failed(updateReturnOps(func, appendedEntryArgs, options))) { | ||
if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) { | ||
return failure(); | ||
} | ||
} | ||
if (failed(updateCalls(module, options))) | ||
if (failed(updateCalls(module, map, options))) | ||
return failure(); | ||
return success(); | ||
} | ||
|
@@ -243,6 +321,8 @@ struct BufferResultsToOutParamsPass | |
options.addResultAttribute = true; | ||
if (hoistStaticAllocs) | ||
options.hoistStaticAllocs = true; | ||
if (hoistDynamicAllocs) | ||
options.hoistDynamicAllocs = true; | ||
|
||
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), | ||
options))) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-dynamic-allocs})' %s -split-input-file | FileCheck %s | ||
|
||
func.func private @single_alloc(%size : index) -> (memref<?xf32>) { | ||
%alloc = memref.alloc(%size) : memref<?xf32> | ||
return %alloc : memref<?xf32> | ||
} | ||
|
||
func.func @single_alloc_test(%size : index) { | ||
%alloc = call @single_alloc(%size) : (index) -> (memref<?xf32>) | ||
"test.sink"(%alloc) : (memref<?xf32>) -> () | ||
} | ||
|
||
// CHECK-LABEL: func.func private @single_alloc( | ||
// CHECK-SAME: %{{.*}}: index, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should size args be removed from callee when they aren't used after hoisting? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have already considered this issue. This issue has been resolved. #160755 |
||
// CHECK-SAME: %{{.*}}: memref<?xf32>) { | ||
|
||
// CHECK-LABEL: func.func @single_alloc_test( | ||
// CHECK-SAME: %[[size:.*]]: index) { | ||
// CHECK: %[[alloc:.*]] = memref.alloc(%[[size]]) : memref<?xf32> | ||
// CHECK: call @single_alloc(%[[size]], %[[alloc]]) : (index, memref<?xf32>) -> () | ||
// CHECK: "test.sink"(%[[alloc]]) : (memref<?xf32>) -> () | ||
// CHECK: } | ||
|
||
// ----- | ||
|
||
func.func private @mult_alloc(%size0 : index, %size1 : index) -> (memref<?x?xf32>, memref<?xf32>) { | ||
%alloc0 = memref.alloc(%size0, %size1) : memref<?x?xf32> | ||
%alloc1 = memref.alloc(%size1) : memref<?xf32> | ||
return %alloc0, %alloc1 : memref<?x?xf32>, memref<?xf32> | ||
} | ||
|
||
func.func @mult_alloc_test(%size0 : index, %size1: index) { | ||
%alloc0, %alloc1 = call @mult_alloc(%size0, %size1) : (index, index) -> (memref<?x?xf32>, memref<?xf32>) | ||
"test.sink"(%alloc0, %alloc1) : (memref<?x?xf32>, memref<?xf32>) -> () | ||
} | ||
|
||
// CHECK-LABEL: func private @mult_alloc( | ||
// CHECK-SAME: %{{.*}}: index, %{{.*}}: index, | ||
// CHECK-SAME: %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?xf32>) { | ||
|
||
// CHECK-LABEL: func @mult_alloc_test( | ||
// CHECK-SAME: %[[size0:.*]]: index, | ||
// CHECK-SAME: %[[size1:.*]]: index) { | ||
// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xf32> | ||
// CHECK: %[[alloc1:.*]] = memref.alloc(%[[size1]]) : memref<?xf32> | ||
// CHECK: call @mult_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]]) : (index, index, memref<?x?xf32>, memref<?xf32>) -> () | ||
// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]]) : (memref<?x?xf32>, memref<?xf32>) -> () | ||
// CHECK: } | ||
|
||
|
||
// ----- | ||
|
||
func.func private @complex_alloc(%size0 : index, %size1 : index) -> (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) { | ||
%alloc0 = memref.alloc(%size0, %size1) : memref<?x?xf32> | ||
%alloc1 = memref.alloc() : memref<4xf32> | ||
%alloc2 = memref.alloc(%size1) : memref<?xf32> | ||
return %alloc0, %alloc1, %alloc2 : memref<?x?xf32>, memref<4xf32>, memref<?xf32> | ||
} | ||
|
||
func.func @complex_alloc_test(%size0 : index, %size1: index) { | ||
%alloc0, %alloc1, %alloc2 = call @complex_alloc(%size0, %size1) : (index, index) -> (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) | ||
"test.sink"(%alloc0, %alloc1, %alloc2) : (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> () | ||
} | ||
|
||
// CHECK-LABEL: func private @complex_alloc( | ||
// CHECK-SAME: %{{.*}}: index, %{{.*}}: index, | ||
// CHECK-SAME: %{{.*}}: memref<?x?xf32>, | ||
// CHECK-SAME: %{{.*}}: memref<4xf32>, | ||
// CHECK-SAME: %{{.*}}: memref<?xf32>) { | ||
|
||
// CHECK-LABEL: func @complex_alloc_test( | ||
// CHECK-SAME: %[[size0:.*]]: index, | ||
// CHECK-SAME: %[[size1:.*]]: index) { | ||
// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xf32> | ||
// CHECK: %[[alloc1:.*]] = memref.alloc() : memref<4xf32> | ||
// CHECK: %[[alloc2:.*]] = memref.alloc(%[[size1]]) : memref<?xf32> | ||
// CHECK: call @complex_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]], %[[alloc2]]) : (index, index, memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> () | ||
// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]], %[[alloc2]]) : (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> () | ||
// CHECK: } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe add a "no-op" case where it's impossible to hoist? like when a dynamic size is defined inside the callee func. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we could introduce such examples in subsequent PRs, such as support for conatsant Op.I believe it is also acceptable to introduce such an example at this point. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. although i guess in some cases it would be possible, but non-trivial. but it doesn't look like this option handles that case so would be good to track that behavior in a test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
sure. i'm not suggesting handling this case now. but more to show this pass won't break in that case in the meantime. maybe i'm just being too cautious though |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if your implementation also works with
memref.realloc
, which implements this interface. The first operand is not a size, but I think it does not matter.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
run pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
run pass
Do I need to add realloc test for realloc?