Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flang][OpenMP][Lower] NFC: Move clause processing helpers into the ClauseProcessor #85258

Merged
merged 1 commit into from
Mar 19, 2024

Conversation

skatrak
Copy link
Contributor

@skatrak skatrak commented Mar 14, 2024

This patch moves some code in PFT to MLIR OpenMP lowering to the ClauseProcessor class. This is so that some behavior that is related to certain clauses stays within the ClauseProcessor and it's not the caller the one responsible for always doing this when the clause is present.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Mar 14, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 14, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Sergio Afonso (skatrak)

Changes

This patch moves some code in PFT to MLIR OpenMP lowering to the ClauseProcessor class. This is so that some behavior that is related to certain clauses stays within the ClauseProcessor and it's not the caller the one responsible for always doing this when the clause is present.


Full diff: https://github.com/llvm/llvm-project/pull/85258.diff

5 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+28-3)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+7-8)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+9-63)
  • (modified) flang/lib/Lower/OpenMP/Utils.cpp (+19)
  • (modified) flang/lib/Lower/OpenMP/Utils.h (+3)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index a41f8312a28c9e..4769a1d4138a18 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -230,6 +230,25 @@ addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
   }
 }
 
+static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
+                              mlir::Location loc,
+                              llvm::SmallVectorImpl<mlir::Value> &lowerBound,
+                              llvm::SmallVectorImpl<mlir::Value> &upperBound,
+                              llvm::SmallVectorImpl<mlir::Value> &step,
+                              std::size_t loopVarTypeSize) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  // The types of lower bound, upper bound, and step are converted into the
+  // type of the loop variable if necessary.
+  mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
+  for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
+    lowerBound[it] =
+        firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]);
+    upperBound[it] =
+        firOpBuilder.createConvert(loc, loopVarType, upperBound[it]);
+    step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // ClauseProcessor unique clauses
 //===----------------------------------------------------------------------===//
@@ -239,8 +258,7 @@ bool ClauseProcessor::processCollapse(
     llvm::SmallVectorImpl<mlir::Value> &lowerBound,
     llvm::SmallVectorImpl<mlir::Value> &upperBound,
     llvm::SmallVectorImpl<mlir::Value> &step,
-    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
-    std::size_t &loopVarTypeSize) const {
+    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const {
   bool found = false;
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
@@ -259,7 +277,7 @@ bool ClauseProcessor::processCollapse(
     found = true;
   }
 
-  loopVarTypeSize = 0;
+  std::size_t loopVarTypeSize = 0;
   do {
     Fortran::lower::pft::Evaluation *doLoop =
         &doConstructEval->getFirstNestedEvaluation();
@@ -290,6 +308,9 @@ bool ClauseProcessor::processCollapse(
         &*std::next(doConstructEval->getNestedEvaluations().begin());
   } while (collapseValue > 0);
 
+  convertLoopBounds(converter, currentLocation, lowerBound, upperBound, step,
+                    loopVarTypeSize);
+
   return found;
 }
 
@@ -961,6 +982,7 @@ bool ClauseProcessor::processMap(
 bool ClauseProcessor::processReduction(
     mlir::Location currentLocation,
     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+    llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
     const {
@@ -971,6 +993,9 @@ bool ClauseProcessor::processReduction(
         rp.addReductionDecl(currentLocation, converter, reductionClause->v,
                             reductionVars, reductionDeclSymbols,
                             reductionSymbols);
+        reductionTypes.reserve(reductionVars.size());
+        llvm::transform(reductionVars, std::back_inserter(reductionTypes),
+                        [](mlir::Value v) { return v.getType(); });
       });
 }
 
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 11aff0be25053d..dda8b4ad21c4f4 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -54,14 +54,12 @@ class ClauseProcessor {
       : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
 
   // 'Unique' clauses: They can appear at most once in the clause list.
-  bool
-  processCollapse(mlir::Location currentLocation,
-                  Fortran::lower::pft::Evaluation &eval,
-                  llvm::SmallVectorImpl<mlir::Value> &lowerBound,
-                  llvm::SmallVectorImpl<mlir::Value> &upperBound,
-                  llvm::SmallVectorImpl<mlir::Value> &step,
-                  llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
-                  std::size_t &loopVarTypeSize) const;
+  bool processCollapse(
+      mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
+      llvm::SmallVectorImpl<mlir::Value> &lowerBound,
+      llvm::SmallVectorImpl<mlir::Value> &upperBound,
+      llvm::SmallVectorImpl<mlir::Value> &step,
+      llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const;
   bool processDefault() const;
   bool processDevice(Fortran::lower::StatementContext &stmtCtx,
                      mlir::Value &result) const;
@@ -125,6 +123,7 @@ class ClauseProcessor {
   bool
   processReduction(mlir::Location currentLocation,
                    llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+                   llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
                    llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
                    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
                        *reductionSymbols = nullptr) const;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 25bb4d9cff5d16..3d9a763a1390f2 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -214,24 +214,6 @@ static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter,
   firOpBuilder.restoreInsertionPoint(insPt);
 }
 
-static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
-                                 std::size_t loopVarTypeSize) {
-  // OpenMP runtime requires 32-bit or 64-bit loop variables.
-  loopVarTypeSize = loopVarTypeSize * 8;
-  if (loopVarTypeSize < 32) {
-    loopVarTypeSize = 32;
-  } else if (loopVarTypeSize > 64) {
-    loopVarTypeSize = 64;
-    mlir::emitWarning(converter.getCurrentLocation(),
-                      "OpenMP loop iteration variable cannot have more than 64 "
-                      "bits size and will be narrowed into 64 bits.");
-  }
-  assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) &&
-         "OpenMP loop iteration variable size must be transformed into 32-bit "
-         "or 64-bit");
-  return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
-}
-
 static mlir::Operation *
 createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
                               mlir::Location loc, mlir::Value indexVal,
@@ -570,6 +552,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
   mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
   llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
       reductionVars;
+  llvm::SmallVector<mlir::Type> reductionTypes;
   llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
   llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
 
@@ -581,13 +564,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
   cp.processDefault();
   cp.processAllocate(allocatorOperands, allocateOperands);
   if (!outerCombined)
-    cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
-                        &reductionSymbols);
-
-  llvm::SmallVector<mlir::Type> reductionTypes;
-  reductionTypes.reserve(reductionVars.size());
-  llvm::transform(reductionVars, std::back_inserter(reductionTypes),
-                  [](mlir::Value v) { return v.getType(); });
+    cp.processReduction(currentLocation, reductionVars, reductionTypes,
+                        reductionDeclSymbols, &reductionSymbols);
 
   auto reductionCallback = [&](mlir::Operation *op) {
     llvm::SmallVector<mlir::Location> locs(reductionVars.size(),
@@ -1476,25 +1454,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
       standaloneConstruct.u);
 }
 
-static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
-                              mlir::Location loc,
-                              llvm::SmallVectorImpl<mlir::Value> &lowerBound,
-                              llvm::SmallVectorImpl<mlir::Value> &upperBound,
-                              llvm::SmallVectorImpl<mlir::Value> &step,
-                              std::size_t loopVarTypeSize) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  // The types of lower bound, upper bound, and step are converted into the
-  // type of the loop variable if necessary.
-  mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
-  for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
-    lowerBound[it] =
-        firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]);
-    upperBound[it] =
-        firOpBuilder.createConvert(loc, loopVarType, upperBound[it]);
-    step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]);
-  }
-}
-
 static llvm::SmallVector<const Fortran::semantics::Symbol *>
 genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
             mlir::Location &loc,
@@ -1590,16 +1549,15 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
   llvm::SmallVector<mlir::Value> alignedVars, nontemporalVars;
   llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
+  llvm::SmallVector<mlir::Type> reductionTypes;
   llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
   mlir::omp::ClauseOrderKindAttr orderClauseOperand;
   mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand;
-  std::size_t loopVarTypeSize;
 
   ClauseProcessor cp(converter, semaCtx, loopOpClauseList);
-  cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv,
-                     loopVarTypeSize);
+  cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
   cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
-  cp.processReduction(loc, reductionVars, reductionDeclSymbols);
+  cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols);
   cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
                ifClauseOperand);
   cp.processSimdlen(simdlenClauseOperand);
@@ -1610,9 +1568,6 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
                  Fortran::parser::OmpClause::Nontemporal,
                  Fortran::parser::OmpClause::Order>(loc, ompDirective);
 
-  convertLoopBounds(converter, loc, lowerBound, upperBound, step,
-                    loopVarTypeSize);
-
   mlir::TypeRange resultType;
   auto simdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>(
       loc, resultType, lowerBound, upperBound, step, alignedVars,
@@ -1650,6 +1605,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
   llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
   llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
+  llvm::SmallVector<mlir::Type> reductionTypes;
   llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
   llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
   mlir::omp::ClauseOrderKindAttr orderClauseOperand;
@@ -1657,20 +1613,15 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
   mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
   mlir::IntegerAttr orderedClauseOperand;
   mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
-  std::size_t loopVarTypeSize;
 
   ClauseProcessor cp(converter, semaCtx, beginClauseList);
-  cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv,
-                     loopVarTypeSize);
+  cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
   cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
-  cp.processReduction(loc, reductionVars, reductionDeclSymbols,
+  cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols,
                       &reductionSymbols);
   cp.processTODO<Fortran::parser::OmpClause::Linear,
                  Fortran::parser::OmpClause::Order>(loc, ompDirective);
 
-  convertLoopBounds(converter, loc, lowerBound, upperBound, step,
-                    loopVarTypeSize);
-
   if (ReductionProcessor::doReductionByRef(reductionVars))
     byrefOperand = firOpBuilder.getUnitAttr();
 
@@ -1711,11 +1662,6 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
   auto *nestedEval = getCollapsedLoopEval(
       eval, Fortran::lower::getCollapseValue(beginClauseList));
 
-  llvm::SmallVector<mlir::Type> reductionTypes;
-  reductionTypes.reserve(reductionVars.size());
-  llvm::transform(reductionVars, std::back_inserter(reductionTypes),
-                  [](mlir::Value v) { return v.getType(); });
-
   auto ivCallback = [&](mlir::Operation *op) {
     return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols,
                                    reductionTypes);
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 49517f62895dfe..84e17d3fa61cd2 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -14,6 +14,7 @@
 
 #include <flang/Lower/AbstractConverter.h>
 #include <flang/Lower/ConvertType.h>
+#include <flang/Optimizer/Builder/FIRBuilder.h>
 #include <flang/Parser/parse-tree.h>
 #include <flang/Parser/tools.h>
 #include <flang/Semantics/tools.h>
@@ -55,6 +56,24 @@ void genObjectList(const Fortran::parser::OmpObjectList &objectList,
   }
 }
 
+mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
+                          std::size_t loopVarTypeSize) {
+  // OpenMP runtime requires 32-bit or 64-bit loop variables.
+  loopVarTypeSize = loopVarTypeSize * 8;
+  if (loopVarTypeSize < 32) {
+    loopVarTypeSize = 32;
+  } else if (loopVarTypeSize > 64) {
+    loopVarTypeSize = 64;
+    mlir::emitWarning(converter.getCurrentLocation(),
+                      "OpenMP loop iteration variable cannot have more than 64 "
+                      "bits size and will be narrowed into 64 bits.");
+  }
+  assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) &&
+         "OpenMP loop iteration variable size must be transformed into 32-bit "
+         "or 64-bit");
+  return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
+}
+
 void gatherFuncAndVarSyms(
     const Fortran::parser::OmpObjectList &objList,
     mlir::omp::DeclareTargetCaptureClause clause,
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 76a15e8bcaab9b..5f2aaafcae55d4 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -50,6 +50,9 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
                 mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
                 bool isVal = false);
 
+mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
+                          std::size_t loopVarTypeSize);
+
 void gatherFuncAndVarSyms(
     const Fortran::parser::OmpObjectList &objList,
     mlir::omp::DeclareTargetCaptureClause clause,

Copy link
Contributor

@kparzysz kparzysz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@skatrak skatrak force-pushed the nfc-clauseprocessor-helpers branch from f264104 to 7620ee6 Compare March 15, 2024 14:20
…lauseProcessor

This patch moves some code in PFT to MLIR OpenMP lowering to the
`ClauseProcessor` class. This is so that some behavior that is related to
certain clauses stays within the `ClauseProcessor` and it's not the caller the
one responsible for always doing this when the clause is present.
@skatrak skatrak force-pushed the nfc-clauseprocessor-helpers branch from 7620ee6 to 74d7776 Compare March 19, 2024 11:49
@skatrak skatrak merged commit 2f2f16f into llvm:main Mar 19, 2024
3 of 4 checks passed
@skatrak skatrak deleted the nfc-clauseprocessor-helpers branch March 19, 2024 11:49
skatrak added a commit that referenced this pull request Mar 19, 2024
…to the ClauseProcessor (#85258)"

Reverting due to failing gfortran test.

This reverts commit 2f2f16f.
skatrak added a commit to skatrak/llvm-project that referenced this pull request Mar 19, 2024
…nto the ClauseProcessor (llvm#85258)"

This patch contains slight modifications to the reverted PR llvm#85258 to avoid
issues with constructs containing multiple reduction clauses, uncovered by
a test on the gfortran testsuite.

This reverts commit 9f80444.
skatrak added a commit to skatrak/llvm-project that referenced this pull request Mar 20, 2024
…nto the ClauseProcessor (llvm#85258)"

This patch contains slight modifications to the reverted PR llvm#85258 to avoid
issues with constructs containing multiple reduction clauses, uncovered by
a test on the gfortran testsuite.

This reverts commit 9f80444.
skatrak added a commit that referenced this pull request Mar 21, 2024
…nto the ClauseProcessor (#85258)" (#85807)

This patch contains slight modifications to the reverted PR #85258 to
avoid issues with constructs containing multiple reduction clauses,
uncovered by a test on the gfortran testsuite.

This reverts commit 9f80444.
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
…lauseProcessor (llvm#85258)

This patch moves some code in PFT to MLIR OpenMP lowering to the
`ClauseProcessor` class. This is so that some behavior that is related
to certain clauses stays within the `ClauseProcessor` and it's not the
caller the one responsible for always doing this when the clause is
present.
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
…to the ClauseProcessor (llvm#85258)"

Reverting due to failing gfortran test.

This reverts commit 2f2f16f.
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
…nto the ClauseProcessor (llvm#85258)" (llvm#85807)

This patch contains slight modifications to the reverted PR llvm#85258 to
avoid issues with constructs containing multiple reduction clauses,
uncovered by a test on the gfortran testsuite.

This reverts commit 9f80444.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants