-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[flang] Update target rewrite to support workgroup and private attributions #164515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-codegen Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesSome operations like the gpu.func have arguments that need to stay in place while rewriting the signature. This is the case for the workgroup and private attribution. Full diff: https://github.com/llvm/llvm-project/pull/164515.diff 3 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index ac285b5d403df..0776346870c72 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -872,6 +872,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}
+ // Count the number of arguments that have to stay in place at the end of
+ // the argument list.
+ unsigned trailingArgs = 0;
+ if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) {
+ trailingArgs =
+ func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions();
+ }
+
// Convert return value(s)
for (auto ty : funcTy.getResults())
llvm::TypeSwitch<mlir::Type>(ty)
@@ -981,6 +989,16 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}
+ // Add the argument at the end if the number of trailing arguments is 0,
+ // otherwise insert the argument at the appropriate index.
+ auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) {
+ unsigned inputIndex = func.front().getArguments().size() - trailingArgs;
+ auto newArg = trailingArgs == 0
+ ? func.front().addArgument(ty, loc)
+ : func.front().insertArgument(inputIndex, ty, loc);
+ return newArg;
+ };
+
if (!func.empty()) {
// If the function has a body, then apply the fixups to the arguments and
// return ops as required. These fixups are done in place.
@@ -1117,8 +1135,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// original arguments. (Boxchar arguments.)
auto newBufArg =
func.front().insertArgument(fixup.index, fixupType, loc);
- auto newLenArg =
- func.front().addArgument(trailingTys[fixup.second], loc);
+ auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
auto boxTy = oldArgTys[fixup.index - offset];
rewriter->setInsertionPointToStart(&func.front());
auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg,
@@ -1133,8 +1150,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// appended after all the original arguments.
auto newProcPointerArg =
func.front().insertArgument(fixup.index, fixupType, loc);
- auto newLenArg =
- func.front().addArgument(trailingTys[fixup.second], loc);
+ auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
auto tupleType = oldArgTys[fixup.index - offset];
rewriter->setInsertionPointToStart(&func.front());
fir::FirOpBuilder builder(*rewriter, getModule());
diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
index a334934f31723..48fee10f3db97 100644
--- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
+++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
@@ -55,3 +55,56 @@ func.func @main(%arg0: complex<f64>) {
// CHECK-SAME: (%arg0: f64, %arg1: f64) kernel {
// CHECK: gpu.return
// CHECK: gpu.launch_func @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64) {cuf.proc_attr = #cuf.cuda_proc<global>}
+
+// -----
+
+module attributes {gpu.container_module, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+ gpu.module @testmod {
+ gpu.func @_QMbarPfoo(%arg0: f32, %arg1: !fir.ref<!fir.array<100xf32>>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
+ %c0 = arith.constant 0 : index
+ memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+ gpu.return
+ }
+// CHECK-LABEL: gpu.func @_QMbarPfoo(
+// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref<!fir.array<100xf32>>, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WORKGROUP:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
+// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
+// CHECK: memref.store %{{.*}}, %[[WORKGROUP]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+
+ gpu.func @_QMbarPfoo2(%arg0: f32, %arg1: !fir.ref<!fir.array<100xf32>>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}, %arg4 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
+ %c0 = arith.constant 0 : index
+ memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+ memref.store %arg0, %arg4[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+ gpu.return
+ }
+// CHECK-LABEL: gpu.func @_QMbarPfoo2(
+// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref<!fir.array<100xf32>>, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WG1:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}, %[[WG2:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
+// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
+// CHECK: memref.store %{{.*}}, %[[WG1]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+// CHECK: memref.store %{{.*}}, %[[WG2]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+
+ gpu.func @_QMbarPprivate(%arg0: f32, %arg1: !fir.boxchar<1>) workgroup(%arg2 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) private(%arg3 : memref<1xf32, #gpu.address_space<private>> {llvm.align = 16 : i32}) {
+ %c0 = arith.constant 0 : index
+ memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+ memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<private>>
+ gpu.return
+ }
+// CHECK-LABEL: gpu.func @_QMbarPprivate(
+// CHECK-SAME: %{{.*}}: f32, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) private(%[[PRIVATE:.*]] : memref<1xf32, #gpu.address_space<private>> {llvm.align = 16 : i32}) {
+// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
+// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+// CHECK: memref.store %{{.*}}, %[[PRIVATE]][%{{.*}}] : memref<1xf32, #gpu.address_space<private>>
+
+ gpu.func @test_with_char_proc(%arg0: f32, %arg1: tuple<() -> (), i64> {fir.char_proc}) workgroup(%arg2 : memref<1xf32, #gpu.address_space<workgroup>>) {
+ %c0 = arith.constant 0 : index
+ memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+ gpu.return
+ }
+// CHECK-LABEL: gpu.func @test_with_char_proc(
+// CHECK-SAME: %{{.*}}: f32, %[[CHARPROC:.*]]: () -> () {fir.char_proc}, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space<workgroup>>) {
+// CHECK: %{{.*}} = fir.undefined tuple<() -> (), i64>
+// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[CHARPROC]], [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64>
+// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[LENGTH]], [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64>
+// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+ }
+}
+
diff --git a/flang/tools/fir-opt/fir-opt.cpp b/flang/tools/fir-opt/fir-opt.cpp
index 32b0a1dfa5c7a..67d07eee1f4fc 100644
--- a/flang/tools/fir-opt/fir-opt.cpp
+++ b/flang/tools/fir-opt/fir-opt.cpp
@@ -50,6 +50,7 @@ int main(int argc, char **argv) {
#endif
DialectRegistry registry;
fir::support::registerDialects(registry);
+ registry.insert<mlir::memref::MemRefDialect>();
fir::support::addFIRExtensions(registry);
return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n",
registry));
|
razvanlupusoru
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
jeanPerier
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
…utions (llvm#164515) Some operations like the gpu.func have arguments that need to stay in place while rewriting the signature. This is the case for the workgroup and private attribution. Update the target rewrite pass to be aware of that when adding argument at the end of the function signature. If any trailing arguments are present, the new argument will be inserted just before them.
…utions (llvm#164515) Some operations like the gpu.func have arguments that need to stay in place while rewriting the signature. This is the case for the workgroup and private attribution. Update the target rewrite pass to be aware of that when adding argument at the end of the function signature. If any trailing arguments are present, the new argument will be inserted just before them.
…utions (llvm#164515) Some operations like the gpu.func have arguments that need to stay in place while rewriting the signature. This is the case for the workgroup and private attribution. Update the target rewrite pass to be aware of that when adding argument at the end of the function signature. If any trailing arguments are present, the new argument will be inserted just before them.
Some operations like the gpu.func have arguments that need to stay in place while rewriting the signature. This is the case for the workgroup and private attribution.
Update the target rewrite pass to be aware of that when adding argument at the end of the function signature. If any trailing arguments are present, the new argument will be inserted just before them.