diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index abac3ff585768..f6db839000690 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -898,16 +898,16 @@ static void genACC(Fortran::lower::AbstractConverter &converter, const auto &accClauseList = std::get(waitConstruct.t); - mlir::Value ifCond, waitDevnum, async; - SmallVector waitOperands; + mlir::Value ifCond, asyncOperand, waitDevnum, async; + SmallVector waitOperands; // Async clause have optional values but can be present with // no value as well. When there is no value, the op has an attribute to // represent the clause. bool addAsyncAttr = false; - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); Fortran::lower::StatementContext stmtCtx; if (waitArgument) { // wait has a value. @@ -930,35 +930,26 @@ static void genACC(Fortran::lower::AbstractConverter &converter, // Lower clauses values mapped to operands. // Keep track of each group of operands separatly as clauses can appear // more than once. - for (const auto &clause : accClauseList.v) { + for (const Fortran::parser::AccClause &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - mlir::Value cond = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); - ifCond = firOpBuilder.createConvert(currentLocation, - firOpBuilder.getI1Type(), cond); + genIfClause(converter, ifClause, ifCond, stmtCtx); } else if (const auto *asyncClause = std::get_if(&clause.u)) { - const auto &asyncClauseValue = asyncClause->v; - if (asyncClauseValue) { // async has a value. - async = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); - } else { - addAsyncAttr = true; - } + genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx); } } // Prepare the operand segement size attribute and the operands value range. - SmallVector operands; - SmallVector operandSegments; + SmallVector operands; + SmallVector operandSegments; addOperands(operands, operandSegments, waitOperands); addOperand(operands, operandSegments, async); addOperand(operands, operandSegments, waitDevnum); addOperand(operands, operandSegments, ifCond); - auto waitOp = createSimpleOp(firOpBuilder, currentLocation, - operands, operandSegments); + mlir::acc::WaitOp waitOp = createSimpleOp( + firOpBuilder, currentLocation, operands, operandSegments); if (addAsyncAttr) waitOp.asyncAttr(firOpBuilder.getUnitAttr()); diff --git a/flang/test/Lower/OpenACC/acc-wait.f90 b/flang/test/Lower/OpenACC/acc-wait.f90 new file mode 100644 index 0000000000000..70285999895f7 --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-wait.f90 @@ -0,0 +1,41 @@ +! This test checks lowering of OpenACC wait directive. + +! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s + +subroutine acc_update + integer :: async = 1 + logical :: ifCondition = .TRUE. + + !$acc wait +!CHECK: acc.wait{{$}} + + !$acc wait if(.true.) +!CHECK: [[IF1:%.*]] = arith.constant true +!CHECK: acc.wait if([[IF1]]){{$}} + + !$acc wait if(ifCondition) +!CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref> +!CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1 +!CHECK: acc.wait if([[IF2]]){{$}} + + !$acc wait(1, 2) +!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 +!CHECK: [[WAIT2:%.*]] = arith.constant 2 : i32 +!CHECK: acc.wait([[WAIT1]], [[WAIT2]] : i32, i32){{$}} + + !$acc wait(1) async +!CHECK: [[WAIT3:%.*]] = arith.constant 1 : i32 +!CHECK: acc.wait([[WAIT3]] : i32) attributes {async} + + !$acc wait(1) async(async) +!CHECK: [[WAIT3:%.*]] = arith.constant 1 : i32 +!CHECK: [[ASYNC1:%.*]] = fir.load %{{.*}} : !fir.ref +!CHECK: acc.wait([[WAIT3]] : i32) async([[ASYNC1]] : i32){{$}} + + !$acc wait(devnum: 3: queues: 1, 2) +!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 +!CHECK: [[WAIT2:%.*]] = arith.constant 2 : i32 +!CHECK: [[DEVNUM:%.*]] = arith.constant 3 : i32 +!CHECK: acc.wait([[WAIT1]], [[WAIT2]] : i32, i32) wait_devnum([[DEVNUM]] : i32){{$}} + +end subroutine acc_update