Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][OpenMP] Make several function local to OpenMP.cpp, NFC #86726

Merged
merged 2 commits into from
Mar 28, 2024

Conversation

kparzysz
Copy link
Contributor

There were several functions, mostly reduction-related, that were only called from OpenMP.cpp. Remove them from OpenMP.h, and make them local in OpenMP.cpp:

  • genOpenMPReduction
  • findReductionChain
  • getConvertFromReductionOp
  • updateReduction
  • removeStoreOp

Also, move the function bodies out of the "public" section.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Mar 26, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 26, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Krzysztof Parzyszek (kparzysz)

Changes

There were several functions, mostly reduction-related, that were only called from OpenMP.cpp. Remove them from OpenMP.h, and make them local in OpenMP.cpp:

  • genOpenMPReduction
  • findReductionChain
  • getConvertFromReductionOp
  • updateReduction
  • removeStoreOp

Also, move the function bodies out of the "public" section.


Patch is 20.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86726.diff

2 Files Affected:

  • (modified) flang/include/flang/Lower/OpenMP.h (-12)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+207-210)
diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h
index 3b22a652d1fc1e..6e150ef4e8e82f 100644
--- a/flang/include/flang/Lower/OpenMP.h
+++ b/flang/include/flang/Lower/OpenMP.h
@@ -19,7 +19,6 @@
 #include <utility>
 
 namespace mlir {
-class Value;
 class Operation;
 class Location;
 namespace omp {
@@ -30,7 +29,6 @@ enum class DeclareTargetCaptureClause : uint32_t;
 
 namespace fir {
 class FirOpBuilder;
-class ConvertOp;
 } // namespace fir
 
 namespace Fortran {
@@ -84,16 +82,6 @@ void genOpenMPSymbolProperties(AbstractConverter &converter,
 int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
 void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
 void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
-void genOpenMPReduction(AbstractConverter &,
-                        Fortran::semantics::SemanticsContext &,
-                        const Fortran::parser::OmpClauseList &clauseList);
-
-mlir::Operation *findReductionChain(mlir::Value, mlir::Value * = nullptr);
-fir::ConvertOp getConvertFromReductionOp(mlir::Operation *, mlir::Value);
-void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
-                     mlir::Value, fir::ConvertOp * = nullptr);
-void removeStoreOp(mlir::Operation *, mlir::Value);
-
 bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
 bool isOpenMPDeviceDeclareTarget(Fortran::lower::AbstractConverter &,
                                  Fortran::semantics::SemanticsContext &,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 0cf2a8f97040a8..0a728b65afbf06 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -237,6 +237,213 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
   return storeOp;
 }
 
+static mlir::Operation *
+findReductionChain(mlir::Value loadVal, mlir::Value *reductionVal = nullptr) {
+  for (mlir::OpOperand &loadOperand : loadVal.getUses()) {
+    if (mlir::Operation *reductionOp = loadOperand.getOwner()) {
+      if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) {
+        for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) {
+          if (mlir::Operation *reductionOp = convertOperand.getOwner())
+            return reductionOp;
+        }
+      }
+      for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
+        if (auto store =
+                mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
+          if (store.getMemref() == *reductionVal) {
+            store.erase();
+            return reductionOp;
+          }
+        }
+        if (auto assign =
+                mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) {
+          if (assign.getLhs() == *reductionVal) {
+            assign.erase();
+            return reductionOp;
+          }
+        }
+      }
+    }
+  }
+  return nullptr;
+}
+
+// for a logical operator 'op' reduction X = X op Y
+// This function returns the operation responsible for converting Y from
+// fir.logical<4> to i1
+static fir::ConvertOp getConvertFromReductionOp(mlir::Operation *reductionOp,
+                                                mlir::Value loadVal) {
+  for (mlir::Value reductionOperand : reductionOp->getOperands()) {
+    if (auto convertOp =
+            mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp())) {
+      if (convertOp.getOperand() == loadVal)
+        continue;
+      return convertOp;
+    }
+  }
+  return nullptr;
+}
+
+static void updateReduction(mlir::Operation *op,
+                            fir::FirOpBuilder &firOpBuilder,
+                            mlir::Value loadVal, mlir::Value reductionVal,
+                            fir::ConvertOp *convertOp = nullptr) {
+  mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
+  firOpBuilder.setInsertionPoint(op);
+
+  mlir::Value reductionOp;
+  if (convertOp)
+    reductionOp = convertOp->getOperand();
+  else if (op->getOperand(0) == loadVal)
+    reductionOp = op->getOperand(1);
+  else
+    reductionOp = op->getOperand(0);
+
+  firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), reductionOp,
+                                              reductionVal);
+  firOpBuilder.restoreInsertionPoint(insertPtDel);
+}
+
+static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) {
+  for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) {
+    if (auto convertReduction =
+            mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) {
+      for (mlir::Operation *convertReductionUse :
+           convertReduction.getRes().getUsers()) {
+        if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) {
+          if (storeOp.getMemref() == symVal)
+            storeOp.erase();
+        }
+        if (auto assignOp =
+                mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
+          if (assignOp.getLhs() == symVal)
+            assignOp.erase();
+        }
+      }
+    }
+  }
+}
+
+// Generate an OpenMP reduction operation.
+// TODO: Currently assumes it is either an integer addition/multiplication
+// reduction, or a logical and reduction. Generalize this for various reduction
+// operation types.
+// TODO: Generate the reduction operation during lowering instead of creating
+// and removing operations since this is not a robust approach. Also, removing
+// ops in the builder (instead of a rewriter) is probably not the best approach.
+static void genOpenMPReduction(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::semantics::SemanticsContext &semaCtx,
+    const Fortran::parser::OmpClauseList &clauseList) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  List<Clause> clauses{makeList(clauseList, semaCtx)};
+
+  for (const Clause &clause : clauses) {
+    if (const auto &reductionClause =
+            std::get_if<clause::Reduction>(&clause.u)) {
+      const auto &redOperatorList{
+          std::get<clause::Reduction::ReductionIdentifiers>(
+              reductionClause->t)};
+      assert(redOperatorList.size() == 1 && "Expecting single operator");
+      const auto &redOperator = redOperatorList.front();
+      const auto &objects{std::get<ObjectList>(reductionClause->t)};
+      if (const auto *reductionOp =
+              std::get_if<clause::DefinedOperator>(&redOperator.u)) {
+        const auto &intrinsicOp{
+            std::get<clause::DefinedOperator::IntrinsicOperator>(
+                reductionOp->u)};
+
+        switch (intrinsicOp) {
+        case clause::DefinedOperator::IntrinsicOperator::Add:
+        case clause::DefinedOperator::IntrinsicOperator::Multiply:
+        case clause::DefinedOperator::IntrinsicOperator::AND:
+        case clause::DefinedOperator::IntrinsicOperator::EQV:
+        case clause::DefinedOperator::IntrinsicOperator::OR:
+        case clause::DefinedOperator::IntrinsicOperator::NEQV:
+          break;
+        default:
+          continue;
+        }
+        for (const Object &object : objects) {
+          if (const Fortran::semantics::Symbol *symbol = object.id()) {
+            mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+            if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+              reductionVal = declOp.getBase();
+            mlir::Type reductionType =
+                reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
+            if (!reductionType.isa<fir::LogicalType>()) {
+              if (!reductionType.isIntOrIndexOrFloat())
+                continue;
+            }
+            for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
+              if (auto loadOp =
+                      mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+                mlir::Value loadVal = loadOp.getRes();
+                if (reductionType.isa<fir::LogicalType>()) {
+                  mlir::Operation *reductionOp = findReductionChain(loadVal);
+                  fir::ConvertOp convertOp =
+                      getConvertFromReductionOp(reductionOp, loadVal);
+                  updateReduction(reductionOp, firOpBuilder, loadVal,
+                                  reductionVal, &convertOp);
+                  removeStoreOp(reductionOp, reductionVal);
+                } else if (mlir::Operation *reductionOp =
+                               findReductionChain(loadVal, &reductionVal)) {
+                  updateReduction(reductionOp, firOpBuilder, loadVal,
+                                  reductionVal);
+                }
+              }
+            }
+          }
+        }
+      } else if (const auto *reductionIntrinsic =
+                     std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
+        if (!ReductionProcessor::supportedIntrinsicProcReduction(
+                *reductionIntrinsic))
+          continue;
+        ReductionProcessor::ReductionIdentifier redId =
+            ReductionProcessor::getReductionType(*reductionIntrinsic);
+        for (const Object &object : objects) {
+          if (const Fortran::semantics::Symbol *symbol = object.id()) {
+            mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+            if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+              reductionVal = declOp.getBase();
+            for (const mlir::OpOperand &reductionValUse :
+                 reductionVal.getUses()) {
+              if (auto loadOp =
+                      mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+                mlir::Value loadVal = loadOp.getRes();
+                // Max is lowered as a compare -> select.
+                // Match the pattern here.
+                mlir::Operation *reductionOp =
+                    findReductionChain(loadVal, &reductionVal);
+                if (reductionOp == nullptr)
+                  continue;
+
+                if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
+                    redId == ReductionProcessor::ReductionIdentifier::MIN) {
+                  assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+                         "Selection Op not found in reduction intrinsic");
+                  mlir::Operation *compareOp =
+                      getCompareFromReductionOp(reductionOp, loadVal);
+                  updateReduction(compareOp, firOpBuilder, loadVal,
+                                  reductionVal);
+                }
+                if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
+                    redId == ReductionProcessor::ReductionIdentifier::IEOR ||
+                    redId == ReductionProcessor::ReductionIdentifier::IAND) {
+                  updateReduction(reductionOp, firOpBuilder, loadVal,
+                                  reductionVal);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
 struct OpWithBodyGenInfo {
   /// A type for a code-gen callback function. This takes as argument the op for
   /// which the code is being generated and returns the arguments of the op's
@@ -2339,216 +2546,6 @@ void Fortran::lower::genDeclareTargetIntGlobal(
   }
 }
 
-// Generate an OpenMP reduction operation.
-// TODO: Currently assumes it is either an integer addition/multiplication
-// reduction, or a logical and reduction. Generalize this for various reduction
-// operation types.
-// TODO: Generate the reduction operation during lowering instead of creating
-// and removing operations since this is not a robust approach. Also, removing
-// ops in the builder (instead of a rewriter) is probably not the best approach.
-void Fortran::lower::genOpenMPReduction(
-    Fortran::lower::AbstractConverter &converter,
-    Fortran::semantics::SemanticsContext &semaCtx,
-    const Fortran::parser::OmpClauseList &clauseList) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  List<Clause> clauses{makeList(clauseList, semaCtx)};
-
-  for (const Clause &clause : clauses) {
-    if (const auto &reductionClause =
-            std::get_if<clause::Reduction>(&clause.u)) {
-      const auto &redOperatorList{
-          std::get<clause::Reduction::ReductionIdentifiers>(
-              reductionClause->t)};
-      assert(redOperatorList.size() == 1 && "Expecting single operator");
-      const auto &redOperator = redOperatorList.front();
-      const auto &objects{std::get<ObjectList>(reductionClause->t)};
-      if (const auto *reductionOp =
-              std::get_if<clause::DefinedOperator>(&redOperator.u)) {
-        const auto &intrinsicOp{
-            std::get<clause::DefinedOperator::IntrinsicOperator>(
-                reductionOp->u)};
-
-        switch (intrinsicOp) {
-        case clause::DefinedOperator::IntrinsicOperator::Add:
-        case clause::DefinedOperator::IntrinsicOperator::Multiply:
-        case clause::DefinedOperator::IntrinsicOperator::AND:
-        case clause::DefinedOperator::IntrinsicOperator::EQV:
-        case clause::DefinedOperator::IntrinsicOperator::OR:
-        case clause::DefinedOperator::IntrinsicOperator::NEQV:
-          break;
-        default:
-          continue;
-        }
-        for (const Object &object : objects) {
-          if (const Fortran::semantics::Symbol *symbol = object.id()) {
-            mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
-            if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
-              reductionVal = declOp.getBase();
-            mlir::Type reductionType =
-                reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
-            if (!reductionType.isa<fir::LogicalType>()) {
-              if (!reductionType.isIntOrIndexOrFloat())
-                continue;
-            }
-            for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
-              if (auto loadOp =
-                      mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
-                mlir::Value loadVal = loadOp.getRes();
-                if (reductionType.isa<fir::LogicalType>()) {
-                  mlir::Operation *reductionOp = findReductionChain(loadVal);
-                  fir::ConvertOp convertOp =
-                      getConvertFromReductionOp(reductionOp, loadVal);
-                  updateReduction(reductionOp, firOpBuilder, loadVal,
-                                  reductionVal, &convertOp);
-                  removeStoreOp(reductionOp, reductionVal);
-                } else if (mlir::Operation *reductionOp =
-                               findReductionChain(loadVal, &reductionVal)) {
-                  updateReduction(reductionOp, firOpBuilder, loadVal,
-                                  reductionVal);
-                }
-              }
-            }
-          }
-        }
-      } else if (const auto *reductionIntrinsic =
-                     std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
-        if (!ReductionProcessor::supportedIntrinsicProcReduction(
-                *reductionIntrinsic))
-          continue;
-        ReductionProcessor::ReductionIdentifier redId =
-            ReductionProcessor::getReductionType(*reductionIntrinsic);
-        for (const Object &object : objects) {
-          if (const Fortran::semantics::Symbol *symbol = object.id()) {
-            mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
-            if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
-              reductionVal = declOp.getBase();
-            for (const mlir::OpOperand &reductionValUse :
-                 reductionVal.getUses()) {
-              if (auto loadOp =
-                      mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
-                mlir::Value loadVal = loadOp.getRes();
-                // Max is lowered as a compare -> select.
-                // Match the pattern here.
-                mlir::Operation *reductionOp =
-                    findReductionChain(loadVal, &reductionVal);
-                if (reductionOp == nullptr)
-                  continue;
-
-                if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
-                    redId == ReductionProcessor::ReductionIdentifier::MIN) {
-                  assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
-                         "Selection Op not found in reduction intrinsic");
-                  mlir::Operation *compareOp =
-                      getCompareFromReductionOp(reductionOp, loadVal);
-                  updateReduction(compareOp, firOpBuilder, loadVal,
-                                  reductionVal);
-                }
-                if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
-                    redId == ReductionProcessor::ReductionIdentifier::IEOR ||
-                    redId == ReductionProcessor::ReductionIdentifier::IAND) {
-                  updateReduction(reductionOp, firOpBuilder, loadVal,
-                                  reductionVal);
-                }
-              }
-            }
-          }
-        }
-      }
-    }
-  }
-}
-
-mlir::Operation *Fortran::lower::findReductionChain(mlir::Value loadVal,
-                                                    mlir::Value *reductionVal) {
-  for (mlir::OpOperand &loadOperand : loadVal.getUses()) {
-    if (mlir::Operation *reductionOp = loadOperand.getOwner()) {
-      if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) {
-        for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) {
-          if (mlir::Operation *reductionOp = convertOperand.getOwner())
-            return reductionOp;
-        }
-      }
-      for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
-        if (auto store =
-                mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
-          if (store.getMemref() == *reductionVal) {
-            store.erase();
-            return reductionOp;
-          }
-        }
-        if (auto assign =
-                mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) {
-          if (assign.getLhs() == *reductionVal) {
-            assign.erase();
-            return reductionOp;
-          }
-        }
-      }
-    }
-  }
-  return nullptr;
-}
-
-// for a logical operator 'op' reduction X = X op Y
-// This function returns the operation responsible for converting Y from
-// fir.logical<4> to i1
-fir::ConvertOp
-Fortran::lower::getConvertFromReductionOp(mlir::Operation *reductionOp,
-                                          mlir::Value loadVal) {
-  for (mlir::Value reductionOperand : reductionOp->getOperands()) {
-    if (auto convertOp =
-            mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp())) {
-      if (convertOp.getOperand() == loadVal)
-        continue;
-      return convertOp;
-    }
-  }
-  return nullptr;
-}
-
-void Fortran::lower::updateReduction(mlir::Operation *op,
-                                     fir::FirOpBuilder &firOpBuilder,
-                                     mlir::Value loadVal,
-                                     mlir::Value reductionVal,
-                                     fir::ConvertOp *convertOp) {
-  mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
-  firOpBuilder.setInsertionPoint(op);
-
-  mlir::Value reductionOp;
-  if (convertOp)
-    reductionOp = convertOp->getOperand();
-  else if (op->getOperand(0) == loadVal)
-    reductionOp = op->getOperand(1);
-  else
-    reductionOp = op->getOperand(0);
-
-  firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), reductionOp,
-                                              reductionVal);
-  firOpBuilder.restoreInsertionPoint(insertPtDel);
-}
-
-void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp,
-                                   mlir::Value symVal) {
-  for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) {
-    if (auto convertReduction =
-            mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) {
-      for (mlir::Operation *convertReductionUse :
-           convertReduction.getRes().getUsers()) {
-        if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) {
-          if (storeOp.getMemref() == symVal)
-            storeOp.erase();
-        }
-        if (auto assignOp =
...
[truncated]

Copy link

github-actions bot commented Mar 26, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

There were several functions, mostly reduction-related, that were only called
from OpenMP.cpp. Remove them from OpenMP.h, and make them local in OpenMP.cpp:
- genOpenMPReduction
- findReductionChain
- getConvertFromReductionOp
- updateReduction
- removeStoreOp

Also, move the function bodies out of the "public" section.
@kparzysz kparzysz force-pushed the users/kparzysz/c13-reduction branch from edd931d to 5be4681 Compare March 26, 2024 20:12
Copy link
Contributor

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@kparzysz kparzysz merged commit 7919975 into llvm:main Mar 28, 2024
4 checks passed
@kparzysz kparzysz deleted the users/kparzysz/c13-reduction branch March 28, 2024 12:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants