Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] Allow lowering of sub-expressions to be overridden #69944

Merged
merged 2 commits into from Oct 25, 2023

Conversation

jeanPerier
Copy link
Contributor

@jeanPerier jeanPerier commented Oct 23, 2023

OpenACC/OpenMP atomic lowering needs a finer control over expression lowering. This patch allows mapping evaluate::Expr to mlir::Value so that any subsequent expression lowering will use these values when an operand is a mapped Expr.

This is an alternative to #69866 From which I took the test.

The same tests as in #69866 are failing because the "non atomic part" is now out of the atomic.update op, which in some cases is causing verification failures because this is generated in the middle of an omp.atomic.capture. I did not try fixing these failures. My patch is about the lowering infrastructure and how to use it rather than the OpenMP semantics.

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking at this @jeanPerier. This is a nice alternative!

@NimishMishra
Copy link
Contributor

Thanks for this. In my opinion, this seems cleaner. Pinging @kiranchandramohan for his take on how to move further on this.

@kiranchandramohan
Copy link
Contributor

Thanks @jeanPerier for this extension of the expression lowering infrastructure. I agree that this is much cleaner than the custom lowering. The only question I have is whether this addition of state to expression lowering will make any future parallel lowering of expressions impossible.

Thanks for this. In my opinion, this seems cleaner. Pinging @kiranchandramohan for his take on how to move further on this.

I think you have to work with Jean on how to merge this patch. Whether the current patch can go in as it is with the failing OpenMP & OpenACC tests XFAILed, or only the expression lowering infrastructure improvement goes in and then you can update the OpenMP/OpenACC lowering code using this new infrastructure.

@kiranchandramohan
Copy link
Contributor

which in some cases is causing verification failures because this is generated in the middle of an omp.atomic.capture.

I guess the failures are because the OpenMP capture statements do not expect any statements other than omp.atomic.read, omp.atomic.write or omp.atomic.update. In these cases, the hoisting for the expression can go above the omp.atomic.capture.

Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Thank you for taking the time to create this infrastructure! :) Definitely helps the atomic implementation!

@jeanPerier
Copy link
Contributor Author

The only question I have is whether this addition of state to expression lowering will make any future parallel lowering of expressions impossible.

Thanks for the review. Yes, this change does not help, but expression lowering is already depending on the converter state since it is providing the symbol map (at least when expression lowering is called via the converter).

We would first need to define the lowering parallelism requirements here. Is to lower expressions of a single statement in parallel, or is it to allow different program units to be lowered in parallel? The former seems hard and of little practicality to me. The latter makes more sense, and then, I think we could modify the bridge a bit in order to have a "converter" instance per thread/program unit lowering, in which case this change is not a blocker (sub-expression override would only impact a single thread/program unit lowering).

jeanPerier and others added 2 commits October 24, 2023 01:26
OpenACC/OpenMP atomic lowering needs a finer control over expression
lowering. This patch allows mapping evaluate::Expr<T> to mlir::Value
so that any subsequent expression lowering will use these values when
an operand is a mapped Expr<T>.

This is an alternative to llvm#69866
From which I took the test.

The same test as in llvm#69866 are
failing because the "non atomic part" is now out of the atomic.update
op, which in some case causing verification failure because this is
generated in the middle of an omp.atomic.capture. I did not try fixing
these failures. My patch is about the lowering infrastructure and how to
use it rather than the OpenMP semantics.

Co-authored-by: Nimish Mishra <neelam.nimish@gmail.com>
@jeanPerier
Copy link
Contributor Author

think you have to work with Jean on how to merge this patch. Whether the current patch can go in as it is with the failing OpenMP & OpenACC tests XFAILed, or only the expression lowering infrastructure improvement goes in and then you can update the OpenMP/OpenACC lowering code using this new infrastructure.

Thanks @kiranchandramohan, with your guidance to generate the non atomic expr evaluation outside of the atomic.capture. Please @razvanlupusoru, @clementval, and @NimishMishra, check that the updated tests match what you expect. Then I can push all of it if it easier.

(I rebased with hope to pass the git fetch error of the build bot)

@jeanPerier jeanPerier marked this pull request as ready for review October 24, 2023 08:39
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp openacc labels Oct 24, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 24, 2023

@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-openacc

Author: None (jeanPerier)

Changes

OpenACC/OpenMP atomic lowering needs a finer control over expression lowering. This patch allows mapping evaluate::Expr<T> to mlir::Value so that any subsequent expression lowering will use these values when an operand is a mapped Expr<T>.

This is an alternative to #69866 From which I took the test.

The same tests as in #69866 are failing because the "non atomic part" is now out of the atomic.update op, which in some cases is causing verification failures because this is generated in the middle of an omp.atomic.capture. I did not try fixing these failures. My patch is about the lowering infrastructure and how to use it rather than the OpenMP semantics.


Patch is 32.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69944.diff

12 Files Affected:

  • (modified) flang/include/flang/Lower/AbstractConverter.h (+10)
  • (modified) flang/lib/Lower/Bridge.cpp (+11)
  • (modified) flang/lib/Lower/ConvertExpr.cpp (+17)
  • (modified) flang/lib/Lower/ConvertExprToHLFIR.cpp (+11)
  • (modified) flang/lib/Lower/DirectivesCommon.h (+55-99)
  • (modified) flang/test/Lower/OpenACC/acc-atomic-capture.f90 (+5-5)
  • (modified) flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 (+2-2)
  • (modified) flang/test/Lower/OpenACC/acc-atomic-update.f90 (+5-5)
  • (modified) flang/test/Lower/OpenMP/FIR/atomic-capture.f90 (+5-5)
  • (modified) flang/test/Lower/OpenMP/FIR/atomic-update.f90 (+11-11)
  • (modified) flang/test/Lower/OpenMP/atomic-update-hlfir.f90 (+2-2)
  • (added) flang/test/Lower/OpenMP/common-atomic-lowering.f90 (+74)
diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index c792e75f1146499..fa67729fe036684 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -60,6 +60,8 @@ using SomeExpr = Fortran::evaluate::Expr<Fortran::evaluate::SomeType>;
 using SymbolRef = Fortran::common::Reference<const Fortran::semantics::Symbol>;
 class StatementContext;
 
+using ExprToValueMap = llvm::DenseMap<const SomeExpr *, mlir::Value>;
+
 //===----------------------------------------------------------------------===//
 // AbstractConverter interface
 //===----------------------------------------------------------------------===//
@@ -90,6 +92,14 @@ class AbstractConverter {
   /// added or replaced at the inner-most level of the local symbol map.
   virtual void bindSymbol(SymbolRef sym, const fir::ExtendedValue &exval) = 0;
 
+  /// Override lowering of expression with pre-lowered values.
+  /// Associate mlir::Value to evaluate::Expr. All subsequent call to
+  /// genExprXXX() will replace any occurrence of an overridden
+  /// expression in the expression tree by the pre-lowered values.
+  virtual void overrideExprValues(const ExprToValueMap *) = 0;
+  void resetExprOverrides() { overrideExprValues(nullptr); }
+  virtual const ExprToValueMap *getExprOverrides() = 0;
+
   /// Get the label set associated with a symbol.
   virtual bool lookupLabelSet(SymbolRef sym, pft::LabelSet &labelSet) = 0;
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index c3afd91d7453caa..2bfab3a250e6916 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -504,6 +504,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     addSymbol(sym, exval, /*forced=*/true);
   }
 
+  void
+  overrideExprValues(const Fortran::lower::ExprToValueMap *map) override final {
+    exprValueOverrides = map;
+  }
+
+  const Fortran::lower::ExprToValueMap *getExprOverrides() override final {
+    return exprValueOverrides;
+  }
+
   bool lookupLabelSet(Fortran::lower::SymbolRef sym,
                       Fortran::lower::pft::LabelSet &labelSet) override final {
     Fortran::lower::pft::FunctionLikeUnit &owningProc =
@@ -4890,6 +4899,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Whether an OpenMP target region or declare target function/subroutine
   /// intended for device offloading has been detected
   bool ompDeviceCodeFound = false;
+
+  const Fortran::lower::ExprToValueMap *exprValueOverrides{nullptr};
 };
 
 } // namespace
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 6d2ac62b61b74c3..1a2b3856c526716 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -2963,8 +2963,21 @@ class ScalarExprLowering {
     return asArray(x);
   }
 
+  template <typename A>
+  mlir::Value getIfOverridenExpr(const Fortran::evaluate::Expr<A> &x) {
+    if (const Fortran::lower::ExprToValueMap *map =
+            converter.getExprOverrides()) {
+      Fortran::lower::SomeExpr someExpr = toEvExpr(x);
+      if (auto match = map->find(&someExpr); match != map->end())
+        return match->second;
+    }
+    return mlir::Value{};
+  }
+
   template <typename A>
   ExtValue gen(const Fortran::evaluate::Expr<A> &x) {
+    if (mlir::Value val = getIfOverridenExpr(x))
+      return val;
     // Whole array symbols or components, and results of transformational
     // functions already have a storage and the scalar expression lowering path
     // is used to not create a new temporary storage.
@@ -2978,6 +2991,8 @@ class ScalarExprLowering {
   }
   template <typename A>
   ExtValue genval(const Fortran::evaluate::Expr<A> &x) {
+    if (mlir::Value val = getIfOverridenExpr(x))
+      return val;
     if (isScalar(x) || Fortran::evaluate::UnwrapWholeSymbolDataRef(x) ||
         inInitializer)
       return std::visit([&](const auto &e) { return genval(e); }, x.u);
@@ -2987,6 +3002,8 @@ class ScalarExprLowering {
   template <int KIND>
   ExtValue genval(const Fortran::evaluate::Expr<Fortran::evaluate::Type<
                       Fortran::common::TypeCategory::Logical, KIND>> &exp) {
+    if (mlir::Value val = getIfOverridenExpr(exp))
+      return val;
     return std::visit([&](const auto &e) { return genval(e); }, exp.u);
   }
 
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index 4cf29c9aecbf577..1da6a5bdd54784e 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1423,6 +1423,17 @@ class HlfirBuilder {
 
   template <typename T>
   hlfir::EntityWithAttributes gen(const Fortran::evaluate::Expr<T> &expr) {
+    if (const Fortran::lower::ExprToValueMap *map =
+            getConverter().getExprOverrides()) {
+      if constexpr (std::is_same_v<T, Fortran::evaluate::SomeType>) {
+        if (auto match = map->find(&expr); match != map->end())
+          return hlfir::EntityWithAttributes{match->second};
+      } else {
+        Fortran::lower::SomeExpr someExpr = toEvExpr(expr);
+        if (auto match = map->find(&someExpr); match != map->end())
+          return hlfir::EntityWithAttributes{match->second};
+      }
+    }
     return std::visit([&](const auto &x) { return gen(x); }, expr.u);
   }
 
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index ed44598bc925212..91801aad9c074a4 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -200,62 +200,13 @@ static inline void genOmpAccAtomicUpdateStatement(
     mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable,
     const Fortran::parser::Expr &assignmentStmtExpr,
     [[maybe_unused]] const AtomicListT *leftHandClauseList,
-    [[maybe_unused]] const AtomicListT *rightHandClauseList) {
+    [[maybe_unused]] const AtomicListT *rightHandClauseList,
+    mlir::Operation *atomicCaptureOp = nullptr) {
   // Generate `omp.atomic.update` operation for atomic assignment statements
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
 
-  const auto *varDesignator =
-      std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
-          &assignmentStmtVariable.u);
-  assert(varDesignator && "Variable designator for atomic update assignment "
-                          "statement does not exist");
-  const Fortran::parser::Name *name =
-      Fortran::semantics::getDesignatorNameIfDataRef(varDesignator->value());
-  if (!name)
-    TODO(converter.getCurrentLocation(),
-         "Array references as atomic update variable");
-  assert(name && name->symbol &&
-         "No symbol attached to atomic update variable");
-  if (Fortran::semantics::IsAllocatableOrPointer(name->symbol->GetUltimate()))
-    converter.bindSymbol(*name->symbol, lhsAddr);
-
-  //  Lowering is in two steps :
-  //  subroutine sb
-  //    integer :: a, b
-  //    !$omp atomic update
-  //      a = a + b
-  //  end subroutine
-  //
-  //  1. Lower to scf.execute_region_op
-  //
-  //  func.func @_QPsb() {
-  //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
-  //    %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
-  //    %2 = scf.execute_region -> i32 {
-  //      %3 = fir.load %0 : !fir.ref<i32>
-  //      %4 = fir.load %1 : !fir.ref<i32>
-  //      %5 = arith.addi %3, %4 : i32
-  //      scf.yield %5 : i32
-  //    }
-  //    return
-  //  }
-  auto tempOp =
-      firOpBuilder.create<mlir::scf::ExecuteRegionOp>(currentLocation, varType);
-  firOpBuilder.createBlock(&tempOp.getRegion());
-  mlir::Block &block = tempOp.getRegion().back();
-  firOpBuilder.setInsertionPointToEnd(&block);
-  Fortran::lower::StatementContext stmtCtx;
-  mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
-      *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
-  mlir::Value convertResult =
-      firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
-  // Insert the terminator: YieldOp.
-  firOpBuilder.create<mlir::scf::YieldOp>(currentLocation, convertResult);
-  firOpBuilder.setInsertionPointToStart(&block);
-
-  //  2. Create the omp.atomic.update Operation using the Operations in the
-  //     temporary scf.execute_region Operation.
+  //  Create the omp.atomic.update Operation
   //
   //  func.func @_QPsb() {
   //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
@@ -269,11 +220,37 @@ static inline void genOmpAccAtomicUpdateStatement(
   //    }
   //    return
   //  }
-  mlir::Value updateVar = converter.getSymbolAddress(*name->symbol);
-  if (auto decl = updateVar.getDefiningOp<hlfir::DeclareOp>())
-    updateVar = decl.getBase();
 
-  firOpBuilder.setInsertionPointAfter(tempOp);
+  Fortran::lower::ExprToValueMap exprValueOverrides;
+  // Lower any non atomic sub-expression before the atomic operation, and
+  // map its lowered value to the semantic representation.
+  const Fortran::lower::SomeExpr *nonAtomicSubExpr{nullptr};
+  std::visit(
+      [&](const auto &op) -> void {
+        using T = std::decay_t<decltype(op)>;
+        if constexpr (std::is_base_of<Fortran::parser::Expr::IntrinsicBinary,
+                                      T>::value) {
+          const auto &exprLeft{std::get<0>(op.t)};
+          const auto &exprRight{std::get<1>(op.t)};
+          if (exprLeft.value().source == assignmentStmtVariable.GetSource())
+            nonAtomicSubExpr = Fortran::semantics::GetExpr(exprRight);
+          else
+            nonAtomicSubExpr = Fortran::semantics::GetExpr(exprLeft);
+        }
+      },
+      assignmentStmtExpr.u);
+  StatementContext nonAtomicStmtCtx;
+  if (nonAtomicSubExpr) {
+    // Generate non atomic part before all the atomic operations.
+    auto insertionPoint = firOpBuilder.saveInsertionPoint();
+    if (atomicCaptureOp)
+      firOpBuilder.setInsertionPoint(atomicCaptureOp);
+    mlir::Value nonAtomicVal = fir::getBase(converter.genExprValue(
+        currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
+    exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
+    if (atomicCaptureOp)
+      firOpBuilder.restoreInsertionPoint(insertionPoint);
+  }
 
   mlir::Operation *atomicUpdateOp = nullptr;
   if constexpr (std::is_same<AtomicListT,
@@ -289,10 +266,10 @@ static inline void genOmpAccAtomicUpdateStatement(
       genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList,
                                             hint, memoryOrder);
     atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
-        currentLocation, updateVar, hint, memoryOrder);
+        currentLocation, lhsAddr, hint, memoryOrder);
   } else {
     atomicUpdateOp = firOpBuilder.create<mlir::acc::AtomicUpdateOp>(
-        currentLocation, updateVar);
+        currentLocation, lhsAddr);
   }
 
   llvm::SmallVector<mlir::Type> varTys = {varType};
@@ -301,38 +278,25 @@ static inline void genOmpAccAtomicUpdateStatement(
   mlir::Value val =
       fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));
 
-  llvm::SmallVector<mlir::Operation *> ops;
-  for (mlir::Operation &op : tempOp.getRegion().getOps())
-    ops.push_back(&op);
-
-  // SCF Yield is converted to OMP Yield. All other operations are copied
-  for (mlir::Operation *op : ops) {
-    if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
-      firOpBuilder.setInsertionPointToEnd(
-          &atomicUpdateOp->getRegion(0).front());
-      if constexpr (std::is_same<AtomicListT,
-                                 Fortran::parser::OmpAtomicClauseList>()) {
-        firOpBuilder.create<mlir::omp::YieldOp>(currentLocation,
-                                                y.getResults());
-      } else {
-        firOpBuilder.create<mlir::acc::YieldOp>(currentLocation,
-                                                y.getResults());
-      }
-      op->erase();
+  exprValueOverrides.try_emplace(
+      Fortran::semantics::GetExpr(assignmentStmtVariable), val);
+  {
+    // statement context inside the atomic block.
+    converter.overrideExprValues(&exprValueOverrides);
+    Fortran::lower::StatementContext atomicStmtCtx;
+    mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx));
+    mlir::Value convertResult =
+        firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
+    if constexpr (std::is_same<AtomicListT,
+                               Fortran::parser::OmpAtomicClauseList>()) {
+      firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, convertResult);
     } else {
-      op->remove();
-      atomicUpdateOp->getRegion(0).front().push_back(op);
+      firOpBuilder.create<mlir::acc::YieldOp>(currentLocation, convertResult);
     }
+    converter.resetExprOverrides();
   }
-
-  // Remove the load and replace all uses of load with the block argument
-  for (mlir::Operation &op : atomicUpdateOp->getRegion(0).getOps()) {
-    fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
-    if (y && y.getMemref() == updateVar)
-      y.getRes().replaceAllUsesWith(val);
-  }
-
-  tempOp.erase();
+  firOpBuilder.setInsertionPointAfter(atomicUpdateOp);
 }
 
 /// Processes an atomic construct with write clause.
@@ -423,11 +387,7 @@ void genOmpAccAtomicUpdate(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  mlir::Type varType =
-      fir::getBase(
-          converter.genExprValue(
-              *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
-          .getType();
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
   genOmpAccAtomicUpdateStatement<AtomicListT>(
       converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr,
       leftHandClauseList, rightHandClauseList);
@@ -450,11 +410,7 @@ void genOmpAtomic(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value lhsAddr = fir::getBase(converter.genExprAddr(
       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
-  mlir::Type varType =
-      fir::getBase(
-          converter.genExprValue(
-              *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
-          .getType();
+  mlir::Type varType = fir::unwrapRefType(lhsAddr.getType());
   // If atomic-clause is not present on the construct, the behaviour is as if
   // the update clause is specified (for both OpenMP and OpenACC).
   genOmpAccAtomicUpdateStatement<AtomicListT>(
@@ -551,7 +507,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
       genOmpAccAtomicUpdateStatement<AtomicListT>(
           converter, stmt1RHSArg, stmt2VarType, stmt2Var, stmt2Expr,
           /*leftHandClauseList=*/nullptr,
-          /*rightHandClauseList=*/nullptr);
+          /*rightHandClauseList=*/nullptr, atomicCaptureOp);
     } else {
       // Atomic capture construct is of the form [capture-stmt, write-stmt]
       const Fortran::semantics::SomeExpr &fromExpr =
@@ -580,7 +536,7 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
     genOmpAccAtomicUpdateStatement<AtomicListT>(
         converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
         /*leftHandClauseList=*/nullptr,
-        /*rightHandClauseList=*/nullptr);
+        /*rightHandClauseList=*/nullptr, atomicCaptureOp);
   }
   firOpBuilder.setInsertionPointToEnd(&block);
   if constexpr (std::is_same<AtomicListT,
diff --git a/flang/test/Lower/OpenACC/acc-atomic-capture.f90 b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
index 1a5fe8d57c1533a..382991cf7221ba7 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-capture.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-capture.f90
@@ -7,11 +7,11 @@ program acc_atomic_capture_test
 
 !CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
 !CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: acc.atomic.capture {
 !CHECK: acc.atomic.read %[[X]] = %[[Y]] : !fir.ref<i32>
 !CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[ARG]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
@@ -23,10 +23,10 @@ program acc_atomic_capture_test
     !$acc end atomic
 
 
+!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: acc.atomic.capture {
 !CHECK: acc.atomic.update %[[Y]] : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
-!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref<i32>
 !CHECK: %[[result:.*]] = arith.muli %[[temp]], %[[ARG]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
@@ -76,12 +76,12 @@ subroutine pointers_in_atomic_capture()
 !CHECK: %[[loaded_A_addr:.*]] = fir.box_addr %[[loaded_A]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK: %[[loaded_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK: %[[loaded_B_addr:.*]] = fir.box_addr %[[loaded_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
-!CHECK: acc.atomic.capture   {
-!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
-!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[PRIVATE_LOADED_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
 !CHECK: %[[PRIVATE_LOADED_B_addr:.*]] = fir.box_addr %[[PRIVATE_LOADED_B]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
 !CHECK: %[[loaded_value:.*]] = fir.load %[[PRIVATE_LOADED_B_addr]] : !fir.ptr<i32>
+!CHECK: acc.atomic.capture   {
+!CHECK: acc.atomic.update %[[loaded_A_addr]] : !fir.ptr<i32> {
+!CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK: %[[result:.*]] = arith.addi %[[ARG]], %[[loaded_value]] : i32
 !CHECK: acc.yield %[[result]] : i32
 !CHECK: }
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90 b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
index 24dd0ee5a8999e4..b2a993ddd825105 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update-hlfir.f90
@@ -14,9 +14,9 @@ subroutine sb
 !CHECK:   %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFsbEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:   %[[Y_REF:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFsbEy"}
 !CHECK:   %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_REF]] {uniq_name = "_QFsbEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-!CHECK:   acc.atomic.update   %[[X_DECL]]#0 : !fir.ref<i32> {
+!CHECK:   %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
+!CHECK:   acc.atomic.update   %[[X_DECL]]#1 : !fir.ref<i32> {
 !CHECK:   ^bb0(%[[ARG_X:.*]]: i32):
-!CHECK:     %[[Y_VAL:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
 !CHECK:     %[[X_UPDATE_VAL:.*]] = arith.addi %[[ARG_X]], %[[Y_VAL]] : i32
 !CHECK:     acc.yield %[[X_UPDATE_VAL]] : i32
 !CHECK:   }
diff --git a/flang/test/Lower/OpenACC/acc-atomic-update.f90 b/flang/test/Lower/OpenACC/acc-atomic-update.f90
index 546d012982e23ba..96e1d5e58e6f482 100644
--- a/flang/test/Lower/OpenACC/acc-atomic-update.f90
+++ b/flang/test/Lower/OpenACC/acc-atomic-update.f90
@@ -31,25 +31,25 @@ program acc_atomic_update_test
 !CHECK: %{{.*}} = fir.convert %[[D_ADDR]] : (!fir.ref<i32>) -> !fir.ptr<i32>
 !CHECK: fir.store {{.*}} to %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
 !CHECK: %[[LOADED_A:.*]] = fir.load %[[A_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
+!CHECK: %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
 !CHECK:  acc.atomic.update   %[[LOADED_A]] : !fir.ptr<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    %[[LOADED_B:.*]] = fir.load %[[B_ADDR]] : !fir.ref<!fir.ptr<i32>>
-!CHECK:    %{{.*}} = fir.load %[[LOADED_B]] : !fir.ptr<i32>
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], %{{.*}} : i32
 !CHECK:    acc.yield %[[RESULT]] : i32
 !CHECK: }
     !$acc atomic update
         a = a + b 
 
+!CHECK: {{.*}} = arith.constant 1 : i32
 !CHECK: acc.atomic.update   %[[Y]] : !fir.ref<i32> {
 !CHECK:  ^bb0(%[[ARG:.*]]: i32):
-!CHECK:    {{.*}} = arith.constant 1 : i32
 !CHECK:    %[[RESULT:.*]] = arith.addi %[[ARG]], {{.*}} : i32
 !CHECK:    acc.yiel...
[truncated]

Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would first need to define the lowering parallelism requirements here. Is to lower expressions of a single statement in parallel, or is it to allow different program units to be lowered in parallel? The former seems hard and of little practicality to me. The latter makes more sense, and then, I think we could modify the bridge a bit in order to have a "converter" instance per thread/program unit lowering, in which case this change is not a blocker (sub-expression override would only impact a single thread/program unit lowering).

Makes sense.

Thanks @jeanPerier once again for your help here.

nonAtomicSubExpr = Fortran::semantics::GetExpr(exprRight);
else
nonAtomicSubExpr = Fortran::semantics::GetExpr(exprLeft);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming we can do something similar here for the allowed intrinsics too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you can indeed go through the intrinsic function arguments here too yes, @NimishMishra, I do not plan to work on this, so I am happy if you do in a later patch!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@NimishMishra NimishMishra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@jeanPerier I suppose we will have to work on lowering of intrinsic procedures too. If you do not plan to extend this PR to include those, can I look into those?

@jeanPerier jeanPerier merged commit b6b0756 into llvm:main Oct 25, 2023
8 checks passed
@jeanPerier jeanPerier deleted the jpr-69866-alternative branch October 25, 2023 07:22
clementval added a commit that referenced this pull request Oct 30, 2023
After #69944 lowering of array ref in atomic operation works properly.
Add some lowering test to catch up regression in the future.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants