-
Notifications
You must be signed in to change notification settings - Fork 11.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] BufferResultsToOutParams: Allow to configure memCpyFn #83389
Conversation
@llvm/pr-subscribers-mlir-bufferization @llvm/pr-subscribers-mlir Author: Matthias Gehre (mgehre-amd) ChangesThis allows us to configure the pass to emit This is consistent with Full diff: https://github.com/llvm/llvm-project/pull/83389.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index bb4b5221981638..809f03407258a8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -149,11 +149,19 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
struct BufferResultsToOutParamsOptions {
+ /// Memcpy function: Generate a memcpy between two memrefs.
+ using MemCpyFn =
+ std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+
// Filter function; returns true if the function should be converted.
// Defaults to true, i.e. all functions are converted.
llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
return true;
};
+
+ /// Memcpy function; used to create a copy between two memrefs.
+ /// If this is empty, memref.copy is used.
+ std::optional<MemCpyFn> memCpyFn;
};
/// Creates a pass that converts memref function results to out-params.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index dd359c2dcca5dd..930f035339c1d3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,6 +21,7 @@ namespace bufferization {
} // namespace mlir
using namespace mlir;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -97,9 +98,10 @@ updateFuncOp(func::FuncOp func,
// Updates all ReturnOps in the scope of the given func::FuncOp by either
// keeping them as return values or copying the associated buffer contents into
// the given out-params.
-static void updateReturnOps(func::FuncOp func,
- ArrayRef<BlockArgument> appendedEntryArgs) {
- func.walk([&](func::ReturnOp op) {
+static LogicalResult updateReturnOps(func::FuncOp func,
+ ArrayRef<BlockArgument> appendedEntryArgs,
+ MemCpyFn memCpyFn) {
+ auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
for (Value operand : op.getOperands()) {
@@ -109,12 +111,16 @@ static void updateReturnOps(func::FuncOp func,
keepAsReturnOperands.push_back(operand);
}
OpBuilder builder(op);
- for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
- builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
- std::get<1>(t));
+ for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
+ if (failed(
+ memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
+ return WalkResult::interrupt();
+ }
builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
op.erase();
+ return WalkResult::advance();
});
+ return failure(res.wasInterrupted());
}
// Updates all CallOps in the scope of the given ModuleOp by allocating
@@ -192,7 +198,15 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
if (func.isExternal())
continue;
- updateReturnOps(func, appendedEntryArgs);
+ auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
+ Value to) {
+ builder.create<memref::CopyOp>(loc, from, to);
+ return success();
+ };
+ if (failed(updateReturnOps(func, appendedEntryArgs,
+ options.memCpyFn.value_or(defaultMemCpyFn)))) {
+ return failure();
+ }
}
if (failed(updateCalls(module, options)))
return failure();
|
|
0e3952c
to
fda142f
Compare
Friendly ping, @matthias-springer. Thanks! |
fda142f
to
afac64c
Compare
This is actually merged in afac64c, but github seems to be confused about it. |
This allows us to configure the pass to emit
linalg.copy
instead ofmemref.copy
.This is consistent with
one-shot-bufferize
, which also allows to configure thememCpyFn
, see https://discord.com/channels/636084430946959380/642426447167881246/1211698722438783087There is no easy way to add a test for this new option, because
a) the pass doesn't define it's current options in the the tablegen file to make them available to mlir-opt
b) even if they were defined there, and if we could have a string option for memCpyFn, turning that string into the
std::function
would mean that the pass would now need to depend on the linalg dialect, which seems wrong.