diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index fca8bbdea1acd..ed20ee2288c43 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -25,6 +25,9 @@ using namespace mlir; +// Special value for * passed in device_type or gang clauses. +static constexpr std::int64_t starCst{-1}; + static const Fortran::parser::Name * getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) { const auto *dataRef{std::get_if(&designator.u)}; @@ -133,6 +136,27 @@ static void genAsyncClause(Fortran::lower::AbstractConverter &converter, } } +static void genDeviceTypeClause( + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccClause::DeviceType *deviceTypeClause, + SmallVectorImpl &operands, + Fortran::lower::StatementContext &stmtCtx) { + const auto &deviceTypeValue = deviceTypeClause->v; + if (deviceTypeValue) { + for (const auto &scalarIntExpr : *deviceTypeValue) { + mlir::Value expr = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(scalarIntExpr), stmtCtx)); + operands.push_back(expr); + } + } else { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + // * was passed as value and will be represented as a special constant. + mlir::Value star = firOpBuilder.createIntegerConstant( + converter.getCurrentLocation(), firOpBuilder.getIndexType(), starCst); + operands.push_back(star); + } +} + static void genIfClause(Fortran::lower::AbstractConverter &converter, const Fortran::parser::AccClause::If *ifClause, mlir::Value &ifCond, @@ -738,22 +762,19 @@ static void genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter, const Fortran::parser::AccClauseList &accClauseList) { mlir::Value ifCond, deviceNum; - SmallVector deviceTypeOperands; + SmallVector deviceTypeOperands; - auto &firOpBuilder = converter.getFirOpBuilder(); - auto currentLocation = converter.getCurrentLocation(); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); Fortran::lower::StatementContext stmtCtx; // Lower clauses values mapped to operands. // Keep track of each group of operands separatly as clauses can appear // more than once. - for (const auto &clause : accClauseList.v) { + for (const Fortran::parser::AccClause &clause : accClauseList.v) { if (const auto *ifClause = std::get_if(&clause.u)) { - mlir::Value cond = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); - ifCond = firOpBuilder.createConvert(currentLocation, - firOpBuilder.getI1Type(), cond); + genIfClause(converter, ifClause, ifCond, stmtCtx); } else if (const auto *deviceNumClause = std::get_if( &clause.u)) { @@ -762,21 +783,8 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter, } else if (const auto *deviceTypeClause = std::get_if( &clause.u)) { - - const auto &deviceTypeValue = deviceTypeClause->v; - if (deviceTypeValue) { - for (const auto &scalarIntExpr : *deviceTypeValue) { - mlir::Value expr = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(scalarIntExpr), stmtCtx)); - deviceTypeOperands.push_back(expr); - } - } else { - // * was passed as value and will be represented as a -1 constant - // integer. - mlir::Value star = firOpBuilder.createIntegerConstant( - currentLocation, firOpBuilder.getIntegerType(32), /* STAR */ -1); - deviceTypeOperands.push_back(star); - } + genDeviceTypeClause(converter, deviceTypeClause, deviceTypeOperands, + stmtCtx); } } diff --git a/flang/test/Lower/OpenACC/acc-init.f90 b/flang/test/Lower/OpenACC/acc-init.f90 new file mode 100644 index 0000000000000..d20e6f8d5204c --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-init.f90 @@ -0,0 +1,30 @@ +! This test checks lowering of OpenACC init directive. + +! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s + +subroutine acc_init + logical :: ifCondition = .TRUE. + + !$acc init +!CHECK: acc.init{{$}} + + !$acc init if(.true.) +!CHECK: [[IF1:%.*]] = arith.constant true +!CHECK: acc.init if([[IF1]]){{$}} + + !$acc init if(ifCondition) +!CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref> +!CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1 +!CHECK: acc.init if([[IF2]]){{$}} + + !$acc init device_num(1) +!CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32 +!CHECK: acc.init device_num([[DEVNUM]] : i32){{$}} + + !$acc init device_num(1) device_type(1, 2) +!CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32 +!CHECK: [[DEVTYPE1:%.*]] = arith.constant 1 : i32 +!CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32 +!CHECK: acc.init device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) device_num([[DEVNUM]] : i32){{$}} + +end subroutine acc_init \ No newline at end of file diff --git a/flang/test/Lower/OpenACC/acc-shutdown.f90 b/flang/test/Lower/OpenACC/acc-shutdown.f90 new file mode 100644 index 0000000000000..6750d8685905c --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-shutdown.f90 @@ -0,0 +1,30 @@ +! This test checks lowering of OpenACC shutdown directive. + +! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s + +subroutine acc_shutdown + logical :: ifCondition = .TRUE. + + !$acc shutdown +!CHECK: acc.shutdown{{$}} + + !$acc shutdown if(.true.) +!CHECK: [[IF1:%.*]] = arith.constant true +!CHECK: acc.shutdown if([[IF1]]){{$}} + + !$acc shutdown if(ifCondition) +!CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref> +!CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1 +!CHECK: acc.shutdown if([[IF2]]){{$}} + + !$acc shutdown device_num(1) +!CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32 +!CHECK: acc.shutdown device_num([[DEVNUM]] : i32){{$}} + + !$acc shutdown device_num(1) device_type(1, 2) +!CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32 +!CHECK: [[DEVTYPE1:%.*]] = arith.constant 1 : i32 +!CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32 +!CHECK: acc.shutdown device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) device_num([[DEVNUM]] : i32){{$}} + +end subroutine acc_shutdown \ No newline at end of file