-
Notifications
You must be signed in to change notification settings - Fork 10.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][OpenMP] Make several function local to OpenMP.cpp, NFC #86726
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: Krzysztof Parzyszek (kparzysz) ChangesThere were several functions, mostly reduction-related, that were only called from OpenMP.cpp. Remove them from OpenMP.h, and make them local in OpenMP.cpp:
Also, move the function bodies out of the "public" section. Patch is 20.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86726.diff 2 Files Affected:
diff --git a/flang/include/flang/Lower/OpenMP.h b/flang/include/flang/Lower/OpenMP.h
index 3b22a652d1fc1e..6e150ef4e8e82f 100644
--- a/flang/include/flang/Lower/OpenMP.h
+++ b/flang/include/flang/Lower/OpenMP.h
@@ -19,7 +19,6 @@
#include <utility>
namespace mlir {
-class Value;
class Operation;
class Location;
namespace omp {
@@ -30,7 +29,6 @@ enum class DeclareTargetCaptureClause : uint32_t;
namespace fir {
class FirOpBuilder;
-class ConvertOp;
} // namespace fir
namespace Fortran {
@@ -84,16 +82,6 @@ void genOpenMPSymbolProperties(AbstractConverter &converter,
int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
-void genOpenMPReduction(AbstractConverter &,
- Fortran::semantics::SemanticsContext &,
- const Fortran::parser::OmpClauseList &clauseList);
-
-mlir::Operation *findReductionChain(mlir::Value, mlir::Value * = nullptr);
-fir::ConvertOp getConvertFromReductionOp(mlir::Operation *, mlir::Value);
-void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
- mlir::Value, fir::ConvertOp * = nullptr);
-void removeStoreOp(mlir::Operation *, mlir::Value);
-
bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
bool isOpenMPDeviceDeclareTarget(Fortran::lower::AbstractConverter &,
Fortran::semantics::SemanticsContext &,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 0cf2a8f97040a8..0a728b65afbf06 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -237,6 +237,213 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
return storeOp;
}
+static mlir::Operation *
+findReductionChain(mlir::Value loadVal, mlir::Value *reductionVal = nullptr) {
+ for (mlir::OpOperand &loadOperand : loadVal.getUses()) {
+ if (mlir::Operation *reductionOp = loadOperand.getOwner()) {
+ if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) {
+ for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) {
+ if (mlir::Operation *reductionOp = convertOperand.getOwner())
+ return reductionOp;
+ }
+ }
+ for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
+ if (auto store =
+ mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
+ if (store.getMemref() == *reductionVal) {
+ store.erase();
+ return reductionOp;
+ }
+ }
+ if (auto assign =
+ mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) {
+ if (assign.getLhs() == *reductionVal) {
+ assign.erase();
+ return reductionOp;
+ }
+ }
+ }
+ }
+ }
+ return nullptr;
+}
+
+// for a logical operator 'op' reduction X = X op Y
+// This function returns the operation responsible for converting Y from
+// fir.logical<4> to i1
+static fir::ConvertOp getConvertFromReductionOp(mlir::Operation *reductionOp,
+ mlir::Value loadVal) {
+ for (mlir::Value reductionOperand : reductionOp->getOperands()) {
+ if (auto convertOp =
+ mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp())) {
+ if (convertOp.getOperand() == loadVal)
+ continue;
+ return convertOp;
+ }
+ }
+ return nullptr;
+}
+
+static void updateReduction(mlir::Operation *op,
+ fir::FirOpBuilder &firOpBuilder,
+ mlir::Value loadVal, mlir::Value reductionVal,
+ fir::ConvertOp *convertOp = nullptr) {
+ mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
+ firOpBuilder.setInsertionPoint(op);
+
+ mlir::Value reductionOp;
+ if (convertOp)
+ reductionOp = convertOp->getOperand();
+ else if (op->getOperand(0) == loadVal)
+ reductionOp = op->getOperand(1);
+ else
+ reductionOp = op->getOperand(0);
+
+ firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), reductionOp,
+ reductionVal);
+ firOpBuilder.restoreInsertionPoint(insertPtDel);
+}
+
+static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) {
+ for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) {
+ if (auto convertReduction =
+ mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) {
+ for (mlir::Operation *convertReductionUse :
+ convertReduction.getRes().getUsers()) {
+ if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) {
+ if (storeOp.getMemref() == symVal)
+ storeOp.erase();
+ }
+ if (auto assignOp =
+ mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
+ if (assignOp.getLhs() == symVal)
+ assignOp.erase();
+ }
+ }
+ }
+ }
+}
+
+// Generate an OpenMP reduction operation.
+// TODO: Currently assumes it is either an integer addition/multiplication
+// reduction, or a logical and reduction. Generalize this for various reduction
+// operation types.
+// TODO: Generate the reduction operation during lowering instead of creating
+// and removing operations since this is not a robust approach. Also, removing
+// ops in the builder (instead of a rewriter) is probably not the best approach.
+static void genOpenMPReduction(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ List<Clause> clauses{makeList(clauseList, semaCtx)};
+
+ for (const Clause &clause : clauses) {
+ if (const auto &reductionClause =
+ std::get_if<clause::Reduction>(&clause.u)) {
+ const auto &redOperatorList{
+ std::get<clause::Reduction::ReductionIdentifiers>(
+ reductionClause->t)};
+ assert(redOperatorList.size() == 1 && "Expecting single operator");
+ const auto &redOperator = redOperatorList.front();
+ const auto &objects{std::get<ObjectList>(reductionClause->t)};
+ if (const auto *reductionOp =
+ std::get_if<clause::DefinedOperator>(&redOperator.u)) {
+ const auto &intrinsicOp{
+ std::get<clause::DefinedOperator::IntrinsicOperator>(
+ reductionOp->u)};
+
+ switch (intrinsicOp) {
+ case clause::DefinedOperator::IntrinsicOperator::Add:
+ case clause::DefinedOperator::IntrinsicOperator::Multiply:
+ case clause::DefinedOperator::IntrinsicOperator::AND:
+ case clause::DefinedOperator::IntrinsicOperator::EQV:
+ case clause::DefinedOperator::IntrinsicOperator::OR:
+ case clause::DefinedOperator::IntrinsicOperator::NEQV:
+ break;
+ default:
+ continue;
+ }
+ for (const Object &object : objects) {
+ if (const Fortran::semantics::Symbol *symbol = object.id()) {
+ mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+ reductionVal = declOp.getBase();
+ mlir::Type reductionType =
+ reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
+ if (!reductionType.isa<fir::LogicalType>()) {
+ if (!reductionType.isIntOrIndexOrFloat())
+ continue;
+ }
+ for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
+ if (auto loadOp =
+ mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+ mlir::Value loadVal = loadOp.getRes();
+ if (reductionType.isa<fir::LogicalType>()) {
+ mlir::Operation *reductionOp = findReductionChain(loadVal);
+ fir::ConvertOp convertOp =
+ getConvertFromReductionOp(reductionOp, loadVal);
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal, &convertOp);
+ removeStoreOp(reductionOp, reductionVal);
+ } else if (mlir::Operation *reductionOp =
+ findReductionChain(loadVal, &reductionVal)) {
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal);
+ }
+ }
+ }
+ }
+ }
+ } else if (const auto *reductionIntrinsic =
+ std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
+ if (!ReductionProcessor::supportedIntrinsicProcReduction(
+ *reductionIntrinsic))
+ continue;
+ ReductionProcessor::ReductionIdentifier redId =
+ ReductionProcessor::getReductionType(*reductionIntrinsic);
+ for (const Object &object : objects) {
+ if (const Fortran::semantics::Symbol *symbol = object.id()) {
+ mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+ reductionVal = declOp.getBase();
+ for (const mlir::OpOperand &reductionValUse :
+ reductionVal.getUses()) {
+ if (auto loadOp =
+ mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+ mlir::Value loadVal = loadOp.getRes();
+ // Max is lowered as a compare -> select.
+ // Match the pattern here.
+ mlir::Operation *reductionOp =
+ findReductionChain(loadVal, &reductionVal);
+ if (reductionOp == nullptr)
+ continue;
+
+ if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
+ redId == ReductionProcessor::ReductionIdentifier::MIN) {
+ assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+ "Selection Op not found in reduction intrinsic");
+ mlir::Operation *compareOp =
+ getCompareFromReductionOp(reductionOp, loadVal);
+ updateReduction(compareOp, firOpBuilder, loadVal,
+ reductionVal);
+ }
+ if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
+ redId == ReductionProcessor::ReductionIdentifier::IEOR ||
+ redId == ReductionProcessor::ReductionIdentifier::IAND) {
+ updateReduction(reductionOp, firOpBuilder, loadVal,
+ reductionVal);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
struct OpWithBodyGenInfo {
/// A type for a code-gen callback function. This takes as argument the op for
/// which the code is being generated and returns the arguments of the op's
@@ -2339,216 +2546,6 @@ void Fortran::lower::genDeclareTargetIntGlobal(
}
}
-// Generate an OpenMP reduction operation.
-// TODO: Currently assumes it is either an integer addition/multiplication
-// reduction, or a logical and reduction. Generalize this for various reduction
-// operation types.
-// TODO: Generate the reduction operation during lowering instead of creating
-// and removing operations since this is not a robust approach. Also, removing
-// ops in the builder (instead of a rewriter) is probably not the best approach.
-void Fortran::lower::genOpenMPReduction(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauseList) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
- List<Clause> clauses{makeList(clauseList, semaCtx)};
-
- for (const Clause &clause : clauses) {
- if (const auto &reductionClause =
- std::get_if<clause::Reduction>(&clause.u)) {
- const auto &redOperatorList{
- std::get<clause::Reduction::ReductionIdentifiers>(
- reductionClause->t)};
- assert(redOperatorList.size() == 1 && "Expecting single operator");
- const auto &redOperator = redOperatorList.front();
- const auto &objects{std::get<ObjectList>(reductionClause->t)};
- if (const auto *reductionOp =
- std::get_if<clause::DefinedOperator>(&redOperator.u)) {
- const auto &intrinsicOp{
- std::get<clause::DefinedOperator::IntrinsicOperator>(
- reductionOp->u)};
-
- switch (intrinsicOp) {
- case clause::DefinedOperator::IntrinsicOperator::Add:
- case clause::DefinedOperator::IntrinsicOperator::Multiply:
- case clause::DefinedOperator::IntrinsicOperator::AND:
- case clause::DefinedOperator::IntrinsicOperator::EQV:
- case clause::DefinedOperator::IntrinsicOperator::OR:
- case clause::DefinedOperator::IntrinsicOperator::NEQV:
- break;
- default:
- continue;
- }
- for (const Object &object : objects) {
- if (const Fortran::semantics::Symbol *symbol = object.id()) {
- mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
- reductionVal = declOp.getBase();
- mlir::Type reductionType =
- reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
- if (!reductionType.isa<fir::LogicalType>()) {
- if (!reductionType.isIntOrIndexOrFloat())
- continue;
- }
- for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
- if (auto loadOp =
- mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
- mlir::Value loadVal = loadOp.getRes();
- if (reductionType.isa<fir::LogicalType>()) {
- mlir::Operation *reductionOp = findReductionChain(loadVal);
- fir::ConvertOp convertOp =
- getConvertFromReductionOp(reductionOp, loadVal);
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal, &convertOp);
- removeStoreOp(reductionOp, reductionVal);
- } else if (mlir::Operation *reductionOp =
- findReductionChain(loadVal, &reductionVal)) {
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal);
- }
- }
- }
- }
- }
- } else if (const auto *reductionIntrinsic =
- std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
- if (!ReductionProcessor::supportedIntrinsicProcReduction(
- *reductionIntrinsic))
- continue;
- ReductionProcessor::ReductionIdentifier redId =
- ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (const Object &object : objects) {
- if (const Fortran::semantics::Symbol *symbol = object.id()) {
- mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
- reductionVal = declOp.getBase();
- for (const mlir::OpOperand &reductionValUse :
- reductionVal.getUses()) {
- if (auto loadOp =
- mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
- mlir::Value loadVal = loadOp.getRes();
- // Max is lowered as a compare -> select.
- // Match the pattern here.
- mlir::Operation *reductionOp =
- findReductionChain(loadVal, &reductionVal);
- if (reductionOp == nullptr)
- continue;
-
- if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
- redId == ReductionProcessor::ReductionIdentifier::MIN) {
- assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
- "Selection Op not found in reduction intrinsic");
- mlir::Operation *compareOp =
- getCompareFromReductionOp(reductionOp, loadVal);
- updateReduction(compareOp, firOpBuilder, loadVal,
- reductionVal);
- }
- if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
- redId == ReductionProcessor::ReductionIdentifier::IEOR ||
- redId == ReductionProcessor::ReductionIdentifier::IAND) {
- updateReduction(reductionOp, firOpBuilder, loadVal,
- reductionVal);
- }
- }
- }
- }
- }
- }
- }
- }
-}
-
-mlir::Operation *Fortran::lower::findReductionChain(mlir::Value loadVal,
- mlir::Value *reductionVal) {
- for (mlir::OpOperand &loadOperand : loadVal.getUses()) {
- if (mlir::Operation *reductionOp = loadOperand.getOwner()) {
- if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) {
- for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) {
- if (mlir::Operation *reductionOp = convertOperand.getOwner())
- return reductionOp;
- }
- }
- for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
- if (auto store =
- mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
- if (store.getMemref() == *reductionVal) {
- store.erase();
- return reductionOp;
- }
- }
- if (auto assign =
- mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) {
- if (assign.getLhs() == *reductionVal) {
- assign.erase();
- return reductionOp;
- }
- }
- }
- }
- }
- return nullptr;
-}
-
-// for a logical operator 'op' reduction X = X op Y
-// This function returns the operation responsible for converting Y from
-// fir.logical<4> to i1
-fir::ConvertOp
-Fortran::lower::getConvertFromReductionOp(mlir::Operation *reductionOp,
- mlir::Value loadVal) {
- for (mlir::Value reductionOperand : reductionOp->getOperands()) {
- if (auto convertOp =
- mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp())) {
- if (convertOp.getOperand() == loadVal)
- continue;
- return convertOp;
- }
- }
- return nullptr;
-}
-
-void Fortran::lower::updateReduction(mlir::Operation *op,
- fir::FirOpBuilder &firOpBuilder,
- mlir::Value loadVal,
- mlir::Value reductionVal,
- fir::ConvertOp *convertOp) {
- mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
- firOpBuilder.setInsertionPoint(op);
-
- mlir::Value reductionOp;
- if (convertOp)
- reductionOp = convertOp->getOperand();
- else if (op->getOperand(0) == loadVal)
- reductionOp = op->getOperand(1);
- else
- reductionOp = op->getOperand(0);
-
- firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), reductionOp,
- reductionVal);
- firOpBuilder.restoreInsertionPoint(insertPtDel);
-}
-
-void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp,
- mlir::Value symVal) {
- for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) {
- if (auto convertReduction =
- mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) {
- for (mlir::Operation *convertReductionUse :
- convertReduction.getRes().getUsers()) {
- if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) {
- if (storeOp.getMemref() == symVal)
- storeOp.erase();
- }
- if (auto assignOp =
...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There were several functions, mostly reduction-related, that were only called from OpenMP.cpp. Remove them from OpenMP.h, and make them local in OpenMP.cpp: - genOpenMPReduction - findReductionChain - getConvertFromReductionOp - updateReduction - removeStoreOp Also, move the function bodies out of the "public" section.
edd931d
to
5be4681
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There were several functions, mostly reduction-related, that were only called from OpenMP.cpp. Remove them from OpenMP.h, and make them local in OpenMP.cpp:
Also, move the function bodies out of the "public" section.