-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[flang][OpenMP] Add support for complex reductions #87488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we support passing complex by val or is this only for by-ref reductions? IIRC we require by-ref reductions to have a trivial type so I guess we don't support that, but it isn't obvious from this patch
d76b93a
to
18d570c
Compare
@llvm/pr-subscribers-flang-openmp @llvm/pr-subscribers-flang-fir-hlfir Author: Mats Petersson (Leporacanthicus) ChangesFull diff: https://github.com/llvm/llvm-project/pull/87488.diff 2 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index 0d05ca5aee658b..df9fc321238f03 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -13,6 +13,8 @@
#include "ReductionProcessor.h"
#include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/ConvertType.h"
+#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRType.h"
@@ -130,7 +132,7 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
fir::FirOpBuilder &builder) {
type = fir::unwrapRefType(type);
if (!fir::isa_integer(type) && !fir::isa_real(type) &&
- !mlir::isa<fir::LogicalType>(type))
+ !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
TODO(loc, "Reduction of some types is not supported");
switch (redId) {
case ReductionIdentifier::MAX: {
@@ -174,6 +176,17 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
case ReductionIdentifier::OR:
case ReductionIdentifier::EQV:
case ReductionIdentifier::NEQV:
+ if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
+ mlir::Type realTy =
+ Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
+ // mlir::FloatType realTy =
+ // mlir::dyn_cast<mlir::FloatType>(cplxTy.getElementType());
+ // const llvm::fltSemantics &sem = (realTy).getFloatSemantics();
+ mlir::Value init = builder.createRealConstant(
+ loc, realTy, getOperationIdentity(redId, loc));
+ return fir::factory::Complex{builder, loc}.createComplex(type, init,
+ init);
+ }
if (type.isa<mlir::FloatType>())
return builder.create<mlir::arith::ConstantOp>(
loc, type,
@@ -228,13 +241,13 @@ mlir::Value ReductionProcessor::createScalarCombiner(
break;
case ReductionIdentifier::ADD:
reductionOp =
- getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
- builder, type, loc, op1, op2);
+ getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
+ fir::AddcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::MULTIPLY:
reductionOp =
- getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
- builder, type, loc, op1, op2);
+ getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
+ fir::MulcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::AND: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index ee2732547fc288..7ea252fde3602e 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -97,6 +97,10 @@ class ReductionProcessor {
fir::FirOpBuilder &builder);
template <typename FloatOp, typename IntegerOp>
+ static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2);
+ template <typename FloatOp, typename IntegerOp, typename ComplexOp>
static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
mlir::Type type, mlir::Location loc,
mlir::Value op1, mlir::Value op2);
@@ -136,12 +140,27 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
mlir::Value op1, mlir::Value op2) {
type = fir::unwrapRefType(type);
assert(type.isIntOrIndexOrFloat() &&
- "only integer and float types are currently supported");
+ "only integer, float and complex types are currently supported");
if (type.isIntOrIndex())
return builder.create<IntegerOp>(loc, op1, op2);
return builder.create<FloatOp>(loc, op1, op2);
}
+template <typename FloatOp, typename IntegerOp, typename ComplexOp>
+mlir::Value
+ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2) {
+ assert(type.isIntOrIndexOrFloat() ||
+ fir::isa_complex(type) &&
+ "only integer, float and complex types are currently supported");
+ if (type.isIntOrIndex())
+ return builder.create<IntegerOp>(loc, op1, op2);
+ if (fir::isa_real(type))
+ return builder.create<FloatOp>(loc, op1, op2);
+ return builder.create<ComplexOp>(loc, op1, op2);
+}
+
} // namespace omp
} // namespace lower
} // namespace Fortran
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this, LGTM (but I see a merge conflict)
I can see from the tests that we do support by-value complex reductions so ignore this comment |
The SALMON application uses OpenMP reductions on complex values, which wasn't supported in Flang. This adds the basic support for this functionality.
7b092c4
to
a659f44
Compare
Verified. It resolves #87839 and the results are correct. |
No description provided.