diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index fae54eefb02f7..59db5ab71b702 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) { case Fortran::parser::AccDeviceTypeExpr::Device::Multicore: return mlir::acc::DeviceType::Multicore; } - return mlir::acc::DeviceType::Default; + return mlir::acc::DeviceType::None; } static void gatherDeviceTypeAttrs( @@ -1781,26 +1781,25 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, bool outerCombined = false) { // Parallel operation operands - mlir::Value async; - mlir::Value numWorkers; - mlir::Value vectorLength; mlir::Value ifCond; mlir::Value selfCond; mlir::Value waitDevnum; llvm::SmallVector waitOperands, attachEntryOperands, copyEntryOperands, copyoutEntryOperands, createEntryOperands, - dataClauseOperands, numGangs; + dataClauseOperands, numGangs, numWorkers, vectorLength, async; + llvm::SmallVector numGangsDeviceTypes, numWorkersDeviceTypes, + vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes, + waitOperandsDeviceTypes, waitOnlyDeviceTypes; + llvm::SmallVector numGangsSegments, waitOperandsSegments; llvm::SmallVector reductionOperands, privateOperands, firstprivateOperands; llvm::SmallVector privatizations, firstPrivatizations, reductionRecipes; - // Async, wait and self clause have optional values but can be present with + // Self clause has optional values but can be present with // no value as well. When there is no value, the op has an attribute to // represent the clause. - bool addAsyncAttr = false; - bool addWaitAttr = false; bool addSelfAttr = false; bool hasDefaultNone = false; @@ -1808,6 +1807,11 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + // device_type attribute is set to `none` until a device_type clause is + // encountered. + auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + builder.getContext(), mlir::acc::DeviceType::None); + // Lower clauses values mapped to operands. // Keep track of each group of operands separatly as clauses can appear // more than once. @@ -1815,27 +1819,52 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, mlir::Location clauseLocation = converter.genLocation(clause.source); if (const auto *asyncClause = std::get_if(&clause.u)) { - genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx); + const auto &asyncClauseValue = asyncClause->v; + if (asyncClauseValue) { // async has a value. + async.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx))); + asyncDeviceTypes.push_back(crtDeviceTypeAttr); + } else { + asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr); + } } else if (const auto *waitClause = std::get_if(&clause.u)) { - genWaitClause(converter, waitClause, waitOperands, waitDevnum, - addWaitAttr, stmtCtx); + const auto &waitClauseValue = waitClause->v; + if (waitClauseValue) { // wait has a value. + const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; + const auto &waitList = + std::get>(waitArg.t); + auto crtWaitOperands = waitOperands.size(); + for (const Fortran::parser::ScalarIntExpr &value : waitList) { + waitOperands.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(value), stmtCtx))); + } + waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr); + waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands); + } else { + waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr); + } } else if (const auto *numGangsClause = std::get_if( &clause.u)) { + auto crtNumGangs = numGangs.size(); for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v) numGangs.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(expr), stmtCtx))); + numGangsDeviceTypes.push_back(crtDeviceTypeAttr); + numGangsSegments.push_back(numGangs.size() - crtNumGangs); } else if (const auto *numWorkersClause = std::get_if( &clause.u)) { - numWorkers = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)); + numWorkers.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx))); + numWorkersDeviceTypes.push_back(crtDeviceTypeAttr); } else if (const auto *vectorLengthClause = std::get_if( &clause.u)) { - vectorLength = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)); + vectorLength.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx))); + vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr); } else if (const auto *ifClause = std::get_if(&clause.u)) { genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); @@ -1986,18 +2015,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, else if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_present) hasDefaultPresent = true; + } else if (const auto *deviceTypeClause = + std::get_if( + &clause.u)) { + const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList = + deviceTypeClause->v; + assert(deviceTypeExprList.v.size() == 1 && + "expect only one device_type expr"); + crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v)); } } // Prepare the operand segment size attribute and the operands value range. llvm::SmallVector operands; llvm::SmallVector operandSegments; - addOperand(operands, operandSegments, async); + addOperands(operands, operandSegments, async); addOperands(operands, operandSegments, waitOperands); if constexpr (!std::is_same_v) { addOperands(operands, operandSegments, numGangs); - addOperand(operands, operandSegments, numWorkers); - addOperand(operands, operandSegments, vectorLength); + addOperands(operands, operandSegments, numWorkers); + addOperands(operands, operandSegments, vectorLength); } addOperand(operands, operandSegments, ifCond); addOperand(operands, operandSegments, selfCond); @@ -2018,10 +2056,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, builder, currentLocation, eval, operands, operandSegments, outerCombined); - if (addAsyncAttr) - computeOp.setAsyncAttrAttr(builder.getUnitAttr()); - if (addWaitAttr) - computeOp.setWaitAttrAttr(builder.getUnitAttr()); if (addSelfAttr) computeOp.setSelfAttrAttr(builder.getUnitAttr()); @@ -2030,6 +2064,34 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, if (hasDefaultPresent) computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present); + if constexpr (!std::is_same_v) { + if (!numWorkersDeviceTypes.empty()) + computeOp.setNumWorkersDeviceTypeAttr( + mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes)); + if (!vectorLengthDeviceTypes.empty()) + computeOp.setVectorLengthDeviceTypeAttr( + mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes)); + if (!numGangsDeviceTypes.empty()) + computeOp.setNumGangsDeviceTypeAttr( + mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes)); + if (!numGangsSegments.empty()) + computeOp.setNumGangsSegmentsAttr( + builder.getDenseI32ArrayAttr(numGangsSegments)); + } + if (!asyncDeviceTypes.empty()) + computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes)); + if (!asyncOnlyDeviceTypes.empty()) + computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes)); + + if (!waitOperandsDeviceTypes.empty()) + computeOp.setWaitOperandsDeviceTypeAttr( + builder.getArrayAttr(waitOperandsDeviceTypes)); + if (!waitOperandsSegments.empty()) + computeOp.setWaitOperandsSegmentsAttr( + builder.getDenseI32ArrayAttr(waitOperandsSegments)); + if (!waitOnlyDeviceTypes.empty()) + computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes)); + if constexpr (!std::is_same_v) { if (!privatizations.empty()) computeOp.setPrivatizationsAttr( diff --git a/flang/test/Lower/OpenACC/acc-device-type.f90 b/flang/test/Lower/OpenACC/acc-device-type.f90 new file mode 100644 index 0000000000000..871dbc95f60fc --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-device-type.f90 @@ -0,0 +1,44 @@ +! This test checks lowering of OpenACC device_type clause on directive where its +! position and the clauses that follow have special semantic + +! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s + +subroutine sub1() + + !$acc parallel num_workers(16) + !$acc end parallel + +! CHECK: acc.parallel num_workers(%c16{{.*}} : i32) { + + !$acc parallel num_workers(1) device_type(nvidia) num_workers(16) + !$acc end parallel + +! CHECK: acc.parallel num_workers(%c1{{.*}} : i32, %c16{{.*}} : i32 [#acc.device_type]) + + !$acc parallel device_type(*) num_workers(1) device_type(nvidia) num_workers(16) + !$acc end parallel + +! CHECK: acc.parallel num_workers(%c1{{.*}} : i32 [#acc.device_type], %c16{{.*}} : i32 [#acc.device_type]) + + !$acc parallel vector_length(1) + !$acc end parallel + +! CHECK: acc.parallel vector_length(%c1{{.*}} : i32) + + !$acc parallel device_type(multicore) vector_length(1) + !$acc end parallel + +! CHECK: acc.parallel vector_length(%c1{{.*}} : i32 [#acc.device_type]) + + !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(4) + !$acc end parallel + +! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c4{{.*}} : i32} [#acc.device_type]) + + !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(1, 1, 1) + !$acc end parallel + +! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type]) + + +end subroutine diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90 index 34e7232697241..93bc699031d55 100644 --- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90 @@ -62,7 +62,7 @@ subroutine acc_kernels_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.terminator -! CHECK-NEXT: } attributes {asyncAttr} +! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type]} !$acc kernels loop async(1) DO i = 1, n @@ -103,7 +103,7 @@ subroutine acc_kernels_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.terminator -! CHECK-NEXT: } attributes {waitAttr} +! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} !$acc kernels loop wait(1) DO i = 1, n @@ -111,7 +111,7 @@ subroutine acc_kernels_loop END DO ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.kernels wait([[WAIT1]] : i32) { +! CHECK: acc.kernels wait({[[WAIT1]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -126,7 +126,7 @@ subroutine acc_kernels_loop ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) { +! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -141,7 +141,7 @@ subroutine acc_kernels_loop ! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) { +! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -155,7 +155,7 @@ subroutine acc_kernels_loop END DO ! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) { +! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -169,7 +169,7 @@ subroutine acc_kernels_loop END DO ! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) { +! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90 index 1f882c6df5106..99629bb835172 100644 --- a/flang/test/Lower/OpenACC/acc-kernels.f90 +++ b/flang/test/Lower/OpenACC/acc-kernels.f90 @@ -40,7 +40,7 @@ subroutine acc_kernels ! CHECK: acc.kernels { ! CHECK: acc.terminator -! CHECK-NEXT: } attributes {asyncAttr} +! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type]} !$acc kernels async(1) !$acc end kernels @@ -63,13 +63,13 @@ subroutine acc_kernels ! CHECK: acc.kernels { ! CHECK: acc.terminator -! CHECK-NEXT: } attributes {waitAttr} +! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} !$acc kernels wait(1) !$acc end kernels ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.kernels wait([[WAIT1]] : i32) { +! CHECK: acc.kernels wait({[[WAIT1]] : i32}) { ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -78,7 +78,7 @@ subroutine acc_kernels ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) { +! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) { ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -87,7 +87,7 @@ subroutine acc_kernels ! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) { +! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) { ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -95,7 +95,7 @@ subroutine acc_kernels !$acc end kernels ! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) { +! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) { ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -103,7 +103,7 @@ subroutine acc_kernels !$acc end kernels ! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) { +! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) { ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90 index 1856215ce59d1..deee7089033ea 100644 --- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90 @@ -64,7 +64,7 @@ subroutine acc_parallel_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.yield -! CHECK-NEXT: } attributes {asyncAttr} +! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type]} !$acc parallel loop async(1) DO i = 1, n @@ -105,7 +105,7 @@ subroutine acc_parallel_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.yield -! CHECK-NEXT: } attributes {waitAttr} +! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} !$acc parallel loop wait(1) DO i = 1, n @@ -113,7 +113,7 @@ subroutine acc_parallel_loop END DO ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.parallel wait([[WAIT1]] : i32) { +! CHECK: acc.parallel wait({[[WAIT1]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -128,7 +128,7 @@ subroutine acc_parallel_loop ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) { +! CHECK: acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -143,7 +143,7 @@ subroutine acc_parallel_loop ! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) { +! CHECK: acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -157,7 +157,7 @@ subroutine acc_parallel_loop END DO ! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.parallel num_gangs([[NUMGANGS1]] : i32) { +! CHECK: acc.parallel num_gangs({[[NUMGANGS1]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -171,7 +171,7 @@ subroutine acc_parallel_loop END DO ! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.parallel num_gangs([[NUMGANGS2]] : i32) { +! CHECK: acc.parallel num_gangs({[[NUMGANGS2]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90 index bbf51ba36a7de..a369bf01f2599 100644 --- a/flang/test/Lower/OpenACC/acc-parallel.f90 +++ b/flang/test/Lower/OpenACC/acc-parallel.f90 @@ -62,7 +62,7 @@ subroutine acc_parallel ! CHECK: acc.parallel { ! CHECK: acc.yield -! CHECK-NEXT: } attributes {asyncAttr} +! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type]} !$acc parallel async(1) !$acc end parallel @@ -85,13 +85,13 @@ subroutine acc_parallel ! CHECK: acc.parallel { ! CHECK: acc.yield -! CHECK-NEXT: } attributes {waitAttr} +! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} !$acc parallel wait(1) !$acc end parallel ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.parallel wait([[WAIT1]] : i32) { +! CHECK: acc.parallel wait({[[WAIT1]] : i32}) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -100,7 +100,7 @@ subroutine acc_parallel ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) { +! CHECK: acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -109,7 +109,7 @@ subroutine acc_parallel ! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) { +! CHECK: acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -117,7 +117,7 @@ subroutine acc_parallel !$acc end parallel ! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.parallel num_gangs([[NUMGANGS1]] : i32) { +! CHECK: acc.parallel num_gangs({[[NUMGANGS1]] : i32}) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -125,14 +125,14 @@ subroutine acc_parallel !$acc end parallel ! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.parallel num_gangs([[NUMGANGS2]] : i32) { +! CHECK: acc.parallel num_gangs({[[NUMGANGS2]] : i32}) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} !$acc parallel num_gangs(1, 1, 1) !$acc end parallel -! CHECK: acc.parallel num_gangs(%{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32) { +! CHECK: acc.parallel num_gangs({%{{.*}} : i32, %{{.*}} : i32, %{{.*}} : i32}) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90 index 4ed7bb8da29a1..712bfc80ce387 100644 --- a/flang/test/Lower/OpenACC/acc-serial-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90 @@ -83,7 +83,7 @@ subroutine acc_serial_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.yield -! CHECK-NEXT: } attributes {asyncAttr} +! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type]} !$acc serial loop async(1) DO i = 1, n @@ -124,7 +124,7 @@ subroutine acc_serial_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.yield -! CHECK-NEXT: } attributes {waitAttr} +! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} !$acc serial loop wait(1) DO i = 1, n @@ -132,7 +132,7 @@ subroutine acc_serial_loop END DO ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.serial wait([[WAIT1]] : i32) { +! CHECK: acc.serial wait({[[WAIT1]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -147,7 +147,7 @@ subroutine acc_serial_loop ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.serial wait([[WAIT2]], [[WAIT3]] : i32, i32) { +! CHECK: acc.serial wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -162,7 +162,7 @@ subroutine acc_serial_loop ! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.serial wait([[WAIT4]], [[WAIT5]] : i32, i32) { +! CHECK: acc.serial wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield diff --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90 index ab3b0ccd54595..d05e51d3d274f 100644 --- a/flang/test/Lower/OpenACC/acc-serial.f90 +++ b/flang/test/Lower/OpenACC/acc-serial.f90 @@ -62,7 +62,7 @@ subroutine acc_serial ! CHECK: acc.serial { ! CHECK: acc.yield -! CHECK-NEXT: } attributes {asyncAttr} +! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type]} !$acc serial async(1) !$acc end serial @@ -85,13 +85,13 @@ subroutine acc_serial ! CHECK: acc.serial { ! CHECK: acc.yield -! CHECK-NEXT: } attributes {waitAttr} +! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} !$acc serial wait(1) !$acc end serial ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.serial wait([[WAIT1]] : i32) { +! CHECK: acc.serial wait({[[WAIT1]] : i32}) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -100,7 +100,7 @@ subroutine acc_serial ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.serial wait([[WAIT2]], [[WAIT3]] : i32, i32) { +! CHECK: acc.serial wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -109,7 +109,7 @@ subroutine acc_serial ! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.serial wait([[WAIT4]], [[WAIT5]] : i32, i32) { +! CHECK: acc.serial wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index a78c3e98c9551..234c1076e14e3 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -156,29 +156,46 @@ def DeclareActionAttr : OpenACC_Attr<"DeclareAction", "declare_action"> { } // Device type enumeration. -def OpenACC_DeviceTypeStar : I32EnumAttrCase<"Star", 0, "star">; -def OpenACC_DeviceTypeDefault : I32EnumAttrCase<"Default", 1, "default">; -def OpenACC_DeviceTypeHost : I32EnumAttrCase<"Host", 2, "host">; -def OpenACC_DeviceTypeMulticore : I32EnumAttrCase<"Multicore", 3, "multicore">; -def OpenACC_DeviceTypeNvidia : I32EnumAttrCase<"Nvidia", 4, "nvidia">; -def OpenACC_DeviceTypeRadeon : I32EnumAttrCase<"Radeon", 5, "radeon">; - +def OpenACC_DeviceTypeNone : I32EnumAttrCase<"None", 0, "none">; +def OpenACC_DeviceTypeStar : I32EnumAttrCase<"Star", 1, "star">; +def OpenACC_DeviceTypeDefault : I32EnumAttrCase<"Default", 2, "default">; +def OpenACC_DeviceTypeHost : I32EnumAttrCase<"Host", 3, "host">; +def OpenACC_DeviceTypeMulticore : I32EnumAttrCase<"Multicore", 4, "multicore">; +def OpenACC_DeviceTypeNvidia : I32EnumAttrCase<"Nvidia", 5, "nvidia">; +def OpenACC_DeviceTypeRadeon : I32EnumAttrCase<"Radeon", 6, "radeon">; def OpenACC_DeviceType : I32EnumAttr<"DeviceType", "built-in device type supported by OpenACC", - [OpenACC_DeviceTypeStar, OpenACC_DeviceTypeDefault, + [OpenACC_DeviceTypeNone, OpenACC_DeviceTypeStar, OpenACC_DeviceTypeDefault, OpenACC_DeviceTypeHost, OpenACC_DeviceTypeMulticore, OpenACC_DeviceTypeNvidia, OpenACC_DeviceTypeRadeon ]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::acc"; } + +// Device type attribute is used to associate a value for for clauses that +// appear after a device_type clause. The list of clauses allowed after the +// device_type clause is defined per construct as follows: +// Loop construct: collapse, gang, worker, vector, seq, independent, auto, +// and tile +// Compute construct: async, wait, num_gangs, num_workers, and vector_length +// Data construct: async and wait +// Routine: gang, worker, vector, seq and bind +// +// The `none` means that the value appears before any device_type clause. +// def OpenACC_DeviceTypeAttr : EnumAttr { let assemblyFormat = [{ ```<` $value `>` }]; } +def DeviceTypeArrayAttr : + TypedArrayAttrBase { + let constBuilderCall = ?; +} + // Define a resource for the OpenACC runtime counters. def OpenACC_RuntimeCounters : Resource<"::mlir::acc::RuntimeCounters">; @@ -863,24 +880,32 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", ``` }]; - let arguments = (ins Optional:$async, - UnitAttr:$asyncAttr, - Variadic:$waitOperands, - UnitAttr:$waitAttr, - Variadic:$numGangs, - Optional:$numWorkers, - Optional:$vectorLength, - Optional:$ifCond, - Optional:$selfCond, - UnitAttr:$selfAttr, - Variadic:$reductionOperands, - OptionalAttr:$reductionRecipes, - Variadic:$gangPrivateOperands, - OptionalAttr:$privatizations, - Variadic:$gangFirstPrivateOperands, - OptionalAttr:$firstprivatizations, - Variadic:$dataClauseOperands, - OptionalAttr:$defaultAttr); + let arguments = (ins + Variadic:$async, + OptionalAttr:$asyncDeviceType, + OptionalAttr:$asyncOnly, + Variadic:$waitOperands, + OptionalAttr:$waitOperandsSegments, + OptionalAttr:$waitOperandsDeviceType, + OptionalAttr:$waitOnly, + Variadic:$numGangs, + OptionalAttr:$numGangsSegments, + OptionalAttr:$numGangsDeviceType, + Variadic:$numWorkers, + OptionalAttr:$numWorkersDeviceType, + Variadic:$vectorLength, + OptionalAttr:$vectorLengthDeviceType, + Optional:$ifCond, + Optional:$selfCond, + UnitAttr:$selfAttr, + Variadic:$reductionOperands, + OptionalAttr:$reductionRecipes, + Variadic:$gangPrivateOperands, + OptionalAttr:$privatizations, + Variadic:$gangFirstPrivateOperands, + OptionalAttr:$firstprivatizations, + Variadic:$dataClauseOperands, + OptionalAttr:$defaultAttr); let regions = (region AnyRegion:$region); @@ -890,22 +915,69 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", /// The i-th data operand passed. Value getDataOperand(unsigned i); + + /// Return true if the op has the async attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasAsyncOnly(); + /// Return true if the op has the async attribute for the given device_type. + bool hasAsyncOnly(mlir::acc::DeviceType deviceType); + /// Return the value of the async clause if present. + mlir::Value getAsyncValue(); + /// Return the value of the async clause for the given device_type if + /// present. + mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType); + + /// Return the value of the num_workers clause if present. + mlir::Value getNumWorkersValue(); + /// Return the value of the num_workers clause for the given device_type if + /// present. + mlir::Value getNumWorkersValue(mlir::acc::DeviceType deviceType); + + /// Return the value of the vector_length clause if present. + mlir::Value getVectorLengthValue(); + /// Return the value of the vector_length clause for the given device_type + /// if present. + mlir::Value getVectorLengthValue(mlir::acc::DeviceType deviceType); + + /// Return the values of the num_gangs clause if present. + mlir::Operation::operand_range getNumGangsValues(); + /// Return the values of the num_gangs clause for the given device_type if + /// present. + mlir::Operation::operand_range + getNumGangsValues(mlir::acc::DeviceType deviceType); + + /// Return true if the op has the wait attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasWaitOnly(); + /// Return true if the op has the wait attribute for the given device_type. + bool hasWaitOnly(mlir::acc::DeviceType deviceType); + /// Return the values of the wait clause if present. + mlir::Operation::operand_range getWaitValues(); + /// Return the values of the wait clause for the given device_type if + /// present. + mlir::Operation::operand_range + getWaitValues(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ oilist( `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` - | `async` `(` $async `:` type($async) `)` + | `async` `(` custom($async, + type($async), $asyncDeviceType) `)` | `firstprivate` `(` custom($gangFirstPrivateOperands, type($gangFirstPrivateOperands), $firstprivatizations) `)` - | `num_gangs` `(` $numGangs `:` type($numGangs) `)` - | `num_workers` `(` $numWorkers `:` type($numWorkers) `)` + | `num_gangs` `(` custom($numGangs, + type($numGangs), $numGangsDeviceType, $numGangsSegments) `)` + | `num_workers` `(` custom($numWorkers, + type($numWorkers), $numWorkersDeviceType) `)` | `private` `(` custom( $gangPrivateOperands, type($gangPrivateOperands), $privatizations) `)` - | `vector_length` `(` $vectorLength `:` type($vectorLength) `)` - | `wait` `(` $waitOperands `:` type($waitOperands) `)` + | `vector_length` `(` custom($vectorLength, + type($vectorLength), $vectorLengthDeviceType) `)` + | `wait` `(` custom($waitOperands, + type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)` | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` | `reduction` `(` custom( @@ -939,21 +1011,25 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", ``` }]; - let arguments = (ins Optional:$async, - UnitAttr:$asyncAttr, - Variadic:$waitOperands, - UnitAttr:$waitAttr, - Optional:$ifCond, - Optional:$selfCond, - UnitAttr:$selfAttr, - Variadic:$reductionOperands, - OptionalAttr:$reductionRecipes, - Variadic:$gangPrivateOperands, - OptionalAttr:$privatizations, - Variadic:$gangFirstPrivateOperands, - OptionalAttr:$firstprivatizations, - Variadic:$dataClauseOperands, - OptionalAttr:$defaultAttr); + let arguments = (ins + Variadic:$async, + OptionalAttr:$asyncDeviceType, + OptionalAttr:$asyncOnly, + Variadic:$waitOperands, + OptionalAttr:$waitOperandsSegments, + OptionalAttr:$waitOperandsDeviceType, + OptionalAttr:$waitOnly, + Optional:$ifCond, + Optional:$selfCond, + UnitAttr:$selfAttr, + Variadic:$reductionOperands, + OptionalAttr:$reductionRecipes, + Variadic:$gangPrivateOperands, + OptionalAttr:$privatizations, + Variadic:$gangFirstPrivateOperands, + OptionalAttr:$firstprivatizations, + Variadic:$dataClauseOperands, + OptionalAttr:$defaultAttr); let regions = (region AnyRegion:$region); @@ -963,19 +1039,44 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", /// The i-th data operand passed. Value getDataOperand(unsigned i); + + /// Return true if the op has the async attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasAsyncOnly(); + /// Return true if the op has the async attribute for the given device_type. + bool hasAsyncOnly(mlir::acc::DeviceType deviceType); + /// Return the value of the async clause if present. + mlir::Value getAsyncValue(); + /// Return the value of the async clause for the given device_type if + /// present. + mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType); + + /// Return true if the op has the wait attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasWaitOnly(); + /// Return true if the op has the wait attribute for the given device_type. + bool hasWaitOnly(mlir::acc::DeviceType deviceType); + /// Return the values of the wait clause if present. + mlir::Operation::operand_range getWaitValues(); + /// Return the values of the wait clause for the given device_type if + /// present. + mlir::Operation::operand_range + getWaitValues(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ oilist( `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` - | `async` `(` $async `:` type($async) `)` + | `async` `(` custom($async, + type($async), $asyncDeviceType) `)` | `firstprivate` `(` custom($gangFirstPrivateOperands, type($gangFirstPrivateOperands), $firstprivatizations) `)` | `private` `(` custom( $gangPrivateOperands, type($gangPrivateOperands), $privatizations) `)` - | `wait` `(` $waitOperands `:` type($waitOperands) `)` + | `wait` `(` custom($waitOperands, + type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)` | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` | `reduction` `(` custom( @@ -1011,18 +1112,26 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", ``` }]; - let arguments = (ins Optional:$async, - UnitAttr:$asyncAttr, - Variadic:$waitOperands, - UnitAttr:$waitAttr, - Variadic:$numGangs, - Optional:$numWorkers, - Optional:$vectorLength, - Optional:$ifCond, - Optional:$selfCond, - UnitAttr:$selfAttr, - Variadic:$dataClauseOperands, - OptionalAttr:$defaultAttr); + let arguments = (ins + Variadic:$async, + OptionalAttr:$asyncDeviceType, + OptionalAttr:$asyncOnly, + Variadic:$waitOperands, + OptionalAttr:$waitOperandsSegments, + OptionalAttr:$waitOperandsDeviceType, + OptionalAttr:$waitOnly, + Variadic:$numGangs, + OptionalAttr:$numGangsSegments, + OptionalAttr:$numGangsDeviceType, + Variadic:$numWorkers, + OptionalAttr:$numWorkersDeviceType, + Variadic:$vectorLength, + OptionalAttr:$vectorLengthDeviceType, + Optional:$ifCond, + Optional:$selfCond, + UnitAttr:$selfAttr, + Variadic:$dataClauseOperands, + OptionalAttr:$defaultAttr); let regions = (region AnyRegion:$region); @@ -1032,16 +1141,63 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", /// The i-th data operand passed. Value getDataOperand(unsigned i); + + /// Return true if the op has the async attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasAsyncOnly(); + /// Return true if the op has the async attribute for the given device_type. + bool hasAsyncOnly(mlir::acc::DeviceType deviceType); + /// Return the value of the async clause if present. + mlir::Value getAsyncValue(); + /// Return the value of the async clause for the given device_type if + /// present. + mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType); + + /// Return the value of the num_workers clause if present. + mlir::Value getNumWorkersValue(); + /// Return the value of the num_workers clause for the given device_type if + /// present. + mlir::Value getNumWorkersValue(mlir::acc::DeviceType deviceType); + + /// Return the value of the vector_length clause if present. + mlir::Value getVectorLengthValue(); + /// Return the value of the vector_length clause for the given device_type + /// if present. + mlir::Value getVectorLengthValue(mlir::acc::DeviceType deviceType); + + /// Return the values of the num_gangs clause if present. + mlir::Operation::operand_range getNumGangsValues(); + /// Return the values of the num_gangs clause for the given device_type if + /// present. + mlir::Operation::operand_range + getNumGangsValues(mlir::acc::DeviceType deviceType); + + /// Return true if the op has the wait attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasWaitOnly(); + /// Return true if the op has the wait attribute for the given device_type. + bool hasWaitOnly(mlir::acc::DeviceType deviceType); + /// Return the values of the wait clause if present. + mlir::Operation::operand_range getWaitValues(); + /// Return the values of the wait clause for the given device_type if + /// present. + mlir::Operation::operand_range + getWaitValues(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ oilist( `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` - | `async` `(` $async `:` type($async) `)` - | `num_gangs` `(` $numGangs `:` type($numGangs) `)` - | `num_workers` `(` $numWorkers `:` type($numWorkers) `)` - | `vector_length` `(` $vectorLength `:` type($vectorLength) `)` - | `wait` `(` $waitOperands `:` type($waitOperands) `)` + | `async` `(` custom($async, + type($async), $asyncDeviceType) `)` + | `num_gangs` `(` custom($numGangs, + type($numGangs), $numGangsDeviceType, $numGangsSegments) `)` + | `num_workers` `(` custom($numWorkers, + type($numWorkers), $numWorkersDeviceType) `)` + | `vector_length` `(` custom($vectorLength, + type($vectorLength), $vectorLengthDeviceType) `)` + | `wait` `(` custom($waitOperands, + type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)` | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` ) diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index df4f7825545c2..45e0632db5ef2 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -615,15 +615,49 @@ unsigned ParallelOp::getNumDataOperands() { } Value ParallelOp::getDataOperand(unsigned i) { - unsigned numOptional = getAsync() ? 1 : 0; + unsigned numOptional = getAsync().size(); numOptional += getNumGangs().size(); - numOptional += getNumWorkers() ? 1 : 0; - numOptional += getVectorLength() ? 1 : 0; + numOptional += getNumWorkers().size(); + numOptional += getVectorLength().size(); numOptional += getIfCond() ? 1 : 0; numOptional += getSelfCond() ? 1 : 0; return getOperand(getWaitOperands().size() + numOptional + i); } +template +static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, + ArrayAttr deviceTypes, + llvm::StringRef keyword) { + if (operands.size() > 0 && deviceTypes.getValue().size() != operands.size()) + return op.emitOpError() << keyword << " operands count must match " + << keyword << " device_type count"; + return success(); +} + +template +static LogicalResult verifyDeviceTypeAndSegmentCountMatch( + Op op, OperandRange operands, DenseI32ArrayAttr segments, + ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) { + std::size_t numOperandsInSegments = 0; + + if (!segments) + return success(); + + for (auto segCount : segments.asArrayRef()) { + if (maxInSegment != 0 && segCount > maxInSegment) + return op.emitOpError() << keyword << " expects a maximum of " + << maxInSegment << " values per segment"; + numOperandsInSegments += segCount; + } + if (numOperandsInSegments != operands.size()) + return op.emitOpError() + << keyword << " operand count does not match count in segments"; + if (deviceTypes.getValue().size() != (size_t)segments.size()) + return op.emitOpError() + << keyword << " segment count does not match device_type count"; + return success(); +} + LogicalResult acc::ParallelOp::verify() { if (failed(checkSymOperandList( *this, getPrivatizations(), getGangPrivateOperands(), "private", @@ -633,11 +667,322 @@ LogicalResult acc::ParallelOp::verify() { *this, getReductionRecipes(), getReductionOperands(), "reduction", "reductions", false))) return failure(); - if (getNumGangs().size() > 3) - return emitOpError() << "num_gangs expects a maximum of 3 values"; + + if (failed(verifyDeviceTypeAndSegmentCountMatch( + *this, getNumGangs(), getNumGangsSegmentsAttr(), + getNumGangsDeviceTypeAttr(), "num_gangs", 3))) + return failure(); + + if (failed(verifyDeviceTypeAndSegmentCountMatch( + *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), + getWaitOperandsDeviceTypeAttr(), "wait"))) + return failure(); + + if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(), + getNumWorkersDeviceTypeAttr(), + "num_workers"))) + return failure(); + + if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(), + getVectorLengthDeviceTypeAttr(), + "vector_length"))) + return failure(); + + if (failed(verifyDeviceTypeCountMatch(*this, getAsync(), + getAsyncDeviceTypeAttr(), "async"))) + return failure(); + return checkDataOperands(*this, getDataClauseOperands()); } +static std::optional findSegment(ArrayAttr segments, + mlir::acc::DeviceType deviceType) { + unsigned segmentIdx = 0; + for (auto attr : segments) { + auto deviceTypeAttr = mlir::dyn_cast(attr); + if (deviceTypeAttr.getValue() == deviceType) + return std::make_optional(segmentIdx); + ++segmentIdx; + } + return std::nullopt; +} + +static mlir::Value +getValueInDeviceTypeSegment(std::optional arrayAttr, + mlir::Operation::operand_range range, + mlir::acc::DeviceType deviceType) { + if (!arrayAttr) + return {}; + if (auto pos = findSegment(*arrayAttr, deviceType)) + return range[*pos]; + return {}; +} + +bool acc::ParallelOp::hasAsyncOnly() { + return hasAsyncOnly(mlir::acc::DeviceType::None); +} + +bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { + if (auto arrayAttr = getAsyncOnly()) { + if (findSegment(*arrayAttr, deviceType)) + return true; + } + return false; +} + +mlir::Value acc::ParallelOp::getAsyncValue() { + return getAsyncValue(mlir::acc::DeviceType::None); +} + +mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) { + return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), + deviceType); +} + +mlir::Value acc::ParallelOp::getNumWorkersValue() { + return getNumWorkersValue(mlir::acc::DeviceType::None); +} + +mlir::Value +acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) { + return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(), + deviceType); +} + +mlir::Value acc::ParallelOp::getVectorLengthValue() { + return getVectorLengthValue(mlir::acc::DeviceType::None); +} + +mlir::Value +acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) { + return getValueInDeviceTypeSegment(getVectorLengthDeviceType(), + getVectorLength(), deviceType); +} + +mlir::Operation::operand_range ParallelOp::getNumGangsValues() { + return getNumGangsValues(mlir::acc::DeviceType::None); +} + +static mlir::Operation::operand_range +getValuesFromSegments(std::optional arrayAttr, + mlir::Operation::operand_range range, + std::optional> segments, + mlir::acc::DeviceType deviceType) { + if (!arrayAttr) + return range.take_front(0); + if (auto pos = findSegment(*arrayAttr, deviceType)) { + int32_t nbOperandsBefore = 0; + for (unsigned i = 0; i < *pos; ++i) + nbOperandsBefore += (*segments)[i]; + return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]); + } + return range.take_front(0); +} + +mlir::Operation::operand_range +ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) { + return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(), + getNumGangsSegments(), deviceType); +} + +bool acc::ParallelOp::hasWaitOnly() { + return hasWaitOnly(mlir::acc::DeviceType::None); +} + +bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { + if (auto arrayAttr = getWaitOnly()) { + if (findSegment(*arrayAttr, deviceType)) + return true; + } + return false; +} + +mlir::Operation::operand_range ParallelOp::getWaitValues() { + return getWaitValues(mlir::acc::DeviceType::None); +} + +mlir::Operation::operand_range +ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) { + return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), + getWaitOperandsSegments(), deviceType); +} + +static ParseResult parseNumGangs( + mlir::OpAsmParser &parser, + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &types, mlir::ArrayAttr &deviceTypes, + mlir::DenseI32ArrayAttr &segments) { + llvm::SmallVector attributes; + llvm::SmallVector seg; + + do { + if (failed(parser.parseLBrace())) + return failure(); + + if (failed(parser.parseCommaSeparatedList( + mlir::AsmParser::Delimiter::None, [&]() { + if (parser.parseOperand(operands.emplace_back()) || + parser.parseColonType(types.emplace_back())) + return failure(); + return success(); + }))) + return failure(); + + seg.push_back(operands.size()); + + if (failed(parser.parseRBrace())) + return failure(); + + if (succeeded(parser.parseOptionalLSquare())) { + if (parser.parseAttribute(attributes.emplace_back()) || + parser.parseRSquare()) + return failure(); + } else { + attributes.push_back(mlir::acc::DeviceTypeAttr::get( + parser.getContext(), mlir::acc::DeviceType::None)); + } + } while (succeeded(parser.parseOptionalComma())); + + llvm::SmallVector arrayAttr(attributes.begin(), + attributes.end()); + deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); + segments = DenseI32ArrayAttr::get(parser.getContext(), seg); + + return success(); +} + +static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::OperandRange operands, mlir::TypeRange types, + std::optional deviceTypes, + std::optional segments) { + unsigned opIdx = 0; + for (unsigned i = 0; i < deviceTypes->size(); ++i) { + if (i != 0) + p << ", "; + p << "{"; + for (int32_t j = 0; j < (*segments)[i]; ++j) { + if (j != 0) + p << ", "; + p << operands[opIdx] << " : " << operands[opIdx].getType(); + ++opIdx; + } + p << "}"; + auto deviceTypeAttr = + mlir::dyn_cast((*deviceTypes)[i]); + if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None) + p << " [" << (*deviceTypes)[i] << "]"; + } +} + +static ParseResult parseWaitOperands( + mlir::OpAsmParser &parser, + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &types, mlir::ArrayAttr &deviceTypes, + mlir::DenseI32ArrayAttr &segments) { + llvm::SmallVector attributes; + llvm::SmallVector seg; + + do { + if (failed(parser.parseLBrace())) + return failure(); + + if (failed(parser.parseCommaSeparatedList( + mlir::AsmParser::Delimiter::None, [&]() { + if (parser.parseOperand(operands.emplace_back()) || + parser.parseColonType(types.emplace_back())) + return failure(); + return success(); + }))) + return failure(); + + seg.push_back(operands.size()); + + if (failed(parser.parseRBrace())) + return failure(); + + if (succeeded(parser.parseOptionalLSquare())) { + if (parser.parseAttribute(attributes.emplace_back()) || + parser.parseRSquare()) + return failure(); + } else { + attributes.push_back(mlir::acc::DeviceTypeAttr::get( + parser.getContext(), mlir::acc::DeviceType::None)); + } + } while (succeeded(parser.parseOptionalComma())); + + llvm::SmallVector arrayAttr(attributes.begin(), + attributes.end()); + deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); + segments = DenseI32ArrayAttr::get(parser.getContext(), seg); + + return success(); +} + +static void printWaitOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::OperandRange operands, + mlir::TypeRange types, + std::optional deviceTypes, + std::optional segments) { + unsigned opIdx = 0; + for (unsigned i = 0; i < deviceTypes->size(); ++i) { + if (i != 0) + p << ", "; + p << "{"; + for (int32_t j = 0; j < (*segments)[i]; ++j) { + if (j != 0) + p << ", "; + p << operands[opIdx] << " : " << operands[opIdx].getType(); + ++opIdx; + } + p << "}"; + auto deviceTypeAttr = + mlir::dyn_cast((*deviceTypes)[i]); + if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None) + p << " [" << (*deviceTypes)[i] << "]"; + } +} + +static ParseResult parseDeviceTypeOperands( + mlir::OpAsmParser &parser, + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &types, mlir::ArrayAttr &deviceTypes) { + llvm::SmallVector attributes; + if (failed(parser.parseCommaSeparatedList([&]() { + if (parser.parseOperand(operands.emplace_back()) || + parser.parseColonType(types.emplace_back())) + return failure(); + if (succeeded(parser.parseOptionalLSquare())) { + if (parser.parseAttribute(attributes.emplace_back()) || + parser.parseRSquare()) + return failure(); + } else { + attributes.push_back(mlir::acc::DeviceTypeAttr::get( + parser.getContext(), mlir::acc::DeviceType::None)); + } + return success(); + }))) + return failure(); + llvm::SmallVector arrayAttr(attributes.begin(), + attributes.end()); + deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); + return success(); +} + +static void +printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::OperandRange operands, mlir::TypeRange types, + std::optional deviceTypes) { + for (unsigned i = 0, e = deviceTypes->size(); i < e; ++i) { + if (i != 0) + p << ", "; + p << operands[i] << " : " << operands[i].getType(); + auto deviceTypeAttr = + mlir::dyn_cast((*deviceTypes)[i]); + if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None) + p << " [" << (*deviceTypes)[i] << "]"; + } +} + //===----------------------------------------------------------------------===// // SerialOp //===----------------------------------------------------------------------===// @@ -648,12 +993,55 @@ unsigned SerialOp::getNumDataOperands() { } Value SerialOp::getDataOperand(unsigned i) { - unsigned numOptional = getAsync() ? 1 : 0; + unsigned numOptional = getAsync().size(); numOptional += getIfCond() ? 1 : 0; numOptional += getSelfCond() ? 1 : 0; return getOperand(getWaitOperands().size() + numOptional + i); } +bool acc::SerialOp::hasAsyncOnly() { + return hasAsyncOnly(mlir::acc::DeviceType::None); +} + +bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { + if (auto arrayAttr = getAsyncOnly()) { + if (findSegment(*arrayAttr, deviceType)) + return true; + } + return false; +} + +mlir::Value acc::SerialOp::getAsyncValue() { + return getAsyncValue(mlir::acc::DeviceType::None); +} + +mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) { + return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), + deviceType); +} + +bool acc::SerialOp::hasWaitOnly() { + return hasWaitOnly(mlir::acc::DeviceType::None); +} + +bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { + if (auto arrayAttr = getWaitOnly()) { + if (findSegment(*arrayAttr, deviceType)) + return true; + } + return false; +} + +mlir::Operation::operand_range SerialOp::getWaitValues() { + return getWaitValues(mlir::acc::DeviceType::None); +} + +mlir::Operation::operand_range +SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) { + return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), + getWaitOperandsSegments(), deviceType); +} + LogicalResult acc::SerialOp::verify() { if (failed(checkSymOperandList( *this, getPrivatizations(), getGangPrivateOperands(), "private", @@ -663,6 +1051,16 @@ LogicalResult acc::SerialOp::verify() { *this, getReductionRecipes(), getReductionOperands(), "reduction", "reductions", false))) return failure(); + + if (failed(verifyDeviceTypeAndSegmentCountMatch( + *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), + getWaitOperandsDeviceTypeAttr(), "wait"))) + return failure(); + + if (failed(verifyDeviceTypeCountMatch(*this, getAsync(), + getAsyncDeviceTypeAttr(), "async"))) + return failure(); + return checkDataOperands(*this, getDataClauseOperands()); } @@ -675,19 +1073,114 @@ unsigned KernelsOp::getNumDataOperands() { } Value KernelsOp::getDataOperand(unsigned i) { - unsigned numOptional = getAsync() ? 1 : 0; + unsigned numOptional = getAsync().size(); numOptional += getWaitOperands().size(); numOptional += getNumGangs().size(); - numOptional += getNumWorkers() ? 1 : 0; - numOptional += getVectorLength() ? 1 : 0; + numOptional += getNumWorkers().size(); + numOptional += getVectorLength().size(); numOptional += getIfCond() ? 1 : 0; numOptional += getSelfCond() ? 1 : 0; return getOperand(numOptional + i); } +bool acc::KernelsOp::hasAsyncOnly() { + return hasAsyncOnly(mlir::acc::DeviceType::None); +} + +bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { + if (auto arrayAttr = getAsyncOnly()) { + if (findSegment(*arrayAttr, deviceType)) + return true; + } + return false; +} + +mlir::Value acc::KernelsOp::getAsyncValue() { + return getAsyncValue(mlir::acc::DeviceType::None); +} + +mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) { + return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), + deviceType); +} + +mlir::Value acc::KernelsOp::getNumWorkersValue() { + return getNumWorkersValue(mlir::acc::DeviceType::None); +} + +mlir::Value +acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) { + return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(), + deviceType); +} + +mlir::Value acc::KernelsOp::getVectorLengthValue() { + return getVectorLengthValue(mlir::acc::DeviceType::None); +} + +mlir::Value +acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) { + return getValueInDeviceTypeSegment(getVectorLengthDeviceType(), + getVectorLength(), deviceType); +} + +mlir::Operation::operand_range KernelsOp::getNumGangsValues() { + return getNumGangsValues(mlir::acc::DeviceType::None); +} + +mlir::Operation::operand_range +KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) { + return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(), + getNumGangsSegments(), deviceType); +} + +bool acc::KernelsOp::hasWaitOnly() { + return hasWaitOnly(mlir::acc::DeviceType::None); +} + +bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { + if (auto arrayAttr = getWaitOnly()) { + if (findSegment(*arrayAttr, deviceType)) + return true; + } + return false; +} + +mlir::Operation::operand_range KernelsOp::getWaitValues() { + return getWaitValues(mlir::acc::DeviceType::None); +} + +mlir::Operation::operand_range +KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) { + return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), + getWaitOperandsSegments(), deviceType); +} + LogicalResult acc::KernelsOp::verify() { - if (getNumGangs().size() > 3) - return emitOpError() << "num_gangs expects a maximum of 3 values"; + if (failed(verifyDeviceTypeAndSegmentCountMatch( + *this, getNumGangs(), getNumGangsSegmentsAttr(), + getNumGangsDeviceTypeAttr(), "num_gangs", 3))) + return failure(); + + if (failed(verifyDeviceTypeAndSegmentCountMatch( + *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), + getWaitOperandsDeviceTypeAttr(), "wait"))) + return failure(); + + if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(), + getNumWorkersDeviceTypeAttr(), + "num_workers"))) + return failure(); + + if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(), + getVectorLengthDeviceTypeAttr(), + "vector_length"))) + return failure(); + + if (failed(verifyDeviceTypeCountMatch(*this, getAsync(), + getAsyncDeviceTypeAttr(), "async"))) + return failure(); + return checkDataOperands(*this, getDataClauseOperands()); } diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir index b9ac68d0592c8..c18d964b370f2 100644 --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -462,8 +462,8 @@ acc.loop gang() { // ----- %i64value = arith.constant 1 : i64 -// expected-error@+1 {{num_gangs expects a maximum of 3 values}} -acc.parallel num_gangs(%i64value, %i64value, %i64value, %i64value : i64, i64, i64, i64) { +// expected-error@+1 {{num_gangs expects a maximum of 3 values per segment}} +acc.parallel num_gangs({%i64value: i64, %i64value : i64, %i64value : i64, %i64value : i64}) { } // ----- diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index 05b0450c7fb91..5a95811685f84 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -137,7 +137,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x %pd = acc.present varPtr(%d : memref<10xf32>) -> memref<10xf32> acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { %private = acc.private varPtr(%c : memref<10xf32>) -> memref<10xf32> - acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) { + acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type]) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) { acc.loop gang { scf.for %x = %lb to %c10 step %st { acc.loop worker { @@ -180,7 +180,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x // CHECK-NEXT: [[NUMWORKERS:%.*]] = arith.constant 10 : i64 // CHECK: acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { // CHECK-NEXT: %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) -> memref<10xf32> -// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) { +// CHECK-NEXT: acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type]) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) { // CHECK-NEXT: acc.loop gang { // CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { // CHECK-NEXT: acc.loop worker { @@ -439,25 +439,25 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x } acc.parallel async(%idxValue: index) { } - acc.parallel wait(%i64value: i64) { + acc.parallel wait({%i64value: i64}) { } - acc.parallel wait(%i32value: i32) { + acc.parallel wait({%i32value: i32}) { } - acc.parallel wait(%idxValue: index) { + acc.parallel wait({%idxValue: index}) { } - acc.parallel wait(%i64value, %i32value, %idxValue : i64, i32, index) { + acc.parallel wait({%i64value : i64, %i32value : i32, %idxValue : index}) { } - acc.parallel num_gangs(%i64value: i64) { + acc.parallel num_gangs({%i64value: i64}) { } - acc.parallel num_gangs(%i32value: i32) { + acc.parallel num_gangs({%i32value: i32}) { } - acc.parallel num_gangs(%idxValue: index) { + acc.parallel num_gangs({%idxValue: index}) { } - acc.parallel num_gangs(%i64value, %i64value, %idxValue : i64, i64, index) { + acc.parallel num_gangs({%i64value: i64, %i64value: i64, %idxValue: index}) { } - acc.parallel num_workers(%i64value: i64) { + acc.parallel num_workers(%i64value: i64 [#acc.device_type]) { } - acc.parallel num_workers(%i32value: i32) { + acc.parallel num_workers(%i32value: i32 [#acc.device_type]) { } acc.parallel num_workers(%idxValue: index) { } @@ -492,25 +492,25 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x // CHECK-NEXT: } // CHECK: acc.parallel async([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.parallel wait([[I64VALUE]] : i64) { +// CHECK: acc.parallel wait({[[I64VALUE]] : i64}) { // CHECK-NEXT: } -// CHECK: acc.parallel wait([[I32VALUE]] : i32) { +// CHECK: acc.parallel wait({[[I32VALUE]] : i32}) { // CHECK-NEXT: } -// CHECK: acc.parallel wait([[IDXVALUE]] : index) { +// CHECK: acc.parallel wait({[[IDXVALUE]] : index}) { // CHECK-NEXT: } -// CHECK: acc.parallel wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) { +// CHECK: acc.parallel wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) { // CHECK-NEXT: } -// CHECK: acc.parallel num_gangs([[I64VALUE]] : i64) { +// CHECK: acc.parallel num_gangs({[[I64VALUE]] : i64}) { // CHECK-NEXT: } -// CHECK: acc.parallel num_gangs([[I32VALUE]] : i32) { +// CHECK: acc.parallel num_gangs({[[I32VALUE]] : i32}) { // CHECK-NEXT: } -// CHECK: acc.parallel num_gangs([[IDXVALUE]] : index) { +// CHECK: acc.parallel num_gangs({[[IDXVALUE]] : index}) { // CHECK-NEXT: } -// CHECK: acc.parallel num_gangs([[I64VALUE]], [[I64VALUE]], [[IDXVALUE]] : i64, i64, index) { +// CHECK: acc.parallel num_gangs({[[I64VALUE]] : i64, [[I64VALUE]] : i64, [[IDXVALUE]] : index}) { // CHECK-NEXT: } -// CHECK: acc.parallel num_workers([[I64VALUE]] : i64) { +// CHECK: acc.parallel num_workers([[I64VALUE]] : i64 [#acc.device_type]) { // CHECK-NEXT: } -// CHECK: acc.parallel num_workers([[I32VALUE]] : i32) { +// CHECK: acc.parallel num_workers([[I32VALUE]] : i32 [#acc.device_type]) { // CHECK-NEXT: } // CHECK: acc.parallel num_workers([[IDXVALUE]] : index) { // CHECK-NEXT: } @@ -590,13 +590,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 } acc.serial async(%idxValue: index) { } - acc.serial wait(%i64value: i64) { + acc.serial wait({%i64value: i64}) { } - acc.serial wait(%i32value: i32) { + acc.serial wait({%i32value: i32}) { } - acc.serial wait(%idxValue: index) { + acc.serial wait({%idxValue: index}) { } - acc.serial wait(%i64value, %i32value, %idxValue : i64, i32, index) { + acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) { } %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) -> memref<10xf32> acc.serial private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(@firstprivatization_memref_10xf32 -> %firstprivate : memref<10xf32>) { @@ -627,13 +627,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 // CHECK-NEXT: } // CHECK: acc.serial async([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.serial wait([[I64VALUE]] : i64) { +// CHECK: acc.serial wait({[[I64VALUE]] : i64}) { // CHECK-NEXT: } -// CHECK: acc.serial wait([[I32VALUE]] : i32) { +// CHECK: acc.serial wait({[[I32VALUE]] : i32}) { // CHECK-NEXT: } -// CHECK: acc.serial wait([[IDXVALUE]] : index) { +// CHECK: acc.serial wait({[[IDXVALUE]] : index}) { // CHECK-NEXT: } -// CHECK: acc.serial wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) { +// CHECK: acc.serial wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) { // CHECK-NEXT: } // CHECK: %[[FIRSTP:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) -> memref<10xf32> // CHECK: acc.serial firstprivate(@firstprivatization_memref_10xf32 -> %[[FIRSTP]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) { @@ -665,13 +665,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 } acc.kernels async(%idxValue: index) { } - acc.kernels wait(%i64value: i64) { + acc.kernels wait({%i64value: i64}) { } - acc.kernels wait(%i32value: i32) { + acc.kernels wait({%i32value: i32}) { } - acc.kernels wait(%idxValue: index) { + acc.kernels wait({%idxValue: index}) { } - acc.kernels wait(%i64value, %i32value, %idxValue : i64, i32, index) { + acc.kernels wait({%i64value : i64, %i32value : i32, %idxValue : index}) { } acc.kernels { } attributes {defaultAttr = #acc} @@ -699,13 +699,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 // CHECK-NEXT: } // CHECK: acc.kernels async([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.kernels wait([[I64VALUE]] : i64) { +// CHECK: acc.kernels wait({[[I64VALUE]] : i64}) { // CHECK-NEXT: } -// CHECK: acc.kernels wait([[I32VALUE]] : i32) { +// CHECK: acc.kernels wait({[[I32VALUE]] : i32}) { // CHECK-NEXT: } -// CHECK: acc.kernels wait([[IDXVALUE]] : index) { +// CHECK: acc.kernels wait({[[IDXVALUE]] : index}) { // CHECK-NEXT: } -// CHECK: acc.kernels wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) { +// CHECK: acc.kernels wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) { // CHECK-NEXT: } // CHECK: acc.kernels { // CHECK-NEXT: } attributes {defaultAttr = #acc} diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 2dec4ba3c001e..13393569f36fe 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -10,6 +10,7 @@ add_subdirectory(ArmSME) add_subdirectory(Index) add_subdirectory(LLVMIR) add_subdirectory(MemRef) +add_subdirectory(OpenACC) add_subdirectory(SCF) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt new file mode 100644 index 0000000000000..5133d7fc38296 --- /dev/null +++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_unittest(MLIROpenACCTests + OpenACCOpsTest.cpp +) +target_link_libraries(MLIROpenACCTests + PRIVATE + MLIRIR + MLIROpenACCDialect +) diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp new file mode 100644 index 0000000000000..dcf6c1240c55d --- /dev/null +++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp @@ -0,0 +1,275 @@ +//===- OpenACCOpsTest.cpp - OpenACC ops extra functiosn Tests -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::acc; + +//===----------------------------------------------------------------------===// +// Test Fixture +//===----------------------------------------------------------------------===// + +class OpenACCOpsTest : public ::testing::Test { +protected: + OpenACCOpsTest() : b(&context), loc(UnknownLoc::get(&context)) { + context.loadDialect(); + } + + MLIRContext context; + OpBuilder b; + Location loc; + llvm::SmallVector dtypes = { + DeviceType::None, DeviceType::Star, DeviceType::Multicore, + DeviceType::Default, DeviceType::Host, DeviceType::Nvidia, + DeviceType::Radeon}; + llvm::SmallVector dtypesWithoutNone = { + DeviceType::Star, DeviceType::Multicore, DeviceType::Default, + DeviceType::Host, DeviceType::Nvidia, DeviceType::Radeon}; +}; + +template +void testAsyncOnly(OpBuilder &b, MLIRContext &context, Location loc, + llvm::SmallVector &dtypes) { + Op op = b.create(loc, TypeRange{}, ValueRange{}); + EXPECT_FALSE(op.hasAsyncOnly()); + for (auto d : dtypes) + EXPECT_FALSE(op.hasAsyncOnly(d)); + + auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); + op.setAsyncOnlyAttr(b.getArrayAttr({dtypeNone})); + EXPECT_TRUE(op.hasAsyncOnly()); + EXPECT_TRUE(op.hasAsyncOnly(DeviceType::None)); + op.removeAsyncOnlyAttr(); + + auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host); + op.setAsyncOnlyAttr(b.getArrayAttr({dtypeHost})); + EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Host)); + EXPECT_FALSE(op.hasAsyncOnly()); + op.removeAsyncOnlyAttr(); + + auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star); + op.setAsyncOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar})); + EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Star)); + EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Host)); + EXPECT_FALSE(op.hasAsyncOnly()); +} + +TEST_F(OpenACCOpsTest, asyncOnlyTest) { + testAsyncOnly(b, context, loc, dtypes); + testAsyncOnly(b, context, loc, dtypes); + testAsyncOnly(b, context, loc, dtypes); +} + +template +void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc, + llvm::SmallVector &dtypes) { + Op op = b.create(loc, TypeRange{}, ValueRange{}); + + mlir::Value empty; + EXPECT_EQ(op.getAsyncValue(), empty); + for (auto d : dtypes) + EXPECT_EQ(op.getAsyncValue(d), empty); + + mlir::Value val = b.create(loc, b.getI32IntegerAttr(1)); + auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia); + op.setAsyncDeviceTypeAttr(b.getArrayAttr({dtypeNvidia})); + op.getAsyncMutable().assign(val); + EXPECT_EQ(op.getAsyncValue(), empty); + EXPECT_EQ(op.getAsyncValue(DeviceType::Nvidia), val); +} + +TEST_F(OpenACCOpsTest, asyncValueTest) { + testAsyncValue(b, context, loc, dtypes); + testAsyncValue(b, context, loc, dtypes); + testAsyncValue(b, context, loc, dtypes); +} + +template +void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc, + llvm::SmallVector &dtypes, + llvm::SmallVector &dtypesWithoutNone) { + Op op = b.create(loc, TypeRange{}, ValueRange{}); + EXPECT_EQ(op.getNumGangsValues().begin(), op.getNumGangsValues().end()); + + mlir::Value val1 = b.create(loc, b.getI32IntegerAttr(1)); + mlir::Value val2 = b.create(loc, b.getI32IntegerAttr(4)); + auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); + op.getNumGangsMutable().assign(val1); + op.setNumGangsDeviceTypeAttr(b.getArrayAttr({dtypeNone})); + op.setNumGangsSegments(b.getDenseI32ArrayAttr({1})); + EXPECT_EQ(op.getNumGangsValues().front(), val1); + for (auto d : dtypesWithoutNone) + EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end()); + + op.getNumGangsMutable().clear(); + op.removeNumGangsDeviceTypeAttr(); + op.removeNumGangsSegmentsAttr(); + for (auto d : dtypes) + EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end()); + + op.getNumGangsMutable().append(val1); + op.getNumGangsMutable().append(val2); + op.setNumGangsDeviceTypeAttr( + b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host), + DeviceTypeAttr::get(&context, DeviceType::Star)})); + op.setNumGangsSegments(b.getDenseI32ArrayAttr({1, 1})); + EXPECT_EQ(op.getNumGangsValues(DeviceType::None).begin(), + op.getNumGangsValues(DeviceType::None).end()); + EXPECT_EQ(op.getNumGangsValues(DeviceType::Host).front(), val1); + EXPECT_EQ(op.getNumGangsValues(DeviceType::Star).front(), val2); + + op.getNumGangsMutable().clear(); + op.removeNumGangsDeviceTypeAttr(); + op.removeNumGangsSegmentsAttr(); + for (auto d : dtypes) + EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end()); + + op.getNumGangsMutable().append(val1); + op.getNumGangsMutable().append(val2); + op.getNumGangsMutable().append(val1); + op.setNumGangsDeviceTypeAttr( + b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default), + DeviceTypeAttr::get(&context, DeviceType::Multicore)})); + op.setNumGangsSegments(b.getDenseI32ArrayAttr({2, 1})); + EXPECT_EQ(op.getNumGangsValues(DeviceType::None).begin(), + op.getNumGangsValues(DeviceType::None).end()); + EXPECT_EQ(op.getNumGangsValues(DeviceType::Default).front(), val1); + EXPECT_EQ(op.getNumGangsValues(DeviceType::Default).drop_front().front(), + val2); + EXPECT_EQ(op.getNumGangsValues(DeviceType::Multicore).front(), val1); +} + +TEST_F(OpenACCOpsTest, numGangsValuesTest) { + testNumGangsValues(b, context, loc, dtypes, dtypesWithoutNone); + testNumGangsValues(b, context, loc, dtypes, dtypesWithoutNone); +} + +template +void testVectorLength(OpBuilder &b, MLIRContext &context, Location loc, + llvm::SmallVector &dtypes) { + auto op = b.create(loc, TypeRange{}, ValueRange{}); + + mlir::Value empty; + EXPECT_EQ(op.getVectorLengthValue(), empty); + for (auto d : dtypes) + EXPECT_EQ(op.getVectorLengthValue(d), empty); + + mlir::Value val = b.create(loc, b.getI32IntegerAttr(1)); + auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia); + op.setVectorLengthDeviceTypeAttr(b.getArrayAttr({dtypeNvidia})); + op.getVectorLengthMutable().assign(val); + EXPECT_EQ(op.getVectorLengthValue(), empty); + EXPECT_EQ(op.getVectorLengthValue(DeviceType::Nvidia), val); +} + +TEST_F(OpenACCOpsTest, vectorLengthTest) { + testVectorLength(b, context, loc, dtypes); + testVectorLength(b, context, loc, dtypes); +} + +template +void testWaitOnly(OpBuilder &b, MLIRContext &context, Location loc, + llvm::SmallVector &dtypes, + llvm::SmallVector &dtypesWithoutNone) { + Op op = b.create(loc, TypeRange{}, ValueRange{}); + EXPECT_FALSE(op.hasWaitOnly()); + for (auto d : dtypes) + EXPECT_FALSE(op.hasWaitOnly(d)); + + auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); + op.setWaitOnlyAttr(b.getArrayAttr({dtypeNone})); + EXPECT_TRUE(op.hasWaitOnly()); + EXPECT_TRUE(op.hasWaitOnly(DeviceType::None)); + for (auto d : dtypesWithoutNone) + EXPECT_FALSE(op.hasWaitOnly(d)); + op.removeWaitOnlyAttr(); + + auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host); + op.setWaitOnlyAttr(b.getArrayAttr({dtypeHost})); + EXPECT_TRUE(op.hasWaitOnly(DeviceType::Host)); + EXPECT_FALSE(op.hasWaitOnly()); + op.removeWaitOnlyAttr(); + + auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star); + op.setWaitOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar})); + EXPECT_TRUE(op.hasWaitOnly(DeviceType::Star)); + EXPECT_TRUE(op.hasWaitOnly(DeviceType::Host)); + EXPECT_FALSE(op.hasWaitOnly()); +} + +TEST_F(OpenACCOpsTest, waitOnlyTest) { + testWaitOnly(b, context, loc, dtypes, dtypesWithoutNone); + testWaitOnly(b, context, loc, dtypes, dtypesWithoutNone); + testWaitOnly(b, context, loc, dtypes, dtypesWithoutNone); +} + +template +void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc, + llvm::SmallVector &dtypes, + llvm::SmallVector &dtypesWithoutNone) { + Op op = b.create(loc, TypeRange{}, ValueRange{}); + EXPECT_EQ(op.getWaitValues().begin(), op.getWaitValues().end()); + + mlir::Value val1 = b.create(loc, b.getI32IntegerAttr(1)); + mlir::Value val2 = b.create(loc, b.getI32IntegerAttr(4)); + auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); + op.getWaitOperandsMutable().assign(val1); + op.setWaitOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNone})); + op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({1})); + EXPECT_EQ(op.getWaitValues().front(), val1); + for (auto d : dtypesWithoutNone) + EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end()); + + op.getWaitOperandsMutable().clear(); + op.removeWaitOperandsDeviceTypeAttr(); + op.removeWaitOperandsSegmentsAttr(); + for (auto d : dtypes) + EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end()); + + op.getWaitOperandsMutable().append(val1); + op.getWaitOperandsMutable().append(val2); + op.setWaitOperandsDeviceTypeAttr( + b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host), + DeviceTypeAttr::get(&context, DeviceType::Star)})); + op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({1, 1})); + EXPECT_EQ(op.getWaitValues(DeviceType::None).begin(), + op.getWaitValues(DeviceType::None).end()); + EXPECT_EQ(op.getWaitValues(DeviceType::Host).front(), val1); + EXPECT_EQ(op.getWaitValues(DeviceType::Star).front(), val2); + + op.getWaitOperandsMutable().clear(); + op.removeWaitOperandsDeviceTypeAttr(); + op.removeWaitOperandsSegmentsAttr(); + for (auto d : dtypes) + EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end()); + + op.getWaitOperandsMutable().append(val1); + op.getWaitOperandsMutable().append(val2); + op.getWaitOperandsMutable().append(val1); + op.setWaitOperandsDeviceTypeAttr( + b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default), + DeviceTypeAttr::get(&context, DeviceType::Multicore)})); + op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({2, 1})); + EXPECT_EQ(op.getWaitValues(DeviceType::None).begin(), + op.getWaitValues(DeviceType::None).end()); + EXPECT_EQ(op.getWaitValues(DeviceType::Default).front(), val1); + EXPECT_EQ(op.getWaitValues(DeviceType::Default).drop_front().front(), val2); + EXPECT_EQ(op.getWaitValues(DeviceType::Multicore).front(), val1); +} + +TEST_F(OpenACCOpsTest, waitValuesTest) { + testWaitValues(b, context, loc, dtypes, dtypesWithoutNone); + testWaitValues(b, context, loc, dtypes, dtypesWithoutNone); + testWaitValues(b, context, loc, dtypes, dtypesWithoutNone); +}