Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flang][OpenMP][Lower] Update workshare-loop lowering (5/5) #89215

Merged
merged 27 commits into from
Apr 24, 2024

Conversation

skatrak
Copy link
Contributor

@skatrak skatrak commented Apr 18, 2024

This patch updates lowering from PFT to MLIR of workshare loops to follow the loop wrapper approach. Unit tests impacted by this change are also updated.

As the last patch of the stack, this should compile and pass unit tests.

This patch updates the definition of `omp.wsloop` to enforce the restrictions
of a loop wrapper operation.

Related tests are updated but this PR on its own will not pass premerge tests.
All patches in the stack are needed before it can be compiled and passes tests.
This patch updates verifiers for `omp.ordered.region`, `omp.cancel` and
`omp.cancellation_point`, which check for a parent `omp.wsloop`.

After transitioning to a loop wrapper-based approach, the expected direct
parent will become `omp.loop_nest` instead, so verifiers need to take this into
account.

This PR on its own will not pass premerge tests. All patches in the stack are
needed before it can be compiled and passes tests.
This patch makes changes to the `scf.parallel` to `omp.parallel` + `omp.wsloop`
lowering pass in order to introduce a nested `omp.loop_nest` as well, and to
follow the new loop wrapper role for `omp.wsloop`.

This PR on its own will not pass premerge tests. All patches in the stack are
needed before it can be compiled and passes tests.
This patch introduces minimal changes to the MLIR to LLVM IR translation of
omp.wsloop to support the loop wrapper approach.

There is `omp.loop_nest` related translation code that should be extracted and
shared among all loop operations (e.g. `omp.simd`). This would possibly also
help in the addition of support for compound constructs later on. This first
approach is only intended to keep things running after the transition to loop
wrappers and not to add support for other use cases enabled by that transition.

This PR on its own will not pass premerge tests. All patches in the stack are
needed before it can be compiled and passes tests.
This patch updates lowering from PFT to MLIR of workshare loops to follow the
loop wrapper approach. Unit tests impacted by this change are also updated.

As the last patch of the stack, this should compile and pass unit tests.
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 18, 2024

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch updates lowering from PFT to MLIR of workshare loops to follow the loop wrapper approach. Unit tests impacted by this change are also updated.

As the last patch of the stack, this should compile and pass unit tests.


Patch is 626.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89215.diff

95 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+26-19)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+42-73)
  • (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+61-47)
  • (modified) flang/test/Lower/OpenMP/FIR/copyin.f90 (+11-5)
  • (modified) flang/test/Lower/OpenMP/FIR/lastprivate-commonblock.f90 (+4-1)
  • (modified) flang/test/Lower/OpenMP/FIR/location.f90 (+10-7)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-lastprivate-clause-scalar.f90 (+36-12)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-private-clause-fixes.f90 (+26-23)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-private-clause.f90 (+60-54)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-wsloop-firstpriv.f90 (+10-2)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-wsloop.f90 (+74-54)
  • (modified) flang/test/Lower/OpenMP/FIR/stop-stmt-in-region.f90 (+21-18)
  • (modified) flang/test/Lower/OpenMP/FIR/target.f90 (+4-1)
  • (modified) flang/test/Lower/OpenMP/FIR/unstructured.f90 (+110-89)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-chunks.f90 (+28-19)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-collapse.f90 (+16-13)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-monotonic.f90 (+17-13)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-nonmonotonic.f90 (+17-14)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-ordered.f90 (+12-6)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-add-byref.f90 (+106-85)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-add.f90 (+106-85)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-iand-byref.f90 (+3-1)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-iand.f90 (+3-1)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ieor-byref.f90 (+3-1)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ieor.f90 (+3-1)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ior-byref.f90 (+3-1)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ior.f90 (+3-1)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-eqv-byref.f90 (+75-69)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-eqv.f90 (+75-69)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-neqv-byref.f90 (+75-69)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-neqv.f90 (+75-69)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-max-byref.f90 (+18-13)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-max.f90 (+18-13)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-min-byref.f90 (+18-14)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-min.f90 (+18-14)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-simd.f90 (+16-13)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-variable.f90 (+93-79)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop.f90 (+36-30)
  • (modified) flang/test/Lower/OpenMP/Todo/omp-default-clause-inner-loop.f90 (+4-1)
  • (modified) flang/test/Lower/OpenMP/copyin.f90 (+17-11)
  • (modified) flang/test/Lower/OpenMP/default-clause-byref.f90 (+4-1)
  • (modified) flang/test/Lower/OpenMP/default-clause.f90 (+4-1)
  • (modified) flang/test/Lower/OpenMP/hlfir-wsloop.f90 (+7-5)
  • (modified) flang/test/Lower/OpenMP/lastprivate-commonblock.f90 (+31-28)
  • (modified) flang/test/Lower/OpenMP/lastprivate-iv.f90 (+48-42)
  • (modified) flang/test/Lower/OpenMP/location.f90 (+10-7)
  • (modified) flang/test/Lower/OpenMP/parallel-lastprivate-clause-scalar.f90 (+36-12)
  • (modified) flang/test/Lower/OpenMP/parallel-private-clause-fixes.f90 (+26-23)
  • (modified) flang/test/Lower/OpenMP/parallel-private-clause.f90 (+56-50)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction3.f90 (+26-23)
  • (modified) flang/test/Lower/OpenMP/parallel-wsloop-firstpriv.f90 (+8-2)
  • (modified) flang/test/Lower/OpenMP/parallel-wsloop.f90 (+79-59)
  • (modified) flang/test/Lower/OpenMP/stop-stmt-in-region.f90 (+21-18)
  • (modified) flang/test/Lower/OpenMP/target.f90 (+4-1)
  • (modified) flang/test/Lower/OpenMP/unstructured.f90 (+110-89)
  • (modified) flang/test/Lower/OpenMP/wsloop-chunks.f90 (+28-19)
  • (modified) flang/test/Lower/OpenMP/wsloop-collapse.f90 (+16-13)
  • (modified) flang/test/Lower/OpenMP/wsloop-monotonic.f90 (+10-8)
  • (modified) flang/test/Lower/OpenMP/wsloop-nonmonotonic.f90 (+11-8)
  • (modified) flang/test/Lower/OpenMP/wsloop-ordered.f90 (+12-6)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-byref.f90 (+120-99)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir-byref.f90 (+10-8)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir.f90 (+10-8)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add.f90 (+120-99)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90 (+18-15)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array.f90 (+19-16)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array2.f90 (+27-24)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-iand-byref.f90 (+13-11)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-iand.f90 (+13-11)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-ieor-byref.f90 (+4-2)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-ieor.f90 (+4-2)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-ior-byref.f90 (+13-11)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-ior.f90 (+13-11)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-and-byref.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv-byref.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv-byref.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-or-byref.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-or.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-byref.f90 (+48-42)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir-byref.f90 (+14-12)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir.f90 (+14-12)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max.f90 (+48-42)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-min-byref.f90 (+49-43)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-min.f90 (+49-43)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-min2.f90 (+9-7)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-mul-byref.f90 (+113-99)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-mul.f90 (+113-99)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-multi.f90 (+26-23)
  • (modified) flang/test/Lower/OpenMP/wsloop-simd.f90 (+16-13)
  • (modified) flang/test/Lower/OpenMP/wsloop-unstructured.f90 (+21-18)
  • (modified) flang/test/Lower/OpenMP/wsloop-variable.f90 (+91-76)
  • (modified) flang/test/Lower/OpenMP/wsloop.f90 (+39-33)
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 5a42e6a6aa4175..d98711c8a8900c 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -135,6 +135,12 @@ void DataSharingProcessor::insertBarrier() {
 }
 
 void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
+  mlir::omp::LoopNestOp loopOp;
+  if (auto wrapper = mlir::dyn_cast<mlir::omp::LoopWrapperInterface>(op))
+    loopOp = wrapper.isWrapper()
+                 ? mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop())
+                 : nullptr;
+
   bool cmpCreated = false;
   mlir::OpBuilder::InsertPoint localInsPt = firOpBuilder.saveInsertionPoint();
   for (const omp::Clause &clause : clauses) {
@@ -215,18 +221,20 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
       // Update the original variable just before exiting the worksharing
       // loop. Conversion as follows:
       //
-      //                       omp.wsloop {
-      // omp.wsloop {            ...
-      //    ...                  store
-      //    store       ===>     %v = arith.addi %iv, %step
-      //    omp.yield            %cmp = %step < 0 ? %v < %ub : %v > %ub
-      // }                       fir.if %cmp {
-      //                           fir.store %v to %loopIV
-      //                           ^%lpv_update_blk:
-      //                         }
-      //                         omp.yield
-      //                       }
-      //
+      // omp.wsloop {             omp.wsloop {
+      //   omp.loop_nest {          omp.loop_nest {
+      //     ...                      ...
+      //     store          ===>      store
+      //     omp.yield                %v = arith.addi %iv, %step
+      //   }                          %cmp = %step < 0 ? %v < %ub : %v > %ub
+      //   omp.terminator             fir.if %cmp {
+      // }                              fir.store %v to %loopIV
+      //                                ^%lpv_update_blk:
+      //                              }
+      //                              omp.yield
+      //                            }
+      //                            omp.terminator
+      //                          }
 
       // Only generate the compare once in presence of multiple LastPrivate
       // clauses.
@@ -234,14 +242,13 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
         continue;
       cmpCreated = true;
 
-      mlir::Location loc = op->getLoc();
-      mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
+      mlir::Location loc = loopOp.getLoc();
+      mlir::Operation *lastOper = loopOp.getRegion().back().getTerminator();
       firOpBuilder.setInsertionPoint(lastOper);
 
-      mlir::Value iv = op->getRegion(0).front().getArguments()[0];
-      mlir::Value ub =
-          mlir::dyn_cast<mlir::omp::WsloopOp>(op).getUpperBound()[0];
-      mlir::Value step = mlir::dyn_cast<mlir::omp::WsloopOp>(op).getStep()[0];
+      mlir::Value iv = loopOp.getIVs()[0];
+      mlir::Value ub = loopOp.getUpperBound()[0];
+      mlir::Value step = loopOp.getStep()[0];
 
       // v = iv + step
       // cmp = step < 0 ? v < ub : v > ub
@@ -260,7 +267,7 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
       auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
       firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
       assert(loopIV && "loopIV was not set");
-      firOpBuilder.create<fir::StoreOp>(op->getLoc(), v, loopIV);
+      firOpBuilder.create<fir::StoreOp>(loopOp.getLoc(), v, loopIV);
       lastPrivIP = firOpBuilder.saveInsertionPoint();
     } else {
       TODO(converter.getCurrentLocation(),
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index bb38082b245ef5..98a1eab4b614fc 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -369,7 +369,9 @@ getDeclareTargetFunctionDevice(
 static llvm::SmallVector<const Fortran::semantics::Symbol *>
 genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
             mlir::Location &loc,
-            llvm::ArrayRef<const Fortran::semantics::Symbol *> args) {
+            llvm::ArrayRef<const Fortran::semantics::Symbol *> args,
+            llvm::ArrayRef<const Fortran::semantics::Symbol *> wrapperSyms = {},
+            llvm::ArrayRef<mlir::BlockArgument> wrapperArgs = {}) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   auto &region = op->getRegion(0);
 
@@ -380,6 +382,14 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
   llvm::SmallVector<mlir::Location> locs(args.size(), loc);
   firOpBuilder.createBlock(&region, {}, tiv, locs);
+
+  // Bind the entry block arguments of parent wrappers to the corresponding
+  // symbols. Do it here so that any hlfir.declare operations created as a
+  // result are inserted inside of the omp.loop_nest rather than the wrapper
+  // operations.
+  for (auto [arg, prv] : llvm::zip_equal(wrapperSyms, wrapperArgs))
+    converter.bindSymbol(*arg, prv);
+
   // The argument is not currently in memory, so make a temporary for the
   // argument, and store it there, then bind that location to the argument.
   mlir::Operation *storeOp = nullptr;
@@ -410,58 +420,6 @@ static void genReductionVars(
   }
 }
 
-static llvm::SmallVector<const Fortran::semantics::Symbol *>
-genLoopAndReductionVars(
-    mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
-    mlir::Location &loc,
-    llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs,
-    llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs,
-    llvm::ArrayRef<mlir::Type> reductionTypes) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  llvm::SmallVector<mlir::Type> blockArgTypes;
-  llvm::SmallVector<mlir::Location> blockArgLocs;
-  blockArgTypes.reserve(loopArgs.size() + reductionArgs.size());
-  blockArgLocs.reserve(blockArgTypes.size());
-  mlir::Block *entryBlock;
-
-  if (loopArgs.size()) {
-    std::size_t loopVarTypeSize = 0;
-    for (const Fortran::semantics::Symbol *arg : loopArgs)
-      loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
-    mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
-    std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(),
-                loopVarType);
-    std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc);
-  }
-  if (reductionArgs.size()) {
-    llvm::copy(reductionTypes, std::back_inserter(blockArgTypes));
-    std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc);
-  }
-  entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes,
-                                        blockArgLocs);
-  // The argument is not currently in memory, so make a temporary for the
-  // argument, and store it there, then bind that location to the argument.
-  if (loopArgs.size()) {
-    mlir::Operation *storeOp = nullptr;
-    for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) {
-      mlir::Value indexVal =
-          fir::getBase(op->getRegion(0).front().getArgument(argIndex));
-      storeOp =
-          createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
-    }
-    firOpBuilder.setInsertionPointAfter(storeOp);
-  }
-  // Bind the reduction arguments to their block arguments
-  for (auto [arg, prv] : llvm::zip_equal(
-           reductionArgs,
-           llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) {
-    converter.bindSymbol(*arg, prv);
-  }
-
-  return llvm::SmallVector<const Fortran::semantics::Symbol *>(loopArgs);
-}
-
 static void
 markDeclareTarget(mlir::Operation *op,
                   Fortran::lower::AbstractConverter &converter,
@@ -1292,20 +1250,16 @@ static void genWsloopClauses(
     Fortran::lower::AbstractConverter &converter,
     Fortran::semantics::SemanticsContext &semaCtx,
     Fortran::lower::StatementContext &stmtCtx,
-    Fortran::lower::pft::Evaluation &eval,
     const Fortran::parser::OmpClauseList &beginClauses,
     const Fortran::parser::OmpClauseList *endClauses, mlir::Location loc,
     mlir::omp::WsloopClauseOps &clauseOps,
-    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
     llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   ClauseProcessor bcp(converter, semaCtx, beginClauses);
-  bcp.processCollapse(loc, eval, clauseOps, iv);
   bcp.processOrdered(clauseOps);
   bcp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
   bcp.processSchedule(stmtCtx, clauseOps);
-  clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
   // TODO Support delayed privatization.
 
   if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
@@ -1844,34 +1798,49 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
             Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
             const Fortran::parser::OmpClauseList &beginClauseList,
             const Fortran::parser::OmpClauseList *endClauseList) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval);
   dsp.processStep1();
 
   Fortran::lower::StatementContext stmtCtx;
-  mlir::omp::WsloopClauseOps clauseOps;
+  mlir::omp::LoopNestClauseOps loopClauseOps;
+  mlir::omp::WsloopClauseOps wsClauseOps;
   llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
   llvm::SmallVector<mlir::Type> reductionTypes;
   llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
-  genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauseList,
-                   endClauseList, loc, clauseOps, iv, reductionTypes,
-                   reductionSyms);
+  genLoopNestClauses(converter, semaCtx, eval, beginClauseList, loc,
+                     loopClauseOps, iv);
+  genWsloopClauses(converter, semaCtx, stmtCtx, beginClauseList, endClauseList,
+                   loc, wsClauseOps, reductionTypes, reductionSyms);
+
+  // Create omp.wsloop wrapper and populate entry block arguments with reduction
+  // variables.
+  auto wsloopOp = firOpBuilder.create<mlir::omp::WsloopOp>(loc, wsClauseOps);
+  llvm::SmallVector<mlir::Location> reductionLocs(reductionSyms.size(), loc);
+  mlir::Block *wsloopEntryBlock = firOpBuilder.createBlock(
+      &wsloopOp.getRegion(), {}, reductionTypes, reductionLocs);
+  firOpBuilder.setInsertionPoint(
+      Fortran::lower::genOpenMPTerminator(firOpBuilder, wsloopOp, loc));
+
+  // Create nested omp.loop_nest and fill body with loop contents.
+  auto loopOp = firOpBuilder.create<mlir::omp::LoopNestOp>(loc, loopClauseOps);
 
   auto *nestedEval = getCollapsedLoopEval(
       eval, Fortran::lower::getCollapseValue(beginClauseList));
 
   auto ivCallback = [&](mlir::Operation *op) {
-    return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms,
-                                   reductionTypes);
+    return genLoopVars(op, converter, loc, iv, reductionSyms,
+                       wsloopEntryBlock->getArguments());
   };
 
-  return genOpWithBody<mlir::omp::WsloopOp>(
-      OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval,
-                        llvm::omp::Directive::OMPD_do)
-          .setClauses(&beginClauseList)
-          .setDataSharingProcessor(&dsp)
-          .setReductions(&reductionSyms, &reductionTypes)
-          .setGenRegionEntryCb(ivCallback),
-      clauseOps);
+  createBodyOfOp(*loopOp,
+                 OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval,
+                                   llvm::omp::Directive::OMPD_do)
+                     .setClauses(&beginClauseList)
+                     .setDataSharingProcessor(&dsp)
+                     .setReductions(&reductionSyms, &reductionTypes)
+                     .setGenRegionEntryCb(ivCallback));
+  return wsloopOp;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2542,8 +2511,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
 mlir::Operation *Fortran::lower::genOpenMPTerminator(fir::FirOpBuilder &builder,
                                                      mlir::Operation *op,
                                                      mlir::Location loc) {
-  if (mlir::isa<mlir::omp::WsloopOp, mlir::omp::DeclareReductionOp,
-                mlir::omp::AtomicUpdateOp, mlir::omp::LoopNestOp>(op))
+  if (mlir::isa<mlir::omp::AtomicUpdateOp, mlir::omp::DeclareReductionOp,
+                mlir::omp::LoopNestOp>(op))
     return builder.create<mlir::omp::YieldOp>(loc);
   return builder.create<mlir::omp::TerminatorOp>(loc);
 }
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index fa7979e8875afc..c7c609bbb35623 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -7,15 +7,17 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
   omp.parallel  {
     %1 = fir.alloca i32 {adapt.valuebyref, pinned}
     %2 = fir.load %arg0 : !fir.ref<i32>
-    omp.wsloop nowait
-    for (%arg2) : i32 = (%c1_i32) to (%2) inclusive step (%c1_i32)  {
-      fir.store %arg2 to %1 : !fir.ref<i32>
-      %3 = fir.load %1 : !fir.ref<i32>
-      %4 = fir.convert %3 : (i32) -> i64
-      %5 = arith.subi %4, %c1_i64 : i64
-      %6 = fir.coordinate_of %arg1, %5 : (!fir.ref<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
-      fir.store %3 to %6 : !fir.ref<i32>
-      omp.yield
+    omp.wsloop nowait {
+      omp.loop_nest (%arg2) : i32 = (%c1_i32) to (%2) inclusive step (%c1_i32)  {
+        fir.store %arg2 to %1 : !fir.ref<i32>
+        %3 = fir.load %1 : !fir.ref<i32>
+        %4 = fir.convert %3 : (i32) -> i64
+        %5 = arith.subi %4, %c1_i64 : i64
+        %6 = fir.coordinate_of %arg1, %5 : (!fir.ref<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
+        fir.store %3 to %6 : !fir.ref<i32>
+        omp.yield
+      }
+      omp.terminator
     }
     omp.terminator
   }
@@ -31,7 +33,7 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
 // CHECK:      %[[I_VAR:.*]] = llvm.alloca %[[ONE_3]] x i32 {pinned} : (i64) -> !llvm.ptr
 // CHECK:      %[[N:.*]] = llvm.load %[[N_REF]] : !llvm.ptr -> i32
 // CHECK: omp.wsloop nowait
-// CHECK-SAME: for (%[[I:.*]]) : i32 = (%[[ONE_2]]) to (%[[N]]) inclusive step (%[[ONE_2]]) {
+// CHECK-NEXT: omp.loop_nest (%[[I:.*]]) : i32 = (%[[ONE_2]]) to (%[[N]]) inclusive step (%[[ONE_2]]) {
 // CHECK:   llvm.store %[[I]], %[[I_VAR]] : i32, !llvm.ptr
 // CHECK:   %[[I1:.*]] = llvm.load %[[I_VAR]] : !llvm.ptr -> i32
 // CHECK:   %[[I1_EXT:.*]] = llvm.sext %[[I1]] : i32 to i64
@@ -42,6 +44,8 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
 // CHECK: }
 // CHECK: omp.terminator
 // CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
 // CHECK: llvm.return
 // CHECK: }
 
@@ -79,13 +83,16 @@ func.func @_QPsb(%arr: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr"}) {
   omp.parallel   {
     %c1 = arith.constant 1 : i32
     %c50 = arith.constant 50 : i32
-    omp.wsloop   for  (%indx) : i32 = (%c1) to (%c50) inclusive step (%c1) {
-      %1 = fir.convert %indx : (i32) -> i64
-      %c1_i64 = arith.constant 1 : i64
-      %2 = arith.subi %1, %c1_i64 : i64
-      %3 = fir.coordinate_of %arr, %2 : (!fir.box<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
-      fir.store %indx to %3 : !fir.ref<i32>
-      omp.yield
+    omp.wsloop {
+      omp.loop_nest (%indx) : i32 = (%c1) to (%c50) inclusive step (%c1) {
+        %1 = fir.convert %indx : (i32) -> i64
+        %c1_i64 = arith.constant 1 : i64
+        %2 = arith.subi %1, %c1_i64 : i64
+        %3 = fir.coordinate_of %arr, %2 : (!fir.box<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
+        fir.store %indx to %3 : !fir.ref<i32>
+        omp.yield
+      }
+      omp.terminator
     }
     omp.terminator
   }
@@ -98,9 +105,11 @@ func.func @_QPsb(%arr: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr"}) {
 // CHECK:    omp.parallel   {
 // CHECK:      %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
 // CHECK:      %[[C50:.*]] = llvm.mlir.constant(50 : i32) : i32
-// CHECK:      omp.wsloop   for  (%[[INDX:.*]]) : i32 = (%[[C1]]) to (%[[C50]]) inclusive step (%[[C1]]) {
-// CHECK:        llvm.store %[[INDX]], %{{.*}} : i32, !llvm.ptr
-// CHECK:        omp.yield
+// CHECK:      omp.wsloop {
+// CHECK-NEXT:   omp.loop_nest (%[[INDX:.*]]) : i32 = (%[[C1]]) to (%[[C50]]) inclusive step (%[[C1]]) {
+// CHECK:          llvm.store %[[INDX]], %{{.*}} : i32, !llvm.ptr
+// CHECK:          omp.yield
+// CHECK:        omp.terminator
 // CHECK:      omp.terminator
 // CHECK:    llvm.return
 
@@ -708,18 +717,20 @@ func.func @_QPsb() {
 // CHECK-SAME: %[[ARRAY_REF:.*]]: !llvm.ptr
 // CHECK:    %[[RED_ACCUMULATOR:.*]] = llvm.alloca %2 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
 // CHECK:    omp.parallel   {
-// CHECK:      omp.wsloop   reduction(@[[EQV_REDUCTION]] %[[RED_ACCUMULATOR]] -> %[[PRV:.+]] : !llvm.ptr) for
-// CHECK:        %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr
-// CHECK:        %[[ARRAY_ELEM:.*]] = llvm.load %[[ARRAY_ELEM_REF]] : !llvm.ptr -> i32
-// CHECK:        %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
-// CHECK:        %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
-// CHECK:        %[[ARGVAL_1:.*]] = llvm.icmp "ne" %[[LPRV]], %[[ZERO_1]] : i32
-// CHECK:        %[[ZERO_2:.*]] = llvm.mlir.constant(0 : i64) : i32
-// CHECK:        %[[ARGVAL_2:.*]] = llvm.icmp "ne" %[[ARRAY_ELEM]], %[[ZERO_2]] : i32
-// CHECK:        %[[RES:.*]] = llvm.icmp "eq" %[[ARGVAL_2]], %[[ARGVAL_1]] : i1
-// CHECK:        %[[RES_EXT:.*]] = llvm.zext %[[RES]] : i1 to i32
-// CHECK:        llvm.store %[[RES_EXT]], %[[PRV]] : i32, !llvm.ptr
-// CHECK:        omp.yield
+// CHECK:      omp.wsloop reduction(@[[EQV_REDUCTION]] %[[RED_ACCUMULATOR]] -> %[[PRV:.+]] : !llvm.ptr) {
+// CHECK-NEXT:   omp.loop_nest
+// CHECK:          %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr
+// CHECK:          %[[ARRAY_ELEM:.*]] = llvm.load %[[ARRAY_ELEM_REF]] : !llvm.ptr -> i32
+// CHECK:          %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
+// CHECK:          %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
+// CHECK:          %[[ARGVAL_1:.*]] = llvm.icmp "ne" %[[LPRV]], %[[ZERO_1]] : i32
+// CHECK:          %[[ZERO_2:.*]] = llvm.mlir.constant(0 : i64) : i32
+// CHECK:          %[[ARGVAL_2:.*]] = llvm.icmp "ne" %[[ARRAY_ELEM]], %[[ZERO_2]] : i32
+// CHECK:          %[[RES:.*]] = llvm.icmp "eq" %[[ARGVAL_2]], %[[ARGVAL_1]] : i1
+// CHECK:          %[[RES_EXT:.*]] = llvm.zext %[[RES]] : i1 to i32
+// CHECK:          llvm.store %[[RES_EXT]], %[[PRV]] : i32, !llvm.ptr
+// CHECK:          omp.yield
+// CHECK:        omp.terminator
 // CHECK:      omp.terminator
 // CHECK:    llvm.return
 
@@ -747,21 +758,24 @@ func.func @_QPsimple_reduction(%arg0: !fir.ref<!fir.array<100x!fir.logical<4>>>
     %c1_i32 = arith.constant 1 : i32
     %c100_i32 = arith.constant 100 : i32
     %c1_i32_0 = arith.constant 1 : i32
-    omp.wsloop   reduction(@eqv_reduction %1 -> %prv : !fir.ref<!fir.logical<4>>) for  (%arg1) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32_0) {
-      fir.store %arg1 to %3 : !fir.ref<i32>
-      %4 = fir.load %3 : !fir.ref<i32>
-      %5 = fir.convert %4 : (i32) -> i64
-      %c1_i64 = arith.constant 1 : i64
-      %6 = arith.subi %5, %c1_i64 : i64
-      %7 = fir.coordinate_of %arg0, %6 : (!fir.ref<!fir.array<100x!fir.logical<4>>>, i64) -> !fir.ref<!fir.logical<4>>
-      %8 = fir.load %7 : !fir.ref<!fir.logical<4>>
-      %lprv = fir.load %prv : !fir.ref<!fir.logical<4>>
-      %lprv1 = fir.convert %lprv : (!fir.logical<4>) -> i1
-      %9 = fir.convert %8 : (!fir.logical<4>) -> i1
-      %10 = arith.cmpi eq, %9, %lprv1 : i1
-      %11 = fir.convert %10 : (i1) -> !fir.logical<4>
-      fir.store %11 to %prv : !fir.ref<!fir.logical<4>>
-      omp.yield
...
[truncated]

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

flang/lib/Lower/OpenMP/OpenMP.cpp Outdated Show resolved Hide resolved
… users/skatrak/spr/wsloop-wrapper-03-scf-parallel
… users/skatrak/spr/wsloop-wrapper-03-scf-parallel
… users/skatrak/spr/wsloop-wrapper-03-scf-parallel
Base automatically changed from users/skatrak/spr/wsloop-wrapper-04-llvm-ir to main April 24, 2024 13:29
@skatrak skatrak merged commit ca4dbc2 into main Apr 24, 2024
2 of 3 checks passed
@skatrak skatrak deleted the users/skatrak/spr/wsloop-wrapper-05-flang branch April 24, 2024 13:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants