-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][openacc] Support multiple device_type when lowering #78634
Conversation
@llvm/pr-subscribers-mlir-openacc @llvm/pr-subscribers-openacc Author: Valentin Clement (バレンタイン クレメン) (clementval) Changesroutine, data, parallel, serial, kernels and loop construct all support the device_type clause. This clause takes a list of device_type. Previously the lowering code was assuming that the list s a single item. This PR updates the lowering to handle any number of device_types. Patch is 30.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78634.diff 5 Files Affected:
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index fd89d27db74dc05..682ca06cabd6f6b 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1470,15 +1470,19 @@ genAsyncClause(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> &async,
llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
- mlir::acc::DeviceTypeAttr deviceTypeAttr,
+ llvm::SmallVector<mlir::Attribute> &deviceTypeAttrs,
Fortran::lower::StatementContext &stmtCtx) {
const auto &asyncClauseValue = asyncClause->v;
if (asyncClauseValue) { // async has a value.
- async.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
- asyncDeviceTypes.push_back(deviceTypeAttr);
+ mlir::Value asyncValue = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
+ for (auto deviceTypeAttr : deviceTypeAttrs) {
+ async.push_back(asyncValue);
+ asyncDeviceTypes.push_back(deviceTypeAttr);
+ }
} else {
- asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
+ for (auto deviceTypeAttr : deviceTypeAttrs)
+ asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
}
}
@@ -1504,10 +1508,9 @@ getDeviceType(Fortran::common::OpenACCDeviceType device) {
}
static void gatherDeviceTypeAttrs(
- fir::FirOpBuilder &builder, mlir::Location clauseLocation,
+ fir::FirOpBuilder &builder,
const Fortran::parser::AccClause::DeviceType *deviceTypeClause,
- llvm::SmallVector<mlir::Attribute> &deviceTypes,
- Fortran::lower::StatementContext &stmtCtx) {
+ llvm::SmallVector<mlir::Attribute> &deviceTypes) {
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
deviceTypeClause->v;
for (const auto &deviceTypeExpr : deviceTypeExprList.v)
@@ -1560,20 +1563,25 @@ genWaitClause(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
llvm::SmallVector<int32_t> &waitOperandsSegments,
- mlir::Value &waitDevnum, mlir::acc::DeviceTypeAttr deviceTypeAttr,
+ mlir::Value &waitDevnum,
+ llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
Fortran::lower::StatementContext &stmtCtx) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
const auto &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
- auto crtWaitOperands = waitOperands.size();
+ llvm::SmallVector<mlir::Value> waitValues;
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
- waitOperands.push_back(fir::getBase(converter.genExprValue(
+ waitValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(value), stmtCtx)));
}
- waitOperandsDeviceTypes.push_back(deviceTypeAttr);
- waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
+ for (auto deviceTypeAttr : deviceTypeAttrs) {
+ for (auto value : waitValues)
+ waitOperands.push_back(value);
+ waitOperandsDeviceTypes.push_back(deviceTypeAttr);
+ waitOperandsSegments.push_back(waitValues.size());
+ }
// TODO: move to device_type model.
const auto &waitDevnumValue =
@@ -1582,7 +1590,8 @@ genWaitClause(Fortran::lower::AbstractConverter &converter,
waitDevnum = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
} else {
- waitOnlyDeviceTypes.push_back(deviceTypeAttr);
+ for (auto deviceTypeAttr : deviceTypeAttrs)
+ waitOnlyDeviceTypes.push_back(deviceTypeAttr);
}
}
@@ -1610,91 +1619,112 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
// device_type attribute is set to `none` until a device_type clause is
// encountered.
- auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
- builder.getContext(), mlir::acc::DeviceType::None);
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+ crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), mlir::acc::DeviceType::None));
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
if (const auto *gangClause =
std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
if (gangClause->v) {
- auto crtGangOperands = gangOperands.size();
const Fortran::parser::AccGangArgList &x = *gangClause->v;
+ mlir::SmallVector<mlir::Value> gangValues;
+ mlir::SmallVector<mlir::Attribute> gangArgs;
for (const Fortran::parser::AccGangArg &gangArg : x.v) {
if (const auto *num =
std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) {
- gangOperands.push_back(fir::getBase(converter.genExprValue(
+ gangValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(num->v), stmtCtx)));
- gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+ gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
builder.getContext(), mlir::acc::GangArgType::Num));
} else if (const auto *staticArg =
std::get_if<Fortran::parser::AccGangArg::Static>(
&gangArg.u)) {
const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v;
if (sizeExpr.v) {
- gangOperands.push_back(fir::getBase(converter.genExprValue(
+ gangValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx)));
} else {
// * was passed as value and will be represented as a special
// constant.
- gangOperands.push_back(builder.createIntegerConstant(
+ gangValues.push_back(builder.createIntegerConstant(
clauseLocation, builder.getIndexType(), starCst));
}
- gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+ gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
builder.getContext(), mlir::acc::GangArgType::Static));
} else if (const auto *dim =
std::get_if<Fortran::parser::AccGangArg::Dim>(
&gangArg.u)) {
- gangOperands.push_back(fir::getBase(converter.genExprValue(
+ gangValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(dim->v), stmtCtx)));
- gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+ gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
builder.getContext(), mlir::acc::GangArgType::Dim));
}
}
- gangOperandsSegments.push_back(gangOperands.size() - crtGangOperands);
- gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ for (const auto &pair : llvm::zip(gangValues, gangArgs)) {
+ gangOperands.push_back(std::get<0>(pair));
+ gangArgTypes.push_back(std::get<1>(pair));
+ }
+ gangOperandsSegments.push_back(gangValues.size());
+ gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else {
- gangDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ gangDeviceTypes.push_back(crtDeviceTypeAttr);
}
} else if (const auto *workerClause =
std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
if (workerClause->v) {
- workerNumOperands.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)));
- workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ mlir::Value workerNumValue = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx));
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ workerNumOperands.push_back(workerNumValue);
+ workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else {
- workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
}
} else if (const auto *vectorClause =
std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
if (vectorClause->v) {
- vectorOperands.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)));
- vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ mlir::Value vectorValue = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx));
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ vectorOperands.push_back(vectorValue);
+ vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else {
- vectorDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ vectorDeviceTypes.push_back(crtDeviceTypeAttr);
}
} else if (const auto *tileClause =
std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
- auto crtTileOperands = tileOperands.size();
+ llvm::SmallVector<mlir::Value> tileValues;
for (const auto &accTileExpr : accTileExprList.v) {
const auto &expr =
std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
accTileExpr.t);
if (expr) {
- tileOperands.push_back(fir::getBase(converter.genExprValue(
+ tileValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*expr), stmtCtx)));
} else {
// * was passed as value and will be represented as a special
// constant.
mlir::Value tileStar = builder.createIntegerConstant(
clauseLocation, builder.getIntegerType(32), starCst);
- tileOperands.push_back(tileStar);
+ tileValues.push_back(tileStar);
}
}
- tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
- tileOperandsSegments.push_back(tileOperands.size() - crtTileOperands);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ for (auto value : tileValues)
+ tileOperands.push_back(value);
+ tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+ tileOperandsSegments.push_back(tileValues.size());
+ }
} else if (const auto *privateClause =
std::get_if<Fortran::parser::AccClause::Private>(
&clause.u)) {
@@ -1707,21 +1737,20 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
reductionOperands, reductionRecipes);
} else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
- seqDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ seqDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (std::get_if<Fortran::parser::AccClause::Independent>(
&clause.u)) {
- independentDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ independentDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
- autoDeviceTypes.push_back(crtDeviceTypeAttr);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes)
+ autoDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
- deviceTypeClause->v;
- assert(deviceTypeExprList.v.size() == 1 &&
- "expect only one device_type expr");
- crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
- builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
+ crtDeviceTypes.clear();
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
} else if (const auto *collapseClause =
std::get_if<Fortran::parser::AccClause::Collapse>(
&clause.u)) {
@@ -1729,14 +1758,18 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
const auto &force = std::get<bool>(arg.t);
if (force)
TODO(clauseLocation, "OpenACC collapse force modifier");
+
const auto &intExpr =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
const auto *expr = Fortran::semantics::GetExpr(intExpr);
const std::optional<int64_t> collapseValue =
Fortran::evaluate::ToInt64(*expr);
assert(collapseValue && "expect integer value for the collapse clause");
- collapseValues.push_back(*collapseValue);
- collapseDeviceTypes.push_back(crtDeviceTypeAttr);
+
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ collapseValues.push_back(*collapseValue);
+ collapseDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
}
}
@@ -1923,45 +1956,56 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
// device_type attribute is set to `none` until a device_type clause is
// encountered.
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
builder.getContext(), mlir::acc::DeviceType::None);
+ crtDeviceTypes.push_back(crtDeviceTypeAttr);
- // Lower clauses values mapped to operands.
- // Keep track of each group of operands separatly as clauses can appear
+ // Lower clauses values mapped to operands and array attributes.
+ // Keep track of each group of operands separately as clauses can appear
// more than once.
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
+ asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
genWaitClause(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
- stmtCtx);
+ waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
- auto crtNumGangs = numGangs.size();
+ llvm::SmallVector<mlir::Value> numGangValues;
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
- numGangs.push_back(fir::getBase(converter.genExprValue(
+ numGangValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(expr), stmtCtx)));
- numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
- numGangsSegments.push_back(numGangs.size() - crtNumGangs);
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ for (auto value : numGangValues)
+ numGangs.push_back(value);
+ numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
+ numGangsSegments.push_back(numGangValues.size());
+ }
} else if (const auto *numWorkersClause =
std::get_if<Fortran::parser::AccClause::NumWorkers>(
&clause.u)) {
- numWorkers.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
- numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
+ mlir::Value numWorkerValue = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ numWorkers.push_back(numWorkerValue);
+ numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else if (const auto *vectorLengthClause =
std::get_if<Fortran::parser::AccClause::VectorLength>(
&clause.u)) {
- vectorLength.push_back(fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
- vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
+ mlir::Value vectorLengthValue = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
+ for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+ vectorLength.push_back(vectorLengthValue);
+ vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
+ }
} else if (const auto *ifClause =
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -2115,12 +2159,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
- deviceTypeClause->v;
- assert(deviceTypeExprList.v.size() == 1 &&
- "expect only one device_type expr");
- crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
- builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
+ crtDeviceTypes.clear();
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
}
}
@@ -2239,10 +2279,11 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
// device_type attribute is set to `none` until a device_type clause is
// encountered.
- auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
- builder.getContext(), mlir::acc::DeviceType::None);
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+ crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), mlir::acc::DeviceType::None));
- // Lower clauses values mapped to operands.
+ // Lower clauses values mapped to operands and array attributes.
// Keep track of each group of operands separately as clauses can appear
// more than once.
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -2323,19 +2364,23 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
+ asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
genWaitClause(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
- stmtCtx);
+ waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
} else if(const auto *defaultClause =
std::get_if<...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with minor remark for the LIT test.
routine, data, parallel, serial, kernels and loop construct all support the device_type clause. This clause takes a list of device_type. Previously the lowering code was assuming that the list s a single item. This PR updates the lowering to handle any number of device_types.