diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index 1c0db205920cb..4c77ad49c6c01 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -668,15 +668,14 @@ createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter, } } -static void -createTargetDataOp(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpClauseList &opClauseList, - const llvm::omp::Directive &directive, - Fortran::lower::pft::Evaluation *eval = nullptr) { +static void createTargetOp(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpClauseList &opClauseList, + const llvm::omp::Directive &directive, + Fortran::lower::pft::Evaluation *eval = nullptr) { Fortran::lower::StatementContext stmtCtx; fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Value ifClauseOperand, deviceOperand; + mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand; mlir::UnitAttr nowaitAttr; llvm::SmallVector useDevicePtrOperand, useDeviceAddrOperand, mapOperands; @@ -777,6 +776,11 @@ createTargetDataOp(Fortran::lower::AbstractConverter &converter, } else if (std::get_if( &clause.u)) { TODO(currentLocation, "OMPD_target Use Device Addr"); + } else if (const auto &threadLmtClause = + std::get_if( + &clause.u)) { + threadLmtOperand = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx)); } else if (std::get_if(&clause.u)) { nowaitAttr = firOpBuilder.getUnitAttr(); } else if (const auto &mapClause = @@ -793,7 +797,12 @@ createTargetDataOp(Fortran::lower::AbstractConverter &converter, ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr); mlir::Location currentLocation = converter.getCurrentLocation(); - if (directive == llvm::omp::Directive::OMPD_target_data) { + if (directive == llvm::omp::Directive::OMPD_target) { + auto targetOp = firOpBuilder.create( + currentLocation, ifClauseOperand, deviceOperand, threadLmtOperand, + nowaitAttr, mapOperands, mapTypesArrayAttr); + createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList); + } else if (directive == llvm::omp::Directive::OMPD_target_data) { auto dataOp = firOpBuilder.create( currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand, useDeviceAddrOperand, mapOperands, mapTypesArrayAttr); @@ -837,7 +846,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_target_data: case llvm::omp::Directive::OMPD_target_enter_data: case llvm::omp::Directive::OMPD_target_exit_data: - createTargetDataOp(converter, opClauseList, directive.v); + createTargetOp(converter, opClauseList, directive.v); break; case llvm::omp::Directive::OMPD_target_update: TODO(converter.getCurrentLocation(), "OMPD_target_update"); @@ -1053,6 +1062,10 @@ genOMP(Fortran::lower::AbstractConverter &converter, // Map clause is exclusive to Target Data directives. It is handled // as part of the DataOp creation. continue; + } else if (std::get_if( + &clause.u)) { + // Handled as part of TargetOp creation. + continue; } else if (const auto &finalClause = std::get_if(&clause.u)) { mlir::Value finalVal = fir::getBase(converter.genExprValue( @@ -1120,8 +1133,10 @@ genOMP(Fortran::lower::AbstractConverter &converter, /*task_reductions=*/nullptr, allocateOperands, allocatorOperands); createBodyOfOp(taskGroupOp, converter, currentLocation, eval, &opClauseList); + } else if (blockDirective.v == llvm::omp::OMPD_target) { + createTargetOp(converter, opClauseList, blockDirective.v, &eval); } else if (blockDirective.v == llvm::omp::OMPD_target_data) { - createTargetDataOp(converter, opClauseList, blockDirective.v, &eval); + createTargetOp(converter, opClauseList, blockDirective.v, &eval); } else { TODO(converter.getCurrentLocation(), "Unhandled block directive"); } diff --git a/flang/test/Lower/OpenMP/target_data.f90 b/flang/test/Lower/OpenMP/target.f90 similarity index 78% rename from flang/test/Lower/OpenMP/target_data.f90 rename to flang/test/Lower/OpenMP/target.f90 index 77f2aae71b87a..0e8574a821940 100644 --- a/flang/test/Lower/OpenMP/target_data.f90 +++ b/flang/test/Lower/OpenMP/target.f90 @@ -124,3 +124,41 @@ subroutine omp_target_data !$omp end target data !CHECK: } end subroutine omp_target_data + +!=============================================================================== +! Target with region +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target() { +subroutine omp_target + !CHECK: %[[VAL_0:.*]] = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_targetEa"} + integer :: a(1024) + !CHECK: omp.target map((tofrom -> %[[VAL_0]] : !fir.ref>)) { + !$omp target map(tofrom: a) + !CHECK: %[[VAL_1:.*]] = arith.constant 10 : i32 + !CHECK: %[[VAL_2:.*]] = arith.constant 1 : i64 + !CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64 + !CHECK: %[[VAL_4:.*]] = arith.subi %[[VAL_2]], %[[VAL_3]] : i64 + !CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref>, i64) -> !fir.ref + !CHECK: fir.store %[[VAL_1]] to %[[VAL_5]] : !fir.ref + a(1) = 10 + !CHECK: omp.terminator + !$omp end target + !CHECK: } +end subroutine omp_target + +!=============================================================================== +! Target `thread_limit` clause +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_thread_limit() { +subroutine omp_target_thread_limit + integer :: a + !CHECK: %[[VAL_1:.*]] = arith.constant 64 : i32 + !CHECK: omp.target thread_limit(%[[VAL_1]] : i32) map((tofrom -> %[[VAL_0]] : !fir.ref)) { + !$omp target map(tofrom: a) thread_limit(64) + a = 10 + !CHECK: omp.terminator + !$omp end target + !CHECK: } +end subroutine omp_target_thread_limit