-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[flang][cuda][openacc] Create new symbol in host_data region for CUDA Fortran interop #161613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-openacc @llvm/pr-subscribers-flang-openmp Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesWhen there is a host_data region with use_device clause, the symbol inside the region should be treated like if they have the device attribute when CUDA Fortran is enable. This allow generic resolution to pick up dedicated functions and subroutine written for device data. Full diff: https://github.com/llvm/llvm-project/pull/161613.diff 10 Files Affected:
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 19d759479abaf..4622dbc8ccf64 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -77,7 +77,8 @@ static constexpr llvm::StringRef privatizationRecipePrefix = "privatization";
mlir::Value genOpenACCConstruct(AbstractConverter &,
Fortran::semantics::SemanticsContext &,
pft::Evaluation &,
- const parser::OpenACCConstruct &);
+ const parser::OpenACCConstruct &,
+ Fortran::lower::SymMap &localSymbols);
void genOpenACCDeclarativeConstruct(
AbstractConverter &, Fortran::semantics::SemanticsContext &,
StatementContext &, const parser::OpenACCDeclarativeConstruct &);
diff --git a/flang/include/flang/Lower/SymbolMap.h b/flang/include/flang/Lower/SymbolMap.h
index 813df777d7a39..e57b6a42d3cc1 100644
--- a/flang/include/flang/Lower/SymbolMap.h
+++ b/flang/include/flang/Lower/SymbolMap.h
@@ -260,6 +260,10 @@ class SymMap {
return lookupSymbol(*sym);
}
+ /// Find a symbol by name and return its value if it appears in the current
+ /// mappings. This lookup is more expensive as it iterates over the map.
+ const semantics::Symbol *lookupSymbolByName(llvm::StringRef symName);
+
/// Find `symbol` and return its value if it appears in the inner-most level
/// map.
SymbolBox shallowLookupSymbol(semantics::SymbolRef sym);
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index e90e9c617805d..a0d5ae7176141 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -801,7 +801,7 @@ class Symbol {
AccPrivate, AccFirstPrivate, AccShared,
// OpenACC data-mapping attribute
AccCopy, AccCopyIn, AccCopyInReadOnly, AccCopyOut, AccCreate, AccDelete,
- AccPresent, AccLink, AccDeviceResident, AccDevicePtr,
+ AccPresent, AccLink, AccDeviceResident, AccDevicePtr, AccUseDevice,
// OpenACC declare
AccDeclare,
// OpenACC data-movement attribute
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 149e51b501a82..780d56f085f69 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3182,7 +3182,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
localSymbols.pushScope();
mlir::Value exitCond = genOpenACCConstruct(
- *this, bridge.getSemanticsContext(), getEval(), acc);
+ *this, bridge.getSemanticsContext(), getEval(), acc, localSymbols);
const Fortran::parser::OpenACCLoopConstruct *accLoop =
std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 95d0adae02670..66f45451b3eeb 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::AccClauseList &accClauseList) {
+ const Fortran::parser::AccClauseList &accClauseList,
+ Fortran::lower::SymMap &localSymbols) {
mlir::Value ifCond;
llvm::SmallVector<mlir::Value> dataOperands;
bool addIfPresentAttr = false;
@@ -3199,6 +3200,18 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *useDevice =
std::get_if<Fortran::parser::AccClause::UseDevice>(
&clause.u)) {
+ // When CUDA Fotran is en
+ if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
+
+ const Fortran::parser::AccObjectList &objectList{useDevice->v};
+ for (const auto &accObject : objectList.v) {
+ Fortran::semantics::Symbol &symbol =
+ getSymbolFromAccObject(accObject);
+ const Fortran::semantics::Symbol *baseSym =
+ localSymbols.lookupSymbolByName(symbol.name().ToString());
+ localSymbols.copySymbolBinding(*baseSym, symbol);
+ }
+ }
genDataOperandOperations<mlir::acc::UseDeviceOp>(
useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
mlir::acc::DataClause::acc_use_device,
@@ -3239,11 +3252,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
hostDataOp.setIfPresentAttr(builder.getUnitAttr());
}
-static void
-genACC(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
+static void genACC(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semanticsContext,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenACCBlockConstruct &blockConstruct,
+ Fortran::lower::SymMap &localSymbols) {
const auto &beginBlockDirective =
std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
const auto &blockDirective =
@@ -3273,7 +3286,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
accClauseList);
} else if (blockDirective.v == llvm::acc::ACCD_host_data) {
genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
- stmtCtx, accClauseList);
+ stmtCtx, accClauseList, localSymbols);
}
}
@@ -4647,13 +4660,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCConstruct &accConstruct) {
+ const Fortran::parser::OpenACCConstruct &accConstruct,
+ Fortran::lower::SymMap &localSymbols) {
mlir::Value exitCond;
Fortran::common::visit(
common::visitors{
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
- genACC(converter, semanticsContext, eval, blockConstruct);
+ genACC(converter, semanticsContext, eval, blockConstruct,
+ localSymbols);
},
[&](const Fortran::parser::OpenACCCombinedConstruct
&combinedConstruct) {
diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp
index 080f21ec67400..78529e0d539fb 100644
--- a/flang/lib/Lower/SymbolMap.cpp
+++ b/flang/lib/Lower/SymbolMap.cpp
@@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
return SymbolBox::None{};
}
+const Fortran::semantics::Symbol *
+Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) {
+ for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend();
+ jmap != jend; ++jmap)
+ for (auto const &[sym, symBox] : *jmap)
+ if (sym->name().ToString() == symName)
+ return sym;
+ return nullptr;
+}
+
Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
Fortran::semantics::SymbolRef symRef) {
auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index 1049a6d2c1b2e..7b881008219df 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -1189,7 +1189,8 @@ void CheckHelper::CheckObjectEntity(
}
} else if (!subpDetails && symbol.owner().kind() != Scope::Kind::Module &&
symbol.owner().kind() != Scope::Kind::MainProgram &&
- symbol.owner().kind() != Scope::Kind::BlockConstruct) {
+ symbol.owner().kind() != Scope::Kind::BlockConstruct &&
+ symbol.owner().kind() != Scope::Kind::OpenACCConstruct) {
messages_.Say(
"ATTRIBUTES(%s) may apply only to module, host subprogram, block, or device subprogram data"_err_en_US,
parser::ToUpperCaseLetters(common::EnumToString(attr)));
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index bd7b8ac552fab..366ebba1b7263 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -328,6 +328,11 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
return false;
}
+ bool Pre(const parser::AccClause::UseDevice &x) {
+ ResolveAccObjectList(x.v, Symbol::Flag::AccUseDevice);
+ return false;
+ }
+
void Post(const parser::Name &);
private:
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index d1150a9eb67f4..7e55eb005dbb2 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1387,6 +1387,8 @@ class ConstructVisitor : public virtual DeclarationVisitor {
// Create scopes for OpenACC constructs
class AccVisitor : public virtual DeclarationVisitor {
public:
+ explicit AccVisitor(SemanticsContext &context) : context_{context} {}
+
void AddAccSourceRange(const parser::CharBlock &);
static bool NeedsScope(const parser::OpenACCBlockConstruct &);
@@ -1395,6 +1397,7 @@ class AccVisitor : public virtual DeclarationVisitor {
void Post(const parser::OpenACCBlockConstruct &);
bool Pre(const parser::OpenACCCombinedConstruct &);
void Post(const parser::OpenACCCombinedConstruct &);
+ bool Pre(const parser::AccClause::UseDevice &x);
bool Pre(const parser::AccBeginBlockDirective &x) {
AddAccSourceRange(x.source);
return true;
@@ -1430,6 +1433,11 @@ class AccVisitor : public virtual DeclarationVisitor {
void Post(const parser::AccBeginLoopDirective &x) {
messageHandler().set_currStmtSource(std::nullopt);
}
+
+ void CopySymbolWithDevice(const parser::Name *name);
+
+private:
+ SemanticsContext &context_;
};
bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) {
@@ -1459,6 +1467,60 @@ bool AccVisitor::Pre(const parser::OpenACCBlockConstruct &x) {
return true;
}
+void AccVisitor::CopySymbolWithDevice(const parser::Name *name) {
+ // When CUDA Fortran is enabled together with OpenACC, new
+ // symbols are created for the one appearing in the use_device
+ // clause. These new symbols have the CUDA Fortran device
+ // attribute.
+ if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) {
+ name->symbol = currScope().CopySymbol(*name->symbol);
+ if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) {
+ object->set_cudaDataAttr(common::CUDADataAttr::Device);
+ }
+ }
+}
+
+bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
+ for (const auto &accObject : x.v.v) {
+ common::visit(
+ common::visitors{
+ [&](const parser::Designator &designator) {
+ if (const auto *name{
+ semantics::getDesignatorNameIfDataRef(designator)}) {
+ Symbol *prev{currScope().FindSymbol(name->source)};
+ if (prev != name->symbol) {
+ name->symbol = prev;
+ }
+ CopySymbolWithDevice(name);
+ } else {
+ if (const auto *dataRef{
+ std::get_if<parser::DataRef>(&designator.u)}) {
+ using ElementIndirection =
+ common::Indirection<parser::ArrayElement>;
+ if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
+ const parser::ArrayElement &arrayElement{ind->value()};
+ Walk(arrayElement.subscripts);
+ const parser::DataRef &base{arrayElement.base};
+ if (auto *name{std::get_if<parser::Name>(&base.u)}) {
+ Symbol *prev{currScope().FindSymbol(name->source)};
+ if (prev != name->symbol) {
+ name->symbol = prev;
+ }
+ CopySymbolWithDevice(name);
+ }
+ }
+ }
+ }
+ },
+ [&](const parser::Name &name) {
+ // TODO: common block in use_device?
+ },
+ },
+ accObject.u);
+ }
+ return false;
+}
+
void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) {
if (NeedsScope(x)) {
PopScope();
@@ -2038,7 +2100,8 @@ class ResolveNamesVisitor : public virtual ScopeHandler,
ResolveNamesVisitor(
SemanticsContext &context, ImplicitRulesMap &rules, Scope &top)
- : BaseVisitor{context, *this, rules}, topScope_{top} {
+ : BaseVisitor{context, *this, rules}, topScope_{top},
+ AccVisitor(context) {
PushScope(top);
}
diff --git a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
new file mode 100644
index 0000000000000..da034adec39c4
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
@@ -0,0 +1,43 @@
+
+! RUN: bbc -fopenacc -fcuda -emit-hlfir %s -o - | FileCheck %s
+
+module m
+
+interface doit
+subroutine __device_sub(a)
+ real(4), device, intent(in) :: a(:,:,:)
+ !dir$ ignore_tkr(c) a
+end
+subroutine __host_sub(a)
+ real(4), intent(in) :: a(:,:,:)
+ !dir$ ignore_tkr(c) a
+end
+end interface
+end module
+
+program testex1
+integer, parameter :: ntimes = 10
+integer, parameter :: ni=128
+integer, parameter :: nj=256
+integer, parameter :: nk=64
+real(4), dimension(ni,nj,nk) :: a
+
+!$acc enter data copyin(a)
+
+block; use m
+!$acc host_data use_device(a)
+do nt = 1, ntimes
+ call doit(a)
+end do
+!$acc end host_data
+end block
+
+block; use m
+do nt = 1, ntimes
+ call doit(a)
+end do
+end block
+end
+
+! CHECK: fir.call @_QP__device_sub
+! CHECK: fir.call @_QP__host_sub
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesWhen there is a host_data region with use_device clause, the symbol inside the region should be treated like if they have the device attribute when CUDA Fortran is enable. This allow generic resolution to pick up dedicated functions and subroutine written for device data. Full diff: https://github.com/llvm/llvm-project/pull/161613.diff 10 Files Affected:
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 19d759479abaf..4622dbc8ccf64 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -77,7 +77,8 @@ static constexpr llvm::StringRef privatizationRecipePrefix = "privatization";
mlir::Value genOpenACCConstruct(AbstractConverter &,
Fortran::semantics::SemanticsContext &,
pft::Evaluation &,
- const parser::OpenACCConstruct &);
+ const parser::OpenACCConstruct &,
+ Fortran::lower::SymMap &localSymbols);
void genOpenACCDeclarativeConstruct(
AbstractConverter &, Fortran::semantics::SemanticsContext &,
StatementContext &, const parser::OpenACCDeclarativeConstruct &);
diff --git a/flang/include/flang/Lower/SymbolMap.h b/flang/include/flang/Lower/SymbolMap.h
index 813df777d7a39..e57b6a42d3cc1 100644
--- a/flang/include/flang/Lower/SymbolMap.h
+++ b/flang/include/flang/Lower/SymbolMap.h
@@ -260,6 +260,10 @@ class SymMap {
return lookupSymbol(*sym);
}
+ /// Find a symbol by name and return its value if it appears in the current
+ /// mappings. This lookup is more expensive as it iterates over the map.
+ const semantics::Symbol *lookupSymbolByName(llvm::StringRef symName);
+
/// Find `symbol` and return its value if it appears in the inner-most level
/// map.
SymbolBox shallowLookupSymbol(semantics::SymbolRef sym);
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index e90e9c617805d..a0d5ae7176141 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -801,7 +801,7 @@ class Symbol {
AccPrivate, AccFirstPrivate, AccShared,
// OpenACC data-mapping attribute
AccCopy, AccCopyIn, AccCopyInReadOnly, AccCopyOut, AccCreate, AccDelete,
- AccPresent, AccLink, AccDeviceResident, AccDevicePtr,
+ AccPresent, AccLink, AccDeviceResident, AccDevicePtr, AccUseDevice,
// OpenACC declare
AccDeclare,
// OpenACC data-movement attribute
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 149e51b501a82..780d56f085f69 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3182,7 +3182,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
localSymbols.pushScope();
mlir::Value exitCond = genOpenACCConstruct(
- *this, bridge.getSemanticsContext(), getEval(), acc);
+ *this, bridge.getSemanticsContext(), getEval(), acc, localSymbols);
const Fortran::parser::OpenACCLoopConstruct *accLoop =
std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 95d0adae02670..66f45451b3eeb 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::AccClauseList &accClauseList) {
+ const Fortran::parser::AccClauseList &accClauseList,
+ Fortran::lower::SymMap &localSymbols) {
mlir::Value ifCond;
llvm::SmallVector<mlir::Value> dataOperands;
bool addIfPresentAttr = false;
@@ -3199,6 +3200,18 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *useDevice =
std::get_if<Fortran::parser::AccClause::UseDevice>(
&clause.u)) {
+ // When CUDA Fotran is en
+ if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
+
+ const Fortran::parser::AccObjectList &objectList{useDevice->v};
+ for (const auto &accObject : objectList.v) {
+ Fortran::semantics::Symbol &symbol =
+ getSymbolFromAccObject(accObject);
+ const Fortran::semantics::Symbol *baseSym =
+ localSymbols.lookupSymbolByName(symbol.name().ToString());
+ localSymbols.copySymbolBinding(*baseSym, symbol);
+ }
+ }
genDataOperandOperations<mlir::acc::UseDeviceOp>(
useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
mlir::acc::DataClause::acc_use_device,
@@ -3239,11 +3252,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
hostDataOp.setIfPresentAttr(builder.getUnitAttr());
}
-static void
-genACC(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
+static void genACC(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semanticsContext,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenACCBlockConstruct &blockConstruct,
+ Fortran::lower::SymMap &localSymbols) {
const auto &beginBlockDirective =
std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
const auto &blockDirective =
@@ -3273,7 +3286,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
accClauseList);
} else if (blockDirective.v == llvm::acc::ACCD_host_data) {
genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
- stmtCtx, accClauseList);
+ stmtCtx, accClauseList, localSymbols);
}
}
@@ -4647,13 +4660,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCConstruct &accConstruct) {
+ const Fortran::parser::OpenACCConstruct &accConstruct,
+ Fortran::lower::SymMap &localSymbols) {
mlir::Value exitCond;
Fortran::common::visit(
common::visitors{
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
- genACC(converter, semanticsContext, eval, blockConstruct);
+ genACC(converter, semanticsContext, eval, blockConstruct,
+ localSymbols);
},
[&](const Fortran::parser::OpenACCCombinedConstruct
&combinedConstruct) {
diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp
index 080f21ec67400..78529e0d539fb 100644
--- a/flang/lib/Lower/SymbolMap.cpp
+++ b/flang/lib/Lower/SymbolMap.cpp
@@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
return SymbolBox::None{};
}
+const Fortran::semantics::Symbol *
+Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) {
+ for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend();
+ jmap != jend; ++jmap)
+ for (auto const &[sym, symBox] : *jmap)
+ if (sym->name().ToString() == symName)
+ return sym;
+ return nullptr;
+}
+
Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
Fortran::semantics::SymbolRef symRef) {
auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index 1049a6d2c1b2e..7b881008219df 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -1189,7 +1189,8 @@ void CheckHelper::CheckObjectEntity(
}
} else if (!subpDetails && symbol.owner().kind() != Scope::Kind::Module &&
symbol.owner().kind() != Scope::Kind::MainProgram &&
- symbol.owner().kind() != Scope::Kind::BlockConstruct) {
+ symbol.owner().kind() != Scope::Kind::BlockConstruct &&
+ symbol.owner().kind() != Scope::Kind::OpenACCConstruct) {
messages_.Say(
"ATTRIBUTES(%s) may apply only to module, host subprogram, block, or device subprogram data"_err_en_US,
parser::ToUpperCaseLetters(common::EnumToString(attr)));
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index bd7b8ac552fab..366ebba1b7263 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -328,6 +328,11 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
return false;
}
+ bool Pre(const parser::AccClause::UseDevice &x) {
+ ResolveAccObjectList(x.v, Symbol::Flag::AccUseDevice);
+ return false;
+ }
+
void Post(const parser::Name &);
private:
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index d1150a9eb67f4..7e55eb005dbb2 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1387,6 +1387,8 @@ class ConstructVisitor : public virtual DeclarationVisitor {
// Create scopes for OpenACC constructs
class AccVisitor : public virtual DeclarationVisitor {
public:
+ explicit AccVisitor(SemanticsContext &context) : context_{context} {}
+
void AddAccSourceRange(const parser::CharBlock &);
static bool NeedsScope(const parser::OpenACCBlockConstruct &);
@@ -1395,6 +1397,7 @@ class AccVisitor : public virtual DeclarationVisitor {
void Post(const parser::OpenACCBlockConstruct &);
bool Pre(const parser::OpenACCCombinedConstruct &);
void Post(const parser::OpenACCCombinedConstruct &);
+ bool Pre(const parser::AccClause::UseDevice &x);
bool Pre(const parser::AccBeginBlockDirective &x) {
AddAccSourceRange(x.source);
return true;
@@ -1430,6 +1433,11 @@ class AccVisitor : public virtual DeclarationVisitor {
void Post(const parser::AccBeginLoopDirective &x) {
messageHandler().set_currStmtSource(std::nullopt);
}
+
+ void CopySymbolWithDevice(const parser::Name *name);
+
+private:
+ SemanticsContext &context_;
};
bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) {
@@ -1459,6 +1467,60 @@ bool AccVisitor::Pre(const parser::OpenACCBlockConstruct &x) {
return true;
}
+void AccVisitor::CopySymbolWithDevice(const parser::Name *name) {
+ // When CUDA Fortran is enabled together with OpenACC, new
+ // symbols are created for the one appearing in the use_device
+ // clause. These new symbols have the CUDA Fortran device
+ // attribute.
+ if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) {
+ name->symbol = currScope().CopySymbol(*name->symbol);
+ if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) {
+ object->set_cudaDataAttr(common::CUDADataAttr::Device);
+ }
+ }
+}
+
+bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
+ for (const auto &accObject : x.v.v) {
+ common::visit(
+ common::visitors{
+ [&](const parser::Designator &designator) {
+ if (const auto *name{
+ semantics::getDesignatorNameIfDataRef(designator)}) {
+ Symbol *prev{currScope().FindSymbol(name->source)};
+ if (prev != name->symbol) {
+ name->symbol = prev;
+ }
+ CopySymbolWithDevice(name);
+ } else {
+ if (const auto *dataRef{
+ std::get_if<parser::DataRef>(&designator.u)}) {
+ using ElementIndirection =
+ common::Indirection<parser::ArrayElement>;
+ if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
+ const parser::ArrayElement &arrayElement{ind->value()};
+ Walk(arrayElement.subscripts);
+ const parser::DataRef &base{arrayElement.base};
+ if (auto *name{std::get_if<parser::Name>(&base.u)}) {
+ Symbol *prev{currScope().FindSymbol(name->source)};
+ if (prev != name->symbol) {
+ name->symbol = prev;
+ }
+ CopySymbolWithDevice(name);
+ }
+ }
+ }
+ }
+ },
+ [&](const parser::Name &name) {
+ // TODO: common block in use_device?
+ },
+ },
+ accObject.u);
+ }
+ return false;
+}
+
void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) {
if (NeedsScope(x)) {
PopScope();
@@ -2038,7 +2100,8 @@ class ResolveNamesVisitor : public virtual ScopeHandler,
ResolveNamesVisitor(
SemanticsContext &context, ImplicitRulesMap &rules, Scope &top)
- : BaseVisitor{context, *this, rules}, topScope_{top} {
+ : BaseVisitor{context, *this, rules}, topScope_{top},
+ AccVisitor(context) {
PushScope(top);
}
diff --git a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
new file mode 100644
index 0000000000000..da034adec39c4
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
@@ -0,0 +1,43 @@
+
+! RUN: bbc -fopenacc -fcuda -emit-hlfir %s -o - | FileCheck %s
+
+module m
+
+interface doit
+subroutine __device_sub(a)
+ real(4), device, intent(in) :: a(:,:,:)
+ !dir$ ignore_tkr(c) a
+end
+subroutine __host_sub(a)
+ real(4), intent(in) :: a(:,:,:)
+ !dir$ ignore_tkr(c) a
+end
+end interface
+end module
+
+program testex1
+integer, parameter :: ntimes = 10
+integer, parameter :: ni=128
+integer, parameter :: nj=256
+integer, parameter :: nk=64
+real(4), dimension(ni,nj,nk) :: a
+
+!$acc enter data copyin(a)
+
+block; use m
+!$acc host_data use_device(a)
+do nt = 1, ntimes
+ call doit(a)
+end do
+!$acc end host_data
+end block
+
+block; use m
+do nt = 1, ntimes
+ call doit(a)
+end do
+end block
+end
+
+! CHECK: fir.call @_QP__device_sub
+! CHECK: fir.call @_QP__host_sub
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
… Fortran interop (llvm#161613)
When there is a host_data region with use_device clause, the symbol inside the region should be treated like if they have the device attribute when CUDA Fortran is enable. This allow generic resolution to pick up dedicated functions and subroutine written for device data.
This patch update OpenACC name resolution to treat use_device clause before generic resolution is done. It creates new symbol in the host_data scope with the device attribute set.
In lowering, the newly create symbols are binded to the outer scope symbol since the lowering is expected to be the same.