Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Lower attribute for module variables #81226

Merged
merged 3 commits into from
Feb 9, 2024

Conversation

clementval
Copy link
Contributor

Propagate the CUDA attribute to fir.global operation for simple module variables.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 9, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 9, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Propagate the CUDA attribute to fir.global operation for simple module variables.


Full diff: https://github.com/llvm/llvm-project/pull/81226.diff

5 Files Affected:

  • (modified) flang/include/flang/Optimizer/Builder/FIRBuilder.h (+4-2)
  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+2-1)
  • (modified) flang/lib/Lower/ConvertVariable.cpp (+11-5)
  • (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+11-9)
  • (modified) flang/test/Lower/CUDA/cuda-data-attribute.cuf (+30-21)
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index 5384f6e8121ec6..f50dacd327a7c0 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -230,12 +230,14 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
                              llvm::StringRef name,
                              mlir::StringAttr linkage = {},
                              mlir::Attribute value = {}, bool isConst = false,
-                             bool isTarget = false);
+                             bool isTarget = false,
+                             fir::CUDAAttributeAttr cudaAttr = {});
 
   fir::GlobalOp createGlobal(mlir::Location loc, mlir::Type type,
                              llvm::StringRef name, bool isConst, bool isTarget,
                              std::function<void(FirOpBuilder &)> bodyBuilder,
-                             mlir::StringAttr linkage = {});
+                             mlir::StringAttr linkage = {},
+                             fir::CUDAAttributeAttr cudaAttr = {});
 
   /// Create a global constant (read-only) value.
   fir::GlobalOp createGlobalConstant(mlir::Location loc, mlir::Type type,
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index b954a0cc74d0e1..d505fedd6e6415 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2737,7 +2737,8 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
     OptionalAttr<AnyAttr>:$initVal,
     OptionalAttr<UnitAttr>:$constant,
     OptionalAttr<UnitAttr>:$target,
-    OptionalAttr<StrAttr>:$linkName
+    OptionalAttr<StrAttr>:$linkName,
+    OptionalAttr<fir_CUDAAttributeAttr>:$cuda_attr
   );
 
   let regions = (region AtMostRegion<1>:$region);
diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index f14267f1234217..2f23757f497ea5 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -138,7 +138,8 @@ static bool isConstant(const Fortran::semantics::Symbol &sym) {
 static fir::GlobalOp defineGlobal(Fortran::lower::AbstractConverter &converter,
                                   const Fortran::lower::pft::Variable &var,
                                   llvm::StringRef globalName,
-                                  mlir::StringAttr linkage);
+                                  mlir::StringAttr linkage,
+                                  fir::CUDAAttributeAttr cudaAttr = {});
 
 static mlir::Location genLocation(Fortran::lower::AbstractConverter &converter,
                                   const Fortran::semantics::Symbol &sym) {
@@ -462,7 +463,8 @@ void Fortran::lower::createGlobalInitialization(
 static fir::GlobalOp defineGlobal(Fortran::lower::AbstractConverter &converter,
                                   const Fortran::lower::pft::Variable &var,
                                   llvm::StringRef globalName,
-                                  mlir::StringAttr linkage) {
+                                  mlir::StringAttr linkage,
+                                  fir::CUDAAttributeAttr cudaAttr) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   const Fortran::semantics::Symbol &sym = var.getSymbol();
   mlir::Location loc = genLocation(converter, sym);
@@ -500,8 +502,9 @@ static fir::GlobalOp defineGlobal(Fortran::lower::AbstractConverter &converter,
     }
   }
   if (!global)
-    global = builder.createGlobal(loc, symTy, globalName, linkage,
-                                  mlir::Attribute{}, isConst, var.isTarget());
+    global =
+        builder.createGlobal(loc, symTy, globalName, linkage, mlir::Attribute{},
+                             isConst, var.isTarget(), cudaAttr);
   if (Fortran::semantics::IsAllocatableOrPointer(sym) &&
       !Fortran::semantics::IsProcedure(sym)) {
     const auto *details =
@@ -2219,7 +2222,10 @@ void Fortran::lower::defineModuleVariable(
     // Do nothing. Mapping will be done on user side.
   } else {
     std::string globalName = converter.mangleName(sym);
-    defineGlobal(converter, var, globalName, linkage);
+    fir::CUDAAttributeAttr cudaAttr =
+        Fortran::lower::translateSymbolCUDAAttribute(
+            converter.getFirOpBuilder().getContext(), sym);
+    defineGlobal(converter, var, globalName, linkage, cudaAttr);
   }
 }
 
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 141f8fcd3ab5fc..cce120c1f872c4 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -271,19 +271,21 @@ mlir::Value fir::FirOpBuilder::createHeapTemporary(
 
 /// Create a global variable in the (read-only) data section. A global variable
 /// must have a unique name to identify and reference it.
-fir::GlobalOp fir::FirOpBuilder::createGlobal(mlir::Location loc,
-                                              mlir::Type type,
-                                              llvm::StringRef name,
-                                              mlir::StringAttr linkage,
-                                              mlir::Attribute value,
-                                              bool isConst, bool isTarget) {
+fir::GlobalOp fir::FirOpBuilder::createGlobal(
+    mlir::Location loc, mlir::Type type, llvm::StringRef name,
+    mlir::StringAttr linkage, mlir::Attribute value, bool isConst,
+    bool isTarget, fir::CUDAAttributeAttr cudaAttr) {
   auto module = getModule();
   auto insertPt = saveInsertionPoint();
   if (auto glob = module.lookupSymbol<fir::GlobalOp>(name))
     return glob;
   setInsertionPoint(module.getBody(), module.getBody()->end());
-  auto glob =
-      create<fir::GlobalOp>(loc, name, isConst, isTarget, type, value, linkage);
+  llvm::SmallVector<mlir::NamedAttribute> attrs;
+  if (cudaAttr)
+    attrs.push_back(mlir::NamedAttribute(
+        mlir::StringAttr::get(module.getContext(), "cuda_attr"), cudaAttr));
+  auto glob = create<fir::GlobalOp>(loc, name, isConst, isTarget, type, value,
+                                    linkage, attrs);
   restoreInsertionPoint(insertPt);
   return glob;
 }
@@ -291,7 +293,7 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(mlir::Location loc,
 fir::GlobalOp fir::FirOpBuilder::createGlobal(
     mlir::Location loc, mlir::Type type, llvm::StringRef name, bool isConst,
     bool isTarget, std::function<void(FirOpBuilder &)> bodyBuilder,
-    mlir::StringAttr linkage) {
+    mlir::StringAttr linkage, fir::CUDAAttributeAttr cudaAttr) {
   auto module = getModule();
   auto insertPt = saveInsertionPoint();
   if (auto glob = module.lookupSymbol<fir::GlobalOp>(name))
diff --git a/flang/test/Lower/CUDA/cuda-data-attribute.cuf b/flang/test/Lower/CUDA/cuda-data-attribute.cuf
index b02701bf3aea5a..7596c6b21efb0d 100644
--- a/flang/test/Lower/CUDA/cuda-data-attribute.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-attribute.cuf
@@ -3,6 +3,18 @@
 
 ! Test lowering of CUDA attribute on variables.
 
+module cuda_var
+  real, constant :: mod_a_rc
+! CHECK: fir.global @_QMcuda_varEmod_a_rc {cuda_attr = #fir.cuda<constant>} : f32 
+  real, device :: mod_b_ra
+! CHECK: fir.global @_QMcuda_varEmod_b_ra {cuda_attr = #fir.cuda<device>} : f32
+  real, allocatable, managed :: mod_c_rm
+! CHECK: fir.global @_QMcuda_varEmod_c_rm {cuda_attr = #fir.cuda<managed>} : !fir.box<!fir.heap<f32>>
+  real, allocatable, pinned :: mod_d_rp
+! CHECK: fir.global @_QMcuda_varEmod_d_rp {cuda_attr = #fir.cuda<pinned>} : !fir.box<!fir.heap<f32>>
+
+contains
+
 subroutine local_var_attrs
   real, constant :: rc
   real, device :: rd
@@ -10,46 +22,43 @@ subroutine local_var_attrs
   real, allocatable, pinned :: rp
 end subroutine
 
-! CHECK-LABEL: func.func @_QPlocal_var_attrs()
-! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFlocal_var_attrsErc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
-! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFlocal_var_attrsErd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
-! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
-! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+! CHECK-LABEL: func.func @_QMcuda_varPlocal_var_attrs()
+! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QMcuda_varFlocal_var_attrsErc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QMcuda_varFlocal_var_attrsErd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMcuda_varFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMcuda_varFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
 
-! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFlocal_var_attrsErc"} : (!fir.ref<f32>) -> !fir.ref<f32>
-! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFlocal_var_attrsErd"} : (!fir.ref<f32>) -> !fir.ref<f32>
-! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
-! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
+! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QMcuda_varFlocal_var_attrsErc"} : (!fir.ref<f32>) -> !fir.ref<f32>
+! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QMcuda_varFlocal_var_attrsErd"} : (!fir.ref<f32>) -> !fir.ref<f32>
+! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMcuda_varFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
+! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMcuda_varFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
 
 subroutine dummy_arg_constant(dc)
   real, constant :: dc
 end subroutine
-! CHECK-LABEL: func.func @_QPdummy_arg_constant(
+! CHECK-LABEL: func.func @_QMcuda_varPdummy_arg_constant(
 ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<f32> {fir.bindc_name = "dc", fir.cuda_attr = #fir.cuda<constant>}
-! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFdummy_arg_constantEdc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<constant>, uniq_name = "_QMcuda_varFdummy_arg_constantEdc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 
 subroutine dummy_arg_device(dd)
   real, device :: dd
 end subroutine
-! CHECK-LABEL: func.func @_QPdummy_arg_device(
+! CHECK-LABEL: func.func @_QMcuda_varPdummy_arg_device(
 ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<f32> {fir.bindc_name = "dd", fir.cuda_attr = #fir.cuda<device>}) {
-! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<device>, uniq_name = "_QFdummy_arg_deviceEdd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<device>, uniq_name = "_QMcuda_varFdummy_arg_deviceEdd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 
 subroutine dummy_arg_managed(dm)
   real, allocatable, managed :: dm
 end subroutine
-! CHECK-LABEL: func.func @_QPdummy_arg_managed(
+! CHECK-LABEL: func.func @_QMcuda_varPdummy_arg_managed(
 ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<f32>>> {fir.bindc_name = "dm", fir.cuda_attr = #fir.cuda<managed>}) {
-! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFdummy_arg_managedEdm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMcuda_varFdummy_arg_managedEdm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
 
 subroutine dummy_arg_pinned(dp)
   real, allocatable, pinned :: dp
 end subroutine
-! CHECK-LABEL: func.func @_QPdummy_arg_pinned(
+! CHECK-LABEL: func.func @_QMcuda_varPdummy_arg_pinned(
 ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<f32>>> {fir.bindc_name = "dp", fir.cuda_attr = #fir.cuda<pinned>}) {
-! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFdummy_arg_pinnedEdp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
-
-
-
-
+! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMcuda_varFdummy_arg_pinnedEdp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
 
+end module

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

flang/lib/Optimizer/Builder/FIRBuilder.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Valentin!

@clementval
Copy link
Contributor Author

Thank you, Valentin!

Thanks for the pointer to the other PR. I was looking to avoid this constexpr in the extra declaration for a long time and it was indeed pretty easy :-)

@clementval clementval merged commit 314ef96 into llvm:main Feb 9, 2024
4 of 5 checks passed
@clementval clementval deleted the cuda_global_attr branch February 9, 2024 18:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants