@@ -387,6 +387,148 @@ static void addDeclareAttr(fir::FirOpBuilder &builder, mlir::Operation *op,
387
387
builder.getContext (), clause)));
388
388
}
389
389
390
+ static mlir::func::FuncOp
391
+ createDeclareFunc (mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder,
392
+ mlir::Location loc, llvm::StringRef funcName,
393
+ llvm::SmallVector<mlir::Type> argsTy = {},
394
+ llvm::SmallVector<mlir::Location> locs = {}) {
395
+ auto funcTy = mlir::FunctionType::get (modBuilder.getContext (), argsTy, {});
396
+ auto funcOp = modBuilder.create <mlir::func::FuncOp>(loc, funcName, funcTy);
397
+ funcOp.setVisibility (mlir::SymbolTable::Visibility::Private);
398
+ builder.createBlock (&funcOp.getRegion (), funcOp.getRegion ().end (), argsTy,
399
+ locs);
400
+ builder.setInsertionPointToEnd (&funcOp.getRegion ().back ());
401
+ builder.create <mlir::func::ReturnOp>(loc);
402
+ builder.setInsertionPointToStart (&funcOp.getRegion ().back ());
403
+ return funcOp;
404
+ }
405
+
406
+ template <typename Op>
407
+ static Op
408
+ createSimpleOp (fir::FirOpBuilder &builder, mlir::Location loc,
409
+ const llvm::SmallVectorImpl<mlir::Value> &operands,
410
+ const llvm::SmallVectorImpl<int32_t > &operandSegments) {
411
+ llvm::ArrayRef<mlir::Type> argTy;
412
+ Op op = builder.create <Op>(loc, argTy, operands);
413
+ op->setAttr (Op::getOperandSegmentSizeAttr (),
414
+ builder.getDenseI32ArrayAttr (operandSegments));
415
+ return op;
416
+ }
417
+
418
+ template <typename EntryOp>
419
+ static void createDeclareAllocFuncWithArg (mlir::OpBuilder &modBuilder,
420
+ fir::FirOpBuilder &builder,
421
+ mlir::Location loc, mlir::Type descTy,
422
+ llvm::StringRef funcNamePrefix,
423
+ std::stringstream &asFortran,
424
+ mlir::acc::DataClause clause) {
425
+ auto crtInsPt = builder.saveInsertionPoint ();
426
+ std::stringstream registerFuncName;
427
+ registerFuncName << funcNamePrefix.str ()
428
+ << Fortran::lower::declarePostAllocSuffix.str ();
429
+
430
+ if (!mlir::isa<fir::ReferenceType>(descTy))
431
+ descTy = fir::ReferenceType::get (descTy);
432
+ auto registerFuncOp = createDeclareFunc (
433
+ modBuilder, builder, loc, registerFuncName.str (), {descTy}, {loc});
434
+
435
+ mlir::Value desc =
436
+ builder.create <fir::LoadOp>(loc, registerFuncOp.getArgument (0 ));
437
+ fir::BoxAddrOp boxAddrOp = builder.create <fir::BoxAddrOp>(loc, desc);
438
+ addDeclareAttr (builder, boxAddrOp.getOperation (), clause);
439
+
440
+ llvm::SmallVector<mlir::Value> bounds;
441
+ EntryOp entryOp = createDataEntryOp<EntryOp>(
442
+ builder, loc, boxAddrOp.getResult (), asFortran, bounds,
443
+ /* structured=*/ false , /* implicit=*/ false , clause, boxAddrOp.getType ());
444
+ builder.create <mlir::acc::DeclareEnterOp>(
445
+ loc, mlir::ValueRange (entryOp.getAccPtr ()));
446
+
447
+ asFortran << " _desc" ;
448
+ mlir::acc::UpdateDeviceOp updateDeviceOp =
449
+ createDataEntryOp<mlir::acc::UpdateDeviceOp>(
450
+ builder, loc, registerFuncOp.getArgument (0 ), asFortran, bounds,
451
+ /* structured=*/ false , /* implicit=*/ true ,
452
+ mlir::acc::DataClause::acc_update_device, descTy);
453
+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 0 , 1 };
454
+ llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
455
+ createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
456
+ modBuilder.setInsertionPointAfter (registerFuncOp);
457
+ builder.restoreInsertionPoint (crtInsPt);
458
+ }
459
+
460
+ template <typename ExitOp>
461
+ static void createDeclareDeallocFuncWithArg (
462
+ mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, mlir::Location loc,
463
+ mlir::Type descTy, llvm::StringRef funcNamePrefix,
464
+ std::stringstream &asFortran, mlir::acc::DataClause clause) {
465
+ auto crtInsPt = builder.saveInsertionPoint ();
466
+ // Generate the pre dealloc function.
467
+ std::stringstream preDeallocFuncName;
468
+ preDeallocFuncName << funcNamePrefix.str ()
469
+ << Fortran::lower::declarePreDeallocSuffix.str ();
470
+ if (!mlir::isa<fir::ReferenceType>(descTy))
471
+ descTy = fir::ReferenceType::get (descTy);
472
+ auto preDeallocOp = createDeclareFunc (
473
+ modBuilder, builder, loc, preDeallocFuncName.str (), {descTy}, {loc});
474
+ mlir::Value loadOp =
475
+ builder.create <fir::LoadOp>(loc, preDeallocOp.getArgument (0 ));
476
+ fir::BoxAddrOp boxAddrOp = builder.create <fir::BoxAddrOp>(loc, loadOp);
477
+ addDeclareAttr (builder, boxAddrOp.getOperation (), clause);
478
+
479
+ llvm::SmallVector<mlir::Value> bounds;
480
+ mlir::acc::GetDevicePtrOp entryOp =
481
+ createDataEntryOp<mlir::acc::GetDevicePtrOp>(
482
+ builder, loc, boxAddrOp.getResult (), asFortran, bounds,
483
+ /* structured=*/ false , /* implicit=*/ false , clause,
484
+ boxAddrOp.getType ());
485
+ builder.create <mlir::acc::DeclareExitOp>(
486
+ loc, mlir::ValueRange (entryOp.getAccPtr ()));
487
+
488
+ mlir::Value varPtr;
489
+ if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
490
+ std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
491
+ varPtr = entryOp.getVarPtr ();
492
+ builder.create <ExitOp>(entryOp.getLoc (), entryOp.getAccPtr (), varPtr,
493
+ entryOp.getBounds (), entryOp.getDataClause (),
494
+ /* structured=*/ false , /* implicit=*/ false ,
495
+ builder.getStringAttr (*entryOp.getName ()));
496
+
497
+ // Generate the post dealloc function.
498
+ modBuilder.setInsertionPointAfter (preDeallocOp);
499
+ std::stringstream postDeallocFuncName;
500
+ postDeallocFuncName << funcNamePrefix.str ()
501
+ << Fortran::lower::declarePostDeallocSuffix.str ();
502
+ auto postDeallocOp = createDeclareFunc (
503
+ modBuilder, builder, loc, postDeallocFuncName.str (), {descTy}, {loc});
504
+ loadOp = builder.create <fir::LoadOp>(loc, postDeallocOp.getArgument (0 ));
505
+ asFortran << " _desc" ;
506
+ mlir::acc::UpdateDeviceOp updateDeviceOp =
507
+ createDataEntryOp<mlir::acc::UpdateDeviceOp>(
508
+ builder, loc, loadOp, asFortran, bounds,
509
+ /* structured=*/ false , /* implicit=*/ true ,
510
+ mlir::acc::DataClause::acc_update_device, loadOp.getType ());
511
+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 0 , 1 };
512
+ llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
513
+ createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
514
+ modBuilder.setInsertionPointAfter (postDeallocOp);
515
+ builder.restoreInsertionPoint (crtInsPt);
516
+ }
517
+
518
+ Fortran::semantics::Symbol &
519
+ getSymbolFromAccObject (const Fortran::parser::AccObject &accObject) {
520
+ if (const auto *designator =
521
+ std::get_if<Fortran::parser::Designator>(&accObject.u )) {
522
+ if (const auto *name =
523
+ Fortran::semantics::getDesignatorNameIfDataRef (*designator))
524
+ return *name->symbol ;
525
+ } else if (const auto *name =
526
+ std::get_if<Fortran::parser::Name>(&accObject.u )) {
527
+ return *name->symbol ;
528
+ }
529
+ llvm::report_fatal_error (" Could not find symbol" );
530
+ }
531
+
390
532
template <typename Op>
391
533
static void
392
534
genDataOperandOperations (const Fortran::parser::AccObjectList &objectList,
@@ -408,11 +550,69 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
408
550
bounds, structured, implicit, dataClause,
409
551
baseAddr.getType ());
410
552
dataOperands.push_back (op.getAccPtr ());
411
- if (setDeclareAttr)
412
- addDeclareAttr (builder, op.getVarPtr ().getDefiningOp (), dataClause);
413
553
}
414
554
}
415
555
556
+ template <typename EntryOp, typename ExitOp>
557
+ static void genDeclareDataOperandOperations (
558
+ const Fortran::parser::AccObjectList &objectList,
559
+ Fortran::lower::AbstractConverter &converter,
560
+ Fortran::semantics::SemanticsContext &semanticsContext,
561
+ Fortran::lower::StatementContext &stmtCtx,
562
+ llvm::SmallVectorImpl<mlir::Value> &dataOperands,
563
+ mlir::acc::DataClause dataClause, bool structured, bool implicit) {
564
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
565
+ for (const auto &accObject : objectList.v ) {
566
+ llvm::SmallVector<mlir::Value> bounds;
567
+ std::stringstream asFortran;
568
+ mlir::Location operandLocation = genOperandLocation (converter, accObject);
569
+ mlir::Value baseAddr = gatherDataOperandAddrAndBounds (
570
+ converter, builder, semanticsContext, stmtCtx, accObject,
571
+ operandLocation, asFortran, bounds);
572
+ EntryOp op = createDataEntryOp<EntryOp>(
573
+ builder, operandLocation, baseAddr, asFortran, bounds, structured,
574
+ implicit, dataClause, baseAddr.getType ());
575
+ dataOperands.push_back (op.getAccPtr ());
576
+ addDeclareAttr (builder, op.getVarPtr ().getDefiningOp (), dataClause);
577
+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType (baseAddr.getType ()))) {
578
+ mlir::OpBuilder modBuilder (builder.getModule ().getBodyRegion ());
579
+ modBuilder.setInsertionPointAfter (builder.getFunction ());
580
+ std::string prefix =
581
+ converter.mangleName (getSymbolFromAccObject (accObject));
582
+ createDeclareAllocFuncWithArg<EntryOp>(
583
+ modBuilder, builder, operandLocation, baseAddr.getType (), prefix,
584
+ asFortran, dataClause);
585
+ if constexpr (!std::is_same_v<EntryOp, ExitOp>)
586
+ createDeclareDeallocFuncWithArg<ExitOp>(
587
+ modBuilder, builder, operandLocation, baseAddr.getType (), prefix,
588
+ asFortran, dataClause);
589
+ }
590
+ }
591
+ }
592
+
593
+ template <typename EntryOp, typename ExitOp, typename Clause>
594
+ static void genDeclareDataOperandOperationsWithModifier (
595
+ const Clause *x, Fortran::lower::AbstractConverter &converter,
596
+ Fortran::semantics::SemanticsContext &semanticsContext,
597
+ Fortran::lower::StatementContext &stmtCtx,
598
+ Fortran::parser::AccDataModifier::Modifier mod,
599
+ llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
600
+ const mlir::acc::DataClause clause,
601
+ const mlir::acc::DataClause clauseWithModifier) {
602
+ const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v ;
603
+ const auto &accObjectList =
604
+ std::get<Fortran::parser::AccObjectList>(listWithModifier.t );
605
+ const auto &modifier =
606
+ std::get<std::optional<Fortran::parser::AccDataModifier>>(
607
+ listWithModifier.t );
608
+ mlir::acc::DataClause dataClause =
609
+ (modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
610
+ genDeclareDataOperandOperations<EntryOp, ExitOp>(
611
+ accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands,
612
+ dataClause,
613
+ /* structured=*/ true , /* implicit=*/ false );
614
+ }
615
+
416
616
template <typename EntryOp, typename ExitOp>
417
617
static void genDataExitOperations (fir::FirOpBuilder &builder,
418
618
llvm::SmallVector<mlir::Value> operands,
@@ -1058,18 +1258,6 @@ createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
1058
1258
return op;
1059
1259
}
1060
1260
1061
- template <typename Op>
1062
- static Op
1063
- createSimpleOp (fir::FirOpBuilder &builder, mlir::Location loc,
1064
- const llvm::SmallVectorImpl<mlir::Value> &operands,
1065
- const llvm::SmallVectorImpl<int32_t > &operandSegments) {
1066
- llvm::ArrayRef<mlir::Type> argTy;
1067
- Op op = builder.create <Op>(loc, argTy, operands);
1068
- op->setAttr (Op::getOperandSegmentSizeAttr (),
1069
- builder.getDenseI32ArrayAttr (operandSegments));
1070
- return op;
1071
- }
1072
-
1073
1261
static void genAsyncClause (Fortran::lower::AbstractConverter &converter,
1074
1262
const Fortran::parser::AccClause::Async *asyncClause,
1075
1263
mlir::Value &async, bool &addAsyncAttr,
@@ -2349,20 +2537,6 @@ static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
2349
2537
modBuilder.setInsertionPointAfter (declareGlobalOp);
2350
2538
}
2351
2539
2352
- static mlir::func::FuncOp createDeclareFunc (mlir::OpBuilder &modBuilder,
2353
- fir::FirOpBuilder &builder,
2354
- mlir::Location loc,
2355
- llvm::StringRef funcName) {
2356
- auto funcTy = mlir::FunctionType::get (modBuilder.getContext (), {}, {});
2357
- auto funcOp = modBuilder.create <mlir::func::FuncOp>(loc, funcName, funcTy);
2358
- funcOp.setVisibility (mlir::SymbolTable::Visibility::Private);
2359
- builder.createBlock (&funcOp.getRegion (), funcOp.getRegion ().end (), {}, {});
2360
- builder.setInsertionPointToEnd (&funcOp.getRegion ().back ());
2361
- builder.create <mlir::func::ReturnOp>(loc);
2362
- builder.setInsertionPointToStart (&funcOp.getRegion ().back ());
2363
- return funcOp;
2364
- }
2365
-
2366
2540
template <typename EntryOp>
2367
2541
static void createDeclareAllocFunc (mlir::OpBuilder &modBuilder,
2368
2542
fir::FirOpBuilder &builder,
@@ -2556,10 +2730,11 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
2556
2730
if (const auto *copyClause =
2557
2731
std::get_if<Fortran::parser::AccClause::Copy>(&clause.u )) {
2558
2732
auto crtDataStart = dataClauseOperands.size ();
2559
- genDataOperandOperations<mlir::acc::CopyinOp>(
2733
+ genDeclareDataOperandOperations<mlir::acc::CopyinOp,
2734
+ mlir::acc::CopyoutOp>(
2560
2735
copyClause->v , converter, semanticsContext, stmtCtx,
2561
2736
dataClauseOperands, mlir::acc::DataClause::acc_copy,
2562
- /* structured=*/ true , /* implicit=*/ false , /* setDeclareAttr= */ true );
2737
+ /* structured=*/ true , /* implicit=*/ false );
2563
2738
copyEntryOperands.append (dataClauseOperands.begin () + crtDataStart,
2564
2739
dataClauseOperands.end ());
2565
2740
} else if (const auto *createClause =
@@ -2569,26 +2744,28 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
2569
2744
const auto &accObjectList =
2570
2745
std::get<Fortran::parser::AccObjectList>(listWithModifier.t );
2571
2746
auto crtDataStart = dataClauseOperands.size ();
2572
- genDataOperandOperations <mlir::acc::CreateOp>(
2747
+ genDeclareDataOperandOperations <mlir::acc::CreateOp, mlir::acc::DeleteOp >(
2573
2748
accObjectList, converter, semanticsContext, stmtCtx,
2574
2749
dataClauseOperands, mlir::acc::DataClause::acc_create,
2575
- /* structured=*/ true , /* implicit=*/ false , /* setDeclareAttr= */ true );
2750
+ /* structured=*/ true , /* implicit=*/ false );
2576
2751
createEntryOperands.append (dataClauseOperands.begin () + crtDataStart,
2577
2752
dataClauseOperands.end ());
2578
2753
} else if (const auto *presentClause =
2579
2754
std::get_if<Fortran::parser::AccClause::Present>(
2580
2755
&clause.u )) {
2581
- genDataOperandOperations<mlir::acc::PresentOp>(
2756
+ genDeclareDataOperandOperations<mlir::acc::PresentOp,
2757
+ mlir::acc::PresentOp>(
2582
2758
presentClause->v , converter, semanticsContext, stmtCtx,
2583
2759
dataClauseOperands, mlir::acc::DataClause::acc_present,
2584
- /* structured=*/ true , /* implicit=*/ false , /* setDeclareAttr= */ true );
2760
+ /* structured=*/ true , /* implicit=*/ false );
2585
2761
} else if (const auto *copyinClause =
2586
2762
std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u )) {
2587
- genDataOperandOperationsWithModifier<mlir::acc::CopyinOp>(
2763
+ genDeclareDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
2764
+ mlir::acc::DeleteOp>(
2588
2765
copyinClause, converter, semanticsContext, stmtCtx,
2589
2766
Fortran::parser::AccDataModifier::Modifier::ReadOnly,
2590
2767
dataClauseOperands, mlir::acc::DataClause::acc_copyin,
2591
- mlir::acc::DataClause::acc_copyin_readonly, /* setDeclareAttr= */ true );
2768
+ mlir::acc::DataClause::acc_copyin_readonly);
2592
2769
} else if (const auto *copyoutClause =
2593
2770
std::get_if<Fortran::parser::AccClause::Copyout>(
2594
2771
&clause.u )) {
@@ -2597,34 +2774,38 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
2597
2774
const auto &accObjectList =
2598
2775
std::get<Fortran::parser::AccObjectList>(listWithModifier.t );
2599
2776
auto crtDataStart = dataClauseOperands.size ();
2600
- genDataOperandOperations<mlir::acc::CreateOp>(
2777
+ genDeclareDataOperandOperations<mlir::acc::CreateOp,
2778
+ mlir::acc::CopyoutOp>(
2601
2779
accObjectList, converter, semanticsContext, stmtCtx,
2602
2780
dataClauseOperands, mlir::acc::DataClause::acc_copyout,
2603
- /* structured=*/ true , /* implicit=*/ false , /* setDeclareAttr= */ true );
2781
+ /* structured=*/ true , /* implicit=*/ false );
2604
2782
copyoutEntryOperands.append (dataClauseOperands.begin () + crtDataStart,
2605
2783
dataClauseOperands.end ());
2606
2784
} else if (const auto *devicePtrClause =
2607
2785
std::get_if<Fortran::parser::AccClause::Deviceptr>(
2608
2786
&clause.u )) {
2609
- genDataOperandOperations<mlir::acc::DevicePtrOp>(
2787
+ genDeclareDataOperandOperations<mlir::acc::DevicePtrOp,
2788
+ mlir::acc::DevicePtrOp>(
2610
2789
devicePtrClause->v , converter, semanticsContext, stmtCtx,
2611
2790
dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
2612
- /* structured=*/ true , /* implicit=*/ false , /* setDeclareAttr= */ true );
2791
+ /* structured=*/ true , /* implicit=*/ false );
2613
2792
} else if (const auto *linkClause =
2614
2793
std::get_if<Fortran::parser::AccClause::Link>(&clause.u )) {
2615
- genDataOperandOperations<mlir::acc::DeclareLinkOp>(
2794
+ genDeclareDataOperandOperations<mlir::acc::DeclareLinkOp,
2795
+ mlir::acc::DeclareLinkOp>(
2616
2796
linkClause->v , converter, semanticsContext, stmtCtx,
2617
2797
dataClauseOperands, mlir::acc::DataClause::acc_declare_link,
2618
- /* structured=*/ true , /* implicit=*/ false , /* setDeclareAttr= */ true );
2798
+ /* structured=*/ true , /* implicit=*/ false );
2619
2799
} else if (const auto *deviceResidentClause =
2620
2800
std::get_if<Fortran::parser::AccClause::DeviceResident>(
2621
2801
&clause.u )) {
2622
2802
auto crtDataStart = dataClauseOperands.size ();
2623
- genDataOperandOperations<mlir::acc::DeclareDeviceResidentOp>(
2803
+ genDeclareDataOperandOperations<mlir::acc::DeclareDeviceResidentOp,
2804
+ mlir::acc::DeleteOp>(
2624
2805
deviceResidentClause->v , converter, semanticsContext, stmtCtx,
2625
2806
dataClauseOperands,
2626
2807
mlir::acc::DataClause::acc_declare_device_resident,
2627
- /* structured=*/ true , /* implicit=*/ false , /* setDeclareAttr= */ true );
2808
+ /* structured=*/ true , /* implicit=*/ false );
2628
2809
deviceResidentEntryOperands.append (
2629
2810
dataClauseOperands.begin () + crtDataStart, dataClauseOperands.end ());
2630
2811
} else {
0 commit comments