Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Flang][OpenMP] Make omp.wsloop into a loop wrapper #88403

Closed
wants to merge 8 commits into from

Conversation

skatrak
Copy link
Contributor

@skatrak skatrak commented Apr 11, 2024

This patch updates the definition of omp.wsloop to enforce the restrictions of a wrapper operation. Given the widespread use of this operation, the changes introduced in this patch are several:

  • Update the MLIR definition of the omp.wsloop, as well as parser/printer, builder and verifier.
  • Update verifiers for omp.ordered.region, omp.cancel and omp.cancellation_point to correctly check for a parent omp.wsloop.
  • Update MLIR to LLVM IR translation of omp.wsloop to keep working after the change in representation. Another patch should be created to reduce the current code duplication between omp.wsloop and omp.simd after introducing a common omp.loop_nest operation.
  • Update the scf.parallel lowering pass to OpenMP to produce the new expected representation.
  • Update flang lowering to implement omp.wsloop representation changes, including changes to lastprivate, and reduction handling to avoid adding operations into a wrapper and attach entry block arguments to the right operation.
  • Fix unit tests broken due to the representation change.

This patch introduces an operation intended to hold loop information associated
to the `omp.distribute`, `omp.simdloop`, `omp.taskloop` and `omp.wsloop`
operations. This is a stopgap solution to unblock work on transitioning these
operations to becoming wrappers, as discussed in
[this RFC](https://discourse.llvm.org/t/rfc-representing-combined-composite-constructs-in-the-openmp-dialect/76986).

Long-term, this operation will likely be replaced by `omp.canonical_loop`,
which is being designed to address missing support for loop transformations,
etc.
This patch defines a common interface to be shared by all OpenMP loop wrapper
operations. The main restrictions these operations must meet in order to be
considered a wrapper are:

- They contain a single region.
- Their region contains a single block.
- Their block only contains another loop wrapper or `omp.loop_nest` and a
terminator.

The new interface is attached to the `omp.parallel`, `omp.wsloop`,
`omp.simdloop`, `omp.distribute` and `omp.taskloop` operations. It is not
currently enforced that these operations meet the wrapper restrictions, which
would break existing OpenMP loop-generating code. Rather, this will be
introduced progressively in subsequent patches.
This patch updates the definition of `omp.simdloop` to enforce the restrictions
of a wrapper operation. It has been renamed to `omp.simd`, to better reflect
the naming used in the spec. All uses of "simdloop" in function names have been
updated accordingly.

Some changes to Flang lowering and OpenMP to LLVM IR translation are introduced
to prevent the introduction of compilation/test failures. The eventual long
term solution might be different.
This patch updates the definition of `omp.wsloop` to enforce the restrictions
of a wrapper operation. Given the widespread use of this operation, the changes
introduced in this patch are several:

- Update the MLIR definition of the `omp.wsloop`, as well as parser/printer,
builder and verifier.
- Update verifiers for `omp.ordered.region`, `omp.cancel` and
`omp.cancellation_point` to correctly check for a parent `omp.wsloop`.
- Update MLIR to LLVM IR translation of `omp.wsloop` to keep working after the
change in representation. Another patch should be created to reduce the current
code duplication between `omp.wsloop` and `omp.simd` after introducing a common
`omp.loop_nest` operation.
- Update the `scf.parallel` lowering pass to OpenMP to produce the new expected
representation.
- Update flang lowering to implement `omp.wsloop` representation changes,
including changes to `lastprivate`, and `reduction` handling to avoid adding
operations into a wrapper and attach entry block arguments to the right
operation.
- Fix unit tests broken due to the representation change.
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 11, 2024

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch updates the definition of omp.wsloop to enforce the restrictions of a wrapper operation. Given the widespread use of this operation, the changes introduced in this patch are several:

  • Update the MLIR definition of the omp.wsloop, as well as parser/printer, builder and verifier.
  • Update verifiers for omp.ordered.region, omp.cancel and omp.cancellation_point to correctly check for a parent omp.wsloop.
  • Update MLIR to LLVM IR translation of omp.wsloop to keep working after the change in representation. Another patch should be created to reduce the current code duplication between omp.wsloop and omp.simd after introducing a common omp.loop_nest operation.
  • Update the scf.parallel lowering pass to OpenMP to produce the new expected representation.
  • Update flang lowering to implement omp.wsloop representation changes, including changes to lastprivate, and reduction handling to avoid adding operations into a wrapper and attach entry block arguments to the right operation.
  • Fix unit tests broken due to the representation change.

Patch is 758.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88403.diff

110 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+27-21)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+43-74)
  • (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+61-47)
  • (modified) flang/test/Lower/OpenMP/FIR/copyin.f90 (+11-5)
  • (modified) flang/test/Lower/OpenMP/FIR/lastprivate-commonblock.f90 (+4-1)
  • (modified) flang/test/Lower/OpenMP/FIR/location.f90 (+10-7)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-lastprivate-clause-scalar.f90 (+36-12)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-private-clause-fixes.f90 (+26-23)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-private-clause.f90 (+60-54)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-wsloop-firstpriv.f90 (+10-2)
  • (modified) flang/test/Lower/OpenMP/FIR/parallel-wsloop.f90 (+74-54)
  • (modified) flang/test/Lower/OpenMP/FIR/stop-stmt-in-region.f90 (+21-18)
  • (modified) flang/test/Lower/OpenMP/FIR/target.f90 (+4-1)
  • (modified) flang/test/Lower/OpenMP/FIR/unstructured.f90 (+110-89)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-chunks.f90 (+28-19)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-collapse.f90 (+16-13)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-monotonic.f90 (+17-13)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-nonmonotonic.f90 (+17-14)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-ordered.f90 (+12-6)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-add-byref.f90 (+106-85)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-add.f90 (+106-85)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-iand-byref.f90 (+3-1)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-iand.f90 (+3-1)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ieor-byref.f90 (+3-1)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ieor.f90 (+3-1)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ior-byref.f90 (+3-1)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ior.f90 (+3-1)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-eqv-byref.f90 (+75-69)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-eqv.f90 (+75-69)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-neqv-byref.f90 (+75-69)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-neqv.f90 (+75-69)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-max-byref.f90 (+18-13)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-max.f90 (+18-13)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-min-byref.f90 (+18-14)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-min.f90 (+18-14)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-simd.f90 (+16-13)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop-variable.f90 (+93-79)
  • (modified) flang/test/Lower/OpenMP/FIR/wsloop.f90 (+36-30)
  • (modified) flang/test/Lower/OpenMP/Todo/omp-default-clause-inner-loop.f90 (+4-1)
  • (modified) flang/test/Lower/OpenMP/copyin.f90 (+17-11)
  • (modified) flang/test/Lower/OpenMP/default-clause-byref.f90 (+4-1)
  • (modified) flang/test/Lower/OpenMP/default-clause.f90 (+4-1)
  • (modified) flang/test/Lower/OpenMP/hlfir-wsloop.f90 (+7-5)
  • (modified) flang/test/Lower/OpenMP/lastprivate-commonblock.f90 (+31-28)
  • (modified) flang/test/Lower/OpenMP/lastprivate-iv.f90 (+48-42)
  • (modified) flang/test/Lower/OpenMP/location.f90 (+10-7)
  • (modified) flang/test/Lower/OpenMP/parallel-lastprivate-clause-scalar.f90 (+36-12)
  • (modified) flang/test/Lower/OpenMP/parallel-private-clause-fixes.f90 (+26-23)
  • (modified) flang/test/Lower/OpenMP/parallel-private-clause.f90 (+56-50)
  • (modified) flang/test/Lower/OpenMP/parallel-wsloop-firstpriv.f90 (+8-2)
  • (modified) flang/test/Lower/OpenMP/parallel-wsloop.f90 (+79-59)
  • (modified) flang/test/Lower/OpenMP/stop-stmt-in-region.f90 (+21-18)
  • (modified) flang/test/Lower/OpenMP/target.f90 (+4-1)
  • (modified) flang/test/Lower/OpenMP/unstructured.f90 (+110-89)
  • (modified) flang/test/Lower/OpenMP/wsloop-chunks.f90 (+28-19)
  • (modified) flang/test/Lower/OpenMP/wsloop-collapse.f90 (+16-13)
  • (modified) flang/test/Lower/OpenMP/wsloop-monotonic.f90 (+10-8)
  • (modified) flang/test/Lower/OpenMP/wsloop-nonmonotonic.f90 (+11-8)
  • (modified) flang/test/Lower/OpenMP/wsloop-ordered.f90 (+12-6)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-byref.f90 (+120-99)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir-byref.f90 (+10-8)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir.f90 (+10-8)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add.f90 (+120-99)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array.f90 (+19-16)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array2.f90 (+27-24)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-iand-byref.f90 (+13-11)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-iand.f90 (+13-11)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-ieor-byref.f90 (+4-2)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-ieor.f90 (+4-2)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-ior-byref.f90 (+13-11)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-ior.f90 (+13-11)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-and-byref.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv-byref.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv-byref.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-or-byref.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-or.f90 (+70-64)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-byref.f90 (+48-42)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir-byref.f90 (+14-12)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir.f90 (+14-12)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max.f90 (+48-42)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-min-byref.f90 (+49-43)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-min.f90 (+49-43)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-min2.f90 (+9-7)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-mul-byref.f90 (+113-99)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-mul.f90 (+113-99)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-multi.f90 (+26-23)
  • (modified) flang/test/Lower/OpenMP/wsloop-simd.f90 (+16-13)
  • (modified) flang/test/Lower/OpenMP/wsloop-unstructured.f90 (+21-18)
  • (modified) flang/test/Lower/OpenMP/wsloop-variable.f90 (+91-76)
  • (modified) flang/test/Lower/OpenMP/wsloop.f90 (+39-33)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+21-35)
  • (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+40-12)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+47-85)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+32-34)
  • (modified) mlir/test/CAPI/execution_engine.c (+5-2)
  • (modified) mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir (+36-28)
  • (modified) mlir/test/Conversion/SCFToOpenMP/reductions.mlir (+4)
  • (modified) mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir (+23-8)
  • (modified) mlir/test/Dialect/LLVMIR/legalize-for-export.mlir (+11-8)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+167-90)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+326-236)
  • (modified) mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir (+7-4)
  • (modified) mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir (+10-7)
  • (modified) mlir/test/Target/LLVMIR/omptarget-wsloop.mlir (+12-6)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+404-337)
  • (modified) mlir/test/Target/LLVMIR/openmp-nested.mlir (+18-12)
  • (modified) mlir/test/Target/LLVMIR/openmp-reduction.mlir (+63-50)
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index e114ab9f4548ab..645c351ac6c08c 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -133,8 +133,14 @@ void DataSharingProcessor::insertBarrier() {
 }
 
 void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
+  mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+  mlir::omp::LoopNestOp loopOp;
+  if (auto wrapper = mlir::dyn_cast<mlir::omp::LoopWrapperInterface>(op))
+    loopOp = wrapper.isWrapper()
+                 ? mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop())
+                 : nullptr;
+
   bool cmpCreated = false;
-  mlir::OpBuilder::InsertPoint localInsPt = firOpBuilder.saveInsertionPoint();
   for (const omp::Clause &clause : clauses) {
     if (clause.id != llvm::omp::OMPC_lastprivate)
       continue;
@@ -213,18 +219,20 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
       // Update the original variable just before exiting the worksharing
       // loop. Conversion as follows:
       //
-      //                       omp.wsloop {
-      // omp.wsloop {            ...
-      //    ...                  store
-      //    store       ===>     %v = arith.addi %iv, %step
-      //    omp.yield            %cmp = %step < 0 ? %v < %ub : %v > %ub
-      // }                       fir.if %cmp {
-      //                           fir.store %v to %loopIV
-      //                           ^%lpv_update_blk:
-      //                         }
-      //                         omp.yield
-      //                       }
-      //
+      // omp.wsloop {             omp.wsloop {
+      //   omp.loop_nest {          omp.loop_nest {
+      //     ...                      ...
+      //     store          ===>      store
+      //     omp.yield                %v = arith.addi %iv, %step
+      //   }                          %cmp = %step < 0 ? %v < %ub : %v > %ub
+      //   omp.terminator             fir.if %cmp {
+      // }                              fir.store %v to %loopIV
+      //                                ^%lpv_update_blk:
+      //                              }
+      //                              omp.yield
+      //                            }
+      //                            omp.terminator
+      //                          }
 
       // Only generate the compare once in presence of multiple LastPrivate
       // clauses.
@@ -232,14 +240,13 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
         continue;
       cmpCreated = true;
 
-      mlir::Location loc = op->getLoc();
-      mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
+      mlir::Location loc = loopOp.getLoc();
+      mlir::Operation *lastOper = loopOp.getRegion().back().getTerminator();
       firOpBuilder.setInsertionPoint(lastOper);
 
-      mlir::Value iv = op->getRegion(0).front().getArguments()[0];
-      mlir::Value ub =
-          mlir::dyn_cast<mlir::omp::WsloopOp>(op).getUpperBound()[0];
-      mlir::Value step = mlir::dyn_cast<mlir::omp::WsloopOp>(op).getStep()[0];
+      mlir::Value iv = loopOp.getIVs()[0];
+      mlir::Value ub = loopOp.getUpperBound()[0];
+      mlir::Value step = loopOp.getStep()[0];
 
       // v = iv + step
       // cmp = step < 0 ? v < ub : v > ub
@@ -258,7 +265,7 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
       auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
       firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
       assert(loopIV && "loopIV was not set");
-      firOpBuilder.create<fir::StoreOp>(op->getLoc(), v, loopIV);
+      firOpBuilder.create<fir::StoreOp>(loopOp.getLoc(), v, loopIV);
       lastPrivIP = firOpBuilder.saveInsertionPoint();
     } else {
       TODO(converter.getCurrentLocation(),
@@ -266,7 +273,6 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
            "simd/worksharing-loop");
     }
   }
-  firOpBuilder.restoreInsertionPoint(localInsPt);
 }
 
 void DataSharingProcessor::collectSymbols(
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 1800fcb19dcd2e..b21351382b6bdf 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1626,7 +1626,9 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
 static llvm::SmallVector<const Fortran::semantics::Symbol *>
 genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
             mlir::Location &loc,
-            llvm::ArrayRef<const Fortran::semantics::Symbol *> args) {
+            llvm::ArrayRef<const Fortran::semantics::Symbol *> args,
+            llvm::ArrayRef<const Fortran::semantics::Symbol *> wrapperSyms = {},
+            llvm::ArrayRef<mlir::BlockArgument> wrapperArgs = {}) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   auto &region = op->getRegion(0);
 
@@ -1637,6 +1639,14 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
   llvm::SmallVector<mlir::Location> locs(args.size(), loc);
   firOpBuilder.createBlock(&region, {}, tiv, locs);
+
+  // Bind the entry block arguments of parent wrappers to the corresponding
+  // symbols. Do it here so that any hlfir.declare operations created as a
+  // result are inserted inside of the omp.loop_nest rather than the wrapper
+  // operations.
+  for (auto [arg, prv] : llvm::zip_equal(wrapperSyms, wrapperArgs))
+    converter.bindSymbol(*arg, prv);
+
   // The argument is not currently in memory, so make a temporary for the
   // argument, and store it there, then bind that location to the argument.
   mlir::Operation *storeOp = nullptr;
@@ -1650,58 +1660,6 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
   return llvm::SmallVector<const Fortran::semantics::Symbol *>(args);
 }
 
-static llvm::SmallVector<const Fortran::semantics::Symbol *>
-genLoopAndReductionVars(
-    mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
-    mlir::Location &loc,
-    llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs,
-    llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs,
-    llvm::ArrayRef<mlir::Type> reductionTypes) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  llvm::SmallVector<mlir::Type> blockArgTypes;
-  llvm::SmallVector<mlir::Location> blockArgLocs;
-  blockArgTypes.reserve(loopArgs.size() + reductionArgs.size());
-  blockArgLocs.reserve(blockArgTypes.size());
-  mlir::Block *entryBlock;
-
-  if (loopArgs.size()) {
-    std::size_t loopVarTypeSize = 0;
-    for (const Fortran::semantics::Symbol *arg : loopArgs)
-      loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
-    mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
-    std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(),
-                loopVarType);
-    std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc);
-  }
-  if (reductionArgs.size()) {
-    llvm::copy(reductionTypes, std::back_inserter(blockArgTypes));
-    std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc);
-  }
-  entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes,
-                                        blockArgLocs);
-  // The argument is not currently in memory, so make a temporary for the
-  // argument, and store it there, then bind that location to the argument.
-  if (loopArgs.size()) {
-    mlir::Operation *storeOp = nullptr;
-    for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) {
-      mlir::Value indexVal =
-          fir::getBase(op->getRegion(0).front().getArgument(argIndex));
-      storeOp =
-          createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
-    }
-    firOpBuilder.setInsertionPointAfter(storeOp);
-  }
-  // Bind the reduction arguments to their block arguments
-  for (auto [arg, prv] : llvm::zip_equal(
-           reductionArgs,
-           llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) {
-    converter.bindSymbol(*arg, prv);
-  }
-
-  return llvm::SmallVector<const Fortran::semantics::Symbol *>(loopArgs);
-}
-
 static void createSimd(Fortran::lower::AbstractConverter &converter,
                        Fortran::semantics::SemanticsContext &semaCtx,
                        Fortran::lower::pft::Evaluation &eval,
@@ -1797,28 +1755,26 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
   if (ReductionProcessor::doReductionByRef(reductionVars))
     byrefOperand = firOpBuilder.getUnitAttr();
 
-  auto wsLoopOp = firOpBuilder.create<mlir::omp::WsloopOp>(
-      loc, lowerBound, upperBound, step, linearVars, linearStepVars,
-      reductionVars,
+  auto wsloopOp = firOpBuilder.create<mlir::omp::WsloopOp>(
+      loc, linearVars, linearStepVars, reductionVars,
       reductionDeclSymbols.empty()
           ? nullptr
           : mlir::ArrayAttr::get(firOpBuilder.getContext(),
                                  reductionDeclSymbols),
       scheduleValClauseOperand, scheduleChunkClauseOperand,
-      /*schedule_modifiers=*/nullptr,
-      /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
-      orderedClauseOperand, orderClauseOperand,
-      /*inclusive=*/firOpBuilder.getUnitAttr());
+      /*schedule_modifiers=*/nullptr, /*simd_modifier=*/nullptr,
+      nowaitClauseOperand, byrefOperand, orderedClauseOperand,
+      orderClauseOperand);
 
   // Handle attribute based clauses.
   if (cp.processOrdered(orderedClauseOperand))
-    wsLoopOp.setOrderedValAttr(orderedClauseOperand);
+    wsloopOp.setOrderedValAttr(orderedClauseOperand);
 
   if (cp.processSchedule(scheduleValClauseOperand, scheduleModClauseOperand,
                          scheduleSimdClauseOperand)) {
-    wsLoopOp.setScheduleValAttr(scheduleValClauseOperand);
-    wsLoopOp.setScheduleModifierAttr(scheduleModClauseOperand);
-    wsLoopOp.setSimdModifierAttr(scheduleSimdClauseOperand);
+    wsloopOp.setScheduleValAttr(scheduleValClauseOperand);
+    wsloopOp.setScheduleModifierAttr(scheduleModClauseOperand);
+    wsloopOp.setSimdModifierAttr(scheduleSimdClauseOperand);
   }
   // In FORTRAN `nowait` clause occur at the end of `omp do` directive.
   // i.e
@@ -1828,23 +1784,36 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
   if (endClauseList) {
     if (ClauseProcessor(converter, semaCtx, *endClauseList)
             .processNowait(nowaitClauseOperand))
-      wsLoopOp.setNowaitAttr(nowaitClauseOperand);
+      wsloopOp.setNowaitAttr(nowaitClauseOperand);
   }
 
+  // Create omp.wsloop wrapper and populate entry block arguments with reduction
+  // variables.
+  llvm::SmallVector<mlir::Location> reductionLocs(reductionSymbols.size(), loc);
+  mlir::Block *wsloopEntryBlock = firOpBuilder.createBlock(
+      &wsloopOp.getRegion(), {}, reductionTypes, reductionLocs);
+  firOpBuilder.setInsertionPoint(
+      Fortran::lower::genOpenMPTerminator(firOpBuilder, wsloopOp, loc));
+
+  // Create nested omp.loop_nest and fill body with loop contents.
+  auto loopOp = firOpBuilder.create<mlir::omp::LoopNestOp>(
+      loc, lowerBound, upperBound, step,
+      /*inclusive=*/firOpBuilder.getUnitAttr());
+
   auto *nestedEval = getCollapsedLoopEval(
       eval, Fortran::lower::getCollapseValue(beginClauseList));
 
   auto ivCallback = [&](mlir::Operation *op) {
-    return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols,
-                                   reductionTypes);
+    return genLoopVars(op, converter, loc, iv, reductionSymbols,
+                       wsloopEntryBlock->getArguments());
   };
 
   createBodyOfOp<mlir::omp::WsloopOp>(
-      *wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
-                     .setClauses(&beginClauseList)
-                     .setDataSharingProcessor(&dsp)
-                     .setReductions(&reductionSymbols, &reductionTypes)
-                     .setGenRegionEntryCb(ivCallback));
+      *loopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
+                   .setClauses(&beginClauseList)
+                   .setDataSharingProcessor(&dsp)
+                   .setReductions(&reductionSymbols, &reductionTypes)
+                   .setGenRegionEntryCb(ivCallback));
 }
 
 static void createSimdWsloop(
@@ -2430,8 +2399,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
 mlir::Operation *Fortran::lower::genOpenMPTerminator(fir::FirOpBuilder &builder,
                                                      mlir::Operation *op,
                                                      mlir::Location loc) {
-  if (mlir::isa<mlir::omp::WsloopOp, mlir::omp::DeclareReductionOp,
-                mlir::omp::AtomicUpdateOp, mlir::omp::LoopNestOp>(op))
+  if (mlir::isa<mlir::omp::AtomicUpdateOp, mlir::omp::DeclareReductionOp,
+                mlir::omp::LoopNestOp>(op))
     return builder.create<mlir::omp::YieldOp>(loc);
   return builder.create<mlir::omp::TerminatorOp>(loc);
 }
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index fa7979e8875afc..c7c609bbb35623 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -7,15 +7,17 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
   omp.parallel  {
     %1 = fir.alloca i32 {adapt.valuebyref, pinned}
     %2 = fir.load %arg0 : !fir.ref<i32>
-    omp.wsloop nowait
-    for (%arg2) : i32 = (%c1_i32) to (%2) inclusive step (%c1_i32)  {
-      fir.store %arg2 to %1 : !fir.ref<i32>
-      %3 = fir.load %1 : !fir.ref<i32>
-      %4 = fir.convert %3 : (i32) -> i64
-      %5 = arith.subi %4, %c1_i64 : i64
-      %6 = fir.coordinate_of %arg1, %5 : (!fir.ref<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
-      fir.store %3 to %6 : !fir.ref<i32>
-      omp.yield
+    omp.wsloop nowait {
+      omp.loop_nest (%arg2) : i32 = (%c1_i32) to (%2) inclusive step (%c1_i32)  {
+        fir.store %arg2 to %1 : !fir.ref<i32>
+        %3 = fir.load %1 : !fir.ref<i32>
+        %4 = fir.convert %3 : (i32) -> i64
+        %5 = arith.subi %4, %c1_i64 : i64
+        %6 = fir.coordinate_of %arg1, %5 : (!fir.ref<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
+        fir.store %3 to %6 : !fir.ref<i32>
+        omp.yield
+      }
+      omp.terminator
     }
     omp.terminator
   }
@@ -31,7 +33,7 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
 // CHECK:      %[[I_VAR:.*]] = llvm.alloca %[[ONE_3]] x i32 {pinned} : (i64) -> !llvm.ptr
 // CHECK:      %[[N:.*]] = llvm.load %[[N_REF]] : !llvm.ptr -> i32
 // CHECK: omp.wsloop nowait
-// CHECK-SAME: for (%[[I:.*]]) : i32 = (%[[ONE_2]]) to (%[[N]]) inclusive step (%[[ONE_2]]) {
+// CHECK-NEXT: omp.loop_nest (%[[I:.*]]) : i32 = (%[[ONE_2]]) to (%[[N]]) inclusive step (%[[ONE_2]]) {
 // CHECK:   llvm.store %[[I]], %[[I_VAR]] : i32, !llvm.ptr
 // CHECK:   %[[I1:.*]] = llvm.load %[[I_VAR]] : !llvm.ptr -> i32
 // CHECK:   %[[I1_EXT:.*]] = llvm.sext %[[I1]] : i32 to i64
@@ -42,6 +44,8 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
 // CHECK: }
 // CHECK: omp.terminator
 // CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
 // CHECK: llvm.return
 // CHECK: }
 
@@ -79,13 +83,16 @@ func.func @_QPsb(%arr: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr"}) {
   omp.parallel   {
     %c1 = arith.constant 1 : i32
     %c50 = arith.constant 50 : i32
-    omp.wsloop   for  (%indx) : i32 = (%c1) to (%c50) inclusive step (%c1) {
-      %1 = fir.convert %indx : (i32) -> i64
-      %c1_i64 = arith.constant 1 : i64
-      %2 = arith.subi %1, %c1_i64 : i64
-      %3 = fir.coordinate_of %arr, %2 : (!fir.box<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
-      fir.store %indx to %3 : !fir.ref<i32>
-      omp.yield
+    omp.wsloop {
+      omp.loop_nest (%indx) : i32 = (%c1) to (%c50) inclusive step (%c1) {
+        %1 = fir.convert %indx : (i32) -> i64
+        %c1_i64 = arith.constant 1 : i64
+        %2 = arith.subi %1, %c1_i64 : i64
+        %3 = fir.coordinate_of %arr, %2 : (!fir.box<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
+        fir.store %indx to %3 : !fir.ref<i32>
+        omp.yield
+      }
+      omp.terminator
     }
     omp.terminator
   }
@@ -98,9 +105,11 @@ func.func @_QPsb(%arr: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr"}) {
 // CHECK:    omp.parallel   {
 // CHECK:      %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
 // CHECK:      %[[C50:.*]] = llvm.mlir.constant(50 : i32) : i32
-// CHECK:      omp.wsloop   for  (%[[INDX:.*]]) : i32 = (%[[C1]]) to (%[[C50]]) inclusive step (%[[C1]]) {
-// CHECK:        llvm.store %[[INDX]], %{{.*}} : i32, !llvm.ptr
-// CHECK:        omp.yield
+// CHECK:      omp.wsloop {
+// CHECK-NEXT:   omp.loop_nest (%[[INDX:.*]]) : i32 = (%[[C1]]) to (%[[C50]]) inclusive step (%[[C1]]) {
+// CHECK:          llvm.store %[[INDX]], %{{.*}} : i32, !llvm.ptr
+// CHECK:          omp.yield
+// CHECK:        omp.terminator
 // CHECK:      omp.terminator
 // CHECK:    llvm.return
 
@@ -708,18 +717,20 @@ func.func @_QPsb() {
 // CHECK-SAME: %[[ARRAY_REF:.*]]: !llvm.ptr
 // CHECK:    %[[RED_ACCUMULATOR:.*]] = llvm.alloca %2 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
 // CHECK:    omp.parallel   {
-// CHECK:      omp.wsloop   reduction(@[[EQV_REDUCTION]] %[[RED_ACCUMULATOR]] -> %[[PRV:.+]] : !llvm.ptr) for
-// CHECK:        %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr
-// CHECK:        %[[ARRAY_ELEM:.*]] = llvm.load %[[ARRAY_ELEM_REF]] : !llvm.ptr -> i32
-// CHECK:        %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
-// CHECK:        %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
-// CHECK:        %[[ARGVAL_1:.*]] = llvm.icmp "ne" %[[LPRV]], %[[ZERO_1]] : i32
-// CHECK:        %[[ZERO_2:.*]] = llvm.mlir.constant(0 : i64) : i32
-// CHECK:        %[[ARGVAL_2:.*]] = llvm.icmp "ne" %[[ARRAY_ELEM]], %[[ZERO_2]] : i32
-// CHECK:        %[[RES:.*]] = llvm.icmp "eq" %[[ARGVAL_2]], %[[ARGVAL_1]] : i1
-// CHECK:        %[[RES_EXT:.*]] = llvm.zext %[[RES]] : i1 to i32
-// CHECK:        llvm.store %[[RES_EXT]], %[[PRV]] : i32, !llvm.ptr
-// CHECK:        omp.yield
+// CHECK:      omp.wsloop reduction(@[[EQV_REDUCTION]] %[[RED_ACCUMULATOR]] -> %[[PRV:.+]] : !llvm.ptr) {
+// CHECK-NEXT:   omp.loop_nest
+// CHECK:          %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr
+// CHECK:          %[[ARRAY_ELEM:.*]] = llvm.load %[[ARRAY_ELEM_REF]] : !llvm.ptr -> i32
+// CHECK:          %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
+// CHECK:          %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
+// CHECK:          %[[ARGVAL_1:.*]] = llvm.icmp "ne" %[[LPRV]], %[[ZERO_1]] : i32
+// CHECK:          %[[ZERO_2:.*]] = llvm.mlir.constant(0 : i64) : i32
+// CHECK:          %[[ARGVAL_2:.*]] = llvm.icmp "ne" %[[ARRAY_ELEM]], %[[ZERO_2]] : i32
+// CHECK:          %[[RES:.*]] = llvm.icmp "eq" %[[ARGVAL_2]], %[[ARGVAL_1]] : i1
+// CHECK:          %[[RES_EXT:.*]] = llvm.zext %[[RES]] : i1 to i32
+// CHECK:          llvm.store %[[RES_EXT]], %[[PRV]] : i32, !llvm.ptr
+// CHECK:          omp.yield
+// CHECK:        omp.terminator
 // CHECK:      omp.terminator
 // CHECK:    llvm.return
 
@@ -747,21 +758,24 @@ func.func @_QPsimple_reduction(%arg0: !fir.ref<!fir.array<100x!fir.logical<4>>>
     %c1_i32 = arith.constant 1 : i32
     %c100_i32 = arith.constant 100 : i32
     %c1_i32_0 = arith.constant 1 : i32
-    omp.wsloop   reduction(@eqv_reduction %1 -> %prv : !fir.ref<!fir.logical<4>>) for  (%arg1) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32_0) {
-      fir.store %arg1 to %3 : !fir.ref<i32>
-      %4 = fir.load %3 : !fir.ref<i32>
-      %5 = fir.convert %4 : (i32) -> i64
-      %c1_i64 = arith.constant 1 : i64
-      %6 = arith.subi %5, %c1_i64 : i64
-      %7 = fir.coordinate_of %arg0, %6 : (!fir.ref<!fir.array<100x!fir.logical<4>>>, i64) -> !fir.ref<!fir.logical<4>>
-      %8 = fir.load %7 : !fir.ref<!fir.logical<4>>
-      %lprv = fir.load %prv : !fir.ref<!fir.logical<4>>
-      %lprv1 = fir.convert %lprv : (!fir.logical<...
[truncated]

@skatrak
Copy link
Contributor Author

skatrak commented Apr 16, 2024

I'm planning to split this into a PR stack after landing #87365, since it's too large to review. However, only the last commit of the stack will compile and pass tests, so they all would have to land simultaneously. I'm open to suggestions on how to best achieve this.

Base automatically changed from users/skatrak/spr/loop-nest-03-simd-mlir to main April 17, 2024 10:28
@skatrak skatrak closed this Apr 18, 2024
@skatrak skatrak deleted the users/skatrak/spr/loop-nest-04-wsloop-mlir branch April 18, 2024 10:57
@Meinersbur
Copy link
Member

Thanks a lot for the effort to split this into smaller patches. Very appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants