Skip to content

Commit

Permalink
[flang][openacc] Generate pre/post alloc/dealloc function for in subr…
Browse files Browse the repository at this point in the history
…outine declare

Lowering was missing to generate the pre/post alloc/dealloc
functions for the acc declare variables. This patch adds the generation.
These functions have the descriptor as their unique argument.

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D158103
  • Loading branch information
clementval committed Aug 16, 2023
1 parent 0c3f51c commit 9a96b0a
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 44 deletions.
269 changes: 225 additions & 44 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,148 @@ static void addDeclareAttr(fir::FirOpBuilder &builder, mlir::Operation *op,
builder.getContext(), clause)));
}

static mlir::func::FuncOp
createDeclareFunc(mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder,
mlir::Location loc, llvm::StringRef funcName,
llvm::SmallVector<mlir::Type> argsTy = {},
llvm::SmallVector<mlir::Location> locs = {}) {
auto funcTy = mlir::FunctionType::get(modBuilder.getContext(), argsTy, {});
auto funcOp = modBuilder.create<mlir::func::FuncOp>(loc, funcName, funcTy);
funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
locs);
builder.setInsertionPointToEnd(&funcOp.getRegion().back());
builder.create<mlir::func::ReturnOp>(loc);
builder.setInsertionPointToStart(&funcOp.getRegion().back());
return funcOp;
}

template <typename Op>
static Op
createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc,
const llvm::SmallVectorImpl<mlir::Value> &operands,
const llvm::SmallVectorImpl<int32_t> &operandSegments) {
llvm::ArrayRef<mlir::Type> argTy;
Op op = builder.create<Op>(loc, argTy, operands);
op->setAttr(Op::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr(operandSegments));
return op;
}

template <typename EntryOp>
static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
fir::FirOpBuilder &builder,
mlir::Location loc, mlir::Type descTy,
llvm::StringRef funcNamePrefix,
std::stringstream &asFortran,
mlir::acc::DataClause clause) {
auto crtInsPt = builder.saveInsertionPoint();
std::stringstream registerFuncName;
registerFuncName << funcNamePrefix.str()
<< Fortran::lower::declarePostAllocSuffix.str();

if (!mlir::isa<fir::ReferenceType>(descTy))
descTy = fir::ReferenceType::get(descTy);
auto registerFuncOp = createDeclareFunc(
modBuilder, builder, loc, registerFuncName.str(), {descTy}, {loc});

mlir::Value desc =
builder.create<fir::LoadOp>(loc, registerFuncOp.getArgument(0));
fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, desc);
addDeclareAttr(builder, boxAddrOp.getOperation(), clause);

llvm::SmallVector<mlir::Value> bounds;
EntryOp entryOp = createDataEntryOp<EntryOp>(
builder, loc, boxAddrOp.getResult(), asFortran, bounds,
/*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType());
builder.create<mlir::acc::DeclareEnterOp>(
loc, mlir::ValueRange(entryOp.getAccPtr()));

asFortran << "_desc";
mlir::acc::UpdateDeviceOp updateDeviceOp =
createDataEntryOp<mlir::acc::UpdateDeviceOp>(
builder, loc, registerFuncOp.getArgument(0), asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, descTy);
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(registerFuncOp);
builder.restoreInsertionPoint(crtInsPt);
}

template <typename ExitOp>
static void createDeclareDeallocFuncWithArg(
mlir::OpBuilder &modBuilder, fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type descTy, llvm::StringRef funcNamePrefix,
std::stringstream &asFortran, mlir::acc::DataClause clause) {
auto crtInsPt = builder.saveInsertionPoint();
// Generate the pre dealloc function.
std::stringstream preDeallocFuncName;
preDeallocFuncName << funcNamePrefix.str()
<< Fortran::lower::declarePreDeallocSuffix.str();
if (!mlir::isa<fir::ReferenceType>(descTy))
descTy = fir::ReferenceType::get(descTy);
auto preDeallocOp = createDeclareFunc(
modBuilder, builder, loc, preDeallocFuncName.str(), {descTy}, {loc});
mlir::Value loadOp =
builder.create<fir::LoadOp>(loc, preDeallocOp.getArgument(0));
fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
addDeclareAttr(builder, boxAddrOp.getOperation(), clause);

llvm::SmallVector<mlir::Value> bounds;
mlir::acc::GetDevicePtrOp entryOp =
createDataEntryOp<mlir::acc::GetDevicePtrOp>(
builder, loc, boxAddrOp.getResult(), asFortran, bounds,
/*structured=*/false, /*implicit=*/false, clause,
boxAddrOp.getType());
builder.create<mlir::acc::DeclareExitOp>(
loc, mlir::ValueRange(entryOp.getAccPtr()));

mlir::Value varPtr;
if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
varPtr = entryOp.getVarPtr();
builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(), varPtr,
entryOp.getBounds(), entryOp.getDataClause(),
/*structured=*/false, /*implicit=*/false,
builder.getStringAttr(*entryOp.getName()));

// Generate the post dealloc function.
modBuilder.setInsertionPointAfter(preDeallocOp);
std::stringstream postDeallocFuncName;
postDeallocFuncName << funcNamePrefix.str()
<< Fortran::lower::declarePostDeallocSuffix.str();
auto postDeallocOp = createDeclareFunc(
modBuilder, builder, loc, postDeallocFuncName.str(), {descTy}, {loc});
loadOp = builder.create<fir::LoadOp>(loc, postDeallocOp.getArgument(0));
asFortran << "_desc";
mlir::acc::UpdateDeviceOp updateDeviceOp =
createDataEntryOp<mlir::acc::UpdateDeviceOp>(
builder, loc, loadOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, loadOp.getType());
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
builder.restoreInsertionPoint(crtInsPt);
}

Fortran::semantics::Symbol &
getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
if (const auto *designator =
std::get_if<Fortran::parser::Designator>(&accObject.u)) {
if (const auto *name =
Fortran::semantics::getDesignatorNameIfDataRef(*designator))
return *name->symbol;
} else if (const auto *name =
std::get_if<Fortran::parser::Name>(&accObject.u)) {
return *name->symbol;
}
llvm::report_fatal_error("Could not find symbol");
}

template <typename Op>
static void
genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
Expand All @@ -408,11 +550,69 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
bounds, structured, implicit, dataClause,
baseAddr.getType());
dataOperands.push_back(op.getAccPtr());
if (setDeclareAttr)
addDeclareAttr(builder, op.getVarPtr().getDefiningOp(), dataClause);
}
}

template <typename EntryOp, typename ExitOp>
static void genDeclareDataOperandOperations(
const Fortran::parser::AccObjectList &objectList,
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
mlir::acc::DataClause dataClause, bool structured, bool implicit) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
mlir::Value baseAddr = gatherDataOperandAddrAndBounds(
converter, builder, semanticsContext, stmtCtx, accObject,
operandLocation, asFortran, bounds);
EntryOp op = createDataEntryOp<EntryOp>(
builder, operandLocation, baseAddr, asFortran, bounds, structured,
implicit, dataClause, baseAddr.getType());
dataOperands.push_back(op.getAccPtr());
addDeclareAttr(builder, op.getVarPtr().getDefiningOp(), dataClause);
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseAddr.getType()))) {
mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
modBuilder.setInsertionPointAfter(builder.getFunction());
std::string prefix =
converter.mangleName(getSymbolFromAccObject(accObject));
createDeclareAllocFuncWithArg<EntryOp>(
modBuilder, builder, operandLocation, baseAddr.getType(), prefix,
asFortran, dataClause);
if constexpr (!std::is_same_v<EntryOp, ExitOp>)
createDeclareDeallocFuncWithArg<ExitOp>(
modBuilder, builder, operandLocation, baseAddr.getType(), prefix,
asFortran, dataClause);
}
}
}

template <typename EntryOp, typename ExitOp, typename Clause>
static void genDeclareDataOperandOperationsWithModifier(
const Clause *x, Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
Fortran::parser::AccDataModifier::Modifier mod,
llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
const mlir::acc::DataClause clause,
const mlir::acc::DataClause clauseWithModifier) {
const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
const auto &accObjectList =
std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
const auto &modifier =
std::get<std::optional<Fortran::parser::AccDataModifier>>(
listWithModifier.t);
mlir::acc::DataClause dataClause =
(modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
genDeclareDataOperandOperations<EntryOp, ExitOp>(
accObjectList, converter, semanticsContext, stmtCtx, dataClauseOperands,
dataClause,
/*structured=*/true, /*implicit=*/false);
}

template <typename EntryOp, typename ExitOp>
static void genDataExitOperations(fir::FirOpBuilder &builder,
llvm::SmallVector<mlir::Value> operands,
Expand Down Expand Up @@ -1058,18 +1258,6 @@ createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
return op;
}

template <typename Op>
static Op
createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc,
const llvm::SmallVectorImpl<mlir::Value> &operands,
const llvm::SmallVectorImpl<int32_t> &operandSegments) {
llvm::ArrayRef<mlir::Type> argTy;
Op op = builder.create<Op>(loc, argTy, operands);
op->setAttr(Op::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr(operandSegments));
return op;
}

static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AccClause::Async *asyncClause,
mlir::Value &async, bool &addAsyncAttr,
Expand Down Expand Up @@ -2349,20 +2537,6 @@ static void createDeclareGlobalOp(mlir::OpBuilder &modBuilder,
modBuilder.setInsertionPointAfter(declareGlobalOp);
}

static mlir::func::FuncOp createDeclareFunc(mlir::OpBuilder &modBuilder,
fir::FirOpBuilder &builder,
mlir::Location loc,
llvm::StringRef funcName) {
auto funcTy = mlir::FunctionType::get(modBuilder.getContext(), {}, {});
auto funcOp = modBuilder.create<mlir::func::FuncOp>(loc, funcName, funcTy);
funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), {}, {});
builder.setInsertionPointToEnd(&funcOp.getRegion().back());
builder.create<mlir::func::ReturnOp>(loc);
builder.setInsertionPointToStart(&funcOp.getRegion().back());
return funcOp;
}

template <typename EntryOp>
static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
fir::FirOpBuilder &builder,
Expand Down Expand Up @@ -2556,10 +2730,11 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
if (const auto *copyClause =
std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
auto crtDataStart = dataClauseOperands.size();
genDataOperandOperations<mlir::acc::CopyinOp>(
genDeclareDataOperandOperations<mlir::acc::CopyinOp,
mlir::acc::CopyoutOp>(
copyClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_copy,
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
/*structured=*/true, /*implicit=*/false);
copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
dataClauseOperands.end());
} else if (const auto *createClause =
Expand All @@ -2569,26 +2744,28 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
const auto &accObjectList =
std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
auto crtDataStart = dataClauseOperands.size();
genDataOperandOperations<mlir::acc::CreateOp>(
genDeclareDataOperandOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
accObjectList, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_create,
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
/*structured=*/true, /*implicit=*/false);
createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
dataClauseOperands.end());
} else if (const auto *presentClause =
std::get_if<Fortran::parser::AccClause::Present>(
&clause.u)) {
genDataOperandOperations<mlir::acc::PresentOp>(
genDeclareDataOperandOperations<mlir::acc::PresentOp,
mlir::acc::PresentOp>(
presentClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_present,
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
/*structured=*/true, /*implicit=*/false);
} else if (const auto *copyinClause =
std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
genDataOperandOperationsWithModifier<mlir::acc::CopyinOp>(
genDeclareDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
mlir::acc::DeleteOp>(
copyinClause, converter, semanticsContext, stmtCtx,
Fortran::parser::AccDataModifier::Modifier::ReadOnly,
dataClauseOperands, mlir::acc::DataClause::acc_copyin,
mlir::acc::DataClause::acc_copyin_readonly, /*setDeclareAttr=*/true);
mlir::acc::DataClause::acc_copyin_readonly);
} else if (const auto *copyoutClause =
std::get_if<Fortran::parser::AccClause::Copyout>(
&clause.u)) {
Expand All @@ -2597,34 +2774,38 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
const auto &accObjectList =
std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
auto crtDataStart = dataClauseOperands.size();
genDataOperandOperations<mlir::acc::CreateOp>(
genDeclareDataOperandOperations<mlir::acc::CreateOp,
mlir::acc::CopyoutOp>(
accObjectList, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_copyout,
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
/*structured=*/true, /*implicit=*/false);
copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
dataClauseOperands.end());
} else if (const auto *devicePtrClause =
std::get_if<Fortran::parser::AccClause::Deviceptr>(
&clause.u)) {
genDataOperandOperations<mlir::acc::DevicePtrOp>(
genDeclareDataOperandOperations<mlir::acc::DevicePtrOp,
mlir::acc::DevicePtrOp>(
devicePtrClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
/*structured=*/true, /*implicit=*/false);
} else if (const auto *linkClause =
std::get_if<Fortran::parser::AccClause::Link>(&clause.u)) {
genDataOperandOperations<mlir::acc::DeclareLinkOp>(
genDeclareDataOperandOperations<mlir::acc::DeclareLinkOp,
mlir::acc::DeclareLinkOp>(
linkClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_declare_link,
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
/*structured=*/true, /*implicit=*/false);
} else if (const auto *deviceResidentClause =
std::get_if<Fortran::parser::AccClause::DeviceResident>(
&clause.u)) {
auto crtDataStart = dataClauseOperands.size();
genDataOperandOperations<mlir::acc::DeclareDeviceResidentOp>(
genDeclareDataOperandOperations<mlir::acc::DeclareDeviceResidentOp,
mlir::acc::DeleteOp>(
deviceResidentClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands,
mlir::acc::DataClause::acc_declare_device_resident,
/*structured=*/true, /*implicit=*/false, /*setDeclareAttr=*/true);
/*structured=*/true, /*implicit=*/false);
deviceResidentEntryOperands.append(
dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end());
} else {
Expand Down

0 comments on commit 9a96b0a

Please sign in to comment.