diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index ecfdaa5be9935..427b36a12a2df 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -171,7 +171,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder, builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds, /*structured=*/false, /*implicit=*/true, mlir::acc::DataClause::acc_update_device, descTy); - llvm::SmallVector operandSegments{0, 0, 0, 0, 1}; + llvm::SmallVector operandSegments{0, 0, 0, 1}; llvm::SmallVector operands{updateDeviceOp.getResult()}; createSimpleOp(builder, loc, operands, operandSegments); @@ -245,7 +245,7 @@ static void createDeclareDeallocFuncWithArg( builder, loc, loadOp, asFortran, bounds, /*structured=*/false, /*implicit=*/true, mlir::acc::DataClause::acc_update_device, loadOp.getType()); - llvm::SmallVector operandSegments{0, 0, 0, 0, 1}; + llvm::SmallVector operandSegments{0, 0, 0, 1}; llvm::SmallVector operands{updateDeviceOp.getResult()}; createSimpleOp(builder, loc, operands, operandSegments); modBuilder.setInsertionPointAfter(postDeallocOp); @@ -1559,39 +1559,44 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter, } } -static void -genWaitClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::AccClause::Wait *waitClause, - llvm::SmallVector &waitOperands, - llvm::SmallVector &waitOperandsDeviceTypes, - llvm::SmallVector &waitOnlyDeviceTypes, - llvm::SmallVector &waitOperandsSegments, - mlir::Value &waitDevnum, - llvm::SmallVector deviceTypeAttrs, - Fortran::lower::StatementContext &stmtCtx) { +static void genWaitClauseWithDeviceType( + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClause::Wait *waitClause, + llvm::SmallVector &waitOperands, + llvm::SmallVector &waitOperandsDeviceTypes, + llvm::SmallVector &waitOnlyDeviceTypes, + llvm::SmallVector &hasDevnums, + llvm::SmallVector &waitOperandsSegments, + llvm::SmallVector deviceTypeAttrs, + Fortran::lower::StatementContext &stmtCtx) { const auto &waitClauseValue = waitClause->v; if (waitClauseValue) { // wait has a value. + llvm::SmallVector waitValues; + const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; + const auto &waitDevnumValue = + std::get>(waitArg.t); + bool hasDevnum = false; + if (waitDevnumValue) { + waitValues.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx))); + hasDevnum = true; + } + const auto &waitList = std::get>(waitArg.t); - llvm::SmallVector waitValues; for (const Fortran::parser::ScalarIntExpr &value : waitList) { waitValues.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(value), stmtCtx))); } + for (auto deviceTypeAttr : deviceTypeAttrs) { for (auto value : waitValues) waitOperands.push_back(value); waitOperandsDeviceTypes.push_back(deviceTypeAttr); waitOperandsSegments.push_back(waitValues.size()); + hasDevnums.push_back(hasDevnum); } - - // TODO: move to device_type model. - const auto &waitDevnumValue = - std::get>(waitArg.t); - if (waitDevnumValue) - waitDevnum = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)); } else { for (auto deviceTypeAttr : deviceTypeAttrs) waitOnlyDeviceTypes.push_back(deviceTypeAttr); @@ -2093,12 +2098,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes; llvm::SmallVector numGangsSegments, waitOperandsSegments; + llvm::SmallVector hasWaitDevnums; llvm::SmallVector reductionOperands, privateOperands, firstprivateOperands; llvm::SmallVector privatizations, firstPrivatizations, reductionRecipes; - mlir::Value waitDevnum; // TODO not yet implemented on compute op. // Self clause has optional values but can be present with // no value as well. When there is no value, the op has an attribute to @@ -2128,9 +2133,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx); } else if (const auto *waitClause = std::get_if(&clause.u)) { - genWaitClause(converter, waitClause, waitOperands, - waitOperandsDeviceTypes, waitOnlyDeviceTypes, - waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx); + genWaitClauseWithDeviceType(converter, waitClause, waitOperands, + waitOperandsDeviceTypes, waitOnlyDeviceTypes, + hasWaitDevnums, waitOperandsSegments, + crtDeviceTypes, stmtCtx); } else if (const auto *numGangsClause = std::get_if( &clause.u)) { @@ -2372,7 +2378,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, builder.getDenseI32ArrayAttr(numGangsSegments)); } if (!asyncDeviceTypes.empty()) - computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes)); + computeOp.setAsyncOperandsDeviceTypeAttr( + builder.getArrayAttr(asyncDeviceTypes)); if (!asyncOnlyDeviceTypes.empty()) computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes)); @@ -2382,6 +2389,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, if (!waitOperandsSegments.empty()) computeOp.setWaitOperandsSegmentsAttr( builder.getDenseI32ArrayAttr(waitOperandsSegments)); + if (!hasWaitDevnums.empty()) + computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums)); if (!waitOnlyDeviceTypes.empty()) computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes)); @@ -2427,6 +2436,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector asyncDeviceTypes, asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes; llvm::SmallVector waitOperandsSegments; + llvm::SmallVector hasWaitDevnums; bool hasDefaultNone = false; bool hasDefaultPresent = false; @@ -2523,9 +2533,10 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter, asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx); } else if (const auto *waitClause = std::get_if(&clause.u)) { - genWaitClause(converter, waitClause, waitOperands, - waitOperandsDeviceTypes, waitOnlyDeviceTypes, - waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx); + genWaitClauseWithDeviceType(converter, waitClause, waitOperands, + waitOperandsDeviceTypes, waitOnlyDeviceTypes, + hasWaitDevnums, waitOperandsSegments, + crtDeviceTypes, stmtCtx); } else if(const auto *defaultClause = std::get_if(&clause.u)) { if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none) @@ -2545,7 +2556,6 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector operandSegments; addOperand(operands, operandSegments, ifCond); addOperands(operands, operandSegments, async); - addOperand(operands, operandSegments, waitDevnum); addOperands(operands, operandSegments, waitOperands); addOperands(operands, operandSegments, dataClauseOperands); @@ -2557,7 +2567,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter, operandSegments); if (!asyncDeviceTypes.empty()) - dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes)); + dataOp.setAsyncOperandsDeviceTypeAttr( + builder.getArrayAttr(asyncDeviceTypes)); if (!asyncOnlyDeviceTypes.empty()) dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes)); if (!waitOperandsDeviceTypes.empty()) @@ -2566,6 +2577,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter, if (!waitOperandsSegments.empty()) dataOp.setWaitOperandsSegmentsAttr( builder.getDenseI32ArrayAttr(waitOperandsSegments)); + if (!hasWaitDevnums.empty()) + dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums)); if (!waitOnlyDeviceTypes.empty()) dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes)); @@ -3007,6 +3020,11 @@ getArrayAttr(fir::FirOpBuilder &b, return attributes.empty() ? nullptr : b.getArrayAttr(attributes); } +static inline mlir::ArrayAttr +getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector &values) { + return values.empty() ? nullptr : b.getBoolArrayAttr(values); +} + static inline mlir::DenseI32ArrayAttr getDenseI32ArrayAttr(fir::FirOpBuilder &builder, llvm::SmallVector &values) { @@ -3024,6 +3042,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter, waitOperands, deviceTypeOperands, asyncOperands; llvm::SmallVector asyncOperandsDeviceTypes, asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes; + llvm::SmallVector hasWaitDevnums; llvm::SmallVector waitOperandsSegments; fir::FirOpBuilder &builder = converter.getFirOpBuilder(); @@ -3051,9 +3070,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter, crtDeviceTypes, stmtCtx); } else if (const auto *waitClause = std::get_if(&clause.u)) { - genWaitClause(converter, waitClause, waitOperands, - waitOperandsDeviceTypes, waitOnlyDeviceTypes, - waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx); + genWaitClauseWithDeviceType(converter, waitClause, waitOperands, + waitOperandsDeviceTypes, waitOnlyDeviceTypes, + hasWaitDevnums, waitOperandsSegments, + crtDeviceTypes, stmtCtx); } else if (const auto *deviceTypeClause = std::get_if( &clause.u)) { @@ -3092,9 +3112,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter, builder.create( currentLocation, ifCond, asyncOperands, getArrayAttr(builder, asyncOperandsDeviceTypes), - getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands, + getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands, getDenseI32ArrayAttr(builder, waitOperandsSegments), getArrayAttr(builder, waitOperandsDeviceTypes), + getBoolArrayAttr(builder, hasWaitDevnums), getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands, ifPresent); @@ -3268,7 +3289,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder, builder, loc, addrOp, asFortranDesc, bounds, /*structured=*/false, /*implicit=*/true, mlir::acc::DataClause::acc_update_device, addrOp.getType()); - llvm::SmallVector operandSegments{0, 0, 0, 0, 1}; + llvm::SmallVector operandSegments{0, 0, 0, 1}; llvm::SmallVector operands{updateDeviceOp.getResult()}; createSimpleOp(builder, loc, operands, operandSegments); @@ -3349,7 +3370,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder, builder, loc, addrOp, asFortran, bounds, /*structured=*/false, /*implicit=*/true, mlir::acc::DataClause::acc_update_device, addrOp.getType()); - llvm::SmallVector operandSegments{0, 0, 0, 0, 1}; + llvm::SmallVector operandSegments{0, 0, 0, 1}; llvm::SmallVector operands{updateDeviceOp.getResult()}; createSimpleOp(builder, loc, operands, operandSegments); modBuilder.setInsertionPointAfter(postDeallocOp); diff --git a/flang/test/Lower/OpenACC/acc-data.f90 b/flang/test/Lower/OpenACC/acc-data.f90 index 75ffd1fc3fcab..5b4ab5a65ee6b 100644 --- a/flang/test/Lower/OpenACC/acc-data.f90 +++ b/flang/test/Lower/OpenACC/acc-data.f90 @@ -164,8 +164,8 @@ subroutine acc_data !$acc data present(a) wait !$acc end data -! CHECK: acc.data dataOperands(%{{.*}}) { -! CHECK: } attributes {waitOnly = [#acc.device_type]} +! CHECK: acc.data dataOperands(%{{.*}}) wait { +! CHECK: } !$acc data present(a) wait(1) !$acc end data @@ -176,7 +176,7 @@ subroutine acc_data !$acc data present(a) wait(devnum: 0: 1) !$acc end data -! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) { +! CHECK: acc.data dataOperands(%{{.*}}) wait({devnum: %{{.*}} : i32, %{{.*}} : i32}) { ! CHECK: }{{$}} !$acc data default(none) diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90 index 21660d5c3a131..d2134e8d2337c 100644 --- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90 @@ -93,12 +93,12 @@ subroutine acc_kernels_loop a(i) = b(i) END DO -! CHECK: acc.kernels { +! CHECK: acc.kernels wait { ! CHECK: acc.loop {{.*}} { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.terminator -! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} +! CHECK-NEXT: } !$acc kernels loop wait(1) DO i = 1, n diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90 index 99629bb835172..06194edbe1654 100644 --- a/flang/test/Lower/OpenACC/acc-kernels.f90 +++ b/flang/test/Lower/OpenACC/acc-kernels.f90 @@ -61,9 +61,9 @@ subroutine acc_kernels !$acc kernels wait !$acc end kernels -! CHECK: acc.kernels { +! CHECK: acc.kernels wait { ! CHECK: acc.terminator -! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} +! CHECK-NEXT: } !$acc kernels wait(1) !$acc end kernels diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90 index 614d201f98e26..24e443a20c895 100644 --- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90 @@ -95,12 +95,12 @@ subroutine acc_parallel_loop a(i) = b(i) END DO -! CHECK: acc.parallel { +! CHECK: acc.parallel wait { ! CHECK: acc.loop {{.*}} { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.yield -! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} +! CHECK-NEXT: } !$acc parallel loop wait(1) DO i = 1, n diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90 index a369bf01f2599..6b37ecb5fab9a 100644 --- a/flang/test/Lower/OpenACC/acc-parallel.f90 +++ b/flang/test/Lower/OpenACC/acc-parallel.f90 @@ -83,9 +83,9 @@ subroutine acc_parallel !$acc parallel wait !$acc end parallel -! CHECK: acc.parallel { +! CHECK: acc.parallel wait { ! CHECK: acc.yield -! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} +! CHECK-NEXT: } !$acc parallel wait(1) !$acc end parallel diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90 index 4134f9ff0ccf5..9c0dbff0d7dac 100644 --- a/flang/test/Lower/OpenACC/acc-serial-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90 @@ -114,12 +114,12 @@ subroutine acc_serial_loop a(i) = b(i) END DO -! CHECK: acc.serial { +! CHECK: acc.serial wait { ! CHECK: acc.loop {{.*}} { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.yield -! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} +! CHECK-NEXT: } !$acc serial loop wait(1) DO i = 1, n diff --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90 index d05e51d3d274f..d0fa9436be14a 100644 --- a/flang/test/Lower/OpenACC/acc-serial.f90 +++ b/flang/test/Lower/OpenACC/acc-serial.f90 @@ -83,9 +83,9 @@ subroutine acc_serial !$acc serial wait !$acc end serial -! CHECK: acc.serial { +! CHECK: acc.serial wait { ! CHECK: acc.yield -! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} +! CHECK-NEXT: } !$acc serial wait(1) !$acc end serial diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90 index ba036ac928118..f42ae1356664b 100644 --- a/flang/test/Lower/OpenACC/acc-update.f90 +++ b/flang/test/Lower/OpenACC/acc-update.f90 @@ -101,10 +101,7 @@ subroutine acc_update !$acc update host(a) wait(devnum: 1: queues: 1, 2) ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref> {dataClause = #acc, name = "a", structured = false} -! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32 -! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32 -! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32 -! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref>) +! CHECK: acc.update wait({devnum: %c1{{.*}} : i32, %c1{{.*}} : i32, %c2{{.*}} : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref>) ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref>) {name = "a", structured = false} !$acc update host(a) device_type(host, nvidia) async diff --git a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td index 4269a966a988d..31578710365b2 100644 --- a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td +++ b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td @@ -141,7 +141,7 @@ class Directive { // allowedClauses and requiredClauses lists. // List of allowed clauses for the directive. - list allowedClauses = []; + list allowedClauses = []; // List of clauses that are allowed to appear only once. list allowedOnceClauses = []; diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 87fd587782e7c..9398cbfdacee4 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -903,12 +903,13 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", }]; let arguments = (ins - Variadic:$async, - OptionalAttr:$asyncDeviceType, + Variadic:$asyncOperands, + OptionalAttr:$asyncOperandsDeviceType, OptionalAttr:$asyncOnly, Variadic:$waitOperands, OptionalAttr:$waitOperandsSegments, OptionalAttr:$waitOperandsDeviceType, + OptionalAttr:$hasWaitDevnum, OptionalAttr:$waitOnly, Variadic:$numGangs, OptionalAttr:$numGangsSegments, @@ -979,13 +980,18 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", /// present. mlir::Operation::operand_range getWaitValues(mlir::acc::DeviceType deviceType); + /// Return the wait devnum value clause if present; + mlir::Value getWaitDevnum(); + /// Return the wait devnum value clause for the given device_type if + /// present. + mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ oilist( `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` - | `async` `(` custom($async, - type($async), $asyncDeviceType) `)` + | `async` `(` custom($asyncOperands, + type($asyncOperands), $asyncOperandsDeviceType) `)` | `firstprivate` `(` custom($gangFirstPrivateOperands, type($gangFirstPrivateOperands), $firstprivatizations) `)` @@ -998,8 +1004,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", `)` | `vector_length` `(` custom($vectorLength, type($vectorLength), $vectorLengthDeviceType) `)` - | `wait` `(` custom($waitOperands, - type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)` + | `wait` `` custom($waitOperands, type($waitOperands), + $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum, + $waitOnly) | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` | `reduction` `(` custom( @@ -1034,12 +1041,13 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", }]; let arguments = (ins - Variadic:$async, - OptionalAttr:$asyncDeviceType, + Variadic:$asyncOperands, + OptionalAttr:$asyncOperandsDeviceType, OptionalAttr:$asyncOnly, Variadic:$waitOperands, OptionalAttr:$waitOperandsSegments, OptionalAttr:$waitOperandsDeviceType, + OptionalAttr:$hasWaitDevnum, OptionalAttr:$waitOnly, Optional:$ifCond, Optional:$selfCond, @@ -1084,21 +1092,27 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", /// present. mlir::Operation::operand_range getWaitValues(mlir::acc::DeviceType deviceType); + /// Return the wait devnum value clause if present; + mlir::Value getWaitDevnum(); + /// Return the wait devnum value clause for the given device_type if + /// present. + mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ oilist( `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` - | `async` `(` custom($async, - type($async), $asyncDeviceType) `)` + | `async` `(` custom($asyncOperands, + type($asyncOperands), $asyncOperandsDeviceType) `)` | `firstprivate` `(` custom($gangFirstPrivateOperands, type($gangFirstPrivateOperands), $firstprivatizations) `)` | `private` `(` custom( $gangPrivateOperands, type($gangPrivateOperands), $privatizations) `)` - | `wait` `(` custom($waitOperands, - type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)` + | `wait` `` custom($waitOperands, type($waitOperands), + $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum, + $waitOnly) | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` | `reduction` `(` custom( @@ -1135,12 +1149,13 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", }]; let arguments = (ins - Variadic:$async, - OptionalAttr:$asyncDeviceType, + Variadic:$asyncOperands, + OptionalAttr:$asyncOperandsDeviceType, OptionalAttr:$asyncOnly, Variadic:$waitOperands, OptionalAttr:$waitOperandsSegments, OptionalAttr:$waitOperandsDeviceType, + OptionalAttr:$hasWaitDevnum, OptionalAttr:$waitOnly, Variadic:$numGangs, OptionalAttr:$numGangsSegments, @@ -1205,22 +1220,27 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", /// present. mlir::Operation::operand_range getWaitValues(mlir::acc::DeviceType deviceType); + /// Return the wait devnum value clause if present; + mlir::Value getWaitDevnum(); + /// Return the wait devnum value clause for the given device_type if + /// present. + mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ oilist( `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` - | `async` `(` custom($async, - type($async), $asyncDeviceType) `)` + | `async` `(` custom($asyncOperands, + type($asyncOperands), $asyncOperandsDeviceType) `)` | `num_gangs` `(` custom($numGangs, type($numGangs), $numGangsDeviceType, $numGangsSegments) `)` | `num_workers` `(` custom($numWorkers, type($numWorkers), $numWorkersDeviceType) `)` | `vector_length` `(` custom($vectorLength, type($vectorLength), $vectorLengthDeviceType) `)` - | `wait` `(` custom($waitOperands, - type($waitOperands), $waitOperandsDeviceType, - $waitOperandsSegments) `)` + | `wait` `` custom($waitOperands, type($waitOperands), + $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum, + $waitOnly) | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` ) @@ -1258,13 +1278,13 @@ def OpenACC_DataOp : OpenACC_Op<"data", let arguments = (ins Optional:$ifCond, - Variadic:$async, - OptionalAttr:$asyncDeviceType, + Variadic:$asyncOperands, + OptionalAttr:$asyncOperandsDeviceType, OptionalAttr:$asyncOnly, - Optional:$waitDevnum, Variadic:$waitOperands, OptionalAttr:$waitOperandsSegments, OptionalAttr:$waitOperandsDeviceType, + OptionalAttr:$hasWaitDevnum, OptionalAttr:$waitOnly, Variadic:$dataClauseOperands, OptionalAttr:$defaultAttr); @@ -1300,18 +1320,22 @@ def OpenACC_DataOp : OpenACC_Op<"data", /// present. mlir::Operation::operand_range getWaitValues(mlir::acc::DeviceType deviceType); + /// Return the wait devnum value clause if present; + mlir::Value getWaitDevnum(); + /// Return the wait devnum value clause for the given device_type if + /// present. + mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ oilist( `if` `(` $ifCond `)` - | `async` `(` custom($async, - type($async), $asyncDeviceType) `)` + | `async` `(` custom($asyncOperands, + type($asyncOperands), $asyncOperandsDeviceType) `)` | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` - | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)` - | `wait` `(` custom($waitOperands, - type($waitOperands), $waitOperandsDeviceType, - $waitOperandsSegments) `)` + | `wait` `` custom($waitOperands, type($waitOperands), + $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum, + $waitOnly) ) $region attr-dict-with-keyword }]; @@ -2199,11 +2223,11 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", Variadic:$asyncOperands, OptionalAttr:$asyncOperandsDeviceType, OptionalAttr:$async, - Optional:$waitDevnum, Variadic:$waitOperands, OptionalAttr:$waitOperandsSegments, OptionalAttr:$waitOperandsDeviceType, - OptionalAttr:$wait, + OptionalAttr:$hasWaitDevnum, + OptionalAttr:$waitOnly, Variadic:$dataClauseOperands, UnitAttr:$ifPresent); @@ -2236,6 +2260,11 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", /// present. mlir::Operation::operand_range getWaitValues(mlir::acc::DeviceType deviceType); + /// Return the wait devnum value clause if present; + mlir::Value getWaitDevnum(); + /// Return the wait devnum value clause for the given device_type if + /// present. + mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ @@ -2244,10 +2273,9 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", | `async` `` custom( $asyncOperands, type($asyncOperands), $asyncOperandsDeviceType, $async) - | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)` - | `wait` `` custom($waitOperands, - type($waitOperands), $waitOperandsDeviceType, - $waitOperandsSegments, $wait) + | `wait` `` custom($waitOperands, type($waitOperands), + $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum, + $waitOnly) | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` ) attr-dict-with-keyword diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index e1e69113bca16..ae5da686f8595 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -104,6 +104,91 @@ static void printDeviceTypes(mlir::OpAsmPrinter &p, p << "]"; } +static std::optional findSegment(ArrayAttr segments, + mlir::acc::DeviceType deviceType) { + unsigned segmentIdx = 0; + for (auto attr : segments) { + auto deviceTypeAttr = mlir::dyn_cast(attr); + if (deviceTypeAttr.getValue() == deviceType) + return std::make_optional(segmentIdx); + ++segmentIdx; + } + return std::nullopt; +} + +static mlir::Operation::operand_range +getValuesFromSegments(std::optional arrayAttr, + mlir::Operation::operand_range range, + std::optional> segments, + mlir::acc::DeviceType deviceType) { + if (!arrayAttr) + return range.take_front(0); + if (auto pos = findSegment(*arrayAttr, deviceType)) { + int32_t nbOperandsBefore = 0; + for (unsigned i = 0; i < *pos; ++i) + nbOperandsBefore += (*segments)[i]; + return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]); + } + return range.take_front(0); +} + +static mlir::Value +getWaitDevnumValue(std::optional deviceTypeAttr, + mlir::Operation::operand_range operands, + std::optional> segments, + std::optional hasWaitDevnum, + mlir::acc::DeviceType deviceType) { + if (!hasDeviceTypeValues(deviceTypeAttr)) + return {}; + if (auto pos = findSegment(*deviceTypeAttr, deviceType)) + if (hasWaitDevnum->getValue()[*pos]) + return getValuesFromSegments(deviceTypeAttr, operands, segments, + deviceType) + .front(); + return {}; +} + +static mlir::Operation::operand_range +getWaitValuesWithoutDevnum(std::optional deviceTypeAttr, + mlir::Operation::operand_range operands, + std::optional> segments, + std::optional hasWaitDevnum, + mlir::acc::DeviceType deviceType) { + auto range = + getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType); + if (range.empty()) + return range; + if (auto pos = findSegment(*deviceTypeAttr, deviceType)) { + if (hasWaitDevnum && *hasWaitDevnum) { + auto boolAttr = mlir::dyn_cast((*hasWaitDevnum)[*pos]); + if (boolAttr.getValue()) + return range.drop_front(1); // first value is devnum + } + } + return range; +} + +template +static LogicalResult checkWaitAndAsyncConflict(Op op) { + for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType(); + ++dtypeInt) { + auto dtype = static_cast(dtypeInt); + + // The async attribute represent the async clause without value. Therefore + // the attribute and operand cannot appear at the same time. + if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) && + op.hasAsyncOnly(dtype)) + return op.emitError("async attribute cannot appear with asyncOperand"); + + // The wait attribute represent the wait clause without values. Therefore + // the attribute and operands cannot appear at the same time. + if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) && + op.hasWaitOnly(dtype)) + return op.emitError("wait attribute cannot appear with waitOperands"); + } + return success(); +} + //===----------------------------------------------------------------------===// // DataBoundsOp //===----------------------------------------------------------------------===// @@ -649,7 +734,7 @@ unsigned ParallelOp::getNumDataOperands() { } Value ParallelOp::getDataOperand(unsigned i) { - unsigned numOptional = getAsync().size(); + unsigned numOptional = getAsyncOperands().size(); numOptional += getNumGangs().size(); numOptional += getNumWorkers().size(); numOptional += getVectorLength().size(); @@ -722,25 +807,17 @@ LogicalResult acc::ParallelOp::verify() { "vector_length"))) return failure(); - if (failed(verifyDeviceTypeCountMatch(*this, getAsync(), - getAsyncDeviceTypeAttr(), "async"))) + if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(), + getAsyncOperandsDeviceTypeAttr(), + "async"))) + return failure(); + + if (failed(checkWaitAndAsyncConflict(*this))) return failure(); return checkDataOperands(*this, getDataClauseOperands()); } -static std::optional findSegment(ArrayAttr segments, - mlir::acc::DeviceType deviceType) { - unsigned segmentIdx = 0; - for (auto attr : segments) { - auto deviceTypeAttr = mlir::dyn_cast(attr); - if (deviceTypeAttr.getValue() == deviceType) - return std::make_optional(segmentIdx); - ++segmentIdx; - } - return std::nullopt; -} - static mlir::Value getValueInDeviceTypeSegment(std::optional arrayAttr, mlir::Operation::operand_range range, @@ -765,8 +842,8 @@ mlir::Value acc::ParallelOp::getAsyncValue() { } mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) { - return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), - deviceType); + return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(), + getAsyncOperands(), deviceType); } mlir::Value acc::ParallelOp::getNumWorkersValue() { @@ -793,22 +870,6 @@ mlir::Operation::operand_range ParallelOp::getNumGangsValues() { return getNumGangsValues(mlir::acc::DeviceType::None); } -static mlir::Operation::operand_range -getValuesFromSegments(std::optional arrayAttr, - mlir::Operation::operand_range range, - std::optional> segments, - mlir::acc::DeviceType deviceType) { - if (!arrayAttr) - return range.take_front(0); - if (auto pos = findSegment(*arrayAttr, deviceType)) { - int32_t nbOperandsBefore = 0; - for (unsigned i = 0; i < *pos; ++i) - nbOperandsBefore += (*segments)[i]; - return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]); - } - return range.take_front(0); -} - mlir::Operation::operand_range ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) { return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(), @@ -829,8 +890,19 @@ mlir::Operation::operand_range ParallelOp::getWaitValues() { mlir::Operation::operand_range ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) { - return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), - getWaitOperandsSegments(), deviceType); + return getWaitValuesWithoutDevnum( + getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), + getHasWaitDevnum(), deviceType); +} + +mlir::Value ParallelOp::getWaitDevnum() { + return getWaitDevnum(mlir::acc::DeviceType::None); +} + +mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { + return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), + getWaitOperandsSegments(), getHasWaitDevnum(), + deviceType); } static ParseResult parseNumGangs( @@ -967,8 +1039,9 @@ static ParseResult parseWaitClause( mlir::OpAsmParser &parser, llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &types, mlir::ArrayAttr &deviceTypes, - mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &keywordOnly) { - llvm::SmallVector deviceTypeAttrs, keywordAttrs; + mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, + mlir::ArrayAttr &keywordOnly) { + llvm::SmallVector deviceTypeAttrs, keywordAttrs, devnum; llvm::SmallVector seg; bool needCommaBeforeOperands = false; @@ -1003,6 +1076,14 @@ static ParseResult parseWaitClause( int32_t crtOperandsSize = operands.size(); + if (succeeded(parser.parseOptionalKeyword("devnum"))) { + if (failed(parser.parseColon())) + return failure(); + devnum.push_back(BoolAttr::get(parser.getContext(), true)); + } else { + devnum.push_back(BoolAttr::get(parser.getContext(), false)); + } + if (failed(parser.parseCommaSeparatedList( mlir::AsmParser::Delimiter::None, [&]() { if (parser.parseOperand(operands.emplace_back()) || @@ -1033,6 +1114,7 @@ static ParseResult parseWaitClause( deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs); keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs); segments = DenseI32ArrayAttr::get(parser.getContext(), seg); + hasDevNum = ArrayAttr::get(parser.getContext(), devnum); return success(); } @@ -1052,6 +1134,7 @@ static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional deviceTypes, std::optional segments, + std::optional hasDevNum, std::optional keywordOnly) { if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly)) @@ -1066,6 +1149,9 @@ static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, unsigned opIdx = 0; llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { p << "{"; + auto boolAttr = mlir::dyn_cast((*hasDevNum)[it.index()]); + if (boolAttr && boolAttr.getValue()) + p << "devnum: "; llvm::interleaveComma( llvm::seq(0, (*segments)[it.index()]), p, [&](auto it) { p << operands[opIdx] << " : " << operands[opIdx].getType(); @@ -1209,7 +1295,7 @@ unsigned SerialOp::getNumDataOperands() { } Value SerialOp::getDataOperand(unsigned i) { - unsigned numOptional = getAsync().size(); + unsigned numOptional = getAsyncOperands().size(); numOptional += getIfCond() ? 1 : 0; numOptional += getSelfCond() ? 1 : 0; return getOperand(getWaitOperands().size() + numOptional + i); @@ -1228,8 +1314,8 @@ mlir::Value acc::SerialOp::getAsyncValue() { } mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) { - return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), - deviceType); + return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(), + getAsyncOperands(), deviceType); } bool acc::SerialOp::hasWaitOnly() { @@ -1246,8 +1332,19 @@ mlir::Operation::operand_range SerialOp::getWaitValues() { mlir::Operation::operand_range SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) { - return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), - getWaitOperandsSegments(), deviceType); + return getWaitValuesWithoutDevnum( + getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), + getHasWaitDevnum(), deviceType); +} + +mlir::Value SerialOp::getWaitDevnum() { + return getWaitDevnum(mlir::acc::DeviceType::None); +} + +mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { + return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), + getWaitOperandsSegments(), getHasWaitDevnum(), + deviceType); } LogicalResult acc::SerialOp::verify() { @@ -1265,8 +1362,12 @@ LogicalResult acc::SerialOp::verify() { getWaitOperandsDeviceTypeAttr(), "wait"))) return failure(); - if (failed(verifyDeviceTypeCountMatch(*this, getAsync(), - getAsyncDeviceTypeAttr(), "async"))) + if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(), + getAsyncOperandsDeviceTypeAttr(), + "async"))) + return failure(); + + if (failed(checkWaitAndAsyncConflict(*this))) return failure(); return checkDataOperands(*this, getDataClauseOperands()); @@ -1281,7 +1382,7 @@ unsigned KernelsOp::getNumDataOperands() { } Value KernelsOp::getDataOperand(unsigned i) { - unsigned numOptional = getAsync().size(); + unsigned numOptional = getAsyncOperands().size(); numOptional += getWaitOperands().size(); numOptional += getNumGangs().size(); numOptional += getNumWorkers().size(); @@ -1304,8 +1405,8 @@ mlir::Value acc::KernelsOp::getAsyncValue() { } mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) { - return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), - deviceType); + return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(), + getAsyncOperands(), deviceType); } mlir::Value acc::KernelsOp::getNumWorkersValue() { @@ -1352,8 +1453,19 @@ mlir::Operation::operand_range KernelsOp::getWaitValues() { mlir::Operation::operand_range KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) { - return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), - getWaitOperandsSegments(), deviceType); + return getWaitValuesWithoutDevnum( + getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), + getHasWaitDevnum(), deviceType); +} + +mlir::Value KernelsOp::getWaitDevnum() { + return getWaitDevnum(mlir::acc::DeviceType::None); +} + +mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { + return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), + getWaitOperandsSegments(), getHasWaitDevnum(), + deviceType); } LogicalResult acc::KernelsOp::verify() { @@ -1377,8 +1489,12 @@ LogicalResult acc::KernelsOp::verify() { "vector_length"))) return failure(); - if (failed(verifyDeviceTypeCountMatch(*this, getAsync(), - getAsyncDeviceTypeAttr(), "async"))) + if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(), + getAsyncOperandsDeviceTypeAttr(), + "async"))) + return failure(); + + if (failed(checkWaitAndAsyncConflict(*this))) return failure(); return checkDataOperands(*this, getDataClauseOperands()); @@ -1943,6 +2059,9 @@ LogicalResult acc::DataOp::verify() { return emitError("expect data entry/exit operation or acc.getdeviceptr " "as defining op"); + if (failed(checkWaitAndAsyncConflict(*this))) + return failure(); + return success(); } @@ -1950,7 +2069,7 @@ unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); } Value DataOp::getDataOperand(unsigned i) { unsigned numOptional = getIfCond() ? 1 : 0; - numOptional += getAsync().size() ? 1 : 0; + numOptional += getAsyncOperands().size() ? 1 : 0; numOptional += getWaitOperands().size(); return getOperand(numOptional + i); } @@ -1968,8 +2087,8 @@ mlir::Value DataOp::getAsyncValue() { } mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) { - return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), - deviceType); + return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(), + getAsyncOperands(), deviceType); } bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); } @@ -1984,8 +2103,19 @@ mlir::Operation::operand_range DataOp::getWaitValues() { mlir::Operation::operand_range DataOp::getWaitValues(mlir::acc::DeviceType deviceType) { - return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), - getWaitOperandsSegments(), deviceType); + return getWaitValuesWithoutDevnum( + getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), + getHasWaitDevnum(), deviceType); +} + +mlir::Value DataOp::getWaitDevnum() { + return getWaitDevnum(mlir::acc::DeviceType::None); +} + +mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { + return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), + getWaitOperandsSegments(), getHasWaitDevnum(), + deviceType); } //===----------------------------------------------------------------------===// @@ -2549,23 +2679,8 @@ LogicalResult acc::UpdateOp::verify() { getWaitOperandsDeviceTypeAttr(), "wait"))) return failure(); - for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType(); - ++dtypeInt) { - auto dtype = static_cast(dtypeInt); - - // The async attribute represent the async clause without value. Therefore - // the attribute and operand cannot appear at the same time. - if (getAsyncValue(dtype) && hasAsyncOnly(dtype)) - return emitError("async attribute cannot appear with asyncOperand"); - - // The wait attribute represent the wait clause without values. Therefore - // the attribute and operands cannot appear at the same time. - if (!getWaitValues(dtype).empty() && hasWaitOnly(dtype)) - return emitError("wait attribute cannot appear with waitOperands"); - } - - if (getWaitDevnum() && getWaitOperands().empty()) - return emitError("wait_devnum cannot appear without waitOperands"); + if (failed(checkWaitAndAsyncConflict(*this))) + return failure(); for (mlir::Value operand : getDataClauseOperands()) if (!mlir::isa( @@ -2582,7 +2697,6 @@ unsigned UpdateOp::getNumDataOperands() { Value UpdateOp::getDataOperand(unsigned i) { unsigned numOptional = getAsyncOperands().size(); - numOptional += getWaitDevnum() ? 1 : 0; numOptional += getIfCond() ? 1 : 0; return getOperand(getWaitOperands().size() + numOptional + i); } @@ -2619,7 +2733,7 @@ bool UpdateOp::hasWaitOnly() { } bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { - return hasDeviceType(getWait(), deviceType); + return hasDeviceType(getWaitOnly(), deviceType); } mlir::Operation::operand_range UpdateOp::getWaitValues() { @@ -2628,8 +2742,19 @@ mlir::Operation::operand_range UpdateOp::getWaitValues() { mlir::Operation::operand_range UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) { - return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), - getWaitOperandsSegments(), deviceType); + return getWaitValuesWithoutDevnum( + getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), + getHasWaitDevnum(), deviceType); +} + +mlir::Value UpdateOp::getWaitDevnum() { + return getWaitDevnum(mlir::acc::DeviceType::None); +} + +mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { + return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), + getWaitOperandsSegments(), getHasWaitDevnum(), + deviceType); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir index 80d439f19d9f4..16df33eec642c 100644 --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -126,14 +126,6 @@ acc.update // ----- -%cst = arith.constant 1 : index -%value = memref.alloc() : memref -%0 = acc.update_device varPtr(%value : memref) -> memref -// expected-error@+1 {{wait_devnum cannot appear without waitOperands}} -acc.update wait_devnum(%cst: index) dataOperands(%0: memref) - -// ----- - %cst = arith.constant 1 : index %value = memref.alloc() : memref %0 = acc.update_device varPtr(%value : memref) -> memref @@ -146,7 +138,7 @@ acc.update async(%cst: index) dataOperands(%0 : memref) attributes {async = %value = memref.alloc() : memref %0 = acc.update_device varPtr(%value : memref) -> memref // expected-error@+1 {{wait attribute cannot appear with waitOperands}} -acc.update wait({%cst: index}) dataOperands(%0: memref) attributes {wait = [#acc.device_type]} +acc.update wait({%cst: index}) dataOperands(%0: memref) attributes {waitOnly = [#acc.device_type]} // ----- diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index 45b41f1a77225..4e6ed8645cdbc 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -802,7 +802,7 @@ func.func @testdataop(%a: memref, %b: memref, %c: memref) -> () { } attributes { defaultAttr = #acc, wait } %wd1 = arith.constant 1 : i64 - acc.data wait_devnum(%wd1 : i64) wait({%w1 : i64}) { + acc.data wait({devnum: %wd1 : i64, %w1 : i64}) { } attributes { defaultAttr = #acc, wait } return @@ -916,7 +916,7 @@ func.func @testdataop(%a: memref, %b: memref, %c: memref) -> () { // CHECK: acc.data wait({%{{.*}} : i64}) { // CHECK-NEXT: } attributes {defaultAttr = #acc, wait} -// CHECK: acc.data wait_devnum(%{{.*}} : i64) wait({%{{.*}} : i64}) { +// CHECK: acc.data wait({devnum: %{{.*}} : i64, %{{.*}} : i64}) { // CHECK-NEXT: } attributes {defaultAttr = #acc, wait} // ----- @@ -934,7 +934,7 @@ func.func @testupdateop(%a: memref, %b: memref, %c: memref) -> () acc.update async(%i32Value: i32) dataOperands(%0: memref) acc.update async(%i32Value: i32) dataOperands(%0: memref) acc.update async(%idxValue: index) dataOperands(%0: memref) - acc.update wait_devnum(%i64Value: i64) wait({%i32Value : i32, %idxValue : index}) dataOperands(%0: memref) + acc.update wait({devnum: %i64Value: i64, %i32Value : i32, %idxValue : index}) dataOperands(%0: memref) acc.update if(%ifCond) dataOperands(%0: memref) acc.update dataOperands(%0: memref) acc.update dataOperands(%0, %1, %2 : memref, memref, memref) @@ -953,7 +953,7 @@ func.func @testupdateop(%a: memref, %b: memref, %c: memref) -> () // CHECK: acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref) // CHECK: acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref) // CHECK: acc.update async([[IDXVALUE]] : index) dataOperands(%{{.*}} : memref) -// CHECK: acc.update wait_devnum([[I64VALUE]] : i64) wait({[[I32VALUE]] : i32, [[IDXVALUE]] : index}) dataOperands(%{{.*}} : memref) +// CHECK: acc.update wait({devnum: [[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) dataOperands(%{{.*}} : memref) // CHECK: acc.update if([[IFCOND]]) dataOperands(%{{.*}} : memref) // CHECK: acc.update dataOperands(%{{.*}} : memref) // CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref, memref, memref) diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp index 474f887928992..452f39d8cae9f 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp @@ -86,13 +86,13 @@ void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc, OwningOpRef val = b.create(loc, 1); auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia); - op->setAsyncDeviceTypeAttr(b.getArrayAttr({dtypeNvidia})); - op->getAsyncMutable().assign(val->getResult()); + op->setAsyncOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNvidia})); + op->getAsyncOperandsMutable().assign(val->getResult()); EXPECT_EQ(op->getAsyncValue(), empty); EXPECT_EQ(op->getAsyncValue(DeviceType::Nvidia), val->getResult()); - op->getAsyncMutable().clear(); - op->removeAsyncDeviceTypeAttr(); + op->getAsyncOperandsMutable().clear(); + op->removeAsyncOperandsDeviceTypeAttr(); } TEST_F(OpenACCOpsTest, asyncValueTest) { @@ -232,6 +232,8 @@ TEST_F(OpenACCOpsTest, waitOnlyTest) { testWaitOnly(b, context, loc, dtypes, dtypesWithoutNone); testWaitOnly(b, context, loc, dtypes, dtypesWithoutNone); testWaitOnly(b, context, loc, dtypes, dtypesWithoutNone); + testWaitOnly(b, context, loc, dtypes, dtypesWithoutNone); + testWaitOnly(b, context, loc, dtypes, dtypesWithoutNone); } template @@ -245,19 +247,23 @@ void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc, b.create(loc, 1); OwningOpRef val2 = b.create(loc, 4); + OwningOpRef val3 = + b.create(loc, 5); auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); op->getWaitOperandsMutable().assign(val1->getResult()); op->setWaitOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNone})); op->setWaitOperandsSegments(b.getDenseI32ArrayAttr({1})); + op->setHasWaitDevnumAttr(b.getBoolArrayAttr({false})); EXPECT_EQ(op->getWaitValues().front(), val1->getResult()); for (auto d : dtypesWithoutNone) - EXPECT_EQ(op->getWaitValues(d).begin(), op->getWaitValues(d).end()); + EXPECT_TRUE(op->getWaitValues(d).empty()); op->getWaitOperandsMutable().clear(); op->removeWaitOperandsDeviceTypeAttr(); op->removeWaitOperandsSegmentsAttr(); + op->removeHasWaitDevnumAttr(); for (auto d : dtypes) - EXPECT_EQ(op->getWaitValues(d).begin(), op->getWaitValues(d).end()); + EXPECT_TRUE(op->getWaitValues(d).empty()); op->getWaitOperandsMutable().append(val1->getResult()); op->getWaitOperandsMutable().append(val2->getResult()); @@ -265,6 +271,7 @@ void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc, b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host), DeviceTypeAttr::get(&context, DeviceType::Star)})); op->setWaitOperandsSegments(b.getDenseI32ArrayAttr({1, 1})); + op->setHasWaitDevnumAttr(b.getBoolArrayAttr({false, false})); EXPECT_EQ(op->getWaitValues(DeviceType::None).begin(), op->getWaitValues(DeviceType::None).end()); EXPECT_EQ(op->getWaitValues(DeviceType::Host).front(), val1->getResult()); @@ -273,8 +280,9 @@ void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc, op->getWaitOperandsMutable().clear(); op->removeWaitOperandsDeviceTypeAttr(); op->removeWaitOperandsSegmentsAttr(); + op->removeHasWaitDevnumAttr(); for (auto d : dtypes) - EXPECT_EQ(op->getWaitValues(d).begin(), op->getWaitValues(d).end()); + EXPECT_TRUE(op->getWaitValues(d).empty()); op->getWaitOperandsMutable().append(val1->getResult()); op->getWaitOperandsMutable().append(val2->getResult()); @@ -283,6 +291,7 @@ void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc, b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default), DeviceTypeAttr::get(&context, DeviceType::Multicore)})); op->setWaitOperandsSegments(b.getDenseI32ArrayAttr({2, 1})); + op->setHasWaitDevnumAttr(b.getBoolArrayAttr({false, false})); EXPECT_EQ(op->getWaitValues(DeviceType::None).begin(), op->getWaitValues(DeviceType::None).end()); EXPECT_EQ(op->getWaitValues(DeviceType::Default).front(), val1->getResult()); @@ -294,6 +303,28 @@ void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc, op->getWaitOperandsMutable().clear(); op->removeWaitOperandsDeviceTypeAttr(); op->removeWaitOperandsSegmentsAttr(); + + op->getWaitOperandsMutable().append(val3->getResult()); + op->getWaitOperandsMutable().append(val2->getResult()); + op->getWaitOperandsMutable().append(val1->getResult()); + op->setWaitOperandsDeviceTypeAttr( + b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Multicore)})); + op->setHasWaitDevnumAttr(b.getBoolArrayAttr({true})); + op->setWaitOperandsSegments(b.getDenseI32ArrayAttr({3})); + EXPECT_EQ(op->getWaitValues(DeviceType::None).begin(), + op->getWaitValues(DeviceType::None).end()); + EXPECT_FALSE(op->getWaitDevnum()); + + EXPECT_EQ(op->getWaitDevnum(DeviceType::Multicore), val3->getResult()); + EXPECT_EQ(op->getWaitValues(DeviceType::Multicore).front(), + val2->getResult()); + EXPECT_EQ(op->getWaitValues(DeviceType::Multicore).drop_front().front(), + val1->getResult()); + + op->getWaitOperandsMutable().clear(); + op->removeWaitOperandsDeviceTypeAttr(); + op->removeWaitOperandsSegmentsAttr(); + op->removeHasWaitDevnumAttr(); } TEST_F(OpenACCOpsTest, waitValuesTest) {