Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flang][OpenMP][OpenACC] Hoist nonAtomic Expr in atomic intrinsics #72131

Merged
merged 1 commit into from
Nov 16, 2023

Conversation

kiranchandramohan
Copy link
Contributor

Hoist non-atomic expressions in atomic intrinsics. Use a list to collect the non-atomic expressions since the max and min intrinsics can have more than two operands.

Hoisting makes the lowering to LLVMIR of iand,ior,ieor intrinsics trivial. For max and min this still results in multiple instructions in the atomic region but the loads are removed, which should help improve performance.

Hoist non-atomic expressions in atomic intrinsics. Use a list to
collect the non-atomic expressions since the max and min intrinsics
can have more than two operands.

Hoisting makes the lowering to LLVMIR of iand,ior,ieor intrinsics
trivial. For max and min this still results in multiple instructions
in the atomic region but the loads are removed, which should help
improve performance.
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 13, 2023

@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-openacc

Author: Kiran Chandramohan (kiranchandramohan)

Changes

Hoist non-atomic expressions in atomic intrinsics. Use a list to collect the non-atomic expressions since the max and min intrinsics can have more than two operands.

Hoisting makes the lowering to LLVMIR of iand,ior,ieor intrinsics trivial. For max and min this still results in multiple instructions in the atomic region but the loads are removed, which should help improve performance.


Full diff: https://github.com/llvm/llvm-project/pull/72131.diff

4 Files Affected:

  • (modified) flang/lib/Lower/DirectivesCommon.h (+62-21)
  • (modified) flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 (+36-7)
  • (modified) flang/test/Lower/OpenMP/FIR/atomic-update.f90 (+2-2)
  • (modified) flang/test/Lower/OpenMP/atomic-update.f90 (+40-3)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index f4903f607a2da8c..90216d3aafcd5f0 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -198,7 +198,7 @@ static inline void genOmpAccAtomicUpdateStatement(
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
 
-  //  Create the omp.atomic.update or acc.atmoic.update operation
+  //  Create the omp.atomic.update or acc.atomic.update operation
   //
   //  func.func @_QPsb() {
   //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
@@ -206,40 +206,81 @@ static inline void genOmpAccAtomicUpdateStatement(
   //    %2 = fir.load %1 : !fir.ref<i32>
   //    omp.atomic.update   %0 : !fir.ref<i32> {
   //    ^bb0(%arg0: i32):
-  //      %3 = fir.load %1 : !fir.ref<i32>
-  //      %4 = arith.addi %arg0, %3 : i32
+  //      %3 = arith.addi %arg0, %2 : i32
   //      omp.yield(%3 : i32)
   //    }
   //    return
   //  }
 
-  Fortran::lower::ExprToValueMap exprValueOverrides;
+  auto getArgExpression =
+      [](std::list<parser::ActualArgSpec>::const_iterator it) {
+        const auto &arg{std::get<parser::ActualArg>((*it).t)};
+        const auto *parserExpr{
+            std::get_if<common::Indirection<parser::Expr>>(&arg.u)};
+        return parserExpr;
+      };
+
   // Lower any non atomic sub-expression before the atomic operation, and
   // map its lowered value to the semantic representation.
-  const Fortran::lower::SomeExpr *nonAtomicSubExpr{nullptr};
-  std::visit(
-      [&](const auto &op) -> void {
-        using T = std::decay_t<decltype(op)>;
-        if constexpr (std::is_base_of<Fortran::parser::Expr::IntrinsicBinary,
-                                      T>::value) {
-          const auto &exprLeft{std::get<0>(op.t)};
-          const auto &exprRight{std::get<1>(op.t)};
-          if (exprLeft.value().source == assignmentStmtVariable.GetSource())
-            nonAtomicSubExpr = Fortran::semantics::GetExpr(exprRight);
-          else
-            nonAtomicSubExpr = Fortran::semantics::GetExpr(exprLeft);
-        }
+  Fortran::lower::ExprToValueMap exprValueOverrides;
+  // Max and min intrinsics can have a list of Args. Hence we need a list
+  // of nonAtomicSubExprs to hoist. Currently, only the load is hoisted.
+  llvm::SmallVector<const Fortran::lower::SomeExpr *> nonAtomicSubExprs;
+  Fortran::common::visit(
+      Fortran::common::visitors{
+          [&](const common::Indirection<parser::FunctionReference> &funcRef)
+              -> void {
+            const auto &args{std::get<std::list<parser::ActualArgSpec>>(
+                funcRef.value().v.t)};
+            std::list<parser::ActualArgSpec>::const_iterator beginIt =
+                args.begin();
+            std::list<parser::ActualArgSpec>::const_iterator endIt = args.end();
+            const auto *exprFirst{getArgExpression(beginIt)};
+            if (exprFirst && exprFirst->value().source ==
+                                 assignmentStmtVariable.GetSource()) {
+              // Add everything except the first
+              beginIt++;
+            } else {
+              // Add everything except the last
+              endIt--;
+            }
+            std::list<parser::ActualArgSpec>::const_iterator it;
+            for (it = beginIt; it != endIt; it++) {
+              const common::Indirection<parser::Expr> *expr =
+                  getArgExpression(it);
+              if (expr)
+                nonAtomicSubExprs.push_back(Fortran::semantics::GetExpr(*expr));
+            }
+          },
+          [&](const auto &op) -> void {
+            using T = std::decay_t<decltype(op)>;
+            if constexpr (std::is_base_of<
+                              Fortran::parser::Expr::IntrinsicBinary,
+                              T>::value) {
+              const auto &exprLeft{std::get<0>(op.t)};
+              const auto &exprRight{std::get<1>(op.t)};
+              if (exprLeft.value().source == assignmentStmtVariable.GetSource())
+                nonAtomicSubExprs.push_back(
+                    Fortran::semantics::GetExpr(exprRight));
+              else
+                nonAtomicSubExprs.push_back(
+                    Fortran::semantics::GetExpr(exprLeft));
+            }
+          },
       },
       assignmentStmtExpr.u);
   StatementContext nonAtomicStmtCtx;
-  if (nonAtomicSubExpr) {
+  if (!nonAtomicSubExprs.empty()) {
     // Generate non atomic part before all the atomic operations.
     auto insertionPoint = firOpBuilder.saveInsertionPoint();
     if (atomicCaptureOp)
       firOpBuilder.setInsertionPoint(atomicCaptureOp);
-    mlir::Value nonAtomicVal = fir::getBase(converter.genExprValue(
-        currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
-    exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
+    mlir::Value nonAtomicVal;
+    for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
+      nonAtomicVal = fir::getBase(converter.genExprValue(
+          currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+      exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
+    }
     if (atomicCaptureOp)
       firOpBuilder.restoreInsertionPoint(insertionPoint);
   }
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
index b2a993ddd825105..7b51a9cceb0eecb 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
@@ -2,22 +2,51 @@
 ! RUN: bbc -hlfir -fopenacc -emit-hlfir %s -o - | FileCheck %s
 ! RUN: %flang_fc1 -flang-experimental-hlfir -emit-hlfir -fopenacc %s -o - | FileCheck %s
 
-subroutine sb
-  integer :: x, y
-
-  !$acc atomic update
-    x = x + y
-end subroutine
-
 !CHECK-LABEL: @_QPsb
+subroutine sb
+!CHECK:   %[[W_REF:.*]] = fir.alloca i32 {bindc_name = "w", uniq_name = "_QFsbEw"}
+!CHECK:   %[[W_DECL:.*]]:2 = hlfir.declare %[[W_REF]] {uniq_name = "_QFsbEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:   %[[X_REF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsbEx"}
 !CHECK:   %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:   %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
 !CHECK:   %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:   %[[Z_REF:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFsbEz"}
+!CHECK:   %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z_REF]] {uniq_name = "_QFsbEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  integer :: w, x, y, z
+
 !CHECK:   %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
 !CHECK:   acc.atomic.update   %[[X_DECL]]#1 : !fir.ref<i32> {
 !CHECK:   ^bb0(%[[ARG_X:.*]]: i32):
 !CHECK:     %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
 !CHECK:     acc.yield %[[X_UPDATE_VAL]] : i32
 !CHECK:   }
+  !$acc atomic update
+    x = x + y
+
+!CHECK:   %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
+!CHECK:   acc.atomic.update %[[X_DECL]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG_X:.*]]: i32):
+!CHECK:     %[[X_UPDATE_VAL:.*]] = arith.ori %[[ARG_X]], %[[Y_VAL]] : i32
+!CHECK:     acc.yield %[[X_UPDATE_VAL]] : i32
+!CHECK:   }
+  !$acc atomic update
+    x = ior(x,y)
+
+!CHECK:   %[[W_VAL:.*]] = fir.load %[[W_DECL]]#0 : !fir.ref<i32>
+!CHECK:   %[[X_VAL:.*]] = fir.load %[[X_DECL]]#0 : !fir.ref<i32>
+!CHECK:   %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
+!CHECK:   acc.atomic.update %[[Z_DECL]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG_Z:.*]]: i32):
+!CHECK:     %[[WX_CMP:.*]] = arith.cmpi slt, %[[W_VAL]], %[[X_VAL]] : i32
+!CHECK:     %[[WX_MIN:.*]] = arith.select %[[WX_CMP]], %[[W_VAL]], %[[X_VAL]] : i32
+!CHECK:     %[[WXY_CMP:.*]] = arith.cmpi slt, %[[WX_MIN]], %[[Y_VAL]] : i32
+!CHECK:     %[[WXY_MIN:.*]] = arith.select %[[WXY_CMP]], %[[WX_MIN]], %[[Y_VAL]] : i32
+!CHECK:     %[[WXYZ_CMP:.*]] = arith.cmpi slt, %[[WXY_MIN]], %[[ARG_Z]] : i32
+!CHECK:     %[[WXYZ_MIN:.*]] = arith.select %[[WXYZ_CMP]], %[[WXY_MIN]], %[[ARG_Z]] : i32
+!CHECK:     acc.yield %[[WXYZ_MIN]] : i32
+!CHECK:   }
+  !$acc atomic update
+    z = min(w,x,y,z)
+
 !CHECK:   return
+end subroutine
diff --git a/flang/test/Lower/OpenMP/FIR/atomic-update.f90 b/flang/test/Lower/OpenMP/FIR/atomic-update.f90
index f4ebeef48cac4a4..bd3d4ace440ee80 100644
--- a/flang/test/Lower/OpenMP/FIR/atomic-update.f90
+++ b/flang/test/Lower/OpenMP/FIR/atomic-update.f90
@@ -65,10 +65,10 @@ program OmpAtomicUpdate
 !CHECK:    %[[RESULT:.*]] = arith.subi %[[ARG]], {{.*}} : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
+!CHECK:  %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
+!CHECK:  %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref<i32>
 !CHECK:  omp.atomic.update   memory_order(relaxed) %[[Y]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
-!CHECK:    %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref<i32>
 !CHECK:    %{{.*}} = arith.cmpi sgt, %[[ARG]], %[[LOADED_X]] : i32
 !CHECK:    %{{.*}} = arith.select %{{.*}}, %[[ARG]], %[[LOADED_X]] : i32
 !CHECK:    %{{.*}} = arith.cmpi sgt, %{{.*}}, %[[LOADED_Z]] : i32
diff --git a/flang/test/Lower/OpenMP/atomic-update.f90 b/flang/test/Lower/OpenMP/atomic-update.f90
index e6319f70f373657..10da97c68c24a17 100644
--- a/flang/test/Lower/OpenMP/atomic-update.f90
+++ b/flang/test/Lower/OpenMP/atomic-update.f90
@@ -25,13 +25,15 @@ program OmpAtomicUpdate
 !CHECK: %[[VAL_K_SHAPED:.*]] = fir.shape %[[VAL_c5]] : (index) -> !fir.shape<1>
 !CHECK: %[[VAL_K_DECLARE:.*]]:2 = hlfir.declare %[[VAL_K_ALLOCA]](%[[VAL_K_SHAPED]]) {{.*}}
 
+!CHECK: %[[VAL_W_ALLOCA:.*]] = fir.alloca i32 {bindc_name = "w", uniq_name = "_QFEw"}
+!CHECK: %[[VAL_W_DECLARE:.*]]:2 = hlfir.declare %[[VAL_W_ALLOCA]] {uniq_name = "_QFEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[VAL_X_ALLOCA:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
 !CHECK: %[[VAL_X_DECLARE:.*]]:2 = hlfir.declare %[[VAL_X_ALLOCA]] {uniq_name = "_QFEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[VAL_Y_ALLOCA:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
 !CHECK: %[[VAL_Y_DECLARE:.*]]:2 = hlfir.declare %[[VAL_Y_ALLOCA]] {uniq_name = "_QFEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[VAL_Z_ALLOCA:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFEz"}
 !CHECK: %[[VAL_Z_DECLARE:.*]]:2 = hlfir.declare %[[VAL_Z_ALLOCA]] {uniq_name = "_QFEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-    integer :: x, y, z
+    integer :: w, x, y, z
     integer, pointer :: a, b
     integer, target :: c, d
     integer(1) :: i1
@@ -95,10 +97,10 @@ program OmpAtomicUpdate
    !$omp atomic relaxed update hint(omp_sync_hint_uncontended)
         x = x - 1
 
-!CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_Y_DECLARE]]#1 : !fir.ref<i32> {
-!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK:  %[[VAL_C_LOADED:.*]] = fir.load %[[VAL_C_DECLARE]]#0 : !fir.ref<i32>
 !CHECK:  %[[VAL_D_LOADED:.*]] = fir.load %[[VAL_D_DECLARE]]#0 : !fir.ref<i32>
+!CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_Y_DECLARE]]#1 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK:  {{.*}} = arith.cmpi sgt, %[[ARG]], {{.*}} : i32
 !CHECK:  {{.*}} = arith.select {{.*}}, %[[ARG]], {{.*}} : i32
 !CHECK:  {{.*}} = arith.cmpi sgt, {{.*}}
@@ -146,4 +148,39 @@ program OmpAtomicUpdate
    !$omp atomic
       i1 = i1 + 1
     !$omp end atomic
+
+!CHECK:   %[[VAL_X_LOADED:.*]] = fir.load %[[VAL_X_DECLARE]]#0 : !fir.ref<i32>
+!CHECK:   omp.atomic.update %[[VAL_Y_DECLARE]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG_Y:.*]]: i32):
+!CHECK:     %[[Y_UPDATE_VAL:.*]] = arith.andi %[[VAL_X_LOADED]], %[[ARG_Y]] : i32
+!CHECK:     omp.yield(%[[Y_UPDATE_VAL]] : i32)
+!CHECK:   }
+  !$omp atomic update
+    y = iand(x,y)
+
+!CHECK:   %[[VAL_X_LOADED:.*]] = fir.load %[[VAL_X_DECLARE]]#0 : !fir.ref<i32>
+!CHECK:   omp.atomic.update %[[VAL_Y_DECLARE]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG_Y:.*]]: i32):
+!CHECK:     %[[Y_UPDATE_VAL:.*]] = arith.xori %[[VAL_X_LOADED]], %[[ARG_Y]] : i32
+!CHECK:     omp.yield(%[[Y_UPDATE_VAL]] : i32)
+!CHECK:   }
+  !$omp atomic update
+    y = ieor(x,y)
+
+!CHECK:   %[[VAL_X_LOADED:.*]] = fir.load %[[VAL_X_DECLARE]]#0 : !fir.ref<i32>
+!CHECK:   %[[VAL_Y_LOADED:.*]] = fir.load %[[VAL_Y_DECLARE]]#0 : !fir.ref<i32>
+!CHECK:   %[[VAL_Z_LOADED:.*]] = fir.load %[[VAL_Z_DECLARE]]#0 : !fir.ref<i32>
+!CHECK:   omp.atomic.update %[[VAL_W_DECLARE]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG_W:.*]]: i32):
+!CHECK:     %[[WX_CMP:.*]] = arith.cmpi sgt, %[[ARG_W]], %[[VAL_X_LOADED]] : i32
+!CHECK:     %[[WX_MIN:.*]] = arith.select %[[WX_CMP]], %[[ARG_W]], %[[VAL_X_LOADED]] : i32
+!CHECK:     %[[WXY_CMP:.*]] = arith.cmpi sgt, %[[WX_MIN]], %[[VAL_Y_LOADED]] : i32
+!CHECK:     %[[WXY_MIN:.*]] = arith.select %[[WXY_CMP]], %[[WX_MIN]], %[[VAL_Y_LOADED]] : i32
+!CHECK:     %[[WXYZ_CMP:.*]] = arith.cmpi sgt, %[[WXY_MIN]], %[[VAL_Z_LOADED]] : i32
+!CHECK:     %[[WXYZ_MIN:.*]] = arith.select %[[WXYZ_CMP]], %[[WXY_MIN]], %[[VAL_Z_LOADED]] : i32
+!CHECK:     omp.yield(%[[WXYZ_MIN]] : i32)
+!CHECK:   }
+  !$omp atomic update
+    w = max(w,x,y,z)
+
 end program OmpAtomicUpdate

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 13, 2023

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kiran Chandramohan (kiranchandramohan)

Changes

Hoist non-atomic expressions in atomic intrinsics. Use a list to collect the non-atomic expressions since the max and min intrinsics can have more than two operands.

Hoisting makes the lowering to LLVMIR of iand,ior,ieor intrinsics trivial. For max and min this still results in multiple instructions in the atomic region but the loads are removed, which should help improve performance.


Full diff: https://github.com/llvm/llvm-project/pull/72131.diff

4 Files Affected:

  • (modified) flang/lib/Lower/DirectivesCommon.h (+62-21)
  • (modified) flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 (+36-7)
  • (modified) flang/test/Lower/OpenMP/FIR/atomic-update.f90 (+2-2)
  • (modified) flang/test/Lower/OpenMP/atomic-update.f90 (+40-3)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index f4903f607a2da8c..90216d3aafcd5f0 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -198,7 +198,7 @@ static inline void genOmpAccAtomicUpdateStatement(
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
 
-  //  Create the omp.atomic.update or acc.atmoic.update operation
+  //  Create the omp.atomic.update or acc.atomic.update operation
   //
   //  func.func @_QPsb() {
   //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
@@ -206,40 +206,81 @@ static inline void genOmpAccAtomicUpdateStatement(
   //    %2 = fir.load %1 : !fir.ref<i32>
   //    omp.atomic.update   %0 : !fir.ref<i32> {
   //    ^bb0(%arg0: i32):
-  //      %3 = fir.load %1 : !fir.ref<i32>
-  //      %4 = arith.addi %arg0, %3 : i32
+  //      %3 = arith.addi %arg0, %2 : i32
   //      omp.yield(%3 : i32)
   //    }
   //    return
   //  }
 
-  Fortran::lower::ExprToValueMap exprValueOverrides;
+  auto getArgExpression =
+      [](std::list<parser::ActualArgSpec>::const_iterator it) {
+        const auto &arg{std::get<parser::ActualArg>((*it).t)};
+        const auto *parserExpr{
+            std::get_if<common::Indirection<parser::Expr>>(&arg.u)};
+        return parserExpr;
+      };
+
   // Lower any non atomic sub-expression before the atomic operation, and
   // map its lowered value to the semantic representation.
-  const Fortran::lower::SomeExpr *nonAtomicSubExpr{nullptr};
-  std::visit(
-      [&](const auto &op) -> void {
-        using T = std::decay_t<decltype(op)>;
-        if constexpr (std::is_base_of<Fortran::parser::Expr::IntrinsicBinary,
-                                      T>::value) {
-          const auto &exprLeft{std::get<0>(op.t)};
-          const auto &exprRight{std::get<1>(op.t)};
-          if (exprLeft.value().source == assignmentStmtVariable.GetSource())
-            nonAtomicSubExpr = Fortran::semantics::GetExpr(exprRight);
-          else
-            nonAtomicSubExpr = Fortran::semantics::GetExpr(exprLeft);
-        }
+  Fortran::lower::ExprToValueMap exprValueOverrides;
+  // Max and min intrinsics can have a list of Args. Hence we need a list
+  // of nonAtomicSubExprs to hoist. Currently, only the load is hoisted.
+  llvm::SmallVector<const Fortran::lower::SomeExpr *> nonAtomicSubExprs;
+  Fortran::common::visit(
+      Fortran::common::visitors{
+          [&](const common::Indirection<parser::FunctionReference> &funcRef)
+              -> void {
+            const auto &args{std::get<std::list<parser::ActualArgSpec>>(
+                funcRef.value().v.t)};
+            std::list<parser::ActualArgSpec>::const_iterator beginIt =
+                args.begin();
+            std::list<parser::ActualArgSpec>::const_iterator endIt = args.end();
+            const auto *exprFirst{getArgExpression(beginIt)};
+            if (exprFirst && exprFirst->value().source ==
+                                 assignmentStmtVariable.GetSource()) {
+              // Add everything except the first
+              beginIt++;
+            } else {
+              // Add everything except the last
+              endIt--;
+            }
+            std::list<parser::ActualArgSpec>::const_iterator it;
+            for (it = beginIt; it != endIt; it++) {
+              const common::Indirection<parser::Expr> *expr =
+                  getArgExpression(it);
+              if (expr)
+                nonAtomicSubExprs.push_back(Fortran::semantics::GetExpr(*expr));
+            }
+          },
+          [&](const auto &op) -> void {
+            using T = std::decay_t<decltype(op)>;
+            if constexpr (std::is_base_of<
+                              Fortran::parser::Expr::IntrinsicBinary,
+                              T>::value) {
+              const auto &exprLeft{std::get<0>(op.t)};
+              const auto &exprRight{std::get<1>(op.t)};
+              if (exprLeft.value().source == assignmentStmtVariable.GetSource())
+                nonAtomicSubExprs.push_back(
+                    Fortran::semantics::GetExpr(exprRight));
+              else
+                nonAtomicSubExprs.push_back(
+                    Fortran::semantics::GetExpr(exprLeft));
+            }
+          },
       },
       assignmentStmtExpr.u);
   StatementContext nonAtomicStmtCtx;
-  if (nonAtomicSubExpr) {
+  if (!nonAtomicSubExprs.empty()) {
     // Generate non atomic part before all the atomic operations.
     auto insertionPoint = firOpBuilder.saveInsertionPoint();
     if (atomicCaptureOp)
       firOpBuilder.setInsertionPoint(atomicCaptureOp);
-    mlir::Value nonAtomicVal = fir::getBase(converter.genExprValue(
-        currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
-    exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
+    mlir::Value nonAtomicVal;
+    for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
+      nonAtomicVal = fir::getBase(converter.genExprValue(
+          currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+      exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
+    }
     if (atomicCaptureOp)
       firOpBuilder.restoreInsertionPoint(insertionPoint);
   }
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
index b2a993ddd825105..7b51a9cceb0eecb 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
@@ -2,22 +2,51 @@
 ! RUN: bbc -hlfir -fopenacc -emit-hlfir %s -o - | FileCheck %s
 ! RUN: %flang_fc1 -flang-experimental-hlfir -emit-hlfir -fopenacc %s -o - | FileCheck %s
 
-subroutine sb
-  integer :: x, y
-
-  !$acc atomic update
-    x = x + y
-end subroutine
-
 !CHECK-LABEL: @_QPsb
+subroutine sb
+!CHECK:   %[[W_REF:.*]] = fir.alloca i32 {bindc_name = "w", uniq_name = "_QFsbEw"}
+!CHECK:   %[[W_DECL:.*]]:2 = hlfir.declare %[[W_REF]] {uniq_name = "_QFsbEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:   %[[X_REF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsbEx"}
 !CHECK:   %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:   %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
 !CHECK:   %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:   %[[Z_REF:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFsbEz"}
+!CHECK:   %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z_REF]] {uniq_name = "_QFsbEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  integer :: w, x, y, z
+
 !CHECK:   %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
 !CHECK:   acc.atomic.update   %[[X_DECL]]#1 : !fir.ref<i32> {
 !CHECK:   ^bb0(%[[ARG_X:.*]]: i32):
 !CHECK:     %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
 !CHECK:     acc.yield %[[X_UPDATE_VAL]] : i32
 !CHECK:   }
+  !$acc atomic update
+    x = x + y
+
+!CHECK:   %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
+!CHECK:   acc.atomic.update %[[X_DECL]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG_X:.*]]: i32):
+!CHECK:     %[[X_UPDATE_VAL:.*]] = arith.ori %[[ARG_X]], %[[Y_VAL]] : i32
+!CHECK:     acc.yield %[[X_UPDATE_VAL]] : i32
+!CHECK:   }
+  !$acc atomic update
+    x = ior(x,y)
+
+!CHECK:   %[[W_VAL:.*]] = fir.load %[[W_DECL]]#0 : !fir.ref<i32>
+!CHECK:   %[[X_VAL:.*]] = fir.load %[[X_DECL]]#0 : !fir.ref<i32>
+!CHECK:   %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
+!CHECK:   acc.atomic.update %[[Z_DECL]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG_Z:.*]]: i32):
+!CHECK:     %[[WX_CMP:.*]] = arith.cmpi slt, %[[W_VAL]], %[[X_VAL]] : i32
+!CHECK:     %[[WX_MIN:.*]] = arith.select %[[WX_CMP]], %[[W_VAL]], %[[X_VAL]] : i32
+!CHECK:     %[[WXY_CMP:.*]] = arith.cmpi slt, %[[WX_MIN]], %[[Y_VAL]] : i32
+!CHECK:     %[[WXY_MIN:.*]] = arith.select %[[WXY_CMP]], %[[WX_MIN]], %[[Y_VAL]] : i32
+!CHECK:     %[[WXYZ_CMP:.*]] = arith.cmpi slt, %[[WXY_MIN]], %[[ARG_Z]] : i32
+!CHECK:     %[[WXYZ_MIN:.*]] = arith.select %[[WXYZ_CMP]], %[[WXY_MIN]], %[[ARG_Z]] : i32
+!CHECK:     acc.yield %[[WXYZ_MIN]] : i32
+!CHECK:   }
+  !$acc atomic update
+    z = min(w,x,y,z)
+
 !CHECK:   return
+end subroutine
diff --git a/flang/test/Lower/OpenMP/FIR/atomic-update.f90 b/flang/test/Lower/OpenMP/FIR/atomic-update.f90
index f4ebeef48cac4a4..bd3d4ace440ee80 100644
--- a/flang/test/Lower/OpenMP/FIR/atomic-update.f90
+++ b/flang/test/Lower/OpenMP/FIR/atomic-update.f90
@@ -65,10 +65,10 @@ program OmpAtomicUpdate
 !CHECK:    %[[RESULT:.*]] = arith.subi %[[ARG]], {{.*}} : i32
 !CHECK:    omp.yield(%[[RESULT]] : i32)
 !CHECK:  }
+!CHECK:  %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
+!CHECK:  %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref<i32>
 !CHECK:  omp.atomic.update   memory_order(relaxed) %[[Y]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_X:.*]] = fir.load %[[X]] : !fir.ref<i32>
-!CHECK:    %[[LOADED_Z:.*]] = fir.load %[[Z]] : !fir.ref<i32>
 !CHECK:    %{{.*}} = arith.cmpi sgt, %[[ARG]], %[[LOADED_X]] : i32
 !CHECK:    %{{.*}} = arith.select %{{.*}}, %[[ARG]], %[[LOADED_X]] : i32
 !CHECK:    %{{.*}} = arith.cmpi sgt, %{{.*}}, %[[LOADED_Z]] : i32
diff --git a/flang/test/Lower/OpenMP/atomic-update.f90 b/flang/test/Lower/OpenMP/atomic-update.f90
index e6319f70f373657..10da97c68c24a17 100644
--- a/flang/test/Lower/OpenMP/atomic-update.f90
+++ b/flang/test/Lower/OpenMP/atomic-update.f90
@@ -25,13 +25,15 @@ program OmpAtomicUpdate
 !CHECK: %[[VAL_K_SHAPED:.*]] = fir.shape %[[VAL_c5]] : (index) -> !fir.shape<1>
 !CHECK: %[[VAL_K_DECLARE:.*]]:2 = hlfir.declare %[[VAL_K_ALLOCA]](%[[VAL_K_SHAPED]]) {{.*}}
 
+!CHECK: %[[VAL_W_ALLOCA:.*]] = fir.alloca i32 {bindc_name = "w", uniq_name = "_QFEw"}
+!CHECK: %[[VAL_W_DECLARE:.*]]:2 = hlfir.declare %[[VAL_W_ALLOCA]] {uniq_name = "_QFEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[VAL_X_ALLOCA:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
 !CHECK: %[[VAL_X_DECLARE:.*]]:2 = hlfir.declare %[[VAL_X_ALLOCA]] {uniq_name = "_QFEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[VAL_Y_ALLOCA:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
 !CHECK: %[[VAL_Y_DECLARE:.*]]:2 = hlfir.declare %[[VAL_Y_ALLOCA]] {uniq_name = "_QFEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK: %[[VAL_Z_ALLOCA:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFEz"}
 !CHECK: %[[VAL_Z_DECLARE:.*]]:2 = hlfir.declare %[[VAL_Z_ALLOCA]] {uniq_name = "_QFEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-    integer :: x, y, z
+    integer :: w, x, y, z
     integer, pointer :: a, b
     integer, target :: c, d
     integer(1) :: i1
@@ -95,10 +97,10 @@ program OmpAtomicUpdate
    !$omp atomic relaxed update hint(omp_sync_hint_uncontended)
         x = x - 1
 
-!CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_Y_DECLARE]]#1 : !fir.ref<i32> {
-!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK:  %[[VAL_C_LOADED:.*]] = fir.load %[[VAL_C_DECLARE]]#0 : !fir.ref<i32>
 !CHECK:  %[[VAL_D_LOADED:.*]] = fir.load %[[VAL_D_DECLARE]]#0 : !fir.ref<i32>
+!CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_Y_DECLARE]]#1 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK:  {{.*}} = arith.cmpi sgt, %[[ARG]], {{.*}} : i32
 !CHECK:  {{.*}} = arith.select {{.*}}, %[[ARG]], {{.*}} : i32
 !CHECK:  {{.*}} = arith.cmpi sgt, {{.*}}
@@ -146,4 +148,39 @@ program OmpAtomicUpdate
    !$omp atomic
       i1 = i1 + 1
     !$omp end atomic
+
+!CHECK:   %[[VAL_X_LOADED:.*]] = fir.load %[[VAL_X_DECLARE]]#0 : !fir.ref<i32>
+!CHECK:   omp.atomic.update %[[VAL_Y_DECLARE]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG_Y:.*]]: i32):
+!CHECK:     %[[Y_UPDATE_VAL:.*]] = arith.andi %[[VAL_X_LOADED]], %[[ARG_Y]] : i32
+!CHECK:     omp.yield(%[[Y_UPDATE_VAL]] : i32)
+!CHECK:   }
+  !$omp atomic update
+    y = iand(x,y)
+
+!CHECK:   %[[VAL_X_LOADED:.*]] = fir.load %[[VAL_X_DECLARE]]#0 : !fir.ref<i32>
+!CHECK:   omp.atomic.update %[[VAL_Y_DECLARE]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG_Y:.*]]: i32):
+!CHECK:     %[[Y_UPDATE_VAL:.*]] = arith.xori %[[VAL_X_LOADED]], %[[ARG_Y]] : i32
+!CHECK:     omp.yield(%[[Y_UPDATE_VAL]] : i32)
+!CHECK:   }
+  !$omp atomic update
+    y = ieor(x,y)
+
+!CHECK:   %[[VAL_X_LOADED:.*]] = fir.load %[[VAL_X_DECLARE]]#0 : !fir.ref<i32>
+!CHECK:   %[[VAL_Y_LOADED:.*]] = fir.load %[[VAL_Y_DECLARE]]#0 : !fir.ref<i32>
+!CHECK:   %[[VAL_Z_LOADED:.*]] = fir.load %[[VAL_Z_DECLARE]]#0 : !fir.ref<i32>
+!CHECK:   omp.atomic.update %[[VAL_W_DECLARE]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG_W:.*]]: i32):
+!CHECK:     %[[WX_CMP:.*]] = arith.cmpi sgt, %[[ARG_W]], %[[VAL_X_LOADED]] : i32
+!CHECK:     %[[WX_MIN:.*]] = arith.select %[[WX_CMP]], %[[ARG_W]], %[[VAL_X_LOADED]] : i32
+!CHECK:     %[[WXY_CMP:.*]] = arith.cmpi sgt, %[[WX_MIN]], %[[VAL_Y_LOADED]] : i32
+!CHECK:     %[[WXY_MIN:.*]] = arith.select %[[WXY_CMP]], %[[WX_MIN]], %[[VAL_Y_LOADED]] : i32
+!CHECK:     %[[WXYZ_CMP:.*]] = arith.cmpi sgt, %[[WXY_MIN]], %[[VAL_Z_LOADED]] : i32
+!CHECK:     %[[WXYZ_MIN:.*]] = arith.select %[[WXYZ_CMP]], %[[WXY_MIN]], %[[VAL_Z_LOADED]] : i32
+!CHECK:     omp.yield(%[[WXYZ_MIN]] : i32)
+!CHECK:   }
+  !$omp atomic update
+    w = max(w,x,y,z)
+
 end program OmpAtomicUpdate

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Kiran, looks great

Copy link
Contributor

@NimishMishra NimishMishra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for me as well.

@kiranchandramohan kiranchandramohan merged commit 25d0f9f into llvm:main Nov 16, 2023
7 checks passed
sr-tream pushed a commit to sr-tream/llvm-project that referenced this pull request Nov 20, 2023
…lvm#72131)

Hoist non-atomic expressions in atomic intrinsics. Use a list to collect
the non-atomic expressions since the max and min intrinsics can have
more than two operands.

Hoisting makes the lowering to LLVMIR of iand,ior,ieor intrinsics
trivial. For max and min this still results in multiple instructions in
the atomic region but the loads are removed, which should help improve
performance.
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…lvm#72131)

Hoist non-atomic expressions in atomic intrinsics. Use a list to collect
the non-atomic expressions since the max and min intrinsics can have
more than two operands.

Hoisting makes the lowering to LLVMIR of iand,ior,ieor intrinsics
trivial. For max and min this still results in multiple instructions in
the atomic region but the loads are removed, which should help improve
performance.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants