-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][openacc][flang] Support wait devnum and clean async/wait IR #79525
Conversation
clementval
commented
Jan 25, 2024
- Support wait(devnum: ) with device_type support on all operations that require it
- devnum value is stored as the first value of waitOperands in its device_type sub-segment. The hasWaitDevnum attribute inform which sub-segment has a wait(devnum) value.
- Make async/wait information homogenous on compute ops, data and update op.
- Unify operands/attributes names across operations and use the same custom parser/printer
@llvm/pr-subscribers-openacc @llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) Changes
Patch is 55.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79525.diff 14 Files Affected:
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index ecfdaa5be993584..427b36a12a2df01 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -171,7 +171,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, descTy);
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
@@ -245,7 +245,7 @@ static void createDeclareDeallocFuncWithArg(
builder, loc, loadOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, loadOp.getType());
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
@@ -1559,39 +1559,44 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
}
}
-static void
-genWaitClause(Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::AccClause::Wait *waitClause,
- llvm::SmallVector<mlir::Value> &waitOperands,
- llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
- llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
- llvm::SmallVector<int32_t> &waitOperandsSegments,
- mlir::Value &waitDevnum,
- llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
- Fortran::lower::StatementContext &stmtCtx) {
+static void genWaitClauseWithDeviceType(
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::AccClause::Wait *waitClause,
+ llvm::SmallVector<mlir::Value> &waitOperands,
+ llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
+ llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
+ llvm::SmallVector<bool> &hasDevnums,
+ llvm::SmallVector<int32_t> &waitOperandsSegments,
+ llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
+ Fortran::lower::StatementContext &stmtCtx) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
+ llvm::SmallVector<mlir::Value> waitValues;
+
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+ const auto &waitDevnumValue =
+ std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+ bool hasDevnum = false;
+ if (waitDevnumValue) {
+ waitValues.push_back(fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)));
+ hasDevnum = true;
+ }
+
const auto &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
- llvm::SmallVector<mlir::Value> waitValues;
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
waitValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(value), stmtCtx)));
}
+
for (auto deviceTypeAttr : deviceTypeAttrs) {
for (auto value : waitValues)
waitOperands.push_back(value);
waitOperandsDeviceTypes.push_back(deviceTypeAttr);
waitOperandsSegments.push_back(waitValues.size());
+ hasDevnums.push_back(hasDevnum);
}
-
- // TODO: move to device_type model.
- const auto &waitDevnumValue =
- std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
- if (waitDevnumValue)
- waitDevnum = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
} else {
for (auto deviceTypeAttr : deviceTypeAttrs)
waitOnlyDeviceTypes.push_back(deviceTypeAttr);
@@ -2093,12 +2098,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
+ llvm::SmallVector<bool> hasWaitDevnums;
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
firstprivateOperands;
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
reductionRecipes;
- mlir::Value waitDevnum; // TODO not yet implemented on compute op.
// Self clause has optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
@@ -2128,9 +2133,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands,
- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+ genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ hasWaitDevnums, waitOperandsSegments,
+ crtDeviceTypes, stmtCtx);
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
@@ -2372,7 +2378,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
builder.getDenseI32ArrayAttr(numGangsSegments));
}
if (!asyncDeviceTypes.empty())
- computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+ computeOp.setAsyncOperandsDeviceTypeAttr(
+ builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
@@ -2382,6 +2389,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
if (!waitOperandsSegments.empty())
computeOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
+ if (!hasWaitDevnums.empty())
+ computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
if (!waitOnlyDeviceTypes.empty())
computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
@@ -2427,6 +2436,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> waitOperandsSegments;
+ llvm::SmallVector<bool> hasWaitDevnums;
bool hasDefaultNone = false;
bool hasDefaultPresent = false;
@@ -2523,9 +2533,10 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands,
- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+ genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ hasWaitDevnums, waitOperandsSegments,
+ crtDeviceTypes, stmtCtx);
} else if(const auto *defaultClause =
std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
@@ -2545,7 +2556,6 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<int32_t> operandSegments;
addOperand(operands, operandSegments, ifCond);
addOperands(operands, operandSegments, async);
- addOperand(operands, operandSegments, waitDevnum);
addOperands(operands, operandSegments, waitOperands);
addOperands(operands, operandSegments, dataClauseOperands);
@@ -2557,7 +2567,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
operandSegments);
if (!asyncDeviceTypes.empty())
- dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
+ dataOp.setAsyncOperandsDeviceTypeAttr(
+ builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
if (!waitOperandsDeviceTypes.empty())
@@ -2566,6 +2577,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
if (!waitOperandsSegments.empty())
dataOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
+ if (!hasWaitDevnums.empty())
+ dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
if (!waitOnlyDeviceTypes.empty())
dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
@@ -3007,6 +3020,11 @@ getArrayAttr(fir::FirOpBuilder &b,
return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
}
+static inline mlir::ArrayAttr
+getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector<bool> &values) {
+ return values.empty() ? nullptr : b.getBoolArrayAttr(values);
+}
+
static inline mlir::DenseI32ArrayAttr
getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
llvm::SmallVector<int32_t> &values) {
@@ -3024,6 +3042,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
waitOperands, deviceTypeOperands, asyncOperands;
llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+ llvm::SmallVector<bool> hasWaitDevnums;
llvm::SmallVector<int32_t> waitOperandsSegments;
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
@@ -3051,9 +3070,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands,
- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
+ genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ hasWaitDevnums, waitOperandsSegments,
+ crtDeviceTypes, stmtCtx);
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
@@ -3092,9 +3112,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
builder.create<mlir::acc::UpdateOp>(
currentLocation, ifCond, asyncOperands,
getArrayAttr(builder, asyncOperandsDeviceTypes),
- getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
+ getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands,
getDenseI32ArrayAttr(builder, waitOperandsSegments),
getArrayAttr(builder, waitOperandsDeviceTypes),
+ getBoolArrayAttr(builder, hasWaitDevnums),
getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
ifPresent);
@@ -3268,7 +3289,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
builder, loc, addrOp, asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, addrOp.getType());
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
@@ -3349,7 +3370,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
builder, loc, addrOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, addrOp.getType());
- llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
+ llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
diff --git a/flang/test/Lower/OpenACC/acc-data.f90 b/flang/test/Lower/OpenACC/acc-data.f90
index 75ffd1fc3fcab2f..5b4ab5a65ee6bd2 100644
--- a/flang/test/Lower/OpenACC/acc-data.f90
+++ b/flang/test/Lower/OpenACC/acc-data.f90
@@ -164,8 +164,8 @@ subroutine acc_data
!$acc data present(a) wait
!$acc end data
-! CHECK: acc.data dataOperands(%{{.*}}) {
-! CHECK: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK: acc.data dataOperands(%{{.*}}) wait {
+! CHECK: }
!$acc data present(a) wait(1)
!$acc end data
@@ -176,7 +176,7 @@ subroutine acc_data
!$acc data present(a) wait(devnum: 0: 1)
!$acc end data
-! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) {
+! CHECK: acc.data dataOperands(%{{.*}}) wait({devnum: %{{.*}} : i32, %{{.*}} : i32}) {
! CHECK: }{{$}}
!$acc data default(none)
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 21660d5c3a13163..d2134e8d2337ce6 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -93,12 +93,12 @@ subroutine acc_kernels_loop
a(i) = b(i)
END DO
-! CHECK: acc.kernels {
+! CHECK: acc.kernels wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc kernels loop wait(1)
DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90
index 99629bb8351723b..06194edbe165498 100644
--- a/flang/test/Lower/OpenACC/acc-kernels.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels.f90
@@ -61,9 +61,9 @@ subroutine acc_kernels
!$acc kernels wait
!$acc end kernels
-! CHECK: acc.kernels {
+! CHECK: acc.kernels wait {
! CHECK: acc.terminator
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc kernels wait(1)
!$acc end kernels
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index 614d201f98e26c4..24e443a20c895d1 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -95,12 +95,12 @@ subroutine acc_parallel_loop
a(i) = b(i)
END DO
-! CHECK: acc.parallel {
+! CHECK: acc.parallel wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc parallel loop wait(1)
DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index a369bf01f259955..6b37ecb5fab9aa6 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -83,9 +83,9 @@ subroutine acc_parallel
!$acc parallel wait
!$acc end parallel
-! CHECK: acc.parallel {
+! CHECK: acc.parallel wait {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc parallel wait(1)
!$acc end parallel
diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 4134f9ff0ccf577..9c0dbff0d7dac16 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -114,12 +114,12 @@ subroutine acc_serial_loop
a(i) = b(i)
END DO
-! CHECK: acc.serial {
+! CHECK: acc.serial wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc serial loop wait(1)
DO i = 1, n
diff --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90
index d05e51d3d274f45..d0fa9436be14a14 100644
--- a/flang/test/Lower/OpenACC/acc-serial.f90
+++ b/flang/test/Lower/OpenACC/acc-serial.f90
@@ -83,9 +83,9 @@ subroutine acc_serial
!$acc serial wait
!$acc end serial
-! CHECK: acc.serial {
+! CHECK: acc.serial wait {
! CHECK: acc.yield
-! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
+! CHECK-NEXT: }
!$acc serial wait(1)
!$acc end serial
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index ba036ac92811826..f42ae1356664b67 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -101,10 +101,7 @@ subroutine acc_update
!$acc update host(a) wait(devnum: 1: queues: 1, 2)
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
-! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
-! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({devnum: %c1{{.*}} : i32, %c1{{.*}} : i32, %c2{{.*}} : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) device_type(host, nvidia) async
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 87fd587782e7c35..9398cbfdacee469 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -903,12 +903,13 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
}];
let arguments = (ins
- Variadic<IntOrIndex>:$async,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
+ Variadic<IntOrIndex>:$asyncOperands,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
Variadic<IntOrIndex>:$waitOperands,
OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
Variadic<IntOrIndex>:$numGangs,
OptionalAttr<DenseI32ArrayAttr>:$numGangsSegments,
@@ -979,13 +980,18 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
/// present.
mlir::Operation::operand_range
getWaitValues(mlir::acc::DeviceType deviceType);
+ /// Return the wait devnum value clause if present;
+ mlir::Value getWaitDevnum();
+ /// Return the wait devnum value clause for the given device_type if
+ /// present.
+ mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
- | `async` `(` custom<DeviceTypeOperands>($async,
- type($async), $asyncDeviceType) `)`
+ | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
+ type($asyncOp...
[truncated]
|
196d4d2
to
170935d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thank you!