diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 59db5ab71b702..fae54eefb02f7 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) { case Fortran::parser::AccDeviceTypeExpr::Device::Multicore: return mlir::acc::DeviceType::Multicore; } - return mlir::acc::DeviceType::None; + return mlir::acc::DeviceType::Default; } static void gatherDeviceTypeAttrs( @@ -1781,25 +1781,26 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, bool outerCombined = false) { // Parallel operation operands + mlir::Value async; + mlir::Value numWorkers; + mlir::Value vectorLength; mlir::Value ifCond; mlir::Value selfCond; mlir::Value waitDevnum; llvm::SmallVector waitOperands, attachEntryOperands, copyEntryOperands, copyoutEntryOperands, createEntryOperands, - dataClauseOperands, numGangs, numWorkers, vectorLength, async; - llvm::SmallVector numGangsDeviceTypes, numWorkersDeviceTypes, - vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes, - waitOperandsDeviceTypes, waitOnlyDeviceTypes; - llvm::SmallVector numGangsSegments, waitOperandsSegments; + dataClauseOperands, numGangs; llvm::SmallVector reductionOperands, privateOperands, firstprivateOperands; llvm::SmallVector privatizations, firstPrivatizations, reductionRecipes; - // Self clause has optional values but can be present with + // Async, wait and self clause have optional values but can be present with // no value as well. When there is no value, the op has an attribute to // represent the clause. + bool addAsyncAttr = false; + bool addWaitAttr = false; bool addSelfAttr = false; bool hasDefaultNone = false; @@ -1807,11 +1808,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder = converter.getFirOpBuilder(); - // device_type attribute is set to `none` until a device_type clause is - // encountered. - auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( - builder.getContext(), mlir::acc::DeviceType::None); - // Lower clauses values mapped to operands. // Keep track of each group of operands separatly as clauses can appear // more than once. @@ -1819,52 +1815,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, mlir::Location clauseLocation = converter.genLocation(clause.source); if (const auto *asyncClause = std::get_if(&clause.u)) { - const auto &asyncClauseValue = asyncClause->v; - if (asyncClauseValue) { // async has a value. - async.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx))); - asyncDeviceTypes.push_back(crtDeviceTypeAttr); - } else { - asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr); - } + genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx); } else if (const auto *waitClause = std::get_if(&clause.u)) { - const auto &waitClauseValue = waitClause->v; - if (waitClauseValue) { // wait has a value. - const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; - const auto &waitList = - std::get>(waitArg.t); - auto crtWaitOperands = waitOperands.size(); - for (const Fortran::parser::ScalarIntExpr &value : waitList) { - waitOperands.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(value), stmtCtx))); - } - waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr); - waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands); - } else { - waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr); - } + genWaitClause(converter, waitClause, waitOperands, waitDevnum, + addWaitAttr, stmtCtx); } else if (const auto *numGangsClause = std::get_if( &clause.u)) { - auto crtNumGangs = numGangs.size(); for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v) numGangs.push_back(fir::getBase(converter.genExprValue( *Fortran::semantics::GetExpr(expr), stmtCtx))); - numGangsDeviceTypes.push_back(crtDeviceTypeAttr); - numGangsSegments.push_back(numGangs.size() - crtNumGangs); } else if (const auto *numWorkersClause = std::get_if( &clause.u)) { - numWorkers.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx))); - numWorkersDeviceTypes.push_back(crtDeviceTypeAttr); + numWorkers = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)); } else if (const auto *vectorLengthClause = std::get_if( &clause.u)) { - vectorLength.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx))); - vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr); + vectorLength = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)); } else if (const auto *ifClause = std::get_if(&clause.u)) { genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); @@ -2015,27 +1986,18 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, else if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_present) hasDefaultPresent = true; - } else if (const auto *deviceTypeClause = - std::get_if( - &clause.u)) { - const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList = - deviceTypeClause->v; - assert(deviceTypeExprList.v.size() == 1 && - "expect only one device_type expr"); - crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( - builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v)); } } // Prepare the operand segment size attribute and the operands value range. llvm::SmallVector operands; llvm::SmallVector operandSegments; - addOperands(operands, operandSegments, async); + addOperand(operands, operandSegments, async); addOperands(operands, operandSegments, waitOperands); if constexpr (!std::is_same_v) { addOperands(operands, operandSegments, numGangs); - addOperands(operands, operandSegments, numWorkers); - addOperands(operands, operandSegments, vectorLength); + addOperand(operands, operandSegments, numWorkers); + addOperand(operands, operandSegments, vectorLength); } addOperand(operands, operandSegments, ifCond); addOperand(operands, operandSegments, selfCond); @@ -2056,6 +2018,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, builder, currentLocation, eval, operands, operandSegments, outerCombined); + if (addAsyncAttr) + computeOp.setAsyncAttrAttr(builder.getUnitAttr()); + if (addWaitAttr) + computeOp.setWaitAttrAttr(builder.getUnitAttr()); if (addSelfAttr) computeOp.setSelfAttrAttr(builder.getUnitAttr()); @@ -2064,34 +2030,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter, if (hasDefaultPresent) computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present); - if constexpr (!std::is_same_v) { - if (!numWorkersDeviceTypes.empty()) - computeOp.setNumWorkersDeviceTypeAttr( - mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes)); - if (!vectorLengthDeviceTypes.empty()) - computeOp.setVectorLengthDeviceTypeAttr( - mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes)); - if (!numGangsDeviceTypes.empty()) - computeOp.setNumGangsDeviceTypeAttr( - mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes)); - if (!numGangsSegments.empty()) - computeOp.setNumGangsSegmentsAttr( - builder.getDenseI32ArrayAttr(numGangsSegments)); - } - if (!asyncDeviceTypes.empty()) - computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes)); - if (!asyncOnlyDeviceTypes.empty()) - computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes)); - - if (!waitOperandsDeviceTypes.empty()) - computeOp.setWaitOperandsDeviceTypeAttr( - builder.getArrayAttr(waitOperandsDeviceTypes)); - if (!waitOperandsSegments.empty()) - computeOp.setWaitOperandsSegmentsAttr( - builder.getDenseI32ArrayAttr(waitOperandsSegments)); - if (!waitOnlyDeviceTypes.empty()) - computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes)); - if constexpr (!std::is_same_v) { if (!privatizations.empty()) computeOp.setPrivatizationsAttr( diff --git a/flang/test/Lower/OpenACC/acc-device-type.f90 b/flang/test/Lower/OpenACC/acc-device-type.f90 deleted file mode 100644 index 871dbc95f60fc..0000000000000 --- a/flang/test/Lower/OpenACC/acc-device-type.f90 +++ /dev/null @@ -1,44 +0,0 @@ -! This test checks lowering of OpenACC device_type clause on directive where its -! position and the clauses that follow have special semantic - -! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s - -subroutine sub1() - - !$acc parallel num_workers(16) - !$acc end parallel - -! CHECK: acc.parallel num_workers(%c16{{.*}} : i32) { - - !$acc parallel num_workers(1) device_type(nvidia) num_workers(16) - !$acc end parallel - -! CHECK: acc.parallel num_workers(%c1{{.*}} : i32, %c16{{.*}} : i32 [#acc.device_type]) - - !$acc parallel device_type(*) num_workers(1) device_type(nvidia) num_workers(16) - !$acc end parallel - -! CHECK: acc.parallel num_workers(%c1{{.*}} : i32 [#acc.device_type], %c16{{.*}} : i32 [#acc.device_type]) - - !$acc parallel vector_length(1) - !$acc end parallel - -! CHECK: acc.parallel vector_length(%c1{{.*}} : i32) - - !$acc parallel device_type(multicore) vector_length(1) - !$acc end parallel - -! CHECK: acc.parallel vector_length(%c1{{.*}} : i32 [#acc.device_type]) - - !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(4) - !$acc end parallel - -! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c4{{.*}} : i32} [#acc.device_type]) - - !$acc parallel num_gangs(2) device_type(nvidia) num_gangs(1, 1, 1) - !$acc end parallel - -! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type]) - - -end subroutine diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90 index 93bc699031d55..34e7232697241 100644 --- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90 @@ -62,7 +62,7 @@ subroutine acc_kernels_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.terminator -! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type]} +! CHECK-NEXT: } attributes {asyncAttr} !$acc kernels loop async(1) DO i = 1, n @@ -103,7 +103,7 @@ subroutine acc_kernels_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.terminator -! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} +! CHECK-NEXT: } attributes {waitAttr} !$acc kernels loop wait(1) DO i = 1, n @@ -111,7 +111,7 @@ subroutine acc_kernels_loop END DO ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.kernels wait({[[WAIT1]] : i32}) { +! CHECK: acc.kernels wait([[WAIT1]] : i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -126,7 +126,7 @@ subroutine acc_kernels_loop ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) { +! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -141,7 +141,7 @@ subroutine acc_kernels_loop ! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) { +! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -155,7 +155,7 @@ subroutine acc_kernels_loop END DO ! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) { +! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -169,7 +169,7 @@ subroutine acc_kernels_loop END DO ! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) { +! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield diff --git a/flang/test/Lower/OpenACC/acc-kernels.f90 b/flang/test/Lower/OpenACC/acc-kernels.f90 index 99629bb835172..1f882c6df5106 100644 --- a/flang/test/Lower/OpenACC/acc-kernels.f90 +++ b/flang/test/Lower/OpenACC/acc-kernels.f90 @@ -40,7 +40,7 @@ subroutine acc_kernels ! CHECK: acc.kernels { ! CHECK: acc.terminator -! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type]} +! CHECK-NEXT: } attributes {asyncAttr} !$acc kernels async(1) !$acc end kernels @@ -63,13 +63,13 @@ subroutine acc_kernels ! CHECK: acc.kernels { ! CHECK: acc.terminator -! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} +! CHECK-NEXT: } attributes {waitAttr} !$acc kernels wait(1) !$acc end kernels ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.kernels wait({[[WAIT1]] : i32}) { +! CHECK: acc.kernels wait([[WAIT1]] : i32) { ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -78,7 +78,7 @@ subroutine acc_kernels ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) { +! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) { ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -87,7 +87,7 @@ subroutine acc_kernels ! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) { +! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) { ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -95,7 +95,7 @@ subroutine acc_kernels !$acc end kernels ! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) { +! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) { ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} @@ -103,7 +103,7 @@ subroutine acc_kernels !$acc end kernels ! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) { +! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) { ! CHECK: acc.terminator ! CHECK-NEXT: }{{$}} diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90 index deee7089033ea..1856215ce59d1 100644 --- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90 @@ -64,7 +64,7 @@ subroutine acc_parallel_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.yield -! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type]} +! CHECK-NEXT: } attributes {asyncAttr} !$acc parallel loop async(1) DO i = 1, n @@ -105,7 +105,7 @@ subroutine acc_parallel_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.yield -! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} +! CHECK-NEXT: } attributes {waitAttr} !$acc parallel loop wait(1) DO i = 1, n @@ -113,7 +113,7 @@ subroutine acc_parallel_loop END DO ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.parallel wait({[[WAIT1]] : i32}) { +! CHECK: acc.parallel wait([[WAIT1]] : i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -128,7 +128,7 @@ subroutine acc_parallel_loop ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) { +! CHECK: acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -143,7 +143,7 @@ subroutine acc_parallel_loop ! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) { +! CHECK: acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -157,7 +157,7 @@ subroutine acc_parallel_loop END DO ! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.parallel num_gangs({[[NUMGANGS1]] : i32}) { +! CHECK: acc.parallel num_gangs([[NUMGANGS1]] : i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -171,7 +171,7 @@ subroutine acc_parallel_loop END DO ! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.parallel num_gangs({[[NUMGANGS2]] : i32}) { +! CHECK: acc.parallel num_gangs([[NUMGANGS2]] : i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90 index a369bf01f2599..bbf51ba36a7de 100644 --- a/flang/test/Lower/OpenACC/acc-parallel.f90 +++ b/flang/test/Lower/OpenACC/acc-parallel.f90 @@ -62,7 +62,7 @@ subroutine acc_parallel ! CHECK: acc.parallel { ! CHECK: acc.yield -! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type]} +! CHECK-NEXT: } attributes {asyncAttr} !$acc parallel async(1) !$acc end parallel @@ -85,13 +85,13 @@ subroutine acc_parallel ! CHECK: acc.parallel { ! CHECK: acc.yield -! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} +! CHECK-NEXT: } attributes {waitAttr} !$acc parallel wait(1) !$acc end parallel ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.parallel wait({[[WAIT1]] : i32}) { +! CHECK: acc.parallel wait([[WAIT1]] : i32) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -100,7 +100,7 @@ subroutine acc_parallel ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.parallel wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) { +! CHECK: acc.parallel wait([[WAIT2]], [[WAIT3]] : i32, i32) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -109,7 +109,7 @@ subroutine acc_parallel ! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.parallel wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) { +! CHECK: acc.parallel wait([[WAIT4]], [[WAIT5]] : i32, i32) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -117,7 +117,7 @@ subroutine acc_parallel !$acc end parallel ! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.parallel num_gangs({[[NUMGANGS1]] : i32}) { +! CHECK: acc.parallel num_gangs([[NUMGANGS1]] : i32) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -125,14 +125,14 @@ subroutine acc_parallel !$acc end parallel ! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.parallel num_gangs({[[NUMGANGS2]] : i32}) { +! CHECK: acc.parallel num_gangs([[NUMGANGS2]] : i32) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} !$acc parallel num_gangs(1, 1, 1) !$acc end parallel -! CHECK: acc.parallel num_gangs({%{{.*}} : i32, %{{.*}} : i32, %{{.*}} : i32}) { +! CHECK: acc.parallel num_gangs(%{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90 index 712bfc80ce387..4ed7bb8da29a1 100644 --- a/flang/test/Lower/OpenACC/acc-serial-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90 @@ -83,7 +83,7 @@ subroutine acc_serial_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.yield -! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type]} +! CHECK-NEXT: } attributes {asyncAttr} !$acc serial loop async(1) DO i = 1, n @@ -124,7 +124,7 @@ subroutine acc_serial_loop ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} ! CHECK: acc.yield -! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} +! CHECK-NEXT: } attributes {waitAttr} !$acc serial loop wait(1) DO i = 1, n @@ -132,7 +132,7 @@ subroutine acc_serial_loop END DO ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.serial wait({[[WAIT1]] : i32}) { +! CHECK: acc.serial wait([[WAIT1]] : i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -147,7 +147,7 @@ subroutine acc_serial_loop ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.serial wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) { +! CHECK: acc.serial wait([[WAIT2]], [[WAIT3]] : i32, i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield @@ -162,7 +162,7 @@ subroutine acc_serial_loop ! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.serial wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) { +! CHECK: acc.serial wait([[WAIT4]], [[WAIT5]] : i32, i32) { ! CHECK: acc.loop { ! CHECK: fir.do_loop ! CHECK: acc.yield diff --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90 index d05e51d3d274f..ab3b0ccd54595 100644 --- a/flang/test/Lower/OpenACC/acc-serial.f90 +++ b/flang/test/Lower/OpenACC/acc-serial.f90 @@ -62,7 +62,7 @@ subroutine acc_serial ! CHECK: acc.serial { ! CHECK: acc.yield -! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type]} +! CHECK-NEXT: } attributes {asyncAttr} !$acc serial async(1) !$acc end serial @@ -85,13 +85,13 @@ subroutine acc_serial ! CHECK: acc.serial { ! CHECK: acc.yield -! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type]} +! CHECK-NEXT: } attributes {waitAttr} !$acc serial wait(1) !$acc end serial ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.serial wait({[[WAIT1]] : i32}) { +! CHECK: acc.serial wait([[WAIT1]] : i32) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -100,7 +100,7 @@ subroutine acc_serial ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.serial wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) { +! CHECK: acc.serial wait([[WAIT2]], [[WAIT3]] : i32, i32) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} @@ -109,7 +109,7 @@ subroutine acc_serial ! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref -! CHECK: acc.serial wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) { +! CHECK: acc.serial wait([[WAIT4]], [[WAIT5]] : i32, i32) { ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 234c1076e14e3..a78c3e98c9551 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -156,46 +156,29 @@ def DeclareActionAttr : OpenACC_Attr<"DeclareAction", "declare_action"> { } // Device type enumeration. -def OpenACC_DeviceTypeNone : I32EnumAttrCase<"None", 0, "none">; -def OpenACC_DeviceTypeStar : I32EnumAttrCase<"Star", 1, "star">; -def OpenACC_DeviceTypeDefault : I32EnumAttrCase<"Default", 2, "default">; -def OpenACC_DeviceTypeHost : I32EnumAttrCase<"Host", 3, "host">; -def OpenACC_DeviceTypeMulticore : I32EnumAttrCase<"Multicore", 4, "multicore">; -def OpenACC_DeviceTypeNvidia : I32EnumAttrCase<"Nvidia", 5, "nvidia">; -def OpenACC_DeviceTypeRadeon : I32EnumAttrCase<"Radeon", 6, "radeon">; +def OpenACC_DeviceTypeStar : I32EnumAttrCase<"Star", 0, "star">; +def OpenACC_DeviceTypeDefault : I32EnumAttrCase<"Default", 1, "default">; +def OpenACC_DeviceTypeHost : I32EnumAttrCase<"Host", 2, "host">; +def OpenACC_DeviceTypeMulticore : I32EnumAttrCase<"Multicore", 3, "multicore">; +def OpenACC_DeviceTypeNvidia : I32EnumAttrCase<"Nvidia", 4, "nvidia">; +def OpenACC_DeviceTypeRadeon : I32EnumAttrCase<"Radeon", 5, "radeon">; + def OpenACC_DeviceType : I32EnumAttr<"DeviceType", "built-in device type supported by OpenACC", - [OpenACC_DeviceTypeNone, OpenACC_DeviceTypeStar, OpenACC_DeviceTypeDefault, + [OpenACC_DeviceTypeStar, OpenACC_DeviceTypeDefault, OpenACC_DeviceTypeHost, OpenACC_DeviceTypeMulticore, OpenACC_DeviceTypeNvidia, OpenACC_DeviceTypeRadeon ]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::acc"; } - -// Device type attribute is used to associate a value for for clauses that -// appear after a device_type clause. The list of clauses allowed after the -// device_type clause is defined per construct as follows: -// Loop construct: collapse, gang, worker, vector, seq, independent, auto, -// and tile -// Compute construct: async, wait, num_gangs, num_workers, and vector_length -// Data construct: async and wait -// Routine: gang, worker, vector, seq and bind -// -// The `none` means that the value appears before any device_type clause. -// def OpenACC_DeviceTypeAttr : EnumAttr { let assemblyFormat = [{ ```<` $value `>` }]; } -def DeviceTypeArrayAttr : - TypedArrayAttrBase { - let constBuilderCall = ?; -} - // Define a resource for the OpenACC runtime counters. def OpenACC_RuntimeCounters : Resource<"::mlir::acc::RuntimeCounters">; @@ -880,32 +863,24 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", ``` }]; - let arguments = (ins - Variadic:$async, - OptionalAttr:$asyncDeviceType, - OptionalAttr:$asyncOnly, - Variadic:$waitOperands, - OptionalAttr:$waitOperandsSegments, - OptionalAttr:$waitOperandsDeviceType, - OptionalAttr:$waitOnly, - Variadic:$numGangs, - OptionalAttr:$numGangsSegments, - OptionalAttr:$numGangsDeviceType, - Variadic:$numWorkers, - OptionalAttr:$numWorkersDeviceType, - Variadic:$vectorLength, - OptionalAttr:$vectorLengthDeviceType, - Optional:$ifCond, - Optional:$selfCond, - UnitAttr:$selfAttr, - Variadic:$reductionOperands, - OptionalAttr:$reductionRecipes, - Variadic:$gangPrivateOperands, - OptionalAttr:$privatizations, - Variadic:$gangFirstPrivateOperands, - OptionalAttr:$firstprivatizations, - Variadic:$dataClauseOperands, - OptionalAttr:$defaultAttr); + let arguments = (ins Optional:$async, + UnitAttr:$asyncAttr, + Variadic:$waitOperands, + UnitAttr:$waitAttr, + Variadic:$numGangs, + Optional:$numWorkers, + Optional:$vectorLength, + Optional:$ifCond, + Optional:$selfCond, + UnitAttr:$selfAttr, + Variadic:$reductionOperands, + OptionalAttr:$reductionRecipes, + Variadic:$gangPrivateOperands, + OptionalAttr:$privatizations, + Variadic:$gangFirstPrivateOperands, + OptionalAttr:$firstprivatizations, + Variadic:$dataClauseOperands, + OptionalAttr:$defaultAttr); let regions = (region AnyRegion:$region); @@ -915,69 +890,22 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", /// The i-th data operand passed. Value getDataOperand(unsigned i); - - /// Return true if the op has the async attribute for the - /// mlir::acc::DeviceType::None device_type. - bool hasAsyncOnly(); - /// Return true if the op has the async attribute for the given device_type. - bool hasAsyncOnly(mlir::acc::DeviceType deviceType); - /// Return the value of the async clause if present. - mlir::Value getAsyncValue(); - /// Return the value of the async clause for the given device_type if - /// present. - mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType); - - /// Return the value of the num_workers clause if present. - mlir::Value getNumWorkersValue(); - /// Return the value of the num_workers clause for the given device_type if - /// present. - mlir::Value getNumWorkersValue(mlir::acc::DeviceType deviceType); - - /// Return the value of the vector_length clause if present. - mlir::Value getVectorLengthValue(); - /// Return the value of the vector_length clause for the given device_type - /// if present. - mlir::Value getVectorLengthValue(mlir::acc::DeviceType deviceType); - - /// Return the values of the num_gangs clause if present. - mlir::Operation::operand_range getNumGangsValues(); - /// Return the values of the num_gangs clause for the given device_type if - /// present. - mlir::Operation::operand_range - getNumGangsValues(mlir::acc::DeviceType deviceType); - - /// Return true if the op has the wait attribute for the - /// mlir::acc::DeviceType::None device_type. - bool hasWaitOnly(); - /// Return true if the op has the wait attribute for the given device_type. - bool hasWaitOnly(mlir::acc::DeviceType deviceType); - /// Return the values of the wait clause if present. - mlir::Operation::operand_range getWaitValues(); - /// Return the values of the wait clause for the given device_type if - /// present. - mlir::Operation::operand_range - getWaitValues(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ oilist( `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` - | `async` `(` custom($async, - type($async), $asyncDeviceType) `)` + | `async` `(` $async `:` type($async) `)` | `firstprivate` `(` custom($gangFirstPrivateOperands, type($gangFirstPrivateOperands), $firstprivatizations) `)` - | `num_gangs` `(` custom($numGangs, - type($numGangs), $numGangsDeviceType, $numGangsSegments) `)` - | `num_workers` `(` custom($numWorkers, - type($numWorkers), $numWorkersDeviceType) `)` + | `num_gangs` `(` $numGangs `:` type($numGangs) `)` + | `num_workers` `(` $numWorkers `:` type($numWorkers) `)` | `private` `(` custom( $gangPrivateOperands, type($gangPrivateOperands), $privatizations) `)` - | `vector_length` `(` custom($vectorLength, - type($vectorLength), $vectorLengthDeviceType) `)` - | `wait` `(` custom($waitOperands, - type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)` + | `vector_length` `(` $vectorLength `:` type($vectorLength) `)` + | `wait` `(` $waitOperands `:` type($waitOperands) `)` | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` | `reduction` `(` custom( @@ -1011,25 +939,21 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", ``` }]; - let arguments = (ins - Variadic:$async, - OptionalAttr:$asyncDeviceType, - OptionalAttr:$asyncOnly, - Variadic:$waitOperands, - OptionalAttr:$waitOperandsSegments, - OptionalAttr:$waitOperandsDeviceType, - OptionalAttr:$waitOnly, - Optional:$ifCond, - Optional:$selfCond, - UnitAttr:$selfAttr, - Variadic:$reductionOperands, - OptionalAttr:$reductionRecipes, - Variadic:$gangPrivateOperands, - OptionalAttr:$privatizations, - Variadic:$gangFirstPrivateOperands, - OptionalAttr:$firstprivatizations, - Variadic:$dataClauseOperands, - OptionalAttr:$defaultAttr); + let arguments = (ins Optional:$async, + UnitAttr:$asyncAttr, + Variadic:$waitOperands, + UnitAttr:$waitAttr, + Optional:$ifCond, + Optional:$selfCond, + UnitAttr:$selfAttr, + Variadic:$reductionOperands, + OptionalAttr:$reductionRecipes, + Variadic:$gangPrivateOperands, + OptionalAttr:$privatizations, + Variadic:$gangFirstPrivateOperands, + OptionalAttr:$firstprivatizations, + Variadic:$dataClauseOperands, + OptionalAttr:$defaultAttr); let regions = (region AnyRegion:$region); @@ -1039,44 +963,19 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", /// The i-th data operand passed. Value getDataOperand(unsigned i); - - /// Return true if the op has the async attribute for the - /// mlir::acc::DeviceType::None device_type. - bool hasAsyncOnly(); - /// Return true if the op has the async attribute for the given device_type. - bool hasAsyncOnly(mlir::acc::DeviceType deviceType); - /// Return the value of the async clause if present. - mlir::Value getAsyncValue(); - /// Return the value of the async clause for the given device_type if - /// present. - mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType); - - /// Return true if the op has the wait attribute for the - /// mlir::acc::DeviceType::None device_type. - bool hasWaitOnly(); - /// Return true if the op has the wait attribute for the given device_type. - bool hasWaitOnly(mlir::acc::DeviceType deviceType); - /// Return the values of the wait clause if present. - mlir::Operation::operand_range getWaitValues(); - /// Return the values of the wait clause for the given device_type if - /// present. - mlir::Operation::operand_range - getWaitValues(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ oilist( `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` - | `async` `(` custom($async, - type($async), $asyncDeviceType) `)` + | `async` `(` $async `:` type($async) `)` | `firstprivate` `(` custom($gangFirstPrivateOperands, type($gangFirstPrivateOperands), $firstprivatizations) `)` | `private` `(` custom( $gangPrivateOperands, type($gangPrivateOperands), $privatizations) `)` - | `wait` `(` custom($waitOperands, - type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)` + | `wait` `(` $waitOperands `:` type($waitOperands) `)` | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` | `reduction` `(` custom( @@ -1112,26 +1011,18 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", ``` }]; - let arguments = (ins - Variadic:$async, - OptionalAttr:$asyncDeviceType, - OptionalAttr:$asyncOnly, - Variadic:$waitOperands, - OptionalAttr:$waitOperandsSegments, - OptionalAttr:$waitOperandsDeviceType, - OptionalAttr:$waitOnly, - Variadic:$numGangs, - OptionalAttr:$numGangsSegments, - OptionalAttr:$numGangsDeviceType, - Variadic:$numWorkers, - OptionalAttr:$numWorkersDeviceType, - Variadic:$vectorLength, - OptionalAttr:$vectorLengthDeviceType, - Optional:$ifCond, - Optional:$selfCond, - UnitAttr:$selfAttr, - Variadic:$dataClauseOperands, - OptionalAttr:$defaultAttr); + let arguments = (ins Optional:$async, + UnitAttr:$asyncAttr, + Variadic:$waitOperands, + UnitAttr:$waitAttr, + Variadic:$numGangs, + Optional:$numWorkers, + Optional:$vectorLength, + Optional:$ifCond, + Optional:$selfCond, + UnitAttr:$selfAttr, + Variadic:$dataClauseOperands, + OptionalAttr:$defaultAttr); let regions = (region AnyRegion:$region); @@ -1141,63 +1032,16 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", /// The i-th data operand passed. Value getDataOperand(unsigned i); - - /// Return true if the op has the async attribute for the - /// mlir::acc::DeviceType::None device_type. - bool hasAsyncOnly(); - /// Return true if the op has the async attribute for the given device_type. - bool hasAsyncOnly(mlir::acc::DeviceType deviceType); - /// Return the value of the async clause if present. - mlir::Value getAsyncValue(); - /// Return the value of the async clause for the given device_type if - /// present. - mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType); - - /// Return the value of the num_workers clause if present. - mlir::Value getNumWorkersValue(); - /// Return the value of the num_workers clause for the given device_type if - /// present. - mlir::Value getNumWorkersValue(mlir::acc::DeviceType deviceType); - - /// Return the value of the vector_length clause if present. - mlir::Value getVectorLengthValue(); - /// Return the value of the vector_length clause for the given device_type - /// if present. - mlir::Value getVectorLengthValue(mlir::acc::DeviceType deviceType); - - /// Return the values of the num_gangs clause if present. - mlir::Operation::operand_range getNumGangsValues(); - /// Return the values of the num_gangs clause for the given device_type if - /// present. - mlir::Operation::operand_range - getNumGangsValues(mlir::acc::DeviceType deviceType); - - /// Return true if the op has the wait attribute for the - /// mlir::acc::DeviceType::None device_type. - bool hasWaitOnly(); - /// Return true if the op has the wait attribute for the given device_type. - bool hasWaitOnly(mlir::acc::DeviceType deviceType); - /// Return the values of the wait clause if present. - mlir::Operation::operand_range getWaitValues(); - /// Return the values of the wait clause for the given device_type if - /// present. - mlir::Operation::operand_range - getWaitValues(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ oilist( `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` - | `async` `(` custom($async, - type($async), $asyncDeviceType) `)` - | `num_gangs` `(` custom($numGangs, - type($numGangs), $numGangsDeviceType, $numGangsSegments) `)` - | `num_workers` `(` custom($numWorkers, - type($numWorkers), $numWorkersDeviceType) `)` - | `vector_length` `(` custom($vectorLength, - type($vectorLength), $vectorLengthDeviceType) `)` - | `wait` `(` custom($waitOperands, - type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments) `)` + | `async` `(` $async `:` type($async) `)` + | `num_gangs` `(` $numGangs `:` type($numGangs) `)` + | `num_workers` `(` $numWorkers `:` type($numWorkers) `)` + | `vector_length` `(` $vectorLength `:` type($vectorLength) `)` + | `wait` `(` $waitOperands `:` type($waitOperands) `)` | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` ) diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 45e0632db5ef2..df4f7825545c2 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -615,49 +615,15 @@ unsigned ParallelOp::getNumDataOperands() { } Value ParallelOp::getDataOperand(unsigned i) { - unsigned numOptional = getAsync().size(); + unsigned numOptional = getAsync() ? 1 : 0; numOptional += getNumGangs().size(); - numOptional += getNumWorkers().size(); - numOptional += getVectorLength().size(); + numOptional += getNumWorkers() ? 1 : 0; + numOptional += getVectorLength() ? 1 : 0; numOptional += getIfCond() ? 1 : 0; numOptional += getSelfCond() ? 1 : 0; return getOperand(getWaitOperands().size() + numOptional + i); } -template -static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, - ArrayAttr deviceTypes, - llvm::StringRef keyword) { - if (operands.size() > 0 && deviceTypes.getValue().size() != operands.size()) - return op.emitOpError() << keyword << " operands count must match " - << keyword << " device_type count"; - return success(); -} - -template -static LogicalResult verifyDeviceTypeAndSegmentCountMatch( - Op op, OperandRange operands, DenseI32ArrayAttr segments, - ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) { - std::size_t numOperandsInSegments = 0; - - if (!segments) - return success(); - - for (auto segCount : segments.asArrayRef()) { - if (maxInSegment != 0 && segCount > maxInSegment) - return op.emitOpError() << keyword << " expects a maximum of " - << maxInSegment << " values per segment"; - numOperandsInSegments += segCount; - } - if (numOperandsInSegments != operands.size()) - return op.emitOpError() - << keyword << " operand count does not match count in segments"; - if (deviceTypes.getValue().size() != (size_t)segments.size()) - return op.emitOpError() - << keyword << " segment count does not match device_type count"; - return success(); -} - LogicalResult acc::ParallelOp::verify() { if (failed(checkSymOperandList( *this, getPrivatizations(), getGangPrivateOperands(), "private", @@ -667,322 +633,11 @@ LogicalResult acc::ParallelOp::verify() { *this, getReductionRecipes(), getReductionOperands(), "reduction", "reductions", false))) return failure(); - - if (failed(verifyDeviceTypeAndSegmentCountMatch( - *this, getNumGangs(), getNumGangsSegmentsAttr(), - getNumGangsDeviceTypeAttr(), "num_gangs", 3))) - return failure(); - - if (failed(verifyDeviceTypeAndSegmentCountMatch( - *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), - getWaitOperandsDeviceTypeAttr(), "wait"))) - return failure(); - - if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(), - getNumWorkersDeviceTypeAttr(), - "num_workers"))) - return failure(); - - if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(), - getVectorLengthDeviceTypeAttr(), - "vector_length"))) - return failure(); - - if (failed(verifyDeviceTypeCountMatch(*this, getAsync(), - getAsyncDeviceTypeAttr(), "async"))) - return failure(); - + if (getNumGangs().size() > 3) + return emitOpError() << "num_gangs expects a maximum of 3 values"; return checkDataOperands(*this, getDataClauseOperands()); } -static std::optional findSegment(ArrayAttr segments, - mlir::acc::DeviceType deviceType) { - unsigned segmentIdx = 0; - for (auto attr : segments) { - auto deviceTypeAttr = mlir::dyn_cast(attr); - if (deviceTypeAttr.getValue() == deviceType) - return std::make_optional(segmentIdx); - ++segmentIdx; - } - return std::nullopt; -} - -static mlir::Value -getValueInDeviceTypeSegment(std::optional arrayAttr, - mlir::Operation::operand_range range, - mlir::acc::DeviceType deviceType) { - if (!arrayAttr) - return {}; - if (auto pos = findSegment(*arrayAttr, deviceType)) - return range[*pos]; - return {}; -} - -bool acc::ParallelOp::hasAsyncOnly() { - return hasAsyncOnly(mlir::acc::DeviceType::None); -} - -bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getAsyncOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; -} - -mlir::Value acc::ParallelOp::getAsyncValue() { - return getAsyncValue(mlir::acc::DeviceType::None); -} - -mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) { - return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), - deviceType); -} - -mlir::Value acc::ParallelOp::getNumWorkersValue() { - return getNumWorkersValue(mlir::acc::DeviceType::None); -} - -mlir::Value -acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) { - return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(), - deviceType); -} - -mlir::Value acc::ParallelOp::getVectorLengthValue() { - return getVectorLengthValue(mlir::acc::DeviceType::None); -} - -mlir::Value -acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) { - return getValueInDeviceTypeSegment(getVectorLengthDeviceType(), - getVectorLength(), deviceType); -} - -mlir::Operation::operand_range ParallelOp::getNumGangsValues() { - return getNumGangsValues(mlir::acc::DeviceType::None); -} - -static mlir::Operation::operand_range -getValuesFromSegments(std::optional arrayAttr, - mlir::Operation::operand_range range, - std::optional> segments, - mlir::acc::DeviceType deviceType) { - if (!arrayAttr) - return range.take_front(0); - if (auto pos = findSegment(*arrayAttr, deviceType)) { - int32_t nbOperandsBefore = 0; - for (unsigned i = 0; i < *pos; ++i) - nbOperandsBefore += (*segments)[i]; - return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]); - } - return range.take_front(0); -} - -mlir::Operation::operand_range -ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) { - return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(), - getNumGangsSegments(), deviceType); -} - -bool acc::ParallelOp::hasWaitOnly() { - return hasWaitOnly(mlir::acc::DeviceType::None); -} - -bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getWaitOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; -} - -mlir::Operation::operand_range ParallelOp::getWaitValues() { - return getWaitValues(mlir::acc::DeviceType::None); -} - -mlir::Operation::operand_range -ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) { - return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), - getWaitOperandsSegments(), deviceType); -} - -static ParseResult parseNumGangs( - mlir::OpAsmParser &parser, - llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &types, mlir::ArrayAttr &deviceTypes, - mlir::DenseI32ArrayAttr &segments) { - llvm::SmallVector attributes; - llvm::SmallVector seg; - - do { - if (failed(parser.parseLBrace())) - return failure(); - - if (failed(parser.parseCommaSeparatedList( - mlir::AsmParser::Delimiter::None, [&]() { - if (parser.parseOperand(operands.emplace_back()) || - parser.parseColonType(types.emplace_back())) - return failure(); - return success(); - }))) - return failure(); - - seg.push_back(operands.size()); - - if (failed(parser.parseRBrace())) - return failure(); - - if (succeeded(parser.parseOptionalLSquare())) { - if (parser.parseAttribute(attributes.emplace_back()) || - parser.parseRSquare()) - return failure(); - } else { - attributes.push_back(mlir::acc::DeviceTypeAttr::get( - parser.getContext(), mlir::acc::DeviceType::None)); - } - } while (succeeded(parser.parseOptionalComma())); - - llvm::SmallVector arrayAttr(attributes.begin(), - attributes.end()); - deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); - segments = DenseI32ArrayAttr::get(parser.getContext(), seg); - - return success(); -} - -static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, - mlir::OperandRange operands, mlir::TypeRange types, - std::optional deviceTypes, - std::optional segments) { - unsigned opIdx = 0; - for (unsigned i = 0; i < deviceTypes->size(); ++i) { - if (i != 0) - p << ", "; - p << "{"; - for (int32_t j = 0; j < (*segments)[i]; ++j) { - if (j != 0) - p << ", "; - p << operands[opIdx] << " : " << operands[opIdx].getType(); - ++opIdx; - } - p << "}"; - auto deviceTypeAttr = - mlir::dyn_cast((*deviceTypes)[i]); - if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None) - p << " [" << (*deviceTypes)[i] << "]"; - } -} - -static ParseResult parseWaitOperands( - mlir::OpAsmParser &parser, - llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &types, mlir::ArrayAttr &deviceTypes, - mlir::DenseI32ArrayAttr &segments) { - llvm::SmallVector attributes; - llvm::SmallVector seg; - - do { - if (failed(parser.parseLBrace())) - return failure(); - - if (failed(parser.parseCommaSeparatedList( - mlir::AsmParser::Delimiter::None, [&]() { - if (parser.parseOperand(operands.emplace_back()) || - parser.parseColonType(types.emplace_back())) - return failure(); - return success(); - }))) - return failure(); - - seg.push_back(operands.size()); - - if (failed(parser.parseRBrace())) - return failure(); - - if (succeeded(parser.parseOptionalLSquare())) { - if (parser.parseAttribute(attributes.emplace_back()) || - parser.parseRSquare()) - return failure(); - } else { - attributes.push_back(mlir::acc::DeviceTypeAttr::get( - parser.getContext(), mlir::acc::DeviceType::None)); - } - } while (succeeded(parser.parseOptionalComma())); - - llvm::SmallVector arrayAttr(attributes.begin(), - attributes.end()); - deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); - segments = DenseI32ArrayAttr::get(parser.getContext(), seg); - - return success(); -} - -static void printWaitOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, - mlir::OperandRange operands, - mlir::TypeRange types, - std::optional deviceTypes, - std::optional segments) { - unsigned opIdx = 0; - for (unsigned i = 0; i < deviceTypes->size(); ++i) { - if (i != 0) - p << ", "; - p << "{"; - for (int32_t j = 0; j < (*segments)[i]; ++j) { - if (j != 0) - p << ", "; - p << operands[opIdx] << " : " << operands[opIdx].getType(); - ++opIdx; - } - p << "}"; - auto deviceTypeAttr = - mlir::dyn_cast((*deviceTypes)[i]); - if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None) - p << " [" << (*deviceTypes)[i] << "]"; - } -} - -static ParseResult parseDeviceTypeOperands( - mlir::OpAsmParser &parser, - llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &types, mlir::ArrayAttr &deviceTypes) { - llvm::SmallVector attributes; - if (failed(parser.parseCommaSeparatedList([&]() { - if (parser.parseOperand(operands.emplace_back()) || - parser.parseColonType(types.emplace_back())) - return failure(); - if (succeeded(parser.parseOptionalLSquare())) { - if (parser.parseAttribute(attributes.emplace_back()) || - parser.parseRSquare()) - return failure(); - } else { - attributes.push_back(mlir::acc::DeviceTypeAttr::get( - parser.getContext(), mlir::acc::DeviceType::None)); - } - return success(); - }))) - return failure(); - llvm::SmallVector arrayAttr(attributes.begin(), - attributes.end()); - deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); - return success(); -} - -static void -printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, - mlir::OperandRange operands, mlir::TypeRange types, - std::optional deviceTypes) { - for (unsigned i = 0, e = deviceTypes->size(); i < e; ++i) { - if (i != 0) - p << ", "; - p << operands[i] << " : " << operands[i].getType(); - auto deviceTypeAttr = - mlir::dyn_cast((*deviceTypes)[i]); - if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None) - p << " [" << (*deviceTypes)[i] << "]"; - } -} - //===----------------------------------------------------------------------===// // SerialOp //===----------------------------------------------------------------------===// @@ -993,55 +648,12 @@ unsigned SerialOp::getNumDataOperands() { } Value SerialOp::getDataOperand(unsigned i) { - unsigned numOptional = getAsync().size(); + unsigned numOptional = getAsync() ? 1 : 0; numOptional += getIfCond() ? 1 : 0; numOptional += getSelfCond() ? 1 : 0; return getOperand(getWaitOperands().size() + numOptional + i); } -bool acc::SerialOp::hasAsyncOnly() { - return hasAsyncOnly(mlir::acc::DeviceType::None); -} - -bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getAsyncOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; -} - -mlir::Value acc::SerialOp::getAsyncValue() { - return getAsyncValue(mlir::acc::DeviceType::None); -} - -mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) { - return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), - deviceType); -} - -bool acc::SerialOp::hasWaitOnly() { - return hasWaitOnly(mlir::acc::DeviceType::None); -} - -bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getWaitOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; -} - -mlir::Operation::operand_range SerialOp::getWaitValues() { - return getWaitValues(mlir::acc::DeviceType::None); -} - -mlir::Operation::operand_range -SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) { - return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), - getWaitOperandsSegments(), deviceType); -} - LogicalResult acc::SerialOp::verify() { if (failed(checkSymOperandList( *this, getPrivatizations(), getGangPrivateOperands(), "private", @@ -1051,16 +663,6 @@ LogicalResult acc::SerialOp::verify() { *this, getReductionRecipes(), getReductionOperands(), "reduction", "reductions", false))) return failure(); - - if (failed(verifyDeviceTypeAndSegmentCountMatch( - *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), - getWaitOperandsDeviceTypeAttr(), "wait"))) - return failure(); - - if (failed(verifyDeviceTypeCountMatch(*this, getAsync(), - getAsyncDeviceTypeAttr(), "async"))) - return failure(); - return checkDataOperands(*this, getDataClauseOperands()); } @@ -1073,114 +675,19 @@ unsigned KernelsOp::getNumDataOperands() { } Value KernelsOp::getDataOperand(unsigned i) { - unsigned numOptional = getAsync().size(); + unsigned numOptional = getAsync() ? 1 : 0; numOptional += getWaitOperands().size(); numOptional += getNumGangs().size(); - numOptional += getNumWorkers().size(); - numOptional += getVectorLength().size(); + numOptional += getNumWorkers() ? 1 : 0; + numOptional += getVectorLength() ? 1 : 0; numOptional += getIfCond() ? 1 : 0; numOptional += getSelfCond() ? 1 : 0; return getOperand(numOptional + i); } -bool acc::KernelsOp::hasAsyncOnly() { - return hasAsyncOnly(mlir::acc::DeviceType::None); -} - -bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getAsyncOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; -} - -mlir::Value acc::KernelsOp::getAsyncValue() { - return getAsyncValue(mlir::acc::DeviceType::None); -} - -mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) { - return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), - deviceType); -} - -mlir::Value acc::KernelsOp::getNumWorkersValue() { - return getNumWorkersValue(mlir::acc::DeviceType::None); -} - -mlir::Value -acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) { - return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(), - deviceType); -} - -mlir::Value acc::KernelsOp::getVectorLengthValue() { - return getVectorLengthValue(mlir::acc::DeviceType::None); -} - -mlir::Value -acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) { - return getValueInDeviceTypeSegment(getVectorLengthDeviceType(), - getVectorLength(), deviceType); -} - -mlir::Operation::operand_range KernelsOp::getNumGangsValues() { - return getNumGangsValues(mlir::acc::DeviceType::None); -} - -mlir::Operation::operand_range -KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) { - return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(), - getNumGangsSegments(), deviceType); -} - -bool acc::KernelsOp::hasWaitOnly() { - return hasWaitOnly(mlir::acc::DeviceType::None); -} - -bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { - if (auto arrayAttr = getWaitOnly()) { - if (findSegment(*arrayAttr, deviceType)) - return true; - } - return false; -} - -mlir::Operation::operand_range KernelsOp::getWaitValues() { - return getWaitValues(mlir::acc::DeviceType::None); -} - -mlir::Operation::operand_range -KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) { - return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), - getWaitOperandsSegments(), deviceType); -} - LogicalResult acc::KernelsOp::verify() { - if (failed(verifyDeviceTypeAndSegmentCountMatch( - *this, getNumGangs(), getNumGangsSegmentsAttr(), - getNumGangsDeviceTypeAttr(), "num_gangs", 3))) - return failure(); - - if (failed(verifyDeviceTypeAndSegmentCountMatch( - *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), - getWaitOperandsDeviceTypeAttr(), "wait"))) - return failure(); - - if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(), - getNumWorkersDeviceTypeAttr(), - "num_workers"))) - return failure(); - - if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(), - getVectorLengthDeviceTypeAttr(), - "vector_length"))) - return failure(); - - if (failed(verifyDeviceTypeCountMatch(*this, getAsync(), - getAsyncDeviceTypeAttr(), "async"))) - return failure(); - + if (getNumGangs().size() > 3) + return emitOpError() << "num_gangs expects a maximum of 3 values"; return checkDataOperands(*this, getDataClauseOperands()); } diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir index c18d964b370f2..b9ac68d0592c8 100644 --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -462,8 +462,8 @@ acc.loop gang() { // ----- %i64value = arith.constant 1 : i64 -// expected-error@+1 {{num_gangs expects a maximum of 3 values per segment}} -acc.parallel num_gangs({%i64value: i64, %i64value : i64, %i64value : i64, %i64value : i64}) { +// expected-error@+1 {{num_gangs expects a maximum of 3 values}} +acc.parallel num_gangs(%i64value, %i64value, %i64value, %i64value : i64, i64, i64, i64) { } // ----- diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index 5a95811685f84..05b0450c7fb91 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -137,7 +137,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x %pd = acc.present varPtr(%d : memref<10xf32>) -> memref<10xf32> acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { %private = acc.private varPtr(%c : memref<10xf32>) -> memref<10xf32> - acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type]) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) { + acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) { acc.loop gang { scf.for %x = %lb to %c10 step %st { acc.loop worker { @@ -180,7 +180,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x // CHECK-NEXT: [[NUMWORKERS:%.*]] = arith.constant 10 : i64 // CHECK: acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { // CHECK-NEXT: %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) -> memref<10xf32> -// CHECK-NEXT: acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type]) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) { +// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) { // CHECK-NEXT: acc.loop gang { // CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { // CHECK-NEXT: acc.loop worker { @@ -439,25 +439,25 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x } acc.parallel async(%idxValue: index) { } - acc.parallel wait({%i64value: i64}) { + acc.parallel wait(%i64value: i64) { } - acc.parallel wait({%i32value: i32}) { + acc.parallel wait(%i32value: i32) { } - acc.parallel wait({%idxValue: index}) { + acc.parallel wait(%idxValue: index) { } - acc.parallel wait({%i64value : i64, %i32value : i32, %idxValue : index}) { + acc.parallel wait(%i64value, %i32value, %idxValue : i64, i32, index) { } - acc.parallel num_gangs({%i64value: i64}) { + acc.parallel num_gangs(%i64value: i64) { } - acc.parallel num_gangs({%i32value: i32}) { + acc.parallel num_gangs(%i32value: i32) { } - acc.parallel num_gangs({%idxValue: index}) { + acc.parallel num_gangs(%idxValue: index) { } - acc.parallel num_gangs({%i64value: i64, %i64value: i64, %idxValue: index}) { + acc.parallel num_gangs(%i64value, %i64value, %idxValue : i64, i64, index) { } - acc.parallel num_workers(%i64value: i64 [#acc.device_type]) { + acc.parallel num_workers(%i64value: i64) { } - acc.parallel num_workers(%i32value: i32 [#acc.device_type]) { + acc.parallel num_workers(%i32value: i32) { } acc.parallel num_workers(%idxValue: index) { } @@ -492,25 +492,25 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x // CHECK-NEXT: } // CHECK: acc.parallel async([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.parallel wait({[[I64VALUE]] : i64}) { +// CHECK: acc.parallel wait([[I64VALUE]] : i64) { // CHECK-NEXT: } -// CHECK: acc.parallel wait({[[I32VALUE]] : i32}) { +// CHECK: acc.parallel wait([[I32VALUE]] : i32) { // CHECK-NEXT: } -// CHECK: acc.parallel wait({[[IDXVALUE]] : index}) { +// CHECK: acc.parallel wait([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.parallel wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) { +// CHECK: acc.parallel wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) { // CHECK-NEXT: } -// CHECK: acc.parallel num_gangs({[[I64VALUE]] : i64}) { +// CHECK: acc.parallel num_gangs([[I64VALUE]] : i64) { // CHECK-NEXT: } -// CHECK: acc.parallel num_gangs({[[I32VALUE]] : i32}) { +// CHECK: acc.parallel num_gangs([[I32VALUE]] : i32) { // CHECK-NEXT: } -// CHECK: acc.parallel num_gangs({[[IDXVALUE]] : index}) { +// CHECK: acc.parallel num_gangs([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.parallel num_gangs({[[I64VALUE]] : i64, [[I64VALUE]] : i64, [[IDXVALUE]] : index}) { +// CHECK: acc.parallel num_gangs([[I64VALUE]], [[I64VALUE]], [[IDXVALUE]] : i64, i64, index) { // CHECK-NEXT: } -// CHECK: acc.parallel num_workers([[I64VALUE]] : i64 [#acc.device_type]) { +// CHECK: acc.parallel num_workers([[I64VALUE]] : i64) { // CHECK-NEXT: } -// CHECK: acc.parallel num_workers([[I32VALUE]] : i32 [#acc.device_type]) { +// CHECK: acc.parallel num_workers([[I32VALUE]] : i32) { // CHECK-NEXT: } // CHECK: acc.parallel num_workers([[IDXVALUE]] : index) { // CHECK-NEXT: } @@ -590,13 +590,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 } acc.serial async(%idxValue: index) { } - acc.serial wait({%i64value: i64}) { + acc.serial wait(%i64value: i64) { } - acc.serial wait({%i32value: i32}) { + acc.serial wait(%i32value: i32) { } - acc.serial wait({%idxValue: index}) { + acc.serial wait(%idxValue: index) { } - acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) { + acc.serial wait(%i64value, %i32value, %idxValue : i64, i32, index) { } %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) -> memref<10xf32> acc.serial private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(@firstprivatization_memref_10xf32 -> %firstprivate : memref<10xf32>) { @@ -627,13 +627,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 // CHECK-NEXT: } // CHECK: acc.serial async([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.serial wait({[[I64VALUE]] : i64}) { +// CHECK: acc.serial wait([[I64VALUE]] : i64) { // CHECK-NEXT: } -// CHECK: acc.serial wait({[[I32VALUE]] : i32}) { +// CHECK: acc.serial wait([[I32VALUE]] : i32) { // CHECK-NEXT: } -// CHECK: acc.serial wait({[[IDXVALUE]] : index}) { +// CHECK: acc.serial wait([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.serial wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) { +// CHECK: acc.serial wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) { // CHECK-NEXT: } // CHECK: %[[FIRSTP:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) -> memref<10xf32> // CHECK: acc.serial firstprivate(@firstprivatization_memref_10xf32 -> %[[FIRSTP]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) { @@ -665,13 +665,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 } acc.kernels async(%idxValue: index) { } - acc.kernels wait({%i64value: i64}) { + acc.kernels wait(%i64value: i64) { } - acc.kernels wait({%i32value: i32}) { + acc.kernels wait(%i32value: i32) { } - acc.kernels wait({%idxValue: index}) { + acc.kernels wait(%idxValue: index) { } - acc.kernels wait({%i64value : i64, %i32value : i32, %idxValue : index}) { + acc.kernels wait(%i64value, %i32value, %idxValue : i64, i32, index) { } acc.kernels { } attributes {defaultAttr = #acc} @@ -699,13 +699,13 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 // CHECK-NEXT: } // CHECK: acc.kernels async([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.kernels wait({[[I64VALUE]] : i64}) { +// CHECK: acc.kernels wait([[I64VALUE]] : i64) { // CHECK-NEXT: } -// CHECK: acc.kernels wait({[[I32VALUE]] : i32}) { +// CHECK: acc.kernels wait([[I32VALUE]] : i32) { // CHECK-NEXT: } -// CHECK: acc.kernels wait({[[IDXVALUE]] : index}) { +// CHECK: acc.kernels wait([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.kernels wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) { +// CHECK: acc.kernels wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) { // CHECK-NEXT: } // CHECK: acc.kernels { // CHECK-NEXT: } attributes {defaultAttr = #acc} diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 13393569f36fe..2dec4ba3c001e 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -10,7 +10,6 @@ add_subdirectory(ArmSME) add_subdirectory(Index) add_subdirectory(LLVMIR) add_subdirectory(MemRef) -add_subdirectory(OpenACC) add_subdirectory(SCF) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt deleted file mode 100644 index 5133d7fc38296..0000000000000 --- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -add_mlir_unittest(MLIROpenACCTests - OpenACCOpsTest.cpp -) -target_link_libraries(MLIROpenACCTests - PRIVATE - MLIRIR - MLIROpenACCDialect -) diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp deleted file mode 100644 index dcf6c1240c55d..0000000000000 --- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp +++ /dev/null @@ -1,275 +0,0 @@ -//===- OpenACCOpsTest.cpp - OpenACC ops extra functiosn Tests -------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/OpenACC/OpenACC.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OwningOpRef.h" -#include "gtest/gtest.h" - -using namespace mlir; -using namespace mlir::acc; - -//===----------------------------------------------------------------------===// -// Test Fixture -//===----------------------------------------------------------------------===// - -class OpenACCOpsTest : public ::testing::Test { -protected: - OpenACCOpsTest() : b(&context), loc(UnknownLoc::get(&context)) { - context.loadDialect(); - } - - MLIRContext context; - OpBuilder b; - Location loc; - llvm::SmallVector dtypes = { - DeviceType::None, DeviceType::Star, DeviceType::Multicore, - DeviceType::Default, DeviceType::Host, DeviceType::Nvidia, - DeviceType::Radeon}; - llvm::SmallVector dtypesWithoutNone = { - DeviceType::Star, DeviceType::Multicore, DeviceType::Default, - DeviceType::Host, DeviceType::Nvidia, DeviceType::Radeon}; -}; - -template -void testAsyncOnly(OpBuilder &b, MLIRContext &context, Location loc, - llvm::SmallVector &dtypes) { - Op op = b.create(loc, TypeRange{}, ValueRange{}); - EXPECT_FALSE(op.hasAsyncOnly()); - for (auto d : dtypes) - EXPECT_FALSE(op.hasAsyncOnly(d)); - - auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); - op.setAsyncOnlyAttr(b.getArrayAttr({dtypeNone})); - EXPECT_TRUE(op.hasAsyncOnly()); - EXPECT_TRUE(op.hasAsyncOnly(DeviceType::None)); - op.removeAsyncOnlyAttr(); - - auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host); - op.setAsyncOnlyAttr(b.getArrayAttr({dtypeHost})); - EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Host)); - EXPECT_FALSE(op.hasAsyncOnly()); - op.removeAsyncOnlyAttr(); - - auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star); - op.setAsyncOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar})); - EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Star)); - EXPECT_TRUE(op.hasAsyncOnly(DeviceType::Host)); - EXPECT_FALSE(op.hasAsyncOnly()); -} - -TEST_F(OpenACCOpsTest, asyncOnlyTest) { - testAsyncOnly(b, context, loc, dtypes); - testAsyncOnly(b, context, loc, dtypes); - testAsyncOnly(b, context, loc, dtypes); -} - -template -void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc, - llvm::SmallVector &dtypes) { - Op op = b.create(loc, TypeRange{}, ValueRange{}); - - mlir::Value empty; - EXPECT_EQ(op.getAsyncValue(), empty); - for (auto d : dtypes) - EXPECT_EQ(op.getAsyncValue(d), empty); - - mlir::Value val = b.create(loc, b.getI32IntegerAttr(1)); - auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia); - op.setAsyncDeviceTypeAttr(b.getArrayAttr({dtypeNvidia})); - op.getAsyncMutable().assign(val); - EXPECT_EQ(op.getAsyncValue(), empty); - EXPECT_EQ(op.getAsyncValue(DeviceType::Nvidia), val); -} - -TEST_F(OpenACCOpsTest, asyncValueTest) { - testAsyncValue(b, context, loc, dtypes); - testAsyncValue(b, context, loc, dtypes); - testAsyncValue(b, context, loc, dtypes); -} - -template -void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc, - llvm::SmallVector &dtypes, - llvm::SmallVector &dtypesWithoutNone) { - Op op = b.create(loc, TypeRange{}, ValueRange{}); - EXPECT_EQ(op.getNumGangsValues().begin(), op.getNumGangsValues().end()); - - mlir::Value val1 = b.create(loc, b.getI32IntegerAttr(1)); - mlir::Value val2 = b.create(loc, b.getI32IntegerAttr(4)); - auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); - op.getNumGangsMutable().assign(val1); - op.setNumGangsDeviceTypeAttr(b.getArrayAttr({dtypeNone})); - op.setNumGangsSegments(b.getDenseI32ArrayAttr({1})); - EXPECT_EQ(op.getNumGangsValues().front(), val1); - for (auto d : dtypesWithoutNone) - EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end()); - - op.getNumGangsMutable().clear(); - op.removeNumGangsDeviceTypeAttr(); - op.removeNumGangsSegmentsAttr(); - for (auto d : dtypes) - EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end()); - - op.getNumGangsMutable().append(val1); - op.getNumGangsMutable().append(val2); - op.setNumGangsDeviceTypeAttr( - b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host), - DeviceTypeAttr::get(&context, DeviceType::Star)})); - op.setNumGangsSegments(b.getDenseI32ArrayAttr({1, 1})); - EXPECT_EQ(op.getNumGangsValues(DeviceType::None).begin(), - op.getNumGangsValues(DeviceType::None).end()); - EXPECT_EQ(op.getNumGangsValues(DeviceType::Host).front(), val1); - EXPECT_EQ(op.getNumGangsValues(DeviceType::Star).front(), val2); - - op.getNumGangsMutable().clear(); - op.removeNumGangsDeviceTypeAttr(); - op.removeNumGangsSegmentsAttr(); - for (auto d : dtypes) - EXPECT_EQ(op.getNumGangsValues(d).begin(), op.getNumGangsValues(d).end()); - - op.getNumGangsMutable().append(val1); - op.getNumGangsMutable().append(val2); - op.getNumGangsMutable().append(val1); - op.setNumGangsDeviceTypeAttr( - b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default), - DeviceTypeAttr::get(&context, DeviceType::Multicore)})); - op.setNumGangsSegments(b.getDenseI32ArrayAttr({2, 1})); - EXPECT_EQ(op.getNumGangsValues(DeviceType::None).begin(), - op.getNumGangsValues(DeviceType::None).end()); - EXPECT_EQ(op.getNumGangsValues(DeviceType::Default).front(), val1); - EXPECT_EQ(op.getNumGangsValues(DeviceType::Default).drop_front().front(), - val2); - EXPECT_EQ(op.getNumGangsValues(DeviceType::Multicore).front(), val1); -} - -TEST_F(OpenACCOpsTest, numGangsValuesTest) { - testNumGangsValues(b, context, loc, dtypes, dtypesWithoutNone); - testNumGangsValues(b, context, loc, dtypes, dtypesWithoutNone); -} - -template -void testVectorLength(OpBuilder &b, MLIRContext &context, Location loc, - llvm::SmallVector &dtypes) { - auto op = b.create(loc, TypeRange{}, ValueRange{}); - - mlir::Value empty; - EXPECT_EQ(op.getVectorLengthValue(), empty); - for (auto d : dtypes) - EXPECT_EQ(op.getVectorLengthValue(d), empty); - - mlir::Value val = b.create(loc, b.getI32IntegerAttr(1)); - auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia); - op.setVectorLengthDeviceTypeAttr(b.getArrayAttr({dtypeNvidia})); - op.getVectorLengthMutable().assign(val); - EXPECT_EQ(op.getVectorLengthValue(), empty); - EXPECT_EQ(op.getVectorLengthValue(DeviceType::Nvidia), val); -} - -TEST_F(OpenACCOpsTest, vectorLengthTest) { - testVectorLength(b, context, loc, dtypes); - testVectorLength(b, context, loc, dtypes); -} - -template -void testWaitOnly(OpBuilder &b, MLIRContext &context, Location loc, - llvm::SmallVector &dtypes, - llvm::SmallVector &dtypesWithoutNone) { - Op op = b.create(loc, TypeRange{}, ValueRange{}); - EXPECT_FALSE(op.hasWaitOnly()); - for (auto d : dtypes) - EXPECT_FALSE(op.hasWaitOnly(d)); - - auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); - op.setWaitOnlyAttr(b.getArrayAttr({dtypeNone})); - EXPECT_TRUE(op.hasWaitOnly()); - EXPECT_TRUE(op.hasWaitOnly(DeviceType::None)); - for (auto d : dtypesWithoutNone) - EXPECT_FALSE(op.hasWaitOnly(d)); - op.removeWaitOnlyAttr(); - - auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host); - op.setWaitOnlyAttr(b.getArrayAttr({dtypeHost})); - EXPECT_TRUE(op.hasWaitOnly(DeviceType::Host)); - EXPECT_FALSE(op.hasWaitOnly()); - op.removeWaitOnlyAttr(); - - auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star); - op.setWaitOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar})); - EXPECT_TRUE(op.hasWaitOnly(DeviceType::Star)); - EXPECT_TRUE(op.hasWaitOnly(DeviceType::Host)); - EXPECT_FALSE(op.hasWaitOnly()); -} - -TEST_F(OpenACCOpsTest, waitOnlyTest) { - testWaitOnly(b, context, loc, dtypes, dtypesWithoutNone); - testWaitOnly(b, context, loc, dtypes, dtypesWithoutNone); - testWaitOnly(b, context, loc, dtypes, dtypesWithoutNone); -} - -template -void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc, - llvm::SmallVector &dtypes, - llvm::SmallVector &dtypesWithoutNone) { - Op op = b.create(loc, TypeRange{}, ValueRange{}); - EXPECT_EQ(op.getWaitValues().begin(), op.getWaitValues().end()); - - mlir::Value val1 = b.create(loc, b.getI32IntegerAttr(1)); - mlir::Value val2 = b.create(loc, b.getI32IntegerAttr(4)); - auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); - op.getWaitOperandsMutable().assign(val1); - op.setWaitOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNone})); - op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({1})); - EXPECT_EQ(op.getWaitValues().front(), val1); - for (auto d : dtypesWithoutNone) - EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end()); - - op.getWaitOperandsMutable().clear(); - op.removeWaitOperandsDeviceTypeAttr(); - op.removeWaitOperandsSegmentsAttr(); - for (auto d : dtypes) - EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end()); - - op.getWaitOperandsMutable().append(val1); - op.getWaitOperandsMutable().append(val2); - op.setWaitOperandsDeviceTypeAttr( - b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host), - DeviceTypeAttr::get(&context, DeviceType::Star)})); - op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({1, 1})); - EXPECT_EQ(op.getWaitValues(DeviceType::None).begin(), - op.getWaitValues(DeviceType::None).end()); - EXPECT_EQ(op.getWaitValues(DeviceType::Host).front(), val1); - EXPECT_EQ(op.getWaitValues(DeviceType::Star).front(), val2); - - op.getWaitOperandsMutable().clear(); - op.removeWaitOperandsDeviceTypeAttr(); - op.removeWaitOperandsSegmentsAttr(); - for (auto d : dtypes) - EXPECT_EQ(op.getWaitValues(d).begin(), op.getWaitValues(d).end()); - - op.getWaitOperandsMutable().append(val1); - op.getWaitOperandsMutable().append(val2); - op.getWaitOperandsMutable().append(val1); - op.setWaitOperandsDeviceTypeAttr( - b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default), - DeviceTypeAttr::get(&context, DeviceType::Multicore)})); - op.setWaitOperandsSegments(b.getDenseI32ArrayAttr({2, 1})); - EXPECT_EQ(op.getWaitValues(DeviceType::None).begin(), - op.getWaitValues(DeviceType::None).end()); - EXPECT_EQ(op.getWaitValues(DeviceType::Default).front(), val1); - EXPECT_EQ(op.getWaitValues(DeviceType::Default).drop_front().front(), val2); - EXPECT_EQ(op.getWaitValues(DeviceType::Multicore).front(), val1); -} - -TEST_F(OpenACCOpsTest, waitValuesTest) { - testWaitValues(b, context, loc, dtypes, dtypesWithoutNone); - testWaitValues(b, context, loc, dtypes, dtypesWithoutNone); - testWaitValues(b, context, loc, dtypes, dtypesWithoutNone); -}