Skip to content

Commit

Permalink
[flang][openacc] Lower data clause on compute construct to data opera…
Browse files Browse the repository at this point in the history
…nd ops

This patch lowers the data clause on the OpenACC compute construct
to their corresponding acc data operand operation.
The decomposition is the same as in D149673.

Note that `private` and `firstprivate` are not lowered to data operand operation as they do not have one and will likely have dedicated design/process.

Depends on D149673

Reviewed By: razvanlupusoru, jeanPerier

Differential Revision: https://reviews.llvm.org/D149785
  • Loading branch information
clementval committed May 4, 2023
1 parent c2bef38 commit ac8c032
Show file tree
Hide file tree
Showing 7 changed files with 1,524 additions and 1,278 deletions.
159 changes: 93 additions & 66 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,28 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
}
}

template <typename Op, typename Clause>
static void genDataOperandOperationsWithModifier(
const Clause *x, Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
Fortran::parser::AccDataModifier::Modifier mod,
llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
const mlir::acc::DataClause clause,
const mlir::acc::DataClause clauseWithModifier) {
const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
const auto &accObjectList =
std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
const auto &modifier =
std::get<std::optional<Fortran::parser::AccDataModifier>>(
listWithModifier.t);
mlir::acc::DataClause dataClause =
(modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
genDataOperandOperations<Op>(accObjectList, converter, semanticsContext,
stmtCtx, dataClauseOperands, dataClause,
/*structured=*/true);
}

template <typename Op>
static Op
createComputeOp(Fortran::lower::AbstractConverter &converter,
Expand All @@ -799,11 +821,13 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
mlir::Value ifCond;
mlir::Value selfCond;
mlir::Value waitDevnum;
llvm::SmallVector<mlir::Value, 2> waitOperands, reductionOperands,
copyOperands, copyinOperands, copyinReadonlyOperands, copyoutOperands,
copyoutZeroOperands, createOperands, createZeroOperands, noCreateOperands,
presentOperands, devicePtrOperands, attachOperands, firstprivateOperands,
privateOperands, dataClauseOperands;
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
dataClauseOperands;

// TODO: need to more work/design.
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
firstprivateOperands;

// Async, wait and self clause have optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
Expand All @@ -812,7 +836,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
bool addWaitAttr = false;
bool addSelfAttr = false;

fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
fir::FirOpBuilder &builder = converter.getFirOpBuilder();

// Lower clauses values mapped to operands.
// Keep track of each group of operands separatly as clauses can appear
Expand Down Expand Up @@ -855,8 +879,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
if (*optCondition) {
mlir::Value cond = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*optCondition), stmtCtx));
selfCond = firOpBuilder.createConvert(
clauseLocation, firOpBuilder.getI1Type(), cond);
selfCond = builder.createConvert(clauseLocation,
builder.getI1Type(), cond);
}
} else if (const auto *accClauseList =
std::get_if<Fortran::parser::AccObjectList>(
Expand All @@ -868,8 +892,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
std::get_if<Fortran::parser::Designator>(&accObject.u)) {
if (const auto *name = getDesignatorNameIfDataRef(*designator)) {
auto cond = converter.getSymbolAddress(*name->symbol);
selfCond = firOpBuilder.createConvert(
clauseLocation, firOpBuilder.getI1Type(), cond);
selfCond = builder.createConvert(clauseLocation,
builder.getI1Type(), cond);
}
}
}
Expand All @@ -879,46 +903,62 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
}
} else if (const auto *copyClause =
std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
genObjectList(copyClause->v, converter, semanticsContext, stmtCtx,
copyOperands);
genDataOperandOperations<mlir::acc::CopyinOp>(
copyClause->v, converter, semanticsContext, stmtCtx,
copyEntryOperands, mlir::acc::DataClause::acc_copy,
/*structured=*/true);
} else if (const auto *copyinClause =
std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
genObjectListWithModifier<Fortran::parser::AccClause::Copyin>(
genDataOperandOperationsWithModifier<mlir::acc::CopyinOp,
Fortran::parser::AccClause::Copyin>(
copyinClause, converter, semanticsContext, stmtCtx,
Fortran::parser::AccDataModifier::Modifier::ReadOnly,
copyinReadonlyOperands, copyinOperands);
dataClauseOperands, mlir::acc::DataClause::acc_copyin,
mlir::acc::DataClause::acc_copyin_readonly);
} else if (const auto *copyoutClause =
std::get_if<Fortran::parser::AccClause::Copyout>(
&clause.u)) {
genObjectListWithModifier<Fortran::parser::AccClause::Copyout>(
genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
Fortran::parser::AccClause::Copyout>(
copyoutClause, converter, semanticsContext, stmtCtx,
Fortran::parser::AccDataModifier::Modifier::Zero, copyoutZeroOperands,
copyoutOperands);
Fortran::parser::AccDataModifier::Modifier::ReadOnly,
copyoutEntryOperands, mlir::acc::DataClause::acc_copyout,
mlir::acc::DataClause::acc_copyout_zero);
} else if (const auto *createClause =
std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
genObjectListWithModifier<Fortran::parser::AccClause::Create>(
genDataOperandOperationsWithModifier<mlir::acc::CreateOp,
Fortran::parser::AccClause::Create>(
createClause, converter, semanticsContext, stmtCtx,
Fortran::parser::AccDataModifier::Modifier::Zero, createZeroOperands,
createOperands);
Fortran::parser::AccDataModifier::Modifier::Zero, createEntryOperands,
mlir::acc::DataClause::acc_create,
mlir::acc::DataClause::acc_create_zero);
} else if (const auto *noCreateClause =
std::get_if<Fortran::parser::AccClause::NoCreate>(
&clause.u)) {
genObjectList(noCreateClause->v, converter, semanticsContext, stmtCtx,
noCreateOperands);
genDataOperandOperations<mlir::acc::NoCreateOp>(
noCreateClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_no_create,
/*structured=*/true);
} else if (const auto *presentClause =
std::get_if<Fortran::parser::AccClause::Present>(
&clause.u)) {
genObjectList(presentClause->v, converter, semanticsContext, stmtCtx,
presentOperands);
genDataOperandOperations<mlir::acc::PresentOp>(
presentClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_present,
/*structured=*/true);
} else if (const auto *devicePtrClause =
std::get_if<Fortran::parser::AccClause::Deviceptr>(
&clause.u)) {
genObjectList(devicePtrClause->v, converter, semanticsContext, stmtCtx,
devicePtrOperands);
genDataOperandOperations<mlir::acc::DevicePtrOp>(
devicePtrClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
/*structured=*/true);
} else if (const auto *attachClause =
std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
genObjectList(attachClause->v, converter, semanticsContext, stmtCtx,
attachOperands);
genDataOperandOperations<mlir::acc::AttachOp>(
attachClause->v, converter, semanticsContext, stmtCtx,
attachEntryOperands, mlir::acc::DataClause::acc_attach,
/*structured=*/true);
} else if (const auto *privateClause =
std::get_if<Fortran::parser::AccClause::Private>(
&clause.u)) {
Expand All @@ -934,6 +974,11 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
}
}

dataClauseOperands.append(attachEntryOperands);
dataClauseOperands.append(copyEntryOperands);
dataClauseOperands.append(copyoutEntryOperands);
dataClauseOperands.append(createEntryOperands);

// Prepare the operand segment size attribute and the operands value range.
llvm::SmallVector<mlir::Value, 8> operands;
llvm::SmallVector<int32_t, 8> operandSegments;
Expand All @@ -948,17 +993,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
addOperand(operands, operandSegments, selfCond);
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>)
addOperands(operands, operandSegments, reductionOperands);
addOperands(operands, operandSegments, copyOperands);
addOperands(operands, operandSegments, copyinOperands);
addOperands(operands, operandSegments, copyinReadonlyOperands);
addOperands(operands, operandSegments, copyoutOperands);
addOperands(operands, operandSegments, copyoutZeroOperands);
addOperands(operands, operandSegments, createOperands);
addOperands(operands, operandSegments, createZeroOperands);
addOperands(operands, operandSegments, noCreateOperands);
addOperands(operands, operandSegments, presentOperands);
addOperands(operands, operandSegments, devicePtrOperands);
addOperands(operands, operandSegments, attachOperands);
operandSegments.append({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
addOperands(operands, operandSegments, privateOperands);
addOperands(operands, operandSegments, firstprivateOperands);
Expand All @@ -968,41 +1003,33 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
Op computeOp;
if constexpr (std::is_same_v<Op, mlir::acc::KernelsOp>)
computeOp = createRegionOp<Op, mlir::acc::TerminatorOp>(
firOpBuilder, currentLocation, operands, operandSegments);
builder, currentLocation, operands, operandSegments);
else
computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
firOpBuilder, currentLocation, operands, operandSegments);
builder, currentLocation, operands, operandSegments);

if (addAsyncAttr)
computeOp.setAsyncAttrAttr(firOpBuilder.getUnitAttr());
computeOp.setAsyncAttrAttr(builder.getUnitAttr());
if (addWaitAttr)
computeOp.setWaitAttrAttr(firOpBuilder.getUnitAttr());
computeOp.setWaitAttrAttr(builder.getUnitAttr());
if (addSelfAttr)
computeOp.setSelfAttrAttr(firOpBuilder.getUnitAttr());
computeOp.setSelfAttrAttr(builder.getUnitAttr());

return computeOp;
}
auto insPt = builder.saveInsertionPoint();
builder.setInsertionPointAfter(computeOp);

template <typename Op, typename Clause>
static void genDataOperandOperationsWithModifier(
const Clause *x, Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
Fortran::parser::AccDataModifier::Modifier mod,
llvm::SmallVectorImpl<mlir::Value> &dataClauseOperands,
const mlir::acc::DataClause clause,
const mlir::acc::DataClause clauseWithModifier) {
const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
const auto &accObjectList =
std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
const auto &modifier =
std::get<std::optional<Fortran::parser::AccDataModifier>>(
listWithModifier.t);
mlir::acc::DataClause dataClause =
(modifier && (*modifier).v == mod) ? clauseWithModifier : clause;
genDataOperandOperations<Op>(accObjectList, converter, semanticsContext,
stmtCtx, dataClauseOperands, dataClause,
/*structured=*/true);
// Create the exit operations after the region.
genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
builder, copyEntryOperands, /*structured=*/true, /*implicit=*/false);
genDataExitOperations<mlir::acc::CreateOp, mlir::acc::CopyoutOp>(
builder, copyoutEntryOperands, /*structured=*/true, /*implicit=*/false);
genDataExitOperations<mlir::acc::AttachOp, mlir::acc::DetachOp>(
builder, attachEntryOperands, /*structured=*/true, /*implicit=*/false);
genDataExitOperations<mlir::acc::CreateOp, mlir::acc::DeleteOp>(
builder, createEntryOperands, /*structured=*/true, /*implicit=*/false);

builder.restoreInsertionPoint(insPt);
return computeOp;
}

static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
Expand Down

0 comments on commit ac8c032

Please sign in to comment.