Skip to content

Commit 9a96b0a

Browse files
committed
[flang][openacc] Generate pre/post alloc/dealloc function for in subroutine declare
Lowering was missing to generate the pre/post alloc/dealloc functions for the acc declare variables. This patch adds the generation. These functions have the descriptor as their unique argument. Reviewed By: razvanlupusoru Differential Revision: https://reviews.llvm.org/D158103
1 parent 0c3f51c commit 9a96b0a

File tree

2 files changed

+255
-44
lines changed

2 files changed

+255
-44
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 225 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,148 @@ static void addDeclareAttr(fir::FirOpBuilder &builder, mlir::Operation *op,
387387
builder.getContext(), clause)));
388388
}
389389

390+
static mlir::func::FuncOp
391+
createDeclareFunc(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder,
392+
mlir::Location loc, llvm::StringRef funcName,
393+
llvm::SmallVector<mlir::Type> argsTy = {},
394+
llvm::SmallVector<mlir::Location> locs = {}) {
395+
auto funcTy = mlir::FunctionType::get(modBuilder.getContext(), argsTy, {});
396+
auto funcOp = modBuilder.create<mlir::func::FuncOp>(loc, funcName, funcTy);
397+
funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
398+
builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
399+
locs);
400+
builder.setInsertionPointToEnd(&funcOp.getRegion().back());
401+
builder.create<mlir::func::ReturnOp>(loc);
402+
builder.setInsertionPointToStart(&funcOp.getRegion().back());
403+
return funcOp;
404+
}
405+
406+
template <typename Op>
407+
static Op
408+
createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc,
409+
const llvm::SmallVectorImpl<mlir::Value> &operands,
410+
const llvm::SmallVectorImpl<int32_t> &operandSegments) {
411+
llvm::ArrayRef<mlir::Type> argTy;
412+
Op op = builder.create<Op>(loc, argTy, operands);
413+
op->setAttr(Op::getOperandSegmentSizeAttr(),
414+
builder.getDenseI32ArrayAttr(operandSegments));
415+
return op;
416+
}
417+
418+
template <typename EntryOp>
419+
static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
420+
fir::FirOpBuilder &builder,
421+
mlir::Location loc, mlir::Type descTy,
422+
llvm::StringRef funcNamePrefix,
423+
std::stringstream &asFortran,
424+
mlir::acc::DataClause clause) {
425+
auto crtInsPt = builder.saveInsertionPoint();
426+
std::stringstream registerFuncName;
427+
registerFuncName << funcNamePrefix.str()
428+
<< Fortran::lower::declarePostAllocSuffix.str();
429+
430+
if (!mlir::isa<fir::ReferenceType>(descTy))
431+
descTy = fir::ReferenceType::get(descTy);
432+
auto registerFuncOp = createDeclareFunc(
433+
modBuilder, builder, loc, registerFuncName.str(), {descTy}, {loc});
434+
435+
mlir::Value desc =
436+
builder.create<fir::LoadOp>(loc, registerFuncOp.getArgument(0));
437+
fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, desc);
438+
addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
439+
440+
llvm::SmallVector<mlir::Value> bounds;
441+
EntryOp entryOp = createDataEntryOp<EntryOp>(
442+
builder, loc, boxAddrOp.getResult(), asFortran, bounds,
443+
/*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
444+
builder.create<mlir::acc::DeclareEnterOp>(
445+
loc, mlir::ValueRange(entryOp.getAccPtr()));
446+
447+
asFortran << "_desc";
448+
mlir::acc::UpdateDeviceOp updateDeviceOp =
449+
createDataEntryOp<mlir::acc::UpdateDeviceOp>(
450+
builder, loc, registerFuncOp.getArgument(0), asFortran, bounds,
451+
/*structured=*/false, /*implicit=*/true,
452+
mlir::acc::DataClause::acc_update_device, descTy);
453+
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
454+
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
455+
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
456+
modBuilder.setInsertionPointAfter(registerFuncOp);
457+
builder.restoreInsertionPoint(crtInsPt);
458+
}
459+
460+
template <typename ExitOp>
461+
static void createDeclareDeallocFuncWithArg(
462+
mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, mlir::Location loc,
463+
mlir::Type descTy, llvm::StringRef funcNamePrefix,
464+
std::stringstream &asFortran, mlir::acc::DataClause clause) {
465+
auto crtInsPt = builder.saveInsertionPoint();
466+
// Generate the pre dealloc function.
467+
std::stringstream preDeallocFuncName;
468+
preDeallocFuncName << funcNamePrefix.str()
469+
<< Fortran::lower::declarePreDeallocSuffix.str();
470+
if (!mlir::isa<fir::ReferenceType>(descTy))
471+
descTy = fir::ReferenceType::get(descTy);
472+
auto preDeallocOp = createDeclareFunc(
473+
modBuilder, builder, loc, preDeallocFuncName.str(), {descTy}, {loc});
474+
mlir::Value loadOp =
475+
builder.create<fir::LoadOp>(loc, preDeallocOp.getArgument(0));
476+
fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
477+
addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
478+
479+
llvm::SmallVector<mlir::Value> bounds;
480+
mlir::acc::GetDevicePtrOp entryOp =
481+
createDataEntryOp<mlir::acc::GetDevicePtrOp>(
482+
builder, loc, boxAddrOp.getResult(), asFortran, bounds,
483+
/*structured=*/false, /*implicit=*/false, clause,
484+
boxAddrOp.getType());
485+
builder.create<mlir::acc::DeclareExitOp>(
486+
loc, mlir::ValueRange(entryOp.getAccPtr()));
487+
488+
mlir::Value varPtr;
489+
if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
490+
std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
491+
varPtr = entryOp.getVarPtr();
492+
builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(), varPtr,
493+
entryOp.getBounds(), entryOp.getDataClause(),
494+
/*structured=*/false, /*implicit=*/false,
495+
builder.getStringAttr(*entryOp.getName()));
496+
497+
// Generate the post dealloc function.
498+
modBuilder.setInsertionPointAfter(preDeallocOp);
499+
std::stringstream postDeallocFuncName;
500+
postDeallocFuncName << funcNamePrefix.str()
501+
<< Fortran::lower::declarePostDeallocSuffix.str();
502+
auto postDeallocOp = createDeclareFunc(
503+
modBuilder, builder, loc, postDeallocFuncName.str(), {descTy}, {loc});
504+
loadOp = builder.create<fir::LoadOp>(loc, postDeallocOp.getArgument(0));
505+
asFortran << "_desc";
506+
mlir::acc::UpdateDeviceOp updateDeviceOp =
507+
createDataEntryOp<mlir::acc::UpdateDeviceOp>(
508+
builder, loc, loadOp, asFortran, bounds,
509+
/*structured=*/false, /*implicit=*/true,
510+
mlir::acc::DataClause::acc_update_device, loadOp.getType());
511+
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
512+
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
513+
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
514+
modBuilder.setInsertionPointAfter(postDeallocOp);
515+
builder.restoreInsertionPoint(crtInsPt);
516+
}
517+
518+
Fortran::semantics::Symbol &
519+
getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
520+
if (const auto *designator =
521+
std::get_if<Fortran::parser::Designator>(&accObject.u)) {
522+
if (const auto *name =
523+
Fortran::semantics::getDesignatorNameIfDataRef(*designator))
524+
return *name->symbol;
525+
} else if (const auto *name =
526+
std::get_if<Fortran::parser::Name>(&accObject.u)) {
527+
return *name->symbol;
528+
}
529+
llvm::report_fatal_error("Could not find symbol");
530+
}
531+
390532
template <typename Op>
391533
static void
392534
genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
@@ -408,11 +550,69 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
408550
bounds, structured, implicit, dataClause,
409551
baseAddr.getType());
410552
dataOperands.push_back(op.getAccPtr());
411-
if (setDeclareAttr)
412-
addDeclareAttr(builder, op.getVarPtr().getDefiningOp(), dataClause);
413553
}
414554
}
415555

556+
template <typename EntryOp, typename ExitOp>
557+
static void genDeclareDataOperandOperations(
558+
const Fortran::parser::AccObjectList &objectList,
559+
Fortran::lower::AbstractConverter &converter,
560+
Fortran::semantics::SemanticsContext &semanticsContext,
561+
Fortran::lower::StatementContext &stmtCtx,
562+
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
563+
mlir::acc::DataClause dataClause, bool structured, bool implicit) {
564+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
565+
for (const auto &accObject : objectList.v) {
566+
llvm::SmallVector<mlir::Value> bounds;
567+
std::stringstream asFortran;
568+
mlir::Location operandLocation = genOperandLocation(converter, accObject);
569+
mlir::Value baseAddr = gatherDataOperandAddrAndBounds(
570+
converter, builder, semanticsContext, stmtCtx, accObject,
571+
operandLocation, asFortran, bounds);
572+
EntryOp op = createDataEntryOp<EntryOp>(
573+
builder, operandLocation, baseAddr, asFortran, bounds, structured,
574+
implicit, dataClause, baseAddr.getType());
575+
dataOperands.push_back(op.getAccPtr());
576+
addDeclareAttr(builder, op.getVarPtr().getDefiningOp(), dataClause);
577+
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseAddr.getType()))) {
578+
mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
579+
modBuilder.setInsertionPointAfter(builder.getFunction());
580+
std::string prefix =
581+
converter.mangleName(getSymbolFromAccObject(accObject));
582+
createDeclareAllocFuncWithArg<EntryOp>(
583+
modBuilder, builder, operandLocation, baseAddr.getType(), prefix,
584+
asFortran, dataClause);
585+
if constexpr (!std::is_same_v<EntryOp, ExitOp>)
586+
createDeclareDeallocFuncWithArg<ExitOp>(
587+
modBuilder, builder, operandLocation, baseAddr.getType(), prefix,
588+
asFortran, dataClause);
589+
}
590+
}
591+
}
592+
593+
template <typename EntryOp, typename ExitOp, typename Clause>
594+
static void genDeclareDataOperandOperationsWithModifier(
595+
const Clause *x, Fortran::lower::AbstractConverter &converter,
596+
Fortran::semantics::SemanticsContext &semanticsContext,
597+
Fortran::lower::StatementContext &stmtCtx,
598+
Fortran::parser::AccDataModifier::Modifier mod,
599+
llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
600+
const mlir::acc::DataClause clause,
601+
const mlir::acc::DataClause clauseWithModifier) {
602+
const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
603+
const auto &accObjectList =
604+
std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
605+
const auto &modifier =
606+
std::get<std::optional<Fortran::parser::AccDataModifier>>(
607+
listWithModifier.t);
608+
mlir::acc::DataClause dataClause =
609+
(modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
610+
genDeclareDataOperandOperations<EntryOp, ExitOp>(
611+
accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands,
612+
dataClause,
613+
/*structured=*/true, /*implicit=*/false);
614+
}
615+
416616
template <typename EntryOp, typename ExitOp>
417617
static void genDataExitOperations(fir::FirOpBuilder &builder,
418618
llvm::SmallVector<mlir::Value> operands,
@@ -1058,18 +1258,6 @@ createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
10581258
return op;
10591259
}
10601260

1061-
template <typename Op>
1062-
static Op
1063-
createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc,
1064-
const llvm::SmallVectorImpl<mlir::Value> &operands,
1065-
const llvm::SmallVectorImpl<int32_t> &operandSegments) {
1066-
llvm::ArrayRef<mlir::Type> argTy;
1067-
Op op = builder.create<Op>(loc, argTy, operands);
1068-
op->setAttr(Op::getOperandSegmentSizeAttr(),
1069-
builder.getDenseI32ArrayAttr(operandSegments));
1070-
return op;
1071-
}
1072-
10731261
static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
10741262
const Fortran::parser::AccClause::Async *asyncClause,
10751263
mlir::Value &async, bool &addAsyncAttr,
@@ -2349,20 +2537,6 @@ static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
23492537
modBuilder.setInsertionPointAfter(declareGlobalOp);
23502538
}
23512539

2352-
static mlir::func::FuncOp createDeclareFunc(mlir::OpBuilder &modBuilder,
2353-
fir::FirOpBuilder &builder,
2354-
mlir::Location loc,
2355-
llvm::StringRef funcName) {
2356-
auto funcTy = mlir::FunctionType::get(modBuilder.getContext(), {}, {});
2357-
auto funcOp = modBuilder.create<mlir::func::FuncOp>(loc, funcName, funcTy);
2358-
funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
2359-
builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), {}, {});
2360-
builder.setInsertionPointToEnd(&funcOp.getRegion().back());
2361-
builder.create<mlir::func::ReturnOp>(loc);
2362-
builder.setInsertionPointToStart(&funcOp.getRegion().back());
2363-
return funcOp;
2364-
}
2365-
23662540
template <typename EntryOp>
23672541
static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
23682542
fir::FirOpBuilder &builder,
@@ -2556,10 +2730,11 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
25562730
if (const auto *copyClause =
25572731
std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
25582732
auto crtDataStart = dataClauseOperands.size();
2559-
genDataOperandOperations<mlir::acc::CopyinOp>(
2733+
genDeclareDataOperandOperations<mlir::acc::CopyinOp,
2734+
mlir::acc::CopyoutOp>(
25602735
copyClause->v, converter, semanticsContext, stmtCtx,
25612736
dataClauseOperands, mlir::acc::DataClause::acc_copy,
2562-
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
2737+
/*structured=*/true, /*implicit=*/false);
25632738
copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
25642739
dataClauseOperands.end());
25652740
} else if (const auto *createClause =
@@ -2569,26 +2744,28 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
25692744
const auto &accObjectList =
25702745
std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
25712746
auto crtDataStart = dataClauseOperands.size();
2572-
genDataOperandOperations<mlir::acc::CreateOp>(
2747+
genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
25732748
accObjectList, converter, semanticsContext, stmtCtx,
25742749
dataClauseOperands, mlir::acc::DataClause::acc_create,
2575-
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
2750+
/*structured=*/true, /*implicit=*/false);
25762751
createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
25772752
dataClauseOperands.end());
25782753
} else if (const auto *presentClause =
25792754
std::get_if<Fortran::parser::AccClause::Present>(
25802755
&clause.u)) {
2581-
genDataOperandOperations<mlir::acc::PresentOp>(
2756+
genDeclareDataOperandOperations<mlir::acc::PresentOp,
2757+
mlir::acc::PresentOp>(
25822758
presentClause->v, converter, semanticsContext, stmtCtx,
25832759
dataClauseOperands, mlir::acc::DataClause::acc_present,
2584-
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
2760+
/*structured=*/true, /*implicit=*/false);
25852761
} else if (const auto *copyinClause =
25862762
std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
2587-
genDataOperandOperationsWithModifier<mlir::acc::CopyinOp>(
2763+
genDeclareDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
2764+
mlir::acc::DeleteOp>(
25882765
copyinClause, converter, semanticsContext, stmtCtx,
25892766
Fortran::parser::AccDataModifier::Modifier::ReadOnly,
25902767
dataClauseOperands, mlir::acc::DataClause::acc_copyin,
2591-
mlir::acc::DataClause::acc_copyin_readonly, /*setDeclareAttr=*/true);
2768+
mlir::acc::DataClause::acc_copyin_readonly);
25922769
} else if (const auto *copyoutClause =
25932770
std::get_if<Fortran::parser::AccClause::Copyout>(
25942771
&clause.u)) {
@@ -2597,34 +2774,38 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
25972774
const auto &accObjectList =
25982775
std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
25992776
auto crtDataStart = dataClauseOperands.size();
2600-
genDataOperandOperations<mlir::acc::CreateOp>(
2777+
genDeclareDataOperandOperations<mlir::acc::CreateOp,
2778+
mlir::acc::CopyoutOp>(
26012779
accObjectList, converter, semanticsContext, stmtCtx,
26022780
dataClauseOperands, mlir::acc::DataClause::acc_copyout,
2603-
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
2781+
/*structured=*/true, /*implicit=*/false);
26042782
copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
26052783
dataClauseOperands.end());
26062784
} else if (const auto *devicePtrClause =
26072785
std::get_if<Fortran::parser::AccClause::Deviceptr>(
26082786
&clause.u)) {
2609-
genDataOperandOperations<mlir::acc::DevicePtrOp>(
2787+
genDeclareDataOperandOperations<mlir::acc::DevicePtrOp,
2788+
mlir::acc::DevicePtrOp>(
26102789
devicePtrClause->v, converter, semanticsContext, stmtCtx,
26112790
dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
2612-
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
2791+
/*structured=*/true, /*implicit=*/false);
26132792
} else if (const auto *linkClause =
26142793
std::get_if<Fortran::parser::AccClause::Link>(&clause.u)) {
2615-
genDataOperandOperations<mlir::acc::DeclareLinkOp>(
2794+
genDeclareDataOperandOperations<mlir::acc::DeclareLinkOp,
2795+
mlir::acc::DeclareLinkOp>(
26162796
linkClause->v, converter, semanticsContext, stmtCtx,
26172797
dataClauseOperands, mlir::acc::DataClause::acc_declare_link,
2618-
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
2798+
/*structured=*/true, /*implicit=*/false);
26192799
} else if (const auto *deviceResidentClause =
26202800
std::get_if<Fortran::parser::AccClause::DeviceResident>(
26212801
&clause.u)) {
26222802
auto crtDataStart = dataClauseOperands.size();
2623-
genDataOperandOperations<mlir::acc::DeclareDeviceResidentOp>(
2803+
genDeclareDataOperandOperations<mlir::acc::DeclareDeviceResidentOp,
2804+
mlir::acc::DeleteOp>(
26242805
deviceResidentClause->v, converter, semanticsContext, stmtCtx,
26252806
dataClauseOperands,
26262807
mlir::acc::DataClause::acc_declare_device_resident,
2627-
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
2808+
/*structured=*/true, /*implicit=*/false);
26282809
deviceResidentEntryOperands.append(
26292810
dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end());
26302811
} else {

0 commit comments

Comments
 (0)