diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h index f4903f607a2da..90216d3aafcd5 100644 --- a/flang/lib/Lower/DirectivesCommon.h +++ b/flang/lib/Lower/DirectivesCommon.h @@ -198,7 +198,7 @@ static inline void genOmpAccAtomicUpdateStatement( fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::Location currentLocation = converter.getCurrentLocation(); - // Create the omp.atomic.update or acc.atmoic.update operation + // Create the omp.atomic.update or acc.atomic.update operation // // func.func @_QPsb() { // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"} @@ -206,40 +206,81 @@ static inline void genOmpAccAtomicUpdateStatement( // %2 = fir.load %1 : !fir.ref // omp.atomic.update %0 : !fir.ref { // ^bb0(%arg0: i32): - // %3 = fir.load %1 : !fir.ref - // %4 = arith.addi %arg0, %3 : i32 + // %3 = arith.addi %arg0, %2 : i32 // omp.yield(%3 : i32) // } // return // } - Fortran::lower::ExprToValueMap exprValueOverrides; + auto getArgExpression = + [](std::list::const_iterator it) { + const auto &arg{std::get((*it).t)}; + const auto *parserExpr{ + std::get_if>(&arg.u)}; + return parserExpr; + }; + // Lower any non atomic sub-expression before the atomic operation, and // map its lowered value to the semantic representation. - const Fortran::lower::SomeExpr *nonAtomicSubExpr{nullptr}; - std::visit( - [&](const auto &op) -> void { - using T = std::decay_t; - if constexpr (std::is_base_of::value) { - const auto &exprLeft{std::get<0>(op.t)}; - const auto &exprRight{std::get<1>(op.t)}; - if (exprLeft.value().source == assignmentStmtVariable.GetSource()) - nonAtomicSubExpr = Fortran::semantics::GetExpr(exprRight); - else - nonAtomicSubExpr = Fortran::semantics::GetExpr(exprLeft); - } + Fortran::lower::ExprToValueMap exprValueOverrides; + // Max and min intrinsics can have a list of Args. Hence we need a list + // of nonAtomicSubExprs to hoist. Currently, only the load is hoisted. + llvm::SmallVector nonAtomicSubExprs; + Fortran::common::visit( + Fortran::common::visitors{ + [&](const common::Indirection &funcRef) + -> void { + const auto &args{std::get>( + funcRef.value().v.t)}; + std::list::const_iterator beginIt = + args.begin(); + std::list::const_iterator endIt = args.end(); + const auto *exprFirst{getArgExpression(beginIt)}; + if (exprFirst && exprFirst->value().source == + assignmentStmtVariable.GetSource()) { + // Add everything except the first + beginIt++; + } else { + // Add everything except the last + endIt--; + } + std::list::const_iterator it; + for (it = beginIt; it != endIt; it++) { + const common::Indirection *expr = + getArgExpression(it); + if (expr) + nonAtomicSubExprs.push_back(Fortran::semantics::GetExpr(*expr)); + } + }, + [&](const auto &op) -> void { + using T = std::decay_t; + if constexpr (std::is_base_of< + Fortran::parser::Expr::IntrinsicBinary, + T>::value) { + const auto &exprLeft{std::get<0>(op.t)}; + const auto &exprRight{std::get<1>(op.t)}; + if (exprLeft.value().source == assignmentStmtVariable.GetSource()) + nonAtomicSubExprs.push_back( + Fortran::semantics::GetExpr(exprRight)); + else + nonAtomicSubExprs.push_back( + Fortran::semantics::GetExpr(exprLeft)); + } + }, }, assignmentStmtExpr.u); StatementContext nonAtomicStmtCtx; - if (nonAtomicSubExpr) { + if (!nonAtomicSubExprs.empty()) { // Generate non atomic part before all the atomic operations. auto insertionPoint = firOpBuilder.saveInsertionPoint(); if (atomicCaptureOp) firOpBuilder.setInsertionPoint(atomicCaptureOp); - mlir::Value nonAtomicVal = fir::getBase(converter.genExprValue( - currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx)); - exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal); + mlir::Value nonAtomicVal; + for (auto *nonAtomicSubExpr : nonAtomicSubExprs) { + nonAtomicVal = fir::getBase(converter.genExprValue( + currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx)); + exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal); + } if (atomicCaptureOp) firOpBuilder.restoreInsertionPoint(insertionPoint); } diff --git a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 index b2a993ddd8251..7b51a9cceb0ee 100644 --- a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 +++ b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 @@ -2,22 +2,51 @@ ! RUN: bbc -hlfir -fopenacc -emit-hlfir %s -o - | FileCheck %s ! RUN: %flang_fc1 -flang-experimental-hlfir -emit-hlfir -fopenacc %s -o - | FileCheck %s -subroutine sb - integer :: x, y - - !$acc atomic update - x = x + y -end subroutine - !CHECK-LABEL: @_QPsb +subroutine sb +!CHECK: %[[W_REF:.*]] = fir.alloca i32 {bindc_name = "w", uniq_name = "_QFsbEw"} +!CHECK: %[[W_DECL:.*]]:2 = hlfir.declare %[[W_REF]] {uniq_name = "_QFsbEw"} : (!fir.ref) -> (!fir.ref, !fir.ref) !CHECK: %[[X_REF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsbEx"} !CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) !CHECK: %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"} !CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[Z_REF:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFsbEz"} +!CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z_REF]] {uniq_name = "_QFsbEz"} : (!fir.ref) -> (!fir.ref, !fir.ref) + integer :: w, x, y, z + !CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref !CHECK: acc.atomic.update %[[X_DECL]]#1 : !fir.ref { !CHECK: ^bb0(%[[ARG_X:.*]]: i32): !CHECK: %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32 !CHECK: acc.yield %[[X_UPDATE_VAL]] : i32 !CHECK: } + !$acc atomic update + x = x + y + +!CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref +!CHECK: acc.atomic.update %[[X_DECL]]#1 : !fir.ref { +!CHECK: ^bb0(%[[ARG_X:.*]]: i32): +!CHECK: %[[X_UPDATE_VAL:.*]] = arith.ori %[[ARG_X]], %[[Y_VAL]] : i32 +!CHECK: acc.yield %[[X_UPDATE_VAL]] : i32 +!CHECK: } + !$acc atomic update + x = ior(x,y) + +!CHECK: %[[W_VAL:.*]] = fir.load %[[W_DECL]]#0 : !fir.ref +!CHECK: %[[X_VAL:.*]] = fir.load %[[X_DECL]]#0 : !fir.ref +!CHECK: %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref +!CHECK: acc.atomic.update %[[Z_DECL]]#1 : !fir.ref { +!CHECK: ^bb0(%[[ARG_Z:.*]]: i32): +!CHECK: %[[WX_CMP:.*]] = arith.cmpi slt, %[[W_VAL]], %[[X_VAL]] : i32 +!CHECK: %[[WX_MIN:.*]] = arith.select %[[WX_CMP]], %[[W_VAL]], %[[X_VAL]] : i32 +!CHECK: %[[WXY_CMP:.*]] = arith.cmpi slt, %[[WX_MIN]], %[[Y_VAL]] : i32 +!CHECK: %[[WXY_MIN:.*]] = arith.select %[[WXY_CMP]], %[[WX_MIN]], %[[Y_VAL]] : i32 +!CHECK: %[[WXYZ_CMP:.*]] = arith.cmpi slt, %[[WXY_MIN]], %[[ARG_Z]] : i32 +!CHECK: %[[WXYZ_MIN:.*]] = arith.select %[[WXYZ_CMP]], %[[WXY_MIN]], %[[ARG_Z]] : i32 +!CHECK: acc.yield %[[WXYZ_MIN]] : i32 +!CHECK: } + !$acc atomic update + z = min(w,x,y,z) + !CHECK: return +end subroutine diff --git a/flang/test/Lower/OpenMP/FIR/atomic-update.f90 b/flang/test/Lower/OpenMP/FIR/atomic-update.f90 index f4ebeef48cac4..bd3d4ace440ee 100644 --- a/flang/test/Lower/OpenMP/FIR/atomic-update.f90 +++ b/flang/test/Lower/OpenMP/FIR/atomic-update.f90 @@ -65,10 +65,10 @@ program OmpAtomicUpdate !CHECK: %[[RESULT:.*]] = arith.subi %[[ARG]], {{.*}} : i32 !CHECK: omp.yield(%[[RESULT]] : i32) !CHECK: } +!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref !CHECK: omp.atomic.update memory_order(relaxed) %[[Y]] : !fir.ref { !CHECK: ^bb0(%[[ARG:.*]]: i32): -!CHECK: %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref -!CHECK: %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref !CHECK: %{{.*}} = arith.cmpi sgt, %[[ARG]], %[[LOADED_X]] : i32 !CHECK: %{{.*}} = arith.select %{{.*}}, %[[ARG]], %[[LOADED_X]] : i32 !CHECK: %{{.*}} = arith.cmpi sgt, %{{.*}}, %[[LOADED_Z]] : i32 diff --git a/flang/test/Lower/OpenMP/atomic-update.f90 b/flang/test/Lower/OpenMP/atomic-update.f90 index e6319f70f3736..10da97c68c24a 100644 --- a/flang/test/Lower/OpenMP/atomic-update.f90 +++ b/flang/test/Lower/OpenMP/atomic-update.f90 @@ -25,13 +25,15 @@ program OmpAtomicUpdate !CHECK: %[[VAL_K_SHAPED:.*]] = fir.shape %[[VAL_c5]] : (index) -> !fir.shape<1> !CHECK: %[[VAL_K_DECLARE:.*]]:2 = hlfir.declare %[[VAL_K_ALLOCA]](%[[VAL_K_SHAPED]]) {{.*}} +!CHECK: %[[VAL_W_ALLOCA:.*]] = fir.alloca i32 {bindc_name = "w", uniq_name = "_QFEw"} +!CHECK: %[[VAL_W_DECLARE:.*]]:2 = hlfir.declare %[[VAL_W_ALLOCA]] {uniq_name = "_QFEw"} : (!fir.ref) -> (!fir.ref, !fir.ref) !CHECK: %[[VAL_X_ALLOCA:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"} !CHECK: %[[VAL_X_DECLARE:.*]]:2 = hlfir.declare %[[VAL_X_ALLOCA]] {uniq_name = "_QFEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) !CHECK: %[[VAL_Y_ALLOCA:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"} !CHECK: %[[VAL_Y_DECLARE:.*]]:2 = hlfir.declare %[[VAL_Y_ALLOCA]] {uniq_name = "_QFEy"} : (!fir.ref) -> (!fir.ref, !fir.ref) !CHECK: %[[VAL_Z_ALLOCA:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFEz"} !CHECK: %[[VAL_Z_DECLARE:.*]]:2 = hlfir.declare %[[VAL_Z_ALLOCA]] {uniq_name = "_QFEz"} : (!fir.ref) -> (!fir.ref, !fir.ref) - integer :: x, y, z + integer :: w, x, y, z integer, pointer :: a, b integer, target :: c, d integer(1) :: i1 @@ -95,10 +97,10 @@ program OmpAtomicUpdate !$omp atomic relaxed update hint(omp_sync_hint_uncontended) x = x - 1 -!CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_Y_DECLARE]]#1 : !fir.ref { -!CHECK: ^bb0(%[[ARG:.*]]: i32): !CHECK: %[[VAL_C_LOADED:.*]] = fir.load %[[VAL_C_DECLARE]]#0 : !fir.ref !CHECK: %[[VAL_D_LOADED:.*]] = fir.load %[[VAL_D_DECLARE]]#0 : !fir.ref +!CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_Y_DECLARE]]#1 : !fir.ref { +!CHECK: ^bb0(%[[ARG:.*]]: i32): !CHECK: {{.*}} = arith.cmpi sgt, %[[ARG]], {{.*}} : i32 !CHECK: {{.*}} = arith.select {{.*}}, %[[ARG]], {{.*}} : i32 !CHECK: {{.*}} = arith.cmpi sgt, {{.*}} @@ -146,4 +148,39 @@ program OmpAtomicUpdate !$omp atomic i1 = i1 + 1 !$omp end atomic + +!CHECK: %[[VAL_X_LOADED:.*]] = fir.load %[[VAL_X_DECLARE]]#0 : !fir.ref +!CHECK: omp.atomic.update %[[VAL_Y_DECLARE]]#1 : !fir.ref { +!CHECK: ^bb0(%[[ARG_Y:.*]]: i32): +!CHECK: %[[Y_UPDATE_VAL:.*]] = arith.andi %[[VAL_X_LOADED]], %[[ARG_Y]] : i32 +!CHECK: omp.yield(%[[Y_UPDATE_VAL]] : i32) +!CHECK: } + !$omp atomic update + y = iand(x,y) + +!CHECK: %[[VAL_X_LOADED:.*]] = fir.load %[[VAL_X_DECLARE]]#0 : !fir.ref +!CHECK: omp.atomic.update %[[VAL_Y_DECLARE]]#1 : !fir.ref { +!CHECK: ^bb0(%[[ARG_Y:.*]]: i32): +!CHECK: %[[Y_UPDATE_VAL:.*]] = arith.xori %[[VAL_X_LOADED]], %[[ARG_Y]] : i32 +!CHECK: omp.yield(%[[Y_UPDATE_VAL]] : i32) +!CHECK: } + !$omp atomic update + y = ieor(x,y) + +!CHECK: %[[VAL_X_LOADED:.*]] = fir.load %[[VAL_X_DECLARE]]#0 : !fir.ref +!CHECK: %[[VAL_Y_LOADED:.*]] = fir.load %[[VAL_Y_DECLARE]]#0 : !fir.ref +!CHECK: %[[VAL_Z_LOADED:.*]] = fir.load %[[VAL_Z_DECLARE]]#0 : !fir.ref +!CHECK: omp.atomic.update %[[VAL_W_DECLARE]]#1 : !fir.ref { +!CHECK: ^bb0(%[[ARG_W:.*]]: i32): +!CHECK: %[[WX_CMP:.*]] = arith.cmpi sgt, %[[ARG_W]], %[[VAL_X_LOADED]] : i32 +!CHECK: %[[WX_MIN:.*]] = arith.select %[[WX_CMP]], %[[ARG_W]], %[[VAL_X_LOADED]] : i32 +!CHECK: %[[WXY_CMP:.*]] = arith.cmpi sgt, %[[WX_MIN]], %[[VAL_Y_LOADED]] : i32 +!CHECK: %[[WXY_MIN:.*]] = arith.select %[[WXY_CMP]], %[[WX_MIN]], %[[VAL_Y_LOADED]] : i32 +!CHECK: %[[WXYZ_CMP:.*]] = arith.cmpi sgt, %[[WXY_MIN]], %[[VAL_Z_LOADED]] : i32 +!CHECK: %[[WXYZ_MIN:.*]] = arith.select %[[WXYZ_CMP]], %[[WXY_MIN]], %[[VAL_Z_LOADED]] : i32 +!CHECK: omp.yield(%[[WXYZ_MIN]] : i32) +!CHECK: } + !$omp atomic update + w = max(w,x,y,z) + end program OmpAtomicUpdate