Skip to content

Conversation

clementval
Copy link
Contributor

When there is a host_data region with use_device clause, the symbol inside the region should be treated like if they have the device attribute when CUDA Fortran is enable. This allow generic resolution to pick up dedicated functions and subroutine written for device data.
This patch update OpenACC name resolution to treat use_device clause before generic resolution is done. It creates new symbol in the host_data scope with the device attribute set.
In lowering, the newly create symbols are binded to the outer scope symbol since the lowering is expected to be the same.

@clementval clementval requested a review from wangzpgi October 2, 2025 00:49
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp openacc flang:semantics labels Oct 2, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2025

@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-flang-semantics

@llvm/pr-subscribers-flang-openmp

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

When there is a host_data region with use_device clause, the symbol inside the region should be treated like if they have the device attribute when CUDA Fortran is enable. This allow generic resolution to pick up dedicated functions and subroutine written for device data.
This patch update OpenACC name resolution to treat use_device clause before generic resolution is done. It creates new symbol in the host_data scope with the device attribute set.
In lowering, the newly create symbols are binded to the outer scope symbol since the lowering is expected to be the same.


Full diff: https://github.com/llvm/llvm-project/pull/161613.diff

10 Files Affected:

  • (modified) flang/include/flang/Lower/OpenACC.h (+2-1)
  • (modified) flang/include/flang/Lower/SymbolMap.h (+4)
  • (modified) flang/include/flang/Semantics/symbol.h (+1-1)
  • (modified) flang/lib/Lower/Bridge.cpp (+1-1)
  • (modified) flang/lib/Lower/OpenACC.cpp (+24-9)
  • (modified) flang/lib/Lower/SymbolMap.cpp (+10)
  • (modified) flang/lib/Semantics/check-declarations.cpp (+2-1)
  • (modified) flang/lib/Semantics/resolve-directives.cpp (+5)
  • (modified) flang/lib/Semantics/resolve-names.cpp (+64-1)
  • (added) flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 (+43)
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 19d759479abaf..4622dbc8ccf64 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -77,7 +77,8 @@ static constexpr llvm::StringRef privatizationRecipePrefix = "privatization";
 mlir::Value genOpenACCConstruct(AbstractConverter &,
                                 Fortran::semantics::SemanticsContext &,
                                 pft::Evaluation &,
-                                const parser::OpenACCConstruct &);
+                                const parser::OpenACCConstruct &,
+                                Fortran::lower::SymMap &localSymbols);
 void genOpenACCDeclarativeConstruct(
     AbstractConverter &, Fortran::semantics::SemanticsContext &,
     StatementContext &, const parser::OpenACCDeclarativeConstruct &);
diff --git a/flang/include/flang/Lower/SymbolMap.h b/flang/include/flang/Lower/SymbolMap.h
index 813df777d7a39..e57b6a42d3cc1 100644
--- a/flang/include/flang/Lower/SymbolMap.h
+++ b/flang/include/flang/Lower/SymbolMap.h
@@ -260,6 +260,10 @@ class SymMap {
     return lookupSymbol(*sym);
   }
 
+  /// Find a symbol by name and return its value if it appears in the current
+  /// mappings. This lookup is more expensive as it iterates over the map.
+  const semantics::Symbol *lookupSymbolByName(llvm::StringRef symName);
+
   /// Find `symbol` and return its value if it appears in the inner-most level
   /// map.
   SymbolBox shallowLookupSymbol(semantics::SymbolRef sym);
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index e90e9c617805d..a0d5ae7176141 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -801,7 +801,7 @@ class Symbol {
       AccPrivate, AccFirstPrivate, AccShared,
       // OpenACC data-mapping attribute
       AccCopy, AccCopyIn, AccCopyInReadOnly, AccCopyOut, AccCreate, AccDelete,
-      AccPresent, AccLink, AccDeviceResident, AccDevicePtr,
+      AccPresent, AccLink, AccDeviceResident, AccDevicePtr, AccUseDevice,
       // OpenACC declare
       AccDeclare,
       // OpenACC data-movement attribute
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 149e51b501a82..780d56f085f69 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3182,7 +3182,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
     localSymbols.pushScope();
     mlir::Value exitCond = genOpenACCConstruct(
-        *this, bridge.getSemanticsContext(), getEval(), acc);
+        *this, bridge.getSemanticsContext(), getEval(), acc, localSymbols);
 
     const Fortran::parser::OpenACCLoopConstruct *accLoop =
         std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 95d0adae02670..66f45451b3eeb 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
                  Fortran::lower::pft::Evaluation &eval,
                  Fortran::semantics::SemanticsContext &semanticsContext,
                  Fortran::lower::StatementContext &stmtCtx,
-                 const Fortran::parser::AccClauseList &accClauseList) {
+                 const Fortran::parser::AccClauseList &accClauseList,
+                 Fortran::lower::SymMap &localSymbols) {
   mlir::Value ifCond;
   llvm::SmallVector<mlir::Value> dataOperands;
   bool addIfPresentAttr = false;
@@ -3199,6 +3200,18 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *useDevice =
                    std::get_if<Fortran::parser::AccClause::UseDevice>(
                        &clause.u)) {
+      // When CUDA Fotran is en
+      if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
+
+        const Fortran::parser::AccObjectList &objectList{useDevice->v};
+        for (const auto &accObject : objectList.v) {
+          Fortran::semantics::Symbol &symbol =
+              getSymbolFromAccObject(accObject);
+          const Fortran::semantics::Symbol *baseSym =
+              localSymbols.lookupSymbolByName(symbol.name().ToString());
+          localSymbols.copySymbolBinding(*baseSym, symbol);
+        }
+      }
       genDataOperandOperations<mlir::acc::UseDeviceOp>(
           useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
           mlir::acc::DataClause::acc_use_device,
@@ -3239,11 +3252,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
     hostDataOp.setIfPresentAttr(builder.getUnitAttr());
 }
 
-static void
-genACC(Fortran::lower::AbstractConverter &converter,
-       Fortran::semantics::SemanticsContext &semanticsContext,
-       Fortran::lower::pft::Evaluation &eval,
-       const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
+static void genACC(Fortran::lower::AbstractConverter &converter,
+                   Fortran::semantics::SemanticsContext &semanticsContext,
+                   Fortran::lower::pft::Evaluation &eval,
+                   const Fortran::parser::OpenACCBlockConstruct &blockConstruct,
+                   Fortran::lower::SymMap &localSymbols) {
   const auto &beginBlockDirective =
       std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
   const auto &blockDirective =
@@ -3273,7 +3286,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
                                           accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_host_data) {
     genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
-                     stmtCtx, accClauseList);
+                     stmtCtx, accClauseList, localSymbols);
   }
 }
 
@@ -4647,13 +4660,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct(
     Fortran::lower::AbstractConverter &converter,
     Fortran::semantics::SemanticsContext &semanticsContext,
     Fortran::lower::pft::Evaluation &eval,
-    const Fortran::parser::OpenACCConstruct &accConstruct) {
+    const Fortran::parser::OpenACCConstruct &accConstruct,
+    Fortran::lower::SymMap &localSymbols) {
 
   mlir::Value exitCond;
   Fortran::common::visit(
       common::visitors{
           [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
-            genACC(converter, semanticsContext, eval, blockConstruct);
+            genACC(converter, semanticsContext, eval, blockConstruct,
+                   localSymbols);
           },
           [&](const Fortran::parser::OpenACCCombinedConstruct
                   &combinedConstruct) {
diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp
index 080f21ec67400..78529e0d539fb 100644
--- a/flang/lib/Lower/SymbolMap.cpp
+++ b/flang/lib/Lower/SymbolMap.cpp
@@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
   return SymbolBox::None{};
 }
 
+const Fortran::semantics::Symbol *
+Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) {
+  for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend();
+       jmap != jend; ++jmap)
+    for (auto const &[sym, symBox] : *jmap)
+      if (sym->name().ToString() == symName)
+        return sym;
+  return nullptr;
+}
+
 Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
     Fortran::semantics::SymbolRef symRef) {
   auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index 1049a6d2c1b2e..7b881008219df 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -1189,7 +1189,8 @@ void CheckHelper::CheckObjectEntity(
       }
     } else if (!subpDetails && symbol.owner().kind() != Scope::Kind::Module &&
         symbol.owner().kind() != Scope::Kind::MainProgram &&
-        symbol.owner().kind() != Scope::Kind::BlockConstruct) {
+        symbol.owner().kind() != Scope::Kind::BlockConstruct &&
+        symbol.owner().kind() != Scope::Kind::OpenACCConstruct) {
       messages_.Say(
           "ATTRIBUTES(%s) may apply only to module, host subprogram, block, or device subprogram data"_err_en_US,
           parser::ToUpperCaseLetters(common::EnumToString(attr)));
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index bd7b8ac552fab..366ebba1b7263 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -328,6 +328,11 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
     return false;
   }
 
+  bool Pre(const parser::AccClause::UseDevice &x) {
+    ResolveAccObjectList(x.v, Symbol::Flag::AccUseDevice);
+    return false;
+  }
+
   void Post(const parser::Name &);
 
 private:
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index d1150a9eb67f4..7e55eb005dbb2 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1387,6 +1387,8 @@ class ConstructVisitor : public virtual DeclarationVisitor {
 // Create scopes for OpenACC constructs
 class AccVisitor : public virtual DeclarationVisitor {
 public:
+  explicit AccVisitor(SemanticsContext &context) : context_{context} {}
+
   void AddAccSourceRange(const parser::CharBlock &);
 
   static bool NeedsScope(const parser::OpenACCBlockConstruct &);
@@ -1395,6 +1397,7 @@ class AccVisitor : public virtual DeclarationVisitor {
   void Post(const parser::OpenACCBlockConstruct &);
   bool Pre(const parser::OpenACCCombinedConstruct &);
   void Post(const parser::OpenACCCombinedConstruct &);
+  bool Pre(const parser::AccClause::UseDevice &x);
   bool Pre(const parser::AccBeginBlockDirective &x) {
     AddAccSourceRange(x.source);
     return true;
@@ -1430,6 +1433,11 @@ class AccVisitor : public virtual DeclarationVisitor {
   void Post(const parser::AccBeginLoopDirective &x) {
     messageHandler().set_currStmtSource(std::nullopt);
   }
+
+  void CopySymbolWithDevice(const parser::Name *name);
+
+private:
+  SemanticsContext &context_;
 };
 
 bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) {
@@ -1459,6 +1467,60 @@ bool AccVisitor::Pre(const parser::OpenACCBlockConstruct &x) {
   return true;
 }
 
+void AccVisitor::CopySymbolWithDevice(const parser::Name *name) {
+  // When CUDA Fortran is enabled together with OpenACC, new
+  // symbols are created for the one appearing in the use_device
+  // clause. These new symbols have the CUDA Fortran device
+  // attribute.
+  if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) {
+    name->symbol = currScope().CopySymbol(*name->symbol);
+    if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) {
+      object->set_cudaDataAttr(common::CUDADataAttr::Device);
+    }
+  }
+}
+
+bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
+  for (const auto &accObject : x.v.v) {
+    common::visit(
+        common::visitors{
+            [&](const parser::Designator &designator) {
+              if (const auto *name{
+                      semantics::getDesignatorNameIfDataRef(designator)}) {
+                Symbol *prev{currScope().FindSymbol(name->source)};
+                if (prev != name->symbol) {
+                  name->symbol = prev;
+                }
+                CopySymbolWithDevice(name);
+              } else {
+                if (const auto *dataRef{
+                        std::get_if<parser::DataRef>(&designator.u)}) {
+                  using ElementIndirection =
+                      common::Indirection<parser::ArrayElement>;
+                  if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
+                    const parser::ArrayElement &arrayElement{ind->value()};
+                    Walk(arrayElement.subscripts);
+                    const parser::DataRef &base{arrayElement.base};
+                    if (auto *name{std::get_if<parser::Name>(&base.u)}) {
+                      Symbol *prev{currScope().FindSymbol(name->source)};
+                      if (prev != name->symbol) {
+                        name->symbol = prev;
+                      }
+                      CopySymbolWithDevice(name);
+                    }
+                  }
+                }
+              }
+            },
+            [&](const parser::Name &name) {
+              // TODO: common block in use_device?
+            },
+        },
+        accObject.u);
+  }
+  return false;
+}
+
 void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) {
   if (NeedsScope(x)) {
     PopScope();
@@ -2038,7 +2100,8 @@ class ResolveNamesVisitor : public virtual ScopeHandler,
 
   ResolveNamesVisitor(
       SemanticsContext &context, ImplicitRulesMap &rules, Scope &top)
-      : BaseVisitor{context, *this, rules}, topScope_{top} {
+      : BaseVisitor{context, *this, rules}, topScope_{top},
+        AccVisitor(context) {
     PushScope(top);
   }
 
diff --git a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
new file mode 100644
index 0000000000000..da034adec39c4
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
@@ -0,0 +1,43 @@
+
+! RUN: bbc -fopenacc -fcuda -emit-hlfir %s -o - | FileCheck %s
+
+module m
+
+interface doit
+subroutine __device_sub(a)
+    real(4), device, intent(in) :: a(:,:,:)
+    !dir$ ignore_tkr(c) a
+end
+subroutine __host_sub(a)
+    real(4), intent(in) :: a(:,:,:)
+    !dir$ ignore_tkr(c) a
+end
+end interface
+end module
+
+program testex1
+integer, parameter :: ntimes = 10
+integer, parameter :: ni=128
+integer, parameter :: nj=256
+integer, parameter :: nk=64
+real(4), dimension(ni,nj,nk) :: a
+
+!$acc enter data copyin(a)
+
+block; use m
+!$acc host_data use_device(a)
+do nt = 1, ntimes
+  call doit(a)
+end do
+!$acc end host_data
+end block
+
+block; use m
+do nt = 1, ntimes
+  call doit(a)
+end do
+end block
+end
+
+! CHECK: fir.call @_QP__device_sub
+! CHECK: fir.call @_QP__host_sub

@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

When there is a host_data region with use_device clause, the symbol inside the region should be treated like if they have the device attribute when CUDA Fortran is enable. This allow generic resolution to pick up dedicated functions and subroutine written for device data.
This patch update OpenACC name resolution to treat use_device clause before generic resolution is done. It creates new symbol in the host_data scope with the device attribute set.
In lowering, the newly create symbols are binded to the outer scope symbol since the lowering is expected to be the same.


Full diff: https://github.com/llvm/llvm-project/pull/161613.diff

10 Files Affected:

  • (modified) flang/include/flang/Lower/OpenACC.h (+2-1)
  • (modified) flang/include/flang/Lower/SymbolMap.h (+4)
  • (modified) flang/include/flang/Semantics/symbol.h (+1-1)
  • (modified) flang/lib/Lower/Bridge.cpp (+1-1)
  • (modified) flang/lib/Lower/OpenACC.cpp (+24-9)
  • (modified) flang/lib/Lower/SymbolMap.cpp (+10)
  • (modified) flang/lib/Semantics/check-declarations.cpp (+2-1)
  • (modified) flang/lib/Semantics/resolve-directives.cpp (+5)
  • (modified) flang/lib/Semantics/resolve-names.cpp (+64-1)
  • (added) flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 (+43)
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 19d759479abaf..4622dbc8ccf64 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -77,7 +77,8 @@ static constexpr llvm::StringRef privatizationRecipePrefix = "privatization";
 mlir::Value genOpenACCConstruct(AbstractConverter &,
                                 Fortran::semantics::SemanticsContext &,
                                 pft::Evaluation &,
-                                const parser::OpenACCConstruct &);
+                                const parser::OpenACCConstruct &,
+                                Fortran::lower::SymMap &localSymbols);
 void genOpenACCDeclarativeConstruct(
     AbstractConverter &, Fortran::semantics::SemanticsContext &,
     StatementContext &, const parser::OpenACCDeclarativeConstruct &);
diff --git a/flang/include/flang/Lower/SymbolMap.h b/flang/include/flang/Lower/SymbolMap.h
index 813df777d7a39..e57b6a42d3cc1 100644
--- a/flang/include/flang/Lower/SymbolMap.h
+++ b/flang/include/flang/Lower/SymbolMap.h
@@ -260,6 +260,10 @@ class SymMap {
     return lookupSymbol(*sym);
   }
 
+  /// Find a symbol by name and return its value if it appears in the current
+  /// mappings. This lookup is more expensive as it iterates over the map.
+  const semantics::Symbol *lookupSymbolByName(llvm::StringRef symName);
+
   /// Find `symbol` and return its value if it appears in the inner-most level
   /// map.
   SymbolBox shallowLookupSymbol(semantics::SymbolRef sym);
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index e90e9c617805d..a0d5ae7176141 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -801,7 +801,7 @@ class Symbol {
       AccPrivate, AccFirstPrivate, AccShared,
       // OpenACC data-mapping attribute
       AccCopy, AccCopyIn, AccCopyInReadOnly, AccCopyOut, AccCreate, AccDelete,
-      AccPresent, AccLink, AccDeviceResident, AccDevicePtr,
+      AccPresent, AccLink, AccDeviceResident, AccDevicePtr, AccUseDevice,
       // OpenACC declare
       AccDeclare,
       // OpenACC data-movement attribute
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 149e51b501a82..780d56f085f69 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3182,7 +3182,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
     localSymbols.pushScope();
     mlir::Value exitCond = genOpenACCConstruct(
-        *this, bridge.getSemanticsContext(), getEval(), acc);
+        *this, bridge.getSemanticsContext(), getEval(), acc, localSymbols);
 
     const Fortran::parser::OpenACCLoopConstruct *accLoop =
         std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 95d0adae02670..66f45451b3eeb 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
                  Fortran::lower::pft::Evaluation &eval,
                  Fortran::semantics::SemanticsContext &semanticsContext,
                  Fortran::lower::StatementContext &stmtCtx,
-                 const Fortran::parser::AccClauseList &accClauseList) {
+                 const Fortran::parser::AccClauseList &accClauseList,
+                 Fortran::lower::SymMap &localSymbols) {
   mlir::Value ifCond;
   llvm::SmallVector<mlir::Value> dataOperands;
   bool addIfPresentAttr = false;
@@ -3199,6 +3200,18 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *useDevice =
                    std::get_if<Fortran::parser::AccClause::UseDevice>(
                        &clause.u)) {
+      // When CUDA Fotran is en
+      if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
+
+        const Fortran::parser::AccObjectList &objectList{useDevice->v};
+        for (const auto &accObject : objectList.v) {
+          Fortran::semantics::Symbol &symbol =
+              getSymbolFromAccObject(accObject);
+          const Fortran::semantics::Symbol *baseSym =
+              localSymbols.lookupSymbolByName(symbol.name().ToString());
+          localSymbols.copySymbolBinding(*baseSym, symbol);
+        }
+      }
       genDataOperandOperations<mlir::acc::UseDeviceOp>(
           useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
           mlir::acc::DataClause::acc_use_device,
@@ -3239,11 +3252,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
     hostDataOp.setIfPresentAttr(builder.getUnitAttr());
 }
 
-static void
-genACC(Fortran::lower::AbstractConverter &converter,
-       Fortran::semantics::SemanticsContext &semanticsContext,
-       Fortran::lower::pft::Evaluation &eval,
-       const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
+static void genACC(Fortran::lower::AbstractConverter &converter,
+                   Fortran::semantics::SemanticsContext &semanticsContext,
+                   Fortran::lower::pft::Evaluation &eval,
+                   const Fortran::parser::OpenACCBlockConstruct &blockConstruct,
+                   Fortran::lower::SymMap &localSymbols) {
   const auto &beginBlockDirective =
       std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
   const auto &blockDirective =
@@ -3273,7 +3286,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
                                           accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_host_data) {
     genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
-                     stmtCtx, accClauseList);
+                     stmtCtx, accClauseList, localSymbols);
   }
 }
 
@@ -4647,13 +4660,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct(
     Fortran::lower::AbstractConverter &converter,
     Fortran::semantics::SemanticsContext &semanticsContext,
     Fortran::lower::pft::Evaluation &eval,
-    const Fortran::parser::OpenACCConstruct &accConstruct) {
+    const Fortran::parser::OpenACCConstruct &accConstruct,
+    Fortran::lower::SymMap &localSymbols) {
 
   mlir::Value exitCond;
   Fortran::common::visit(
       common::visitors{
           [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
-            genACC(converter, semanticsContext, eval, blockConstruct);
+            genACC(converter, semanticsContext, eval, blockConstruct,
+                   localSymbols);
           },
           [&](const Fortran::parser::OpenACCCombinedConstruct
                   &combinedConstruct) {
diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp
index 080f21ec67400..78529e0d539fb 100644
--- a/flang/lib/Lower/SymbolMap.cpp
+++ b/flang/lib/Lower/SymbolMap.cpp
@@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
   return SymbolBox::None{};
 }
 
+const Fortran::semantics::Symbol *
+Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) {
+  for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend();
+       jmap != jend; ++jmap)
+    for (auto const &[sym, symBox] : *jmap)
+      if (sym->name().ToString() == symName)
+        return sym;
+  return nullptr;
+}
+
 Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
     Fortran::semantics::SymbolRef symRef) {
   auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index 1049a6d2c1b2e..7b881008219df 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -1189,7 +1189,8 @@ void CheckHelper::CheckObjectEntity(
       }
     } else if (!subpDetails && symbol.owner().kind() != Scope::Kind::Module &&
         symbol.owner().kind() != Scope::Kind::MainProgram &&
-        symbol.owner().kind() != Scope::Kind::BlockConstruct) {
+        symbol.owner().kind() != Scope::Kind::BlockConstruct &&
+        symbol.owner().kind() != Scope::Kind::OpenACCConstruct) {
       messages_.Say(
           "ATTRIBUTES(%s) may apply only to module, host subprogram, block, or device subprogram data"_err_en_US,
           parser::ToUpperCaseLetters(common::EnumToString(attr)));
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index bd7b8ac552fab..366ebba1b7263 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -328,6 +328,11 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
     return false;
   }
 
+  bool Pre(const parser::AccClause::UseDevice &x) {
+    ResolveAccObjectList(x.v, Symbol::Flag::AccUseDevice);
+    return false;
+  }
+
   void Post(const parser::Name &);
 
 private:
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index d1150a9eb67f4..7e55eb005dbb2 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1387,6 +1387,8 @@ class ConstructVisitor : public virtual DeclarationVisitor {
 // Create scopes for OpenACC constructs
 class AccVisitor : public virtual DeclarationVisitor {
 public:
+  explicit AccVisitor(SemanticsContext &context) : context_{context} {}
+
   void AddAccSourceRange(const parser::CharBlock &);
 
   static bool NeedsScope(const parser::OpenACCBlockConstruct &);
@@ -1395,6 +1397,7 @@ class AccVisitor : public virtual DeclarationVisitor {
   void Post(const parser::OpenACCBlockConstruct &);
   bool Pre(const parser::OpenACCCombinedConstruct &);
   void Post(const parser::OpenACCCombinedConstruct &);
+  bool Pre(const parser::AccClause::UseDevice &x);
   bool Pre(const parser::AccBeginBlockDirective &x) {
     AddAccSourceRange(x.source);
     return true;
@@ -1430,6 +1433,11 @@ class AccVisitor : public virtual DeclarationVisitor {
   void Post(const parser::AccBeginLoopDirective &x) {
     messageHandler().set_currStmtSource(std::nullopt);
   }
+
+  void CopySymbolWithDevice(const parser::Name *name);
+
+private:
+  SemanticsContext &context_;
 };
 
 bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) {
@@ -1459,6 +1467,60 @@ bool AccVisitor::Pre(const parser::OpenACCBlockConstruct &x) {
   return true;
 }
 
+void AccVisitor::CopySymbolWithDevice(const parser::Name *name) {
+  // When CUDA Fortran is enabled together with OpenACC, new
+  // symbols are created for the one appearing in the use_device
+  // clause. These new symbols have the CUDA Fortran device
+  // attribute.
+  if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) {
+    name->symbol = currScope().CopySymbol(*name->symbol);
+    if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) {
+      object->set_cudaDataAttr(common::CUDADataAttr::Device);
+    }
+  }
+}
+
+bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
+  for (const auto &accObject : x.v.v) {
+    common::visit(
+        common::visitors{
+            [&](const parser::Designator &designator) {
+              if (const auto *name{
+                      semantics::getDesignatorNameIfDataRef(designator)}) {
+                Symbol *prev{currScope().FindSymbol(name->source)};
+                if (prev != name->symbol) {
+                  name->symbol = prev;
+                }
+                CopySymbolWithDevice(name);
+              } else {
+                if (const auto *dataRef{
+                        std::get_if<parser::DataRef>(&designator.u)}) {
+                  using ElementIndirection =
+                      common::Indirection<parser::ArrayElement>;
+                  if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
+                    const parser::ArrayElement &arrayElement{ind->value()};
+                    Walk(arrayElement.subscripts);
+                    const parser::DataRef &base{arrayElement.base};
+                    if (auto *name{std::get_if<parser::Name>(&base.u)}) {
+                      Symbol *prev{currScope().FindSymbol(name->source)};
+                      if (prev != name->symbol) {
+                        name->symbol = prev;
+                      }
+                      CopySymbolWithDevice(name);
+                    }
+                  }
+                }
+              }
+            },
+            [&](const parser::Name &name) {
+              // TODO: common block in use_device?
+            },
+        },
+        accObject.u);
+  }
+  return false;
+}
+
 void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) {
   if (NeedsScope(x)) {
     PopScope();
@@ -2038,7 +2100,8 @@ class ResolveNamesVisitor : public virtual ScopeHandler,
 
   ResolveNamesVisitor(
       SemanticsContext &context, ImplicitRulesMap &rules, Scope &top)
-      : BaseVisitor{context, *this, rules}, topScope_{top} {
+      : BaseVisitor{context, *this, rules}, topScope_{top},
+        AccVisitor(context) {
     PushScope(top);
   }
 
diff --git a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
new file mode 100644
index 0000000000000..da034adec39c4
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
@@ -0,0 +1,43 @@
+
+! RUN: bbc -fopenacc -fcuda -emit-hlfir %s -o - | FileCheck %s
+
+module m
+
+interface doit
+subroutine __device_sub(a)
+    real(4), device, intent(in) :: a(:,:,:)
+    !dir$ ignore_tkr(c) a
+end
+subroutine __host_sub(a)
+    real(4), intent(in) :: a(:,:,:)
+    !dir$ ignore_tkr(c) a
+end
+end interface
+end module
+
+program testex1
+integer, parameter :: ntimes = 10
+integer, parameter :: ni=128
+integer, parameter :: nj=256
+integer, parameter :: nk=64
+real(4), dimension(ni,nj,nk) :: a
+
+!$acc enter data copyin(a)
+
+block; use m
+!$acc host_data use_device(a)
+do nt = 1, ntimes
+  call doit(a)
+end do
+!$acc end host_data
+end block
+
+block; use m
+do nt = 1, ntimes
+  call doit(a)
+end do
+end block
+end
+
+! CHECK: fir.call @_QP__device_sub
+! CHECK: fir.call @_QP__host_sub

Copy link

github-actions bot commented Oct 2, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@clementval clementval merged commit c242aff into llvm:main Oct 2, 2025
9 checks passed
@clementval clementval deleted the resolve2 branch October 2, 2025 16:01
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang:semantics flang Flang issues not falling into any other category openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants