diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h index 19d759479abaf..4622dbc8ccf64 100644 --- a/flang/include/flang/Lower/OpenACC.h +++ b/flang/include/flang/Lower/OpenACC.h @@ -77,7 +77,8 @@ static constexpr llvm::StringRef privatizationRecipePrefix = "privatization"; mlir::Value genOpenACCConstruct(AbstractConverter &, Fortran::semantics::SemanticsContext &, pft::Evaluation &, - const parser::OpenACCConstruct &); + const parser::OpenACCConstruct &, + Fortran::lower::SymMap &localSymbols); void genOpenACCDeclarativeConstruct( AbstractConverter &, Fortran::semantics::SemanticsContext &, StatementContext &, const parser::OpenACCDeclarativeConstruct &); diff --git a/flang/include/flang/Lower/SymbolMap.h b/flang/include/flang/Lower/SymbolMap.h index 813df777d7a39..e57b6a42d3cc1 100644 --- a/flang/include/flang/Lower/SymbolMap.h +++ b/flang/include/flang/Lower/SymbolMap.h @@ -260,6 +260,10 @@ class SymMap { return lookupSymbol(*sym); } + /// Find a symbol by name and return its value if it appears in the current + /// mappings. This lookup is more expensive as it iterates over the map. + const semantics::Symbol *lookupSymbolByName(llvm::StringRef symName); + /// Find `symbol` and return its value if it appears in the inner-most level /// map. SymbolBox shallowLookupSymbol(semantics::SymbolRef sym); diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h index e90e9c617805d..a0d5ae7176141 100644 --- a/flang/include/flang/Semantics/symbol.h +++ b/flang/include/flang/Semantics/symbol.h @@ -801,7 +801,7 @@ class Symbol { AccPrivate, AccFirstPrivate, AccShared, // OpenACC data-mapping attribute AccCopy, AccCopyIn, AccCopyInReadOnly, AccCopyOut, AccCreate, AccDelete, - AccPresent, AccLink, AccDeviceResident, AccDevicePtr, + AccPresent, AccLink, AccDeviceResident, AccDevicePtr, AccUseDevice, // OpenACC declare AccDeclare, // OpenACC data-movement attribute diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 149e51b501a82..780d56f085f69 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3182,7 +3182,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); localSymbols.pushScope(); mlir::Value exitCond = genOpenACCConstruct( - *this, bridge.getSemanticsContext(), getEval(), acc); + *this, bridge.getSemanticsContext(), getEval(), acc, localSymbols); const Fortran::parser::OpenACCLoopConstruct *accLoop = std::get_if(&acc.u); diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 95d0adae02670..f9b9b850ad839 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::AccClauseList &accClauseList) { + const Fortran::parser::AccClauseList &accClauseList, + Fortran::lower::SymMap &localSymbols) { mlir::Value ifCond; llvm::SmallVector dataOperands; bool addIfPresentAttr = false; @@ -3199,6 +3200,19 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter, } else if (const auto *useDevice = std::get_if( &clause.u)) { + // When CUDA Fotran is enabled, extra symbols are used in the host_data + // region. Look for them and bind their values with the symbols in the + // outer scope. + if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) { + const Fortran::parser::AccObjectList &objectList{useDevice->v}; + for (const auto &accObject : objectList.v) { + Fortran::semantics::Symbol &symbol = + getSymbolFromAccObject(accObject); + const Fortran::semantics::Symbol *baseSym = + localSymbols.lookupSymbolByName(symbol.name().ToString()); + localSymbols.copySymbolBinding(*baseSym, symbol); + } + } genDataOperandOperations( useDevice->v, converter, semanticsContext, stmtCtx, dataOperands, mlir::acc::DataClause::acc_use_device, @@ -3239,11 +3253,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter, hostDataOp.setIfPresentAttr(builder.getUnitAttr()); } -static void -genACC(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semanticsContext, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { +static void genACC(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenACCBlockConstruct &blockConstruct, + Fortran::lower::SymMap &localSymbols) { const auto &beginBlockDirective = std::get(blockConstruct.t); const auto &blockDirective = @@ -3273,7 +3287,7 @@ genACC(Fortran::lower::AbstractConverter &converter, accClauseList); } else if (blockDirective.v == llvm::acc::ACCD_host_data) { genACCHostDataOp(converter, currentLocation, eval, semanticsContext, - stmtCtx, accClauseList); + stmtCtx, accClauseList, localSymbols); } } @@ -4647,13 +4661,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenACCConstruct &accConstruct) { + const Fortran::parser::OpenACCConstruct &accConstruct, + Fortran::lower::SymMap &localSymbols) { mlir::Value exitCond; Fortran::common::visit( common::visitors{ [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { - genACC(converter, semanticsContext, eval, blockConstruct); + genACC(converter, semanticsContext, eval, blockConstruct, + localSymbols); }, [&](const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) { diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp index 080f21ec67400..78529e0d539fb 100644 --- a/flang/lib/Lower/SymbolMap.cpp +++ b/flang/lib/Lower/SymbolMap.cpp @@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) { return SymbolBox::None{}; } +const Fortran::semantics::Symbol * +Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) { + for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend(); + jmap != jend; ++jmap) + for (auto const &[sym, symBox] : *jmap) + if (sym->name().ToString() == symName) + return sym; + return nullptr; +} + Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol( Fortran::semantics::SymbolRef symRef) { auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate(); diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp index 1049a6d2c1b2e..7b881008219df 100644 --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -1189,7 +1189,8 @@ void CheckHelper::CheckObjectEntity( } } else if (!subpDetails && symbol.owner().kind() != Scope::Kind::Module && symbol.owner().kind() != Scope::Kind::MainProgram && - symbol.owner().kind() != Scope::Kind::BlockConstruct) { + symbol.owner().kind() != Scope::Kind::BlockConstruct && + symbol.owner().kind() != Scope::Kind::OpenACCConstruct) { messages_.Say( "ATTRIBUTES(%s) may apply only to module, host subprogram, block, or device subprogram data"_err_en_US, parser::ToUpperCaseLetters(common::EnumToString(attr))); diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index bd7b8ac552fab..366ebba1b7263 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -328,6 +328,11 @@ class AccAttributeVisitor : DirectiveAttributeVisitor { return false; } + bool Pre(const parser::AccClause::UseDevice &x) { + ResolveAccObjectList(x.v, Symbol::Flag::AccUseDevice); + return false; + } + void Post(const parser::Name &); private: diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index d1150a9eb67f4..5041a6a08fc3c 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -1387,6 +1387,8 @@ class ConstructVisitor : public virtual DeclarationVisitor { // Create scopes for OpenACC constructs class AccVisitor : public virtual DeclarationVisitor { public: + explicit AccVisitor(SemanticsContext &context) : context_{context} {} + void AddAccSourceRange(const parser::CharBlock &); static bool NeedsScope(const parser::OpenACCBlockConstruct &); @@ -1395,6 +1397,7 @@ class AccVisitor : public virtual DeclarationVisitor { void Post(const parser::OpenACCBlockConstruct &); bool Pre(const parser::OpenACCCombinedConstruct &); void Post(const parser::OpenACCCombinedConstruct &); + bool Pre(const parser::AccClause::UseDevice &x); bool Pre(const parser::AccBeginBlockDirective &x) { AddAccSourceRange(x.source); return true; @@ -1430,6 +1433,11 @@ class AccVisitor : public virtual DeclarationVisitor { void Post(const parser::AccBeginLoopDirective &x) { messageHandler().set_currStmtSource(std::nullopt); } + + void CopySymbolWithDevice(const parser::Name *name); + +private: + SemanticsContext &context_; }; bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) { @@ -1459,6 +1467,60 @@ bool AccVisitor::Pre(const parser::OpenACCBlockConstruct &x) { return true; } +void AccVisitor::CopySymbolWithDevice(const parser::Name *name) { + // When CUDA Fortran is enabled together with OpenACC, new + // symbols are created for the one appearing in the use_device + // clause. These new symbols have the CUDA Fortran device + // attribute. + if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) { + name->symbol = currScope().CopySymbol(*name->symbol); + if (auto *object{name->symbol->detailsIf()}) { + object->set_cudaDataAttr(common::CUDADataAttr::Device); + } + } +} + +bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) { + for (const auto &accObject : x.v.v) { + common::visit( + common::visitors{ + [&](const parser::Designator &designator) { + if (const auto *name{ + semantics::getDesignatorNameIfDataRef(designator)}) { + Symbol *prev{currScope().FindSymbol(name->source)}; + if (prev != name->symbol) { + name->symbol = prev; + } + CopySymbolWithDevice(name); + } else { + if (const auto *dataRef{ + std::get_if(&designator.u)}) { + using ElementIndirection = + common::Indirection; + if (auto *ind{std::get_if(&dataRef->u)}) { + const parser::ArrayElement &arrayElement{ind->value()}; + Walk(arrayElement.subscripts); + const parser::DataRef &base{arrayElement.base}; + if (auto *name{std::get_if(&base.u)}) { + Symbol *prev{currScope().FindSymbol(name->source)}; + if (prev != name->symbol) { + name->symbol = prev; + } + CopySymbolWithDevice(name); + } + } + } + } + }, + [&](const parser::Name &name) { + // TODO: common block in use_device? + }, + }, + accObject.u); + } + return false; +} + void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) { if (NeedsScope(x)) { PopScope(); @@ -2038,7 +2100,8 @@ class ResolveNamesVisitor : public virtual ScopeHandler, ResolveNamesVisitor( SemanticsContext &context, ImplicitRulesMap &rules, Scope &top) - : BaseVisitor{context, *this, rules}, topScope_{top} { + : BaseVisitor{context, *this, rules}, AccVisitor(context), + topScope_{top} { PushScope(top); } diff --git a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 new file mode 100644 index 0000000000000..da034adec39c4 --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 @@ -0,0 +1,43 @@ + +! RUN: bbc -fopenacc -fcuda -emit-hlfir %s -o - | FileCheck %s + +module m + +interface doit +subroutine __device_sub(a) + real(4), device, intent(in) :: a(:,:,:) + !dir$ ignore_tkr(c) a +end +subroutine __host_sub(a) + real(4), intent(in) :: a(:,:,:) + !dir$ ignore_tkr(c) a +end +end interface +end module + +program testex1 +integer, parameter :: ntimes = 10 +integer, parameter :: ni=128 +integer, parameter :: nj=256 +integer, parameter :: nk=64 +real(4), dimension(ni,nj,nk) :: a + +!$acc enter data copyin(a) + +block; use m +!$acc host_data use_device(a) +do nt = 1, ntimes + call doit(a) +end do +!$acc end host_data +end block + +block; use m +do nt = 1, ntimes + call doit(a) +end do +end block +end + +! CHECK: fir.call @_QP__device_sub +! CHECK: fir.call @_QP__host_sub