Skip to content

Commit

Permalink
[flang][openacc] Lower loop directive to the new acc.loop op design (#…
Browse files Browse the repository at this point in the history
…65417)

acc.loop was redesigned in https://reviews.llvm.org/D159229. This patch
updates the lowering to match the new op.

DO CONCURRENT construct will be added in a follow up patch.

Note that the pre-commit ci will fail until D159229 is merged. 

Depends on #67355
  • Loading branch information
clementval committed Jan 22, 2024
1 parent 3eb4178 commit 5062a17
Show file tree
Hide file tree
Showing 11 changed files with 511 additions and 514 deletions.
5 changes: 5 additions & 0 deletions flang/include/flang/Lower/OpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class FirOpBuilder;

namespace Fortran {
namespace parser {
struct AccClauseList;
struct OpenACCConstruct;
struct OpenACCDeclarativeConstruct;
struct OpenACCRoutineConstruct;
Expand Down Expand Up @@ -64,6 +65,8 @@ static constexpr llvm::StringRef declarePreDeallocSuffix =
static constexpr llvm::StringRef declarePostDeallocSuffix =
"_acc_declare_update_desc_post_dealloc";

static constexpr llvm::StringRef privatizationRecipePrefix = "privatization";

mlir::Value genOpenACCConstruct(AbstractConverter &,
Fortran::semantics::SemanticsContext &,
pft::Evaluation &,
Expand Down Expand Up @@ -113,6 +116,8 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &,
void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
mlir::Location);

int64_t getCollapseValue(const Fortran::parser::AccClauseList &);

bool isInOpenACCLoop(fir::FirOpBuilder &);

void setInsertionPointAfterOpenACCLoopIfInside(fir::FirOpBuilder &);
Expand Down
36 changes: 33 additions & 3 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2391,13 +2391,43 @@ class FirConverter : public Fortran::lower::AbstractConverter {
localSymbols.pushScope();
mlir::Value exitCond = genOpenACCConstruct(
*this, bridge.getSemanticsContext(), getEval(), acc);
for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())

const Fortran::parser::OpenACCLoopConstruct *accLoop =
std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
const Fortran::parser::OpenACCCombinedConstruct *accCombined =
std::get_if<Fortran::parser::OpenACCCombinedConstruct>(&acc.u);

Fortran::lower::pft::Evaluation *curEval = &getEval();

if (accLoop || accCombined) {
int64_t collapseValue;
if (accLoop) {
const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
collapseValue = Fortran::lower::getCollapseValue(clauseList);
} else if (accCombined) {
const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
std::get<Fortran::parser::AccBeginCombinedDirective>(
accCombined->t);
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
collapseValue = Fortran::lower::getCollapseValue(clauseList);
}

if (curEval->lowerAsStructured()) {
curEval = &curEval->getFirstNestedEvaluation();
for (int64_t i = 1; i < collapseValue; i++)
curEval = &*std::next(curEval->getNestedEvaluations().begin());
}
}

for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
genFIR(e);
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);

const Fortran::parser::OpenACCLoopConstruct *accLoop =
std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
if (accLoop && exitCond) {
Fortran::lower::pft::FunctionLikeUnit *funit =
getEval().getOwningProcedure();
Expand Down
189 changes: 151 additions & 38 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,8 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
std::string recipeName =
fir::getTypeAsString(retTy, converter.getKindMap(), "privatization");
fir::getTypeAsString(retTy, converter.getKindMap(),
Fortran::lower::privatizationRecipePrefix);
recipe = Fortran::lower::createOrGetPrivateRecipe(builder, recipeName,
operandLocation, retTy);
auto op = createDataEntryOp<mlir::acc::PrivateOp>(
Expand Down Expand Up @@ -1412,15 +1413,17 @@ static void addOperand(llvm::SmallVectorImpl<mlir::Value> &operands,
}

template <typename Op, typename Terminator>
static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
Fortran::lower::pft::Evaluation &eval,
const llvm::SmallVectorImpl<mlir::Value> &operands,
const llvm::SmallVectorImpl<int32_t> &operandSegments,
bool outerCombined = false,
llvm::SmallVector<mlir::Type> retTy = {},
mlir::Value yieldValue = {}) {
static Op
createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Location returnLoc, Fortran::lower::pft::Evaluation &eval,
const llvm::SmallVectorImpl<mlir::Value> &operands,
const llvm::SmallVectorImpl<int32_t> &operandSegments,
bool outerCombined = false,
llvm::SmallVector<mlir::Type> retTy = {},
mlir::Value yieldValue = {}, mlir::TypeRange argsTy = {},
llvm::SmallVector<mlir::Location> locs = {}) {
Op op = builder.create<Op>(loc, retTy, operands);
builder.createBlock(&op.getRegion());
builder.createBlock(&op.getRegion(), op.getRegion().end(), argsTy, locs);
mlir::Block &block = op.getRegion().back();
builder.setInsertionPointToStart(&block);

Expand All @@ -1439,13 +1442,13 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,

if (yieldValue) {
if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
Terminator yieldOp = builder.create<Terminator>(loc, yieldValue);
Terminator yieldOp = builder.create<Terminator>(returnLoc, yieldValue);
yieldValue.getDefiningOp()->moveBefore(yieldOp);
} else {
builder.create<Terminator>(loc);
builder.create<Terminator>(returnLoc);
}
} else {
builder.create<Terminator>(loc);
builder.create<Terminator>(returnLoc);
}
builder.setInsertionPointToStart(&block);
return op;
Expand Down Expand Up @@ -1595,18 +1598,28 @@ genWaitClause(Fortran::lower::AbstractConverter &converter,
}
}

mlir::Type getTypeFromIvTypeSize(fir::FirOpBuilder &builder,
const Fortran::semantics::Symbol &ivSym) {
std::size_t ivTypeSize = ivSym.size();
if (ivTypeSize == 0)
llvm::report_fatal_error("unexpected induction variable size");
// ivTypeSize is in bytes and IntegerType needs to be in bits.
return builder.getIntegerType(ivTypeSize * 8);
}

static mlir::acc::LoopOp
createLoopOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::DoConstruct &outerDoConstruct,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::AccClauseList &accClauseList,
bool needEarlyReturnHandling = false) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
llvm::SmallVector<mlir::Value> tileOperands, privateOperands,
llvm::SmallVector<mlir::Value> tileOperands, privateOperands, ivPrivate,
reductionOperands, cacheOperands, vectorOperands, workerNumOperands,
gangOperands;
gangOperands, lowerbounds, upperbounds, steps;
llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
llvm::SmallVector<int32_t> tileOperandsSegments, gangOperandsSegments;
llvm::SmallVector<int64_t> collapseValues;
Expand All @@ -1623,6 +1636,74 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
builder.getContext(), mlir::acc::DeviceType::None));

llvm::SmallVector<mlir::Type> ivTypes;
llvm::SmallVector<mlir::Location> ivLocs;
llvm::SmallVector<bool> inclusiveBounds;

if (outerDoConstruct.IsDoConcurrent())
TODO(currentLocation, "OpenACC loop with DO CONCURRENT");

llvm::SmallVector<mlir::Location> locs;
locs.push_back(currentLocation); // Location of the directive

int64_t collapseValue = Fortran::lower::getCollapseValue(accClauseList);
Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation();
for (unsigned i = 0; i < collapseValue; ++i) {
const Fortran::parser::LoopControl *loopControl;
if (i == 0) {
loopControl = &*outerDoConstruct.GetLoopControl();
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(outerDoConstruct)));
} else {
auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
assert(doCons && "expect do construct");
loopControl = &*doCons->GetLoopControl();
locs.push_back(
converter.genLocation(Fortran::parser::FindSourceLocation(*doCons)));
}

const Fortran::parser::LoopControl::Bounds *bounds =
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
assert(bounds && "Expected bounds on the loop construct");
lowerbounds.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
upperbounds.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
if (bounds->step)
steps.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
else // If `step` is not present, assume it as `1`.
steps.push_back(builder.createIntegerConstant(
currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1));

Fortran::semantics::Symbol &ivSym =
bounds->name.thing.symbol->GetUltimate();

mlir::Type ivTy = getTypeFromIvTypeSize(builder, ivSym);
mlir::Value ivValue = converter.getSymbolAddress(ivSym);
ivTypes.push_back(ivTy);
ivLocs.push_back(currentLocation);
std::string recipeName =
fir::getTypeAsString(ivValue.getType(), converter.getKindMap(),
Fortran::lower::privatizationRecipePrefix);
auto recipe = Fortran::lower::createOrGetPrivateRecipe(
builder, recipeName, currentLocation, ivValue.getType());
std::stringstream asFortran;
auto op = createDataEntryOp<mlir::acc::PrivateOp>(
builder, currentLocation, ivValue, asFortran, {}, true,
/*implicit=*/true, mlir::acc::DataClause::acc_private,
ivValue.getType());

privateOperands.push_back(op.getAccPtr());
ivPrivate.push_back(op.getAccPtr());
privatizations.push_back(mlir::SymbolRefAttr::get(
builder.getContext(), recipe.getSymName().str()));
inclusiveBounds.push_back(true);
converter.bindSymbol(ivSym, op.getAccPtr());
if (i < collapseValue - 1)
crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
}

for (const Fortran::parser::AccClause &clause : accClauseList.v) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
if (const auto *gangClause =
Expand Down Expand Up @@ -1776,6 +1857,9 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
// Prepare the operand segment size attribute and the operands value range.
llvm::SmallVector<mlir::Value> operands;
llvm::SmallVector<int32_t> operandSegments;
addOperands(operands, operandSegments, lowerbounds);
addOperands(operands, operandSegments, upperbounds);
addOperands(operands, operandSegments, steps);
addOperands(operands, operandSegments, gangOperands);
addOperands(operands, operandSegments, workerNumOperands);
addOperands(operands, operandSegments, vectorOperands);
Expand All @@ -1793,8 +1877,15 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
}

auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
builder, currentLocation, eval, operands, operandSegments,
/*outerCombined=*/false, retTy, yieldValue);
builder, builder.getFusedLoc(locs), currentLocation, eval, operands,
operandSegments, /*outerCombined=*/false, retTy, yieldValue, ivTypes,
ivLocs);

for (auto [arg, value] : llvm::zip(
loopOp.getLoopRegions().front()->front().getArguments(), ivPrivate))
builder.create<fir::StoreOp>(currentLocation, arg, value);

loopOp.setInclusiveUpperbound(inclusiveBounds);

if (!gangDeviceTypes.empty())
loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes));
Expand Down Expand Up @@ -1881,15 +1972,19 @@ genACC(Fortran::lower::AbstractConverter &converter,
converter.genLocation(beginLoopDirective.source);
Fortran::lower::StatementContext stmtCtx;

if (loopDirective.v == llvm::acc::ACCD_loop) {
const auto &accClauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
auto loopOp =
createLoopOp(converter, currentLocation, eval, semanticsContext,
stmtCtx, accClauseList, needEarlyExitHandling);
if (needEarlyExitHandling)
return loopOp.getResult(0);
}
assert(loopDirective.v == llvm::acc::ACCD_loop &&
"Unsupported OpenACC loop construct");

const auto &accClauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
const auto &outerDoConstruct =
std::get<std::optional<Fortran::parser::DoConstruct>>(loopConstruct.t);
auto loopOp = createLoopOp(converter, currentLocation, semanticsContext,
stmtCtx, *outerDoConstruct, eval, accClauseList,
needEarlyExitHandling);
if (needEarlyExitHandling)
return loopOp.getResult(0);

return mlir::Value{};
}

Expand Down Expand Up @@ -2186,12 +2281,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
Op computeOp;
if constexpr (std::is_same_v<Op, mlir::acc::KernelsOp>)
computeOp = createRegionOp<Op, mlir::acc::TerminatorOp>(
builder, currentLocation, eval, operands, operandSegments,
outerCombined);
builder, currentLocation, currentLocation, eval, operands,
operandSegments, outerCombined);
else
computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
builder, currentLocation, eval, operands, operandSegments,
outerCombined);
builder, currentLocation, currentLocation, eval, operands,
operandSegments, outerCombined);

if (addSelfAttr)
computeOp.setSelfAttrAttr(builder.getUnitAttr());
Expand Down Expand Up @@ -2397,7 +2492,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
return;

auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
builder, currentLocation, eval, operands, operandSegments);
builder, currentLocation, currentLocation, eval, operands,
operandSegments);

if (!asyncDeviceTypes.empty())
dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
Expand Down Expand Up @@ -2486,7 +2582,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,

auto hostDataOp =
createRegionOp<mlir::acc::HostDataOp, mlir::acc::TerminatorOp>(
builder, currentLocation, eval, operands, operandSegments);
builder, currentLocation, currentLocation, eval, operands,
operandSegments);

if (addIfPresentAttr)
hostDataOp.setIfPresentAttr(builder.getUnitAttr());
Expand Down Expand Up @@ -2539,6 +2636,9 @@ genACC(Fortran::lower::AbstractConverter &converter,
std::get<Fortran::parser::AccCombinedDirective>(beginCombinedDirective.t);
const auto &accClauseList =
std::get<Fortran::parser::AccClauseList>(beginCombinedDirective.t);
const auto &outerDoConstruct =
std::get<std::optional<Fortran::parser::DoConstruct>>(
combinedConstruct.t);

mlir::Location currentLocation =
converter.genLocation(beginCombinedDirective.source);
Expand All @@ -2548,20 +2648,20 @@ genACC(Fortran::lower::AbstractConverter &converter,
createComputeOp<mlir::acc::KernelsOp>(
converter, currentLocation, eval, semanticsContext, stmtCtx,
accClauseList, /*outerCombined=*/true);
createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
accClauseList);
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
*outerDoConstruct, eval, accClauseList);
} else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
createComputeOp<mlir::acc::ParallelOp>(
converter, currentLocation, eval, semanticsContext, stmtCtx,
accClauseList, /*outerCombined=*/true);
createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
accClauseList);
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
*outerDoConstruct, eval, accClauseList);
} else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
semanticsContext, stmtCtx,
accClauseList, /*outerCombined=*/true);
createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
accClauseList);
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
*outerDoConstruct, eval, accClauseList);
} else {
llvm::report_fatal_error("Unknown combined construct encountered");
}
Expand Down Expand Up @@ -3985,3 +4085,16 @@ void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
builder.createIntegerConstant(loc, builder.getI1Type(), 1);
builder.create<mlir::acc::YieldOp>(loc, yieldValue);
}

int64_t Fortran::lower::getCollapseValue(
const Fortran::parser::AccClauseList &clauseList) {
for (const Fortran::parser::AccClause &clause : clauseList.v) {
if (const auto *collapseClause =
std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
const parser::AccCollapseArg &arg = collapseClause->v;
const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)};
return *Fortran::semantics::GetIntValue(collapseValue);
}
}
return 1;
}

0 comments on commit 5062a17

Please sign in to comment.