-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][cuda] Lower attribute for module variables #81226
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesPropagate the CUDA attribute to fir.global operation for simple module variables. Full diff: https://github.com/llvm/llvm-project/pull/81226.diff 5 Files Affected:
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index 5384f6e8121ec6..f50dacd327a7c0 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -230,12 +230,14 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
llvm::StringRef name,
mlir::StringAttr linkage = {},
mlir::Attribute value = {}, bool isConst = false,
- bool isTarget = false);
+ bool isTarget = false,
+ fir::CUDAAttributeAttr cudaAttr = {});
fir::GlobalOp createGlobal(mlir::Location loc, mlir::Type type,
llvm::StringRef name, bool isConst, bool isTarget,
std::function<void(FirOpBuilder &)> bodyBuilder,
- mlir::StringAttr linkage = {});
+ mlir::StringAttr linkage = {},
+ fir::CUDAAttributeAttr cudaAttr = {});
/// Create a global constant (read-only) value.
fir::GlobalOp createGlobalConstant(mlir::Location loc, mlir::Type type,
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index b954a0cc74d0e1..d505fedd6e6415 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2737,7 +2737,8 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
OptionalAttr<AnyAttr>:$initVal,
OptionalAttr<UnitAttr>:$constant,
OptionalAttr<UnitAttr>:$target,
- OptionalAttr<StrAttr>:$linkName
+ OptionalAttr<StrAttr>:$linkName,
+ OptionalAttr<fir_CUDAAttributeAttr>:$cuda_attr
);
let regions = (region AtMostRegion<1>:$region);
diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index f14267f1234217..2f23757f497ea5 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -138,7 +138,8 @@ static bool isConstant(const Fortran::semantics::Symbol &sym) {
static fir::GlobalOp defineGlobal(Fortran::lower::AbstractConverter &converter,
const Fortran::lower::pft::Variable &var,
llvm::StringRef globalName,
- mlir::StringAttr linkage);
+ mlir::StringAttr linkage,
+ fir::CUDAAttributeAttr cudaAttr = {});
static mlir::Location genLocation(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &sym) {
@@ -462,7 +463,8 @@ void Fortran::lower::createGlobalInitialization(
static fir::GlobalOp defineGlobal(Fortran::lower::AbstractConverter &converter,
const Fortran::lower::pft::Variable &var,
llvm::StringRef globalName,
- mlir::StringAttr linkage) {
+ mlir::StringAttr linkage,
+ fir::CUDAAttributeAttr cudaAttr) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
const Fortran::semantics::Symbol &sym = var.getSymbol();
mlir::Location loc = genLocation(converter, sym);
@@ -500,8 +502,9 @@ static fir::GlobalOp defineGlobal(Fortran::lower::AbstractConverter &converter,
}
}
if (!global)
- global = builder.createGlobal(loc, symTy, globalName, linkage,
- mlir::Attribute{}, isConst, var.isTarget());
+ global =
+ builder.createGlobal(loc, symTy, globalName, linkage, mlir::Attribute{},
+ isConst, var.isTarget(), cudaAttr);
if (Fortran::semantics::IsAllocatableOrPointer(sym) &&
!Fortran::semantics::IsProcedure(sym)) {
const auto *details =
@@ -2219,7 +2222,10 @@ void Fortran::lower::defineModuleVariable(
// Do nothing. Mapping will be done on user side.
} else {
std::string globalName = converter.mangleName(sym);
- defineGlobal(converter, var, globalName, linkage);
+ fir::CUDAAttributeAttr cudaAttr =
+ Fortran::lower::translateSymbolCUDAAttribute(
+ converter.getFirOpBuilder().getContext(), sym);
+ defineGlobal(converter, var, globalName, linkage, cudaAttr);
}
}
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 141f8fcd3ab5fc..cce120c1f872c4 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -271,19 +271,21 @@ mlir::Value fir::FirOpBuilder::createHeapTemporary(
/// Create a global variable in the (read-only) data section. A global variable
/// must have a unique name to identify and reference it.
-fir::GlobalOp fir::FirOpBuilder::createGlobal(mlir::Location loc,
- mlir::Type type,
- llvm::StringRef name,
- mlir::StringAttr linkage,
- mlir::Attribute value,
- bool isConst, bool isTarget) {
+fir::GlobalOp fir::FirOpBuilder::createGlobal(
+ mlir::Location loc, mlir::Type type, llvm::StringRef name,
+ mlir::StringAttr linkage, mlir::Attribute value, bool isConst,
+ bool isTarget, fir::CUDAAttributeAttr cudaAttr) {
auto module = getModule();
auto insertPt = saveInsertionPoint();
if (auto glob = module.lookupSymbol<fir::GlobalOp>(name))
return glob;
setInsertionPoint(module.getBody(), module.getBody()->end());
- auto glob =
- create<fir::GlobalOp>(loc, name, isConst, isTarget, type, value, linkage);
+ llvm::SmallVector<mlir::NamedAttribute> attrs;
+ if (cudaAttr)
+ attrs.push_back(mlir::NamedAttribute(
+ mlir::StringAttr::get(module.getContext(), "cuda_attr"), cudaAttr));
+ auto glob = create<fir::GlobalOp>(loc, name, isConst, isTarget, type, value,
+ linkage, attrs);
restoreInsertionPoint(insertPt);
return glob;
}
@@ -291,7 +293,7 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(mlir::Location loc,
fir::GlobalOp fir::FirOpBuilder::createGlobal(
mlir::Location loc, mlir::Type type, llvm::StringRef name, bool isConst,
bool isTarget, std::function<void(FirOpBuilder &)> bodyBuilder,
- mlir::StringAttr linkage) {
+ mlir::StringAttr linkage, fir::CUDAAttributeAttr cudaAttr) {
auto module = getModule();
auto insertPt = saveInsertionPoint();
if (auto glob = module.lookupSymbol<fir::GlobalOp>(name))
diff --git a/flang/test/Lower/CUDA/cuda-data-attribute.cuf b/flang/test/Lower/CUDA/cuda-data-attribute.cuf
index b02701bf3aea5a..7596c6b21efb0d 100644
--- a/flang/test/Lower/CUDA/cuda-data-attribute.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-attribute.cuf
@@ -3,6 +3,18 @@
! Test lowering of CUDA attribute on variables.
+module cuda_var
+ real, constant :: mod_a_rc
+! CHECK: fir.global @_QMcuda_varEmod_a_rc {cuda_attr = #fir.cuda<constant>} : f32
+ real, device :: mod_b_ra
+! CHECK: fir.global @_QMcuda_varEmod_b_ra {cuda_attr = #fir.cuda<device>} : f32
+ real, allocatable, managed :: mod_c_rm
+! CHECK: fir.global @_QMcuda_varEmod_c_rm {cuda_attr = #fir.cuda<managed>} : !fir.box<!fir.heap<f32>>
+ real, allocatable, pinned :: mod_d_rp
+! CHECK: fir.global @_QMcuda_varEmod_d_rp {cuda_attr = #fir.cuda<pinned>} : !fir.box<!fir.heap<f32>>
+
+contains
+
subroutine local_var_attrs
real, constant :: rc
real, device :: rd
@@ -10,46 +22,43 @@ subroutine local_var_attrs
real, allocatable, pinned :: rp
end subroutine
-! CHECK-LABEL: func.func @_QPlocal_var_attrs()
-! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFlocal_var_attrsErc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
-! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFlocal_var_attrsErd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
-! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
-! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+! CHECK-LABEL: func.func @_QMcuda_varPlocal_var_attrs()
+! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QMcuda_varFlocal_var_attrsErc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QMcuda_varFlocal_var_attrsErd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMcuda_varFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMcuda_varFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
-! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFlocal_var_attrsErc"} : (!fir.ref<f32>) -> !fir.ref<f32>
-! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFlocal_var_attrsErd"} : (!fir.ref<f32>) -> !fir.ref<f32>
-! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
-! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
+! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QMcuda_varFlocal_var_attrsErc"} : (!fir.ref<f32>) -> !fir.ref<f32>
+! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QMcuda_varFlocal_var_attrsErd"} : (!fir.ref<f32>) -> !fir.ref<f32>
+! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMcuda_varFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
+! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMcuda_varFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
subroutine dummy_arg_constant(dc)
real, constant :: dc
end subroutine
-! CHECK-LABEL: func.func @_QPdummy_arg_constant(
+! CHECK-LABEL: func.func @_QMcuda_varPdummy_arg_constant(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<f32> {fir.bindc_name = "dc", fir.cuda_attr = #fir.cuda<constant>}
-! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFdummy_arg_constantEdc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<constant>, uniq_name = "_QMcuda_varFdummy_arg_constantEdc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
subroutine dummy_arg_device(dd)
real, device :: dd
end subroutine
-! CHECK-LABEL: func.func @_QPdummy_arg_device(
+! CHECK-LABEL: func.func @_QMcuda_varPdummy_arg_device(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<f32> {fir.bindc_name = "dd", fir.cuda_attr = #fir.cuda<device>}) {
-! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<device>, uniq_name = "_QFdummy_arg_deviceEdd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<device>, uniq_name = "_QMcuda_varFdummy_arg_deviceEdd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
subroutine dummy_arg_managed(dm)
real, allocatable, managed :: dm
end subroutine
-! CHECK-LABEL: func.func @_QPdummy_arg_managed(
+! CHECK-LABEL: func.func @_QMcuda_varPdummy_arg_managed(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<f32>>> {fir.bindc_name = "dm", fir.cuda_attr = #fir.cuda<managed>}) {
-! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFdummy_arg_managedEdm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMcuda_varFdummy_arg_managedEdm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
subroutine dummy_arg_pinned(dp)
real, allocatable, pinned :: dp
end subroutine
-! CHECK-LABEL: func.func @_QPdummy_arg_pinned(
+! CHECK-LABEL: func.func @_QMcuda_varPdummy_arg_pinned(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<f32>>> {fir.bindc_name = "dp", fir.cuda_attr = #fir.cuda<pinned>}) {
-! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFdummy_arg_pinnedEdp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
-
-
-
-
+! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMcuda_varFdummy_arg_pinnedEdp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+end module
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, Valentin!
Thanks for the pointer to the other PR. I was looking to avoid this constexpr in the extra declaration for a long time and it was indeed pretty easy :-) |
Propagate the CUDA attribute to fir.global operation for simple module variables.