Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] BufferResultsToOutParams: Allow to configure memCpyFn #83389

Closed
wants to merge 0 commits into from

Conversation

mgehre-amd
Copy link
Contributor

@mgehre-amd mgehre-amd commented Feb 29, 2024

This allows us to configure the pass to emit linalg.copy instead of memref.copy.

This is consistent with one-shot-bufferize, which also allows to configure the memCpyFn, see https://discord.com/channels/636084430946959380/642426447167881246/1211698722438783087

There is no easy way to add a test for this new option, because
a) the pass doesn't define it's current options in the the tablegen file to make them available to mlir-opt
b) even if they were defined there, and if we could have a string option for memCpyFn, turning that string into the std::function would mean that the pass would now need to depend on the linalg dialect, which seems wrong.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Feb 29, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 29, 2024

@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir

Author: Matthias Gehre (mgehre-amd)

Changes

This allows us to configure the pass to emit linalg.copy instead of memref.copy.

This is consistent with one-shot-bufferize, which also allows to configure the memCpyFn, see https://discord.com/channels/636084430946959380/642426447167881246/1211698722438783087


Full diff: https://github.com/llvm/llvm-project/pull/83389.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (+8)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp (+21-7)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index bb4b5221981638..809f03407258a8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -149,11 +149,19 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
 // Options struct for BufferResultsToOutParams pass.
 // Note: defined only here, not in tablegen.
 struct BufferResultsToOutParamsOptions {
+  /// Memcpy function: Generate a memcpy between two memrefs.
+  using MemCpyFn =
+      std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
+
   // Filter function; returns true if the function should be converted.
   // Defaults to true, i.e. all functions are converted.
   llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
     return true;
   };
+
+  /// Memcpy function; used to create a copy between two memrefs.
+  /// If this is empty, memref.copy is used.
+  std::optional<MemCpyFn> memCpyFn;
 };
 
 /// Creates a pass that converts memref function results to out-params.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index dd359c2dcca5dd..930f035339c1d3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -21,6 +21,7 @@ namespace bufferization {
 } // namespace mlir
 
 using namespace mlir;
+using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
 
 /// Return `true` if the given MemRef type has a fully dynamic layout.
 static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -97,9 +98,10 @@ updateFuncOp(func::FuncOp func,
 // Updates all ReturnOps in the scope of the given func::FuncOp by either
 // keeping them as return values or copying the associated buffer contents into
 // the given out-params.
-static void updateReturnOps(func::FuncOp func,
-                            ArrayRef<BlockArgument> appendedEntryArgs) {
-  func.walk([&](func::ReturnOp op) {
+static LogicalResult updateReturnOps(func::FuncOp func,
+                                     ArrayRef<BlockArgument> appendedEntryArgs,
+                                     MemCpyFn memCpyFn) {
+  auto res = func.walk([&](func::ReturnOp op) {
     SmallVector<Value, 6> copyIntoOutParams;
     SmallVector<Value, 6> keepAsReturnOperands;
     for (Value operand : op.getOperands()) {
@@ -109,12 +111,16 @@ static void updateReturnOps(func::FuncOp func,
         keepAsReturnOperands.push_back(operand);
     }
     OpBuilder builder(op);
-    for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
-      builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
-                                     std::get<1>(t));
+    for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
+      if (failed(
+              memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
+        return WalkResult::interrupt();
+    }
     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
     op.erase();
+    return WalkResult::advance();
   });
+  return failure(res.wasInterrupted());
 }
 
 // Updates all CallOps in the scope of the given ModuleOp by allocating
@@ -192,7 +198,15 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
       return failure();
     if (func.isExternal())
       continue;
-    updateReturnOps(func, appendedEntryArgs);
+    auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
+                              Value to) {
+      builder.create<memref::CopyOp>(loc, from, to);
+      return success();
+    };
+    if (failed(updateReturnOps(func, appendedEntryArgs,
+                               options.memCpyFn.value_or(defaultMemCpyFn)))) {
+      return failure();
+    }
   }
   if (failed(updateCalls(module, options)))
     return failure();

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@mgehre-amd
Copy link
Contributor Author

Friendly ping, @matthias-springer. Thanks!

@mgehre-amd
Copy link
Contributor Author

This is actually merged in afac64c, but github seems to be confused about it.

@mgehre-amd mgehre-amd deleted the matthias.memcpyfn branch March 7, 2024 10:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants