diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp index 7ab3c43016bd9..44f534e7d569a 100644 --- a/flang/lib/Lower/ConvertVariable.cpp +++ b/flang/lib/Lower/ConvertVariable.cpp @@ -702,6 +702,29 @@ static void instantiateGlobal(Fortran::lower::AbstractConverter &converter, mapSymbolAttributes(converter, var, symMap, stmtCtx, cast); } +bool needCUDAAlloc(const Fortran::semantics::Symbol &sym) { + if (Fortran::semantics::IsDummy(sym)) + return false; + if (const auto *details{ + sym.GetUltimate() + .detailsIf()}) { + if (details->cudaDataAttr() && + (*details->cudaDataAttr() == Fortran::common::CUDADataAttr::Device || + *details->cudaDataAttr() == Fortran::common::CUDADataAttr::Managed || + *details->cudaDataAttr() == Fortran::common::CUDADataAttr::Unified || + *details->cudaDataAttr() == Fortran::common::CUDADataAttr::Shared || + *details->cudaDataAttr() == Fortran::common::CUDADataAttr::Pinned)) + return true; + const Fortran::semantics::DeclTypeSpec *type{details->type()}; + const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived() + : nullptr}; + if (derived) + if (FindCUDADeviceAllocatableUltimateComponent(*derived)) + return true; + } + return false; +} + //===----------------------------------------------------------------===// // Local variables instantiation (not for alias) //===----------------------------------------------------------------===// @@ -732,7 +755,7 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter, if (ultimateSymbol.test(Fortran::semantics::Symbol::Flag::CrayPointee)) return builder.create(loc, fir::ReferenceType::get(ty)); - if (Fortran::semantics::NeedCUDAAlloc(ultimateSymbol)) { + if (needCUDAAlloc(ultimateSymbol)) { cuf::DataAttributeAttr dataAttr = Fortran::lower::translateSymbolCUFDataAttribute(builder.getContext(), ultimateSymbol); @@ -1087,7 +1110,7 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter, Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(), symMap); auto *builder = &converter.getFirOpBuilder(); - if (Fortran::semantics::NeedCUDAAlloc(var.getSymbol()) && + if (needCUDAAlloc(var.getSymbol()) && !cuf::isCUDADeviceContext(builder->getRegion())) { cuf::DataAttributeAttr dataAttr = Fortran::lower::translateSymbolCUFDataAttribute(builder->getContext(), diff --git a/flang/lib/Semantics/tools.cpp b/flang/lib/Semantics/tools.cpp index 498bbc18709ab..ea37316e78273 100644 --- a/flang/lib/Semantics/tools.cpp +++ b/flang/lib/Semantics/tools.cpp @@ -1095,9 +1095,20 @@ bool IsDeviceAllocatable(const Symbol &symbol) { } std::optional GetCUDADataAttr(const Symbol *symbol) { - const auto *object{ + const auto *details{ symbol ? symbol->detailsIf() : nullptr}; - return object ? object->cudaDataAttr() : std::nullopt; + if (details) { + const Fortran::semantics::DeclTypeSpec *type{details->type()}; + const Fortran::semantics::DerivedTypeSpec *derived{ + type ? type->AsDerived() : nullptr}; + if (derived) { + if (FindCUDADeviceAllocatableUltimateComponent(*derived)) { + return common::CUDADataAttr::Managed; + } + } + return details->cudaDataAttr(); + } + return std::nullopt; } bool IsAccessible(const Symbol &original, const Scope &scope) { diff --git a/flang/test/Lower/CUDA/cuda-derived.cuf b/flang/test/Lower/CUDA/cuda-derived.cuf index d280ac722d08f..96250d88d81c4 100644 --- a/flang/test/Lower/CUDA/cuda-derived.cuf +++ b/flang/test/Lower/CUDA/cuda-derived.cuf @@ -7,6 +7,16 @@ module m1 type t1; real, device, allocatable :: a(:); end type type t2; type(t1) :: b; end type +contains + subroutine sub1() + type(ty_device) :: a + end subroutine + +! CHECK-LABEL: func.func @_QMm1Psub1() +! CHECK: %[[ALLOC:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box>>}> {bindc_name = "a", data_attr = #cuf.cuda, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref>>}>> +! CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[ALLOC]] {data_attr = #cuf.cuda, uniq_name = "_QMm1Fsub1Ea"} : (!fir.ref>>}>>) -> (!fir.ref>>}>>, !fir.ref>>}>>) +! CHECK: cuf.free %[[DECL]]#0 : !fir.ref>>}>> {data_attr = #cuf.cuda} + end module program main @@ -16,5 +26,5 @@ program main end ! CHECK-LABEL: func.func @_QQmain() attributes {fir.bindc_name = "main"} -! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tty_device{x:!fir.box>>}> {bindc_name = "a", uniq_name = "_QFEa"} -! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tt2{b:!fir.type<_QMm1Tt1{a:!fir.box>>}>}> {bindc_name = "b", uniq_name = "_QFEb"} +! CHECK: %{{.*}} = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box>>}> {bindc_name = "a", data_attr = #cuf.cuda, uniq_name = "_QFEa"} +! CHECK: %{{.*}} = cuf.alloc !fir.type<_QMm1Tt2{b:!fir.type<_QMm1Tt1{a:!fir.box>>}>}> {bindc_name = "b", data_attr = #cuf.cuda, uniq_name = "_QFEb"}