From 36fae81e926e951fd9f0aeb48205afd2bb4b3ede Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Tue, 21 Oct 2025 16:09:32 -0700 Subject: [PATCH] [flang] Update target rewrite to support workgroup and private attributions --- flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 24 +++++++-- flang/test/Fir/CUDA/cuda-target-rewrite.mlir | 53 +++++++++++++++++++ flang/tools/fir-opt/fir-opt.cpp | 1 + 3 files changed, 74 insertions(+), 4 deletions(-) diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index ac285b5d403df..0776346870c72 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -872,6 +872,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { } } + // Count the number of arguments that have to stay in place at the end of + // the argument list. + unsigned trailingArgs = 0; + if constexpr (std::is_same_v) { + trailingArgs = + func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions(); + } + // Convert return value(s) for (auto ty : funcTy.getResults()) llvm::TypeSwitch(ty) @@ -981,6 +989,16 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { } } + // Add the argument at the end if the number of trailing arguments is 0, + // otherwise insert the argument at the appropriate index. + auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) { + unsigned inputIndex = func.front().getArguments().size() - trailingArgs; + auto newArg = trailingArgs == 0 + ? func.front().addArgument(ty, loc) + : func.front().insertArgument(inputIndex, ty, loc); + return newArg; + }; + if (!func.empty()) { // If the function has a body, then apply the fixups to the arguments and // return ops as required. These fixups are done in place. @@ -1117,8 +1135,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { // original arguments. (Boxchar arguments.) auto newBufArg = func.front().insertArgument(fixup.index, fixupType, loc); - auto newLenArg = - func.front().addArgument(trailingTys[fixup.second], loc); + auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc); auto boxTy = oldArgTys[fixup.index - offset]; rewriter->setInsertionPointToStart(&func.front()); auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg, @@ -1133,8 +1150,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { // appended after all the original arguments. auto newProcPointerArg = func.front().insertArgument(fixup.index, fixupType, loc); - auto newLenArg = - func.front().addArgument(trailingTys[fixup.second], loc); + auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc); auto tupleType = oldArgTys[fixup.index - offset]; rewriter->setInsertionPointToStart(&func.front()); fir::FirOpBuilder builder(*rewriter, getModule()); diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir index a334934f31723..48fee10f3db97 100644 --- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir +++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir @@ -55,3 +55,56 @@ func.func @main(%arg0: complex) { // CHECK-SAME: (%arg0: f64, %arg1: f64) kernel { // CHECK: gpu.return // CHECK: gpu.launch_func @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64) {cuf.proc_attr = #cuf.cuda_proc} + +// ----- + +module attributes {gpu.container_module, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} { + gpu.module @testmod { + gpu.func @_QMbarPfoo(%arg0: f32, %arg1: !fir.ref>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space> + gpu.return + } +// CHECK-LABEL: gpu.func @_QMbarPfoo( +// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref>, %[[CHAR:.*]]: !fir.ref>, %[[LENGTH:.*]]: i64) workgroup(%[[WORKGROUP:.*]] : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) { +// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref>, i64) -> !fir.boxchar<1> +// CHECK: memref.store %{{.*}}, %[[WORKGROUP]][%{{.*}}] : memref<1xf32, #gpu.address_space> + + gpu.func @_QMbarPfoo2(%arg0: f32, %arg1: !fir.ref>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}, %arg4 : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space> + memref.store %arg0, %arg4[%c0] : memref<1xf32, #gpu.address_space> + gpu.return + } +// CHECK-LABEL: gpu.func @_QMbarPfoo2( +// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref>, %[[CHAR:.*]]: !fir.ref>, %[[LENGTH:.*]]: i64) workgroup(%[[WG1:.*]] : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}, %[[WG2:.*]] : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) { +// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref>, i64) -> !fir.boxchar<1> +// CHECK: memref.store %{{.*}}, %[[WG1]][%{{.*}}] : memref<1xf32, #gpu.address_space> +// CHECK: memref.store %{{.*}}, %[[WG2]][%{{.*}}] : memref<1xf32, #gpu.address_space> + + gpu.func @_QMbarPprivate(%arg0: f32, %arg1: !fir.boxchar<1>) workgroup(%arg2 : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) private(%arg3 : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space> + memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space> + gpu.return + } +// CHECK-LABEL: gpu.func @_QMbarPprivate( +// CHECK-SAME: %{{.*}}: f32, %[[CHAR:.*]]: !fir.ref>, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) private(%[[PRIVATE:.*]] : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) { +// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref>, i64) -> !fir.boxchar<1> +// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space> +// CHECK: memref.store %{{.*}}, %[[PRIVATE]][%{{.*}}] : memref<1xf32, #gpu.address_space> + + gpu.func @test_with_char_proc(%arg0: f32, %arg1: tuple<() -> (), i64> {fir.char_proc}) workgroup(%arg2 : memref<1xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space> + gpu.return + } +// CHECK-LABEL: gpu.func @test_with_char_proc( +// CHECK-SAME: %{{.*}}: f32, %[[CHARPROC:.*]]: () -> () {fir.char_proc}, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space>) { +// CHECK: %{{.*}} = fir.undefined tuple<() -> (), i64> +// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[CHARPROC]], [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64> +// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[LENGTH]], [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64> +// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space> + } +} + diff --git a/flang/tools/fir-opt/fir-opt.cpp b/flang/tools/fir-opt/fir-opt.cpp index 32b0a1dfa5c7a..67d07eee1f4fc 100644 --- a/flang/tools/fir-opt/fir-opt.cpp +++ b/flang/tools/fir-opt/fir-opt.cpp @@ -50,6 +50,7 @@ int main(int argc, char **argv) { #endif DialectRegistry registry; fir::support::registerDialects(registry); + registry.insert(); fir::support::addFIRExtensions(registry); return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n", registry));