Skip to content

Conversation

@clementval
Copy link
Contributor

Some operations like the gpu.func have arguments that need to stay in place while rewriting the signature. This is the case for the workgroup and private attribution.
Update the target rewrite pass to be aware of that when adding argument at the end of the function signature. If any trailing arguments are present, the new argument will be inserted just before them.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Oct 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 21, 2025

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-codegen

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Some operations like the gpu.func have arguments that need to stay in place while rewriting the signature. This is the case for the workgroup and private attribution.
Update the target rewrite pass to be aware of that when adding argument at the end of the function signature. If any trailing arguments are present, the new argument will be inserted just before them.


Full diff: https://github.com/llvm/llvm-project/pull/164515.diff

3 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (+20-4)
  • (modified) flang/test/Fir/CUDA/cuda-target-rewrite.mlir (+53)
  • (modified) flang/tools/fir-opt/fir-opt.cpp (+1)
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index ac285b5d403df..0776346870c72 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -872,6 +872,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       }
     }
 
+    // Count the number of arguments that have to stay in place at the end of
+    // the argument list.
+    unsigned trailingArgs = 0;
+    if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) {
+      trailingArgs =
+          func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions();
+    }
+
     // Convert return value(s)
     for (auto ty : funcTy.getResults())
       llvm::TypeSwitch<mlir::Type>(ty)
@@ -981,6 +989,16 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       }
     }
 
+    // Add the argument at the end if the number of trailing arguments is 0,
+    // otherwise insert the argument at the appropriate index.
+    auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) {
+      unsigned inputIndex = func.front().getArguments().size() - trailingArgs;
+      auto newArg = trailingArgs == 0
+                        ? func.front().addArgument(ty, loc)
+                        : func.front().insertArgument(inputIndex, ty, loc);
+      return newArg;
+    };
+
     if (!func.empty()) {
       // If the function has a body, then apply the fixups to the arguments and
       // return ops as required. These fixups are done in place.
@@ -1117,8 +1135,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           // original arguments. (Boxchar arguments.)
           auto newBufArg =
               func.front().insertArgument(fixup.index, fixupType, loc);
-          auto newLenArg =
-              func.front().addArgument(trailingTys[fixup.second], loc);
+          auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
           auto boxTy = oldArgTys[fixup.index - offset];
           rewriter->setInsertionPointToStart(&func.front());
           auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg,
@@ -1133,8 +1150,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           // appended after all the original arguments.
           auto newProcPointerArg =
               func.front().insertArgument(fixup.index, fixupType, loc);
-          auto newLenArg =
-              func.front().addArgument(trailingTys[fixup.second], loc);
+          auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
           auto tupleType = oldArgTys[fixup.index - offset];
           rewriter->setInsertionPointToStart(&func.front());
           fir::FirOpBuilder builder(*rewriter, getModule());
diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
index a334934f31723..48fee10f3db97 100644
--- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
+++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
@@ -55,3 +55,56 @@ func.func @main(%arg0: complex<f64>) {
 // CHECK-SAME: (%arg0: f64, %arg1: f64) kernel {
 // CHECK: gpu.return
 // CHECK: gpu.launch_func  @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64) {cuf.proc_attr = #cuf.cuda_proc<global>}
+
+// -----
+
+module attributes {gpu.container_module, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+  gpu.module @testmod {
+    gpu.func @_QMbarPfoo(%arg0: f32, %arg1: !fir.ref<!fir.array<100xf32>>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
+      %c0 = arith.constant 0 : index
+      memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+      gpu.return
+    }
+// CHECK-LABEL: gpu.func @_QMbarPfoo(
+// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref<!fir.array<100xf32>>, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WORKGROUP:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
+// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
+// CHECK: memref.store %{{.*}}, %[[WORKGROUP]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+
+    gpu.func @_QMbarPfoo2(%arg0: f32, %arg1: !fir.ref<!fir.array<100xf32>>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}, %arg4 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
+      %c0 = arith.constant 0 : index
+      memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+      memref.store %arg0, %arg4[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+      gpu.return
+    }
+// CHECK-LABEL: gpu.func @_QMbarPfoo2(
+// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref<!fir.array<100xf32>>, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WG1:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}, %[[WG2:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
+// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
+// CHECK: memref.store %{{.*}}, %[[WG1]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+// CHECK: memref.store %{{.*}}, %[[WG2]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+
+    gpu.func @_QMbarPprivate(%arg0: f32, %arg1: !fir.boxchar<1>) workgroup(%arg2 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) private(%arg3 : memref<1xf32, #gpu.address_space<private>> {llvm.align = 16 : i32}) {
+      %c0 = arith.constant 0 : index
+      memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+      memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<private>>
+      gpu.return
+    }
+// CHECK-LABEL: gpu.func @_QMbarPprivate(
+// CHECK-SAME: %{{.*}}: f32, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) private(%[[PRIVATE:.*]] : memref<1xf32, #gpu.address_space<private>> {llvm.align = 16 : i32}) {
+// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
+// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+// CHECK: memref.store %{{.*}}, %[[PRIVATE]][%{{.*}}] : memref<1xf32, #gpu.address_space<private>>
+    
+    gpu.func @test_with_char_proc(%arg0: f32, %arg1: tuple<() -> (), i64> {fir.char_proc}) workgroup(%arg2 : memref<1xf32, #gpu.address_space<workgroup>>) {
+      %c0 = arith.constant 0 : index
+      memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+      gpu.return
+    }
+// CHECK-LABEL: gpu.func @test_with_char_proc(
+// CHECK-SAME: %{{.*}}: f32, %[[CHARPROC:.*]]: () -> () {fir.char_proc}, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space<workgroup>>) {
+// CHECK: %{{.*}} = fir.undefined tuple<() -> (), i64>
+// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[CHARPROC]], [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64>
+// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[LENGTH]], [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64>
+// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+  }
+}
+
diff --git a/flang/tools/fir-opt/fir-opt.cpp b/flang/tools/fir-opt/fir-opt.cpp
index 32b0a1dfa5c7a..67d07eee1f4fc 100644
--- a/flang/tools/fir-opt/fir-opt.cpp
+++ b/flang/tools/fir-opt/fir-opt.cpp
@@ -50,6 +50,7 @@ int main(int argc, char **argv) {
 #endif
   DialectRegistry registry;
   fir::support::registerDialects(registry);
+  registry.insert<mlir::memref::MemRefDialect>();
   fir::support::addFIRExtensions(registry);
   return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n",
       registry));

Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@clementval clementval merged commit 47ea854 into llvm:main Oct 22, 2025
14 checks passed
@clementval clementval deleted the gpu_workgroup_rewrite branch October 22, 2025 16:49
dvbuka pushed a commit to dvbuka/llvm-project that referenced this pull request Oct 27, 2025
…utions (llvm#164515)

Some operations like the gpu.func have arguments that need to stay in
place while rewriting the signature. This is the case for the workgroup
and private attribution.
Update the target rewrite pass to be aware of that when adding argument
at the end of the function signature. If any trailing arguments are
present, the new argument will be inserted just before them.
Lukacma pushed a commit to Lukacma/llvm-project that referenced this pull request Oct 29, 2025
…utions (llvm#164515)

Some operations like the gpu.func have arguments that need to stay in
place while rewriting the signature. This is the case for the workgroup
and private attribution.
Update the target rewrite pass to be aware of that when adding argument
at the end of the function signature. If any trailing arguments are
present, the new argument will be inserted just before them.
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
…utions (llvm#164515)

Some operations like the gpu.func have arguments that need to stay in
place while rewriting the signature. This is the case for the workgroup
and private attribution.
Update the target rewrite pass to be aware of that when adding argument
at the end of the function signature. If any trailing arguments are
present, the new argument will be inserted just before them.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants