Skip to content

Commit c242aff

Browse files
authored
[flang][cuda][openacc] Create new symbol in host_data region for CUDA Fortran interop (#161613)
1 parent daa4e57 commit c242aff

File tree

10 files changed

+157
-14
lines changed

10 files changed

+157
-14
lines changed

flang/include/flang/Lower/OpenACC.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ static constexpr llvm::StringRef privatizationRecipePrefix = "privatization";
7777
mlir::Value genOpenACCConstruct(AbstractConverter &,
7878
Fortran::semantics::SemanticsContext &,
7979
pft::Evaluation &,
80-
const parser::OpenACCConstruct &);
80+
const parser::OpenACCConstruct &,
81+
Fortran::lower::SymMap &localSymbols);
8182
void genOpenACCDeclarativeConstruct(
8283
AbstractConverter &, Fortran::semantics::SemanticsContext &,
8384
StatementContext &, const parser::OpenACCDeclarativeConstruct &);

flang/include/flang/Lower/SymbolMap.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ class SymMap {
260260
return lookupSymbol(*sym);
261261
}
262262

263+
/// Find a symbol by name and return its value if it appears in the current
264+
/// mappings. This lookup is more expensive as it iterates over the map.
265+
const semantics::Symbol *lookupSymbolByName(llvm::StringRef symName);
266+
263267
/// Find `symbol` and return its value if it appears in the inner-most level
264268
/// map.
265269
SymbolBox shallowLookupSymbol(semantics::SymbolRef sym);

flang/include/flang/Semantics/symbol.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ class Symbol {
801801
AccPrivate, AccFirstPrivate, AccShared,
802802
// OpenACC data-mapping attribute
803803
AccCopy, AccCopyIn, AccCopyInReadOnly, AccCopyOut, AccCreate, AccDelete,
804-
AccPresent, AccLink, AccDeviceResident, AccDevicePtr,
804+
AccPresent, AccLink, AccDeviceResident, AccDevicePtr, AccUseDevice,
805805
// OpenACC declare
806806
AccDeclare,
807807
// OpenACC data-movement attribute

flang/lib/Lower/Bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3182,7 +3182,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31823182
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
31833183
localSymbols.pushScope();
31843184
mlir::Value exitCond = genOpenACCConstruct(
3185-
*this, bridge.getSemanticsContext(), getEval(), acc);
3185+
*this, bridge.getSemanticsContext(), getEval(), acc, localSymbols);
31863186

31873187
const Fortran::parser::OpenACCLoopConstruct *accLoop =
31883188
std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);

flang/lib/Lower/OpenACC.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
31843184
Fortran::lower::pft::Evaluation &eval,
31853185
Fortran::semantics::SemanticsContext &semanticsContext,
31863186
Fortran::lower::StatementContext &stmtCtx,
3187-
const Fortran::parser::AccClauseList &accClauseList) {
3187+
const Fortran::parser::AccClauseList &accClauseList,
3188+
Fortran::lower::SymMap &localSymbols) {
31883189
mlir::Value ifCond;
31893190
llvm::SmallVector<mlir::Value> dataOperands;
31903191
bool addIfPresentAttr = false;
@@ -3199,6 +3200,19 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
31993200
} else if (const auto *useDevice =
32003201
std::get_if<Fortran::parser::AccClause::UseDevice>(
32013202
&clause.u)) {
3203+
// When CUDA Fotran is enabled, extra symbols are used in the host_data
3204+
// region. Look for them and bind their values with the symbols in the
3205+
// outer scope.
3206+
if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
3207+
const Fortran::parser::AccObjectList &objectList{useDevice->v};
3208+
for (const auto &accObject : objectList.v) {
3209+
Fortran::semantics::Symbol &symbol =
3210+
getSymbolFromAccObject(accObject);
3211+
const Fortran::semantics::Symbol *baseSym =
3212+
localSymbols.lookupSymbolByName(symbol.name().ToString());
3213+
localSymbols.copySymbolBinding(*baseSym, symbol);
3214+
}
3215+
}
32023216
genDataOperandOperations<mlir::acc::UseDeviceOp>(
32033217
useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
32043218
mlir::acc::DataClause::acc_use_device,
@@ -3239,11 +3253,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
32393253
hostDataOp.setIfPresentAttr(builder.getUnitAttr());
32403254
}
32413255

3242-
static void
3243-
genACC(Fortran::lower::AbstractConverter &converter,
3244-
Fortran::semantics::SemanticsContext &semanticsContext,
3245-
Fortran::lower::pft::Evaluation &eval,
3246-
const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
3256+
static void genACC(Fortran::lower::AbstractConverter &converter,
3257+
Fortran::semantics::SemanticsContext &semanticsContext,
3258+
Fortran::lower::pft::Evaluation &eval,
3259+
const Fortran::parser::OpenACCBlockConstruct &blockConstruct,
3260+
Fortran::lower::SymMap &localSymbols) {
32473261
const auto &beginBlockDirective =
32483262
std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
32493263
const auto &blockDirective =
@@ -3273,7 +3287,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
32733287
accClauseList);
32743288
} else if (blockDirective.v == llvm::acc::ACCD_host_data) {
32753289
genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
3276-
stmtCtx, accClauseList);
3290+
stmtCtx, accClauseList, localSymbols);
32773291
}
32783292
}
32793293

@@ -4647,13 +4661,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct(
46474661
Fortran::lower::AbstractConverter &converter,
46484662
Fortran::semantics::SemanticsContext &semanticsContext,
46494663
Fortran::lower::pft::Evaluation &eval,
4650-
const Fortran::parser::OpenACCConstruct &accConstruct) {
4664+
const Fortran::parser::OpenACCConstruct &accConstruct,
4665+
Fortran::lower::SymMap &localSymbols) {
46514666

46524667
mlir::Value exitCond;
46534668
Fortran::common::visit(
46544669
common::visitors{
46554670
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
4656-
genACC(converter, semanticsContext, eval, blockConstruct);
4671+
genACC(converter, semanticsContext, eval, blockConstruct,
4672+
localSymbols);
46574673
},
46584674
[&](const Fortran::parser::OpenACCCombinedConstruct
46594675
&combinedConstruct) {

flang/lib/Lower/SymbolMap.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
4545
return SymbolBox::None{};
4646
}
4747

48+
const Fortran::semantics::Symbol *
49+
Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) {
50+
for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend();
51+
jmap != jend; ++jmap)
52+
for (auto const &[sym, symBox] : *jmap)
53+
if (sym->name().ToString() == symName)
54+
return sym;
55+
return nullptr;
56+
}
57+
4858
Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
4959
Fortran::semantics::SymbolRef symRef) {
5060
auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();

flang/lib/Semantics/check-declarations.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1189,7 +1189,8 @@ void CheckHelper::CheckObjectEntity(
11891189
}
11901190
} else if (!subpDetails && symbol.owner().kind() != Scope::Kind::Module &&
11911191
symbol.owner().kind() != Scope::Kind::MainProgram &&
1192-
symbol.owner().kind() != Scope::Kind::BlockConstruct) {
1192+
symbol.owner().kind() != Scope::Kind::BlockConstruct &&
1193+
symbol.owner().kind() != Scope::Kind::OpenACCConstruct) {
11931194
messages_.Say(
11941195
"ATTRIBUTES(%s) may apply only to module, host subprogram, block, or device subprogram data"_err_en_US,
11951196
parser::ToUpperCaseLetters(common::EnumToString(attr)));

flang/lib/Semantics/resolve-directives.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,11 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
328328
return false;
329329
}
330330

331+
bool Pre(const parser::AccClause::UseDevice &x) {
332+
ResolveAccObjectList(x.v, Symbol::Flag::AccUseDevice);
333+
return false;
334+
}
335+
331336
void Post(const parser::Name &);
332337

333338
private:

flang/lib/Semantics/resolve-names.cpp

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,8 @@ class ConstructVisitor : public virtual DeclarationVisitor {
13871387
// Create scopes for OpenACC constructs
13881388
class AccVisitor : public virtual DeclarationVisitor {
13891389
public:
1390+
explicit AccVisitor(SemanticsContext &context) : context_{context} {}
1391+
13901392
void AddAccSourceRange(const parser::CharBlock &);
13911393

13921394
static bool NeedsScope(const parser::OpenACCBlockConstruct &);
@@ -1395,6 +1397,7 @@ class AccVisitor : public virtual DeclarationVisitor {
13951397
void Post(const parser::OpenACCBlockConstruct &);
13961398
bool Pre(const parser::OpenACCCombinedConstruct &);
13971399
void Post(const parser::OpenACCCombinedConstruct &);
1400+
bool Pre(const parser::AccClause::UseDevice &x);
13981401
bool Pre(const parser::AccBeginBlockDirective &x) {
13991402
AddAccSourceRange(x.source);
14001403
return true;
@@ -1430,6 +1433,11 @@ class AccVisitor : public virtual DeclarationVisitor {
14301433
void Post(const parser::AccBeginLoopDirective &x) {
14311434
messageHandler().set_currStmtSource(std::nullopt);
14321435
}
1436+
1437+
void CopySymbolWithDevice(const parser::Name *name);
1438+
1439+
private:
1440+
SemanticsContext &context_;
14331441
};
14341442

14351443
bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) {
@@ -1459,6 +1467,60 @@ bool AccVisitor::Pre(const parser::OpenACCBlockConstruct &x) {
14591467
return true;
14601468
}
14611469

1470+
void AccVisitor::CopySymbolWithDevice(const parser::Name *name) {
1471+
// When CUDA Fortran is enabled together with OpenACC, new
1472+
// symbols are created for the one appearing in the use_device
1473+
// clause. These new symbols have the CUDA Fortran device
1474+
// attribute.
1475+
if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) {
1476+
name->symbol = currScope().CopySymbol(*name->symbol);
1477+
if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) {
1478+
object->set_cudaDataAttr(common::CUDADataAttr::Device);
1479+
}
1480+
}
1481+
}
1482+
1483+
bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
1484+
for (const auto &accObject : x.v.v) {
1485+
common::visit(
1486+
common::visitors{
1487+
[&](const parser::Designator &designator) {
1488+
if (const auto *name{
1489+
semantics::getDesignatorNameIfDataRef(designator)}) {
1490+
Symbol *prev{currScope().FindSymbol(name->source)};
1491+
if (prev != name->symbol) {
1492+
name->symbol = prev;
1493+
}
1494+
CopySymbolWithDevice(name);
1495+
} else {
1496+
if (const auto *dataRef{
1497+
std::get_if<parser::DataRef>(&designator.u)}) {
1498+
using ElementIndirection =
1499+
common::Indirection<parser::ArrayElement>;
1500+
if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
1501+
const parser::ArrayElement &arrayElement{ind->value()};
1502+
Walk(arrayElement.subscripts);
1503+
const parser::DataRef &base{arrayElement.base};
1504+
if (auto *name{std::get_if<parser::Name>(&base.u)}) {
1505+
Symbol *prev{currScope().FindSymbol(name->source)};
1506+
if (prev != name->symbol) {
1507+
name->symbol = prev;
1508+
}
1509+
CopySymbolWithDevice(name);
1510+
}
1511+
}
1512+
}
1513+
}
1514+
},
1515+
[&](const parser::Name &name) {
1516+
// TODO: common block in use_device?
1517+
},
1518+
},
1519+
accObject.u);
1520+
}
1521+
return false;
1522+
}
1523+
14621524
void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) {
14631525
if (NeedsScope(x)) {
14641526
PopScope();
@@ -2038,7 +2100,8 @@ class ResolveNamesVisitor : public virtual ScopeHandler,
20382100

20392101
ResolveNamesVisitor(
20402102
SemanticsContext &context, ImplicitRulesMap &rules, Scope &top)
2041-
: BaseVisitor{context, *this, rules}, topScope_{top} {
2103+
: BaseVisitor{context, *this, rules}, AccVisitor(context),
2104+
topScope_{top} {
20422105
PushScope(top);
20432106
}
20442107

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
2+
! RUN: bbc -fopenacc -fcuda -emit-hlfir %s -o - | FileCheck %s
3+
4+
module m
5+
6+
interface doit
7+
subroutine __device_sub(a)
8+
real(4), device, intent(in) :: a(:,:,:)
9+
!dir$ ignore_tkr(c) a
10+
end
11+
subroutine __host_sub(a)
12+
real(4), intent(in) :: a(:,:,:)
13+
!dir$ ignore_tkr(c) a
14+
end
15+
end interface
16+
end module
17+
18+
program testex1
19+
integer, parameter :: ntimes = 10
20+
integer, parameter :: ni=128
21+
integer, parameter :: nj=256
22+
integer, parameter :: nk=64
23+
real(4), dimension(ni,nj,nk) :: a
24+
25+
!$acc enter data copyin(a)
26+
27+
block; use m
28+
!$acc host_data use_device(a)
29+
do nt = 1, ntimes
30+
call doit(a)
31+
end do
32+
!$acc end host_data
33+
end block
34+
35+
block; use m
36+
do nt = 1, ntimes
37+
call doit(a)
38+
end do
39+
end block
40+
end
41+
42+
! CHECK: fir.call @_QP__device_sub
43+
! CHECK: fir.call @_QP__host_sub

0 commit comments

Comments
 (0)