Skip to content

[flang][OpenMP] Add support for complex reductions #87488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 8, 2024

Conversation

Leporacanthicus
Copy link
Contributor

No description provided.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support passing complex by val or is this only for by-ref reductions? IIRC we require by-ref reductions to have a trivial type so I guess we don't support that, but it isn't obvious from this patch

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Apr 3, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 3, 2024

@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-flang-fir-hlfir

Author: Mats Petersson (Leporacanthicus)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/87488.diff

2 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.cpp (+18-5)
  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.h (+20-1)
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index 0d05ca5aee658b..df9fc321238f03 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -13,6 +13,8 @@
 #include "ReductionProcessor.h"
 
 #include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/ConvertType.h"
+#include "flang/Optimizer/Builder/Complex.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
@@ -130,7 +132,7 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
                                           fir::FirOpBuilder &builder) {
   type = fir::unwrapRefType(type);
   if (!fir::isa_integer(type) && !fir::isa_real(type) &&
-      !mlir::isa<fir::LogicalType>(type))
+      !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
     TODO(loc, "Reduction of some types is not supported");
   switch (redId) {
   case ReductionIdentifier::MAX: {
@@ -174,6 +176,17 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
   case ReductionIdentifier::OR:
   case ReductionIdentifier::EQV:
   case ReductionIdentifier::NEQV:
+    if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
+      mlir::Type realTy =
+          Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
+      //      mlir::FloatType realTy =
+      //      mlir::dyn_cast<mlir::FloatType>(cplxTy.getElementType());
+      //      const llvm::fltSemantics &sem = (realTy).getFloatSemantics();
+      mlir::Value init = builder.createRealConstant(
+          loc, realTy, getOperationIdentity(redId, loc));
+      return fir::factory::Complex{builder, loc}.createComplex(type, init,
+                                                               init);
+    }
     if (type.isa<mlir::FloatType>())
       return builder.create<mlir::arith::ConstantOp>(
           loc, type,
@@ -228,13 +241,13 @@ mlir::Value ReductionProcessor::createScalarCombiner(
     break;
   case ReductionIdentifier::ADD:
     reductionOp =
-        getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
-            builder, type, loc, op1, op2);
+        getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
+                              fir::AddcOp>(builder, type, loc, op1, op2);
     break;
   case ReductionIdentifier::MULTIPLY:
     reductionOp =
-        getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
-            builder, type, loc, op1, op2);
+        getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
+                              fir::MulcOp>(builder, type, loc, op1, op2);
     break;
   case ReductionIdentifier::AND: {
     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index ee2732547fc288..7ea252fde3602e 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -97,6 +97,10 @@ class ReductionProcessor {
                                            fir::FirOpBuilder &builder);
 
   template <typename FloatOp, typename IntegerOp>
+  static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+                                           mlir::Type type, mlir::Location loc,
+                                           mlir::Value op1, mlir::Value op2);
+  template <typename FloatOp, typename IntegerOp, typename ComplexOp>
   static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
                                            mlir::Type type, mlir::Location loc,
                                            mlir::Value op1, mlir::Value op2);
@@ -136,12 +140,27 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
                                           mlir::Value op1, mlir::Value op2) {
   type = fir::unwrapRefType(type);
   assert(type.isIntOrIndexOrFloat() &&
-         "only integer and float types are currently supported");
+         "only integer, float and complex types are currently supported");
   if (type.isIntOrIndex())
     return builder.create<IntegerOp>(loc, op1, op2);
   return builder.create<FloatOp>(loc, op1, op2);
 }
 
+template <typename FloatOp, typename IntegerOp, typename ComplexOp>
+mlir::Value
+ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
+                                          mlir::Type type, mlir::Location loc,
+                                          mlir::Value op1, mlir::Value op2) {
+  assert(type.isIntOrIndexOrFloat() ||
+         fir::isa_complex(type) &&
+             "only integer, float and complex types are currently supported");
+  if (type.isIntOrIndex())
+    return builder.create<IntegerOp>(loc, op1, op2);
+  if (fir::isa_real(type))
+    return builder.create<FloatOp>(loc, op1, op2);
+  return builder.create<ComplexOp>(loc, op1, op2);
+}
+
 } // namespace omp
 } // namespace lower
 } // namespace Fortran

@tblah tblah changed the title Add support for complex reductions [flang][OpenMP] Add support for complex reductions Apr 4, 2024
Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this, LGTM (but I see a merge conflict)

@tblah tblah requested a review from pawosm-arm April 4, 2024 14:15
@tblah
Copy link
Contributor

tblah commented Apr 4, 2024

Do we support passing complex by val or is this only for by-ref reductions? IIRC we require by-ref reductions to have a trivial type so I guess we don't support that, but it isn't obvious from this patch

I can see from the tests that we do support by-value complex reductions so ignore this comment

The SALMON application uses OpenMP reductions on complex values,
which wasn't supported in Flang. This adds the basic support
for this functionality.
@ye-luo
Copy link
Contributor

ye-luo commented Apr 6, 2024

Verified. It resolves #87839 and the results are correct.

@Leporacanthicus Leporacanthicus merged commit 221f438 into llvm:main Apr 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants