@@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
3184
3184
Fortran::lower::pft::Evaluation &eval,
3185
3185
Fortran::semantics::SemanticsContext &semanticsContext,
3186
3186
Fortran::lower::StatementContext &stmtCtx,
3187
- const Fortran::parser::AccClauseList &accClauseList) {
3187
+ const Fortran::parser::AccClauseList &accClauseList,
3188
+ Fortran::lower::SymMap &localSymbols) {
3188
3189
mlir::Value ifCond;
3189
3190
llvm::SmallVector<mlir::Value> dataOperands;
3190
3191
bool addIfPresentAttr = false ;
@@ -3199,6 +3200,19 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
3199
3200
} else if (const auto *useDevice =
3200
3201
std::get_if<Fortran::parser::AccClause::UseDevice>(
3201
3202
&clause.u )) {
3203
+ // When CUDA Fotran is enabled, extra symbols are used in the host_data
3204
+ // region. Look for them and bind their values with the symbols in the
3205
+ // outer scope.
3206
+ if (semanticsContext.IsEnabled (Fortran::common::LanguageFeature::CUDA)) {
3207
+ const Fortran::parser::AccObjectList &objectList{useDevice->v };
3208
+ for (const auto &accObject : objectList.v ) {
3209
+ Fortran::semantics::Symbol &symbol =
3210
+ getSymbolFromAccObject (accObject);
3211
+ const Fortran::semantics::Symbol *baseSym =
3212
+ localSymbols.lookupSymbolByName (symbol.name ().ToString ());
3213
+ localSymbols.copySymbolBinding (*baseSym, symbol);
3214
+ }
3215
+ }
3202
3216
genDataOperandOperations<mlir::acc::UseDeviceOp>(
3203
3217
useDevice->v , converter, semanticsContext, stmtCtx, dataOperands,
3204
3218
mlir::acc::DataClause::acc_use_device,
@@ -3239,11 +3253,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
3239
3253
hostDataOp.setIfPresentAttr (builder.getUnitAttr ());
3240
3254
}
3241
3255
3242
- static void
3243
- genACC ( Fortran::lower::AbstractConverter &converter ,
3244
- Fortran::semantics::SemanticsContext &semanticsContext ,
3245
- Fortran::lower::pft::Evaluation &eval ,
3246
- const Fortran::parser::OpenACCBlockConstruct &blockConstruct ) {
3256
+ static void genACC (Fortran::lower::AbstractConverter &converter,
3257
+ Fortran::semantics::SemanticsContext &semanticsContext ,
3258
+ Fortran::lower::pft::Evaluation &eval ,
3259
+ const Fortran::parser::OpenACCBlockConstruct &blockConstruct ,
3260
+ Fortran::lower::SymMap &localSymbols ) {
3247
3261
const auto &beginBlockDirective =
3248
3262
std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t );
3249
3263
const auto &blockDirective =
@@ -3273,7 +3287,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
3273
3287
accClauseList);
3274
3288
} else if (blockDirective.v == llvm::acc::ACCD_host_data) {
3275
3289
genACCHostDataOp (converter, currentLocation, eval, semanticsContext,
3276
- stmtCtx, accClauseList);
3290
+ stmtCtx, accClauseList, localSymbols );
3277
3291
}
3278
3292
}
3279
3293
@@ -4647,13 +4661,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct(
4647
4661
Fortran::lower::AbstractConverter &converter,
4648
4662
Fortran::semantics::SemanticsContext &semanticsContext,
4649
4663
Fortran::lower::pft::Evaluation &eval,
4650
- const Fortran::parser::OpenACCConstruct &accConstruct) {
4664
+ const Fortran::parser::OpenACCConstruct &accConstruct,
4665
+ Fortran::lower::SymMap &localSymbols) {
4651
4666
4652
4667
mlir::Value exitCond;
4653
4668
Fortran::common::visit (
4654
4669
common::visitors{
4655
4670
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
4656
- genACC (converter, semanticsContext, eval, blockConstruct);
4671
+ genACC (converter, semanticsContext, eval, blockConstruct,
4672
+ localSymbols);
4657
4673
},
4658
4674
[&](const Fortran::parser::OpenACCCombinedConstruct
4659
4675
&combinedConstruct) {
0 commit comments