Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ struct BufferResultsToOutParamsOpts {
/// Allocator function: Generate a memref allocation with the given type.
/// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
/// results, we don't allow passing a range of values for dynamic dims.
using AllocationFn =
std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
using AllocationFn = std::function<FailureOr<Value>(OpBuilder &, Location,
MemRefType, ValueRange)>;

/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
Expand All @@ -147,8 +147,9 @@ struct BufferResultsToOutParamsOpts {
/// Allocation function; used to allocate a memref.
/// Default memref.alloc is used
AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
MemRefType type) {
return memref::AllocOp::create(builder, loc, type).getResult();
MemRefType type, ValueRange dynamicSizes) {
return memref::AllocOp::create(builder, loc, type, dynamicSizes)
.getResult();
};

/// Memcpy function; used to create a copy between two memrefs.
Expand All @@ -166,6 +167,10 @@ struct BufferResultsToOutParamsOpts {
/// If true, the pass eliminates the memref.alloc and memcpy if the returned
/// memref is allocated in the current function.
bool hoistStaticAllocs = false;

/// If true, the pass eliminates the memref.alloc and memcpy if the returned
/// memref is allocated in the current function and has dynamic shape.
bool hoistDynamicAllocs = false;
};

/// Replace buffers that are returned from a function with an out parameter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def BufferResultsToOutParamsPass
"Add the attribute 'bufferize.result' to all output parameters.">,
Option<"hoistStaticAllocs", "hoist-static-allocs", "bool",
/*default=*/"false", "Hoist static allocations to call sites.">,
Option<"hoistDynamicAllocs", "hoist-dynamic-allocs", "bool",
/*default=*/"false", "Hoist dynamic allocations to call sites.">,
];
let dependentDialects = ["memref::MemRefDialect"];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ namespace bufferization {
using namespace mlir;
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
using AllocDynamicSizesMap =
llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>>;

/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
Expand All @@ -43,6 +45,50 @@ static bool hasStaticIdentityLayout(MemRefType type) {
return type.getLayout().isIdentity();
}

/// Return the dynamic shapes of the `memref` based on the defining op. If the
/// complete dynamic shape fails to be captured, return an empty value.
/// Currently, only function block arguments are supported for capturing.
static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
Operation *defOp = memref.getDefiningOp();
if (!defOp)
return {};
auto operands = defOp->getOperands();
SmallVector<Value> dynamicSizes;
for (Value size : operands) {
if (!isa<IndexType>(size.getType()))
continue;

BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
if (!sizeSrc)
return {};
auto arguments = funcOp.getArguments();
auto iter = llvm::find(arguments, sizeSrc);
if (iter == arguments.end())
return {};
dynamicSizes.push_back(*iter);
}
return dynamicSizes;
}

/// Returns the dynamic sizes at the callee, through the call relationship
/// between the caller and callee.
static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call,
func::FuncOp callee,
ValueRange dynamicSizes) {
SmallVector<Value> mappedDynamicSizes;
for (Value size : dynamicSizes) {
for (auto [src, dst] :
llvm::zip_first(call.getOperands(), callee.getArguments())) {
if (size != dst)
continue;
mappedDynamicSizes.push_back(src);
}
}
assert(mappedDynamicSizes.size() == dynamicSizes.size() &&
"could not find all dynamic sizes");
return mappedDynamicSizes;
}

// Updates the func op and entry block.
//
// Any args appended to the entry block are added to `appendedEntryArgs`.
Expand Down Expand Up @@ -109,6 +155,7 @@ updateFuncOp(func::FuncOp func,
// the given out-params.
static LogicalResult
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
AllocDynamicSizesMap &map,
const bufferization::BufferResultsToOutParamsOpts &options) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
Expand All @@ -120,12 +167,22 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
keepAsReturnOperands.push_back(operand);
}
OpBuilder builder(op);
SmallVector<SmallVector<Value>> dynamicSizes;
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (options.hoistStaticAllocs &&
bool hoistStaticAllocs =
options.hoistStaticAllocs &&
cast<MemRefType>(orig.getType()).hasStaticShape();
bool hoistDynamicAllocs =
options.hoistDynamicAllocs &&
!cast<MemRefType>(orig.getType()).hasStaticShape();
if ((hoistStaticAllocs || hoistDynamicAllocs) &&
isa_and_nonnull<bufferization::AllocationOpInterface>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if your implementation also works with memref.realloc, which implements this interface. The first operand is not a size, but I think it does not matter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

func.func private @realloc(%memref : memref<?xf32>, %d:index) -> memref<?xf32> {
  %alloc = memref.realloc %memref(%d) : memref<?xf32> to memref<?xf32>
  return %alloc : memref<?xf32>
}

func.func private @main(%d:index) {
  %c1 = arith.constant 1 : index
  %alloc = memref.alloc(%c1) : memref<?xf32>
  func.call @realloc(%alloc, %d) : (memref<?xf32>, index) -> (memref<?xf32>)
  return
}

run pass

a.mlir:9:3: error: 'memref.alloc' op operand #0 must be variadic of index, but got 'memref<?xf32>'
  func.call @realloc(%alloc, %d) : (memref<?xf32>, index) -> (memref<?xf32>)
  ^
a.mlir:9:3: note: see current operation: %2 = "memref.alloc"(%1, %arg0) <{operandSegmentSizes = array<i32: 2, 0>}> : (memref<?xf32>, index) -> memref<?xf32>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

func.func private @realloc(%memref : memref<?xf32>, %d:index) -> memref<?xf32> {
  %alloc = memref.realloc %memref(%d) : memref<?xf32> to memref<?xf32>
  return %alloc : memref<?xf32>
}

func.func private @main(%d:index) {
  %c1 = arith.constant 1 : index
  %alloc = memref.alloc(%c1) : memref<?xf32>
  func.call @realloc(%alloc, %d) : (memref<?xf32>, index) -> (memref<?xf32>)
  return
}

run pass

module {
  func.func private @realloc(%arg0: memref<?xf32>, %arg1: index, %arg2: memref<?xf32>) {
    return
  }
  func.func private @main(%arg0: index) {
    %c1 = arith.constant 1 : index
    %alloc = memref.alloc(%c1) : memref<?xf32>
    %alloc_0 = memref.alloc(%arg0) : memref<?xf32>
    call @realloc(%alloc, %arg0, %alloc_0) : (memref<?xf32>, index, memref<?xf32>) -> ()
    return
  }
}

Do I need to add realloc test for realloc?

orig.getDefiningOp()) &&
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
orig.getDefiningOp())) {
orig.replaceAllUsesWith(arg);
if (hoistDynamicAllocs) {
SmallVector<Value> dynamicSize = getDynamicSize(orig, func);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when the sizes could not be captured? You already performed the replaceAllUsesWith. Is that correct? Should the hoisting of the value be skipped instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

func.func  @foo() -> memref<?xf32> {
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %f1 = arith.constant 1.0 : f32
  %alloc = memref.alloc(%c1) : memref<?xf32>
  memref.store %f1, %alloc[%c0] : memref<?xf32>
  return %alloc : memref<?xf32>
}

run the pass

module {
  func.func private @foo(%arg0: memref<?xf32>) {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant 1.000000e+00 : f32
    memref.store %cst, %arg0[%c0] : memref<?xf32>
    return
  }
}

can't find dynamic size case.

func.func private @foo() -> memref<?xf32> {
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %f1 = arith.constant 1.0 : f32
  %alloc = memref.alloc(%c1) : memref<?xf32>
  memref.store %f1, %alloc[%c0] : memref<?xf32>
  return %alloc : memref<?xf32>
}

func.func @main() {
  func.call @foo() : () -> memref<?xf32>
  return
}

run the pass.
a.mlir:11:3: error: cannot create out param for dynamically shaped result
  func.call @foo() : () -> memref<?xf32>
  ^
a.mlir:11:3: note: see current operation: %0 = "func.call"() <{callee = @foo}> : () -> memref<?xf32>

dynamicSizes.push_back(dynamicSize);
}
orig.getDefiningOp()->erase();
} else {
if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
Expand All @@ -134,6 +191,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
}
func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
op.erase();
auto dynamicSizePair =
std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
dynamicSizes);
map.insert(dynamicSizePair);
return WalkResult::advance();
});
return failure(res.wasInterrupted());
Expand All @@ -142,7 +203,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
// Updates all CallOps in the scope of the given ModuleOp by allocating
// temporary buffers for newly introduced out params.
static LogicalResult
updateCalls(ModuleOp module,
updateCalls(ModuleOp module, const AllocDynamicSizesMap &map,
const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
Expand All @@ -166,8 +227,15 @@ updateCalls(ModuleOp module,
}
SmallVector<Value, 6> outParams;
OpBuilder builder(op);
SmallVector<SmallVector<Value>> dynamicSizes = map.lookup(callee);
size_t dynamicSizesIndex = 0;
for (Value memref : replaceWithOutParams) {
if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
SmallVector<Value> dynamicSize = dynamicSizes.size() > dynamicSizesIndex
? dynamicSizes[dynamicSizesIndex]
: SmallVector<Value>();
bool memrefStaticShape =
cast<MemRefType>(memref.getType()).hasStaticShape();
if (!memrefStaticShape && dynamicSize.empty()) {
op.emitError()
<< "cannot create out param for dynamically shaped result";
didFail = true;
Expand All @@ -177,8 +245,15 @@ updateCalls(ModuleOp module,
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());

if (memrefStaticShape) {
dynamicSize = {};
} else {
++dynamicSizesIndex;
dynamicSize = mapDynamicSizeAtCaller(op, callee, dynamicSize);
}
auto maybeOutParam =
options.allocationFn(builder, op.getLoc(), allocType);
options.allocationFn(builder, op.getLoc(), allocType, dynamicSize);
if (failed(maybeOutParam)) {
op.emitError() << "failed to create allocation op";
didFail = true;
Expand Down Expand Up @@ -213,6 +288,9 @@ updateCalls(ModuleOp module,
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
ModuleOp module,
const bufferization::BufferResultsToOutParamsOpts &options) {
// It maps the shape source of the dynamic shape memref returned by each
// function.
AllocDynamicSizesMap map;
for (auto func : module.getOps<func::FuncOp>()) {
if (!options.filterFn(&func))
continue;
Expand All @@ -222,11 +300,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
if (func.isExternal())
continue;
if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) {
return failure();
}
}
if (failed(updateCalls(module, options)))
if (failed(updateCalls(module, map, options)))
return failure();
return success();
}
Expand All @@ -243,6 +321,8 @@ struct BufferResultsToOutParamsPass
options.addResultAttribute = true;
if (hoistStaticAllocs)
options.hoistStaticAllocs = true;
if (hoistDynamicAllocs)
options.hoistDynamicAllocs = true;

if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// RUN: mlir-opt -allow-unregistered-dialect -p 'builtin.module(buffer-results-to-out-params{hoist-dynamic-allocs})' %s -split-input-file | FileCheck %s

func.func private @single_alloc(%size : index) -> (memref<?xf32>) {
%alloc = memref.alloc(%size) : memref<?xf32>
return %alloc : memref<?xf32>
}

func.func @single_alloc_test(%size : index) {
%alloc = call @single_alloc(%size) : (index) -> (memref<?xf32>)
"test.sink"(%alloc) : (memref<?xf32>) -> ()
}

// CHECK-LABEL: func.func private @single_alloc(
// CHECK-SAME: %{{.*}}: index,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should size args be removed from callee when they aren't used after hoisting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have already considered this issue. This issue has been resolved. #160755

// CHECK-SAME: %{{.*}}: memref<?xf32>) {

// CHECK-LABEL: func.func @single_alloc_test(
// CHECK-SAME: %[[size:.*]]: index) {
// CHECK: %[[alloc:.*]] = memref.alloc(%[[size]]) : memref<?xf32>
// CHECK: call @single_alloc(%[[size]], %[[alloc]]) : (index, memref<?xf32>) -> ()
// CHECK: "test.sink"(%[[alloc]]) : (memref<?xf32>) -> ()
// CHECK: }

// -----

func.func private @mult_alloc(%size0 : index, %size1 : index) -> (memref<?x?xf32>, memref<?xf32>) {
%alloc0 = memref.alloc(%size0, %size1) : memref<?x?xf32>
%alloc1 = memref.alloc(%size1) : memref<?xf32>
return %alloc0, %alloc1 : memref<?x?xf32>, memref<?xf32>
}

func.func @mult_alloc_test(%size0 : index, %size1: index) {
%alloc0, %alloc1 = call @mult_alloc(%size0, %size1) : (index, index) -> (memref<?x?xf32>, memref<?xf32>)
"test.sink"(%alloc0, %alloc1) : (memref<?x?xf32>, memref<?xf32>) -> ()
}

// CHECK-LABEL: func private @mult_alloc(
// CHECK-SAME: %{{.*}}: index, %{{.*}}: index,
// CHECK-SAME: %{{.*}}: memref<?x?xf32>, %{{.*}}: memref<?xf32>) {

// CHECK-LABEL: func @mult_alloc_test(
// CHECK-SAME: %[[size0:.*]]: index,
// CHECK-SAME: %[[size1:.*]]: index) {
// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xf32>
// CHECK: %[[alloc1:.*]] = memref.alloc(%[[size1]]) : memref<?xf32>
// CHECK: call @mult_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]]) : (index, index, memref<?x?xf32>, memref<?xf32>) -> ()
// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]]) : (memref<?x?xf32>, memref<?xf32>) -> ()
// CHECK: }


// -----

func.func private @complex_alloc(%size0 : index, %size1 : index) -> (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) {
%alloc0 = memref.alloc(%size0, %size1) : memref<?x?xf32>
%alloc1 = memref.alloc() : memref<4xf32>
%alloc2 = memref.alloc(%size1) : memref<?xf32>
return %alloc0, %alloc1, %alloc2 : memref<?x?xf32>, memref<4xf32>, memref<?xf32>
}

func.func @complex_alloc_test(%size0 : index, %size1: index) {
%alloc0, %alloc1, %alloc2 = call @complex_alloc(%size0, %size1) : (index, index) -> (memref<?x?xf32>, memref<4xf32>, memref<?xf32>)
"test.sink"(%alloc0, %alloc1, %alloc2) : (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
}

// CHECK-LABEL: func private @complex_alloc(
// CHECK-SAME: %{{.*}}: index, %{{.*}}: index,
// CHECK-SAME: %{{.*}}: memref<?x?xf32>,
// CHECK-SAME: %{{.*}}: memref<4xf32>,
// CHECK-SAME: %{{.*}}: memref<?xf32>) {

// CHECK-LABEL: func @complex_alloc_test(
// CHECK-SAME: %[[size0:.*]]: index,
// CHECK-SAME: %[[size1:.*]]: index) {
// CHECK: %[[alloc0:.*]] = memref.alloc(%[[size0]], %[[size1]]) : memref<?x?xf32>
// CHECK: %[[alloc1:.*]] = memref.alloc() : memref<4xf32>
// CHECK: %[[alloc2:.*]] = memref.alloc(%[[size1]]) : memref<?xf32>
// CHECK: call @complex_alloc(%[[size0]], %[[size1]], %[[alloc0]], %[[alloc1]], %[[alloc2]]) : (index, index, memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
// CHECK: "test.sink"(%[[alloc0]], %[[alloc1]], %[[alloc2]]) : (memref<?x?xf32>, memref<4xf32>, memref<?xf32>) -> ()
// CHECK: }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a "no-op" case where it's impossible to hoist? like when a dynamic size is defined inside the callee func.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could introduce such examples in subsequent PRs, such as support for conatsant Op.I believe it is also acceptable to introduce such an example at this point.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although i guess in some cases it would be possible, but non-trivial. but it doesn't look like this option handles that case so would be good to track that behavior in a test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could introduce such examples in subsequent PRs, such as support for constant Op.I believe it is also acceptable to introduce such an example at this point.

sure. i'm not suggesting handling this case now. but more to show this pass won't break in that case in the meantime. maybe i'm just being too cautious though