-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][OpenMP][OMPIRBuilder][mlir] Optionally pass reduction vars by ref #84304
Conversation
This will be useful for OpenMP too. I changed the definition slightly to use `fir::isa_ref_type` (which also includes llvm pointers) because I think it reads better using the common type helpers. There shouldn't be any llvm pointers in lowering so this isn't a functional change. Commit series for by-ref openmp reductions: 1/3
TBAA builder assumed that all loads/stores are inside of functions and hit an assertion once it found loads and stores inside of an omp::ReductionDeclareOp. For now just don't add TBAA tags to those loads and stores. They would end up in a different TBAA tree to the host function after OpenMPIRBuilder inlines them anyway so there isn't an easy way of making this work. Commit series for by-ref OpenMP reductions: 2/3
… ref Previously reduction variables were always passed by value into and out of the initialization and combiner regions of the OpenMP reduction declare operation. This worked well for reductions of primitive types (and might perform better than passing by reference). But passing by reference will be useful for array and derived type reductions (e.g. to move allocation inside of the init region). Passing reductions by reference requires different LLVM-IR generation when lowering from MLIR because some of the loads/stores/allocations will now be moved inside of the init and combiner regions. This alternate code generation is requested using a new attribute to omp.wsloop and omp.parallel. Existing lowerings from mlir are unaffected (these will continue to use the by-value argument passing. Flang will continue to pass by-value argument passing for trivial types unless a (hidden) command line argument is supplied. Non-trivial types will always use the by-ref lowering. Array reductions are not ready yet (but are coming very soon). In the meantime, this is tested by forcing existing reductions to use by-ref. Commit series for by-ref OpenMP reductions 3/3 Co-authored-by: Mats Petersson <mats.petersson@arm.com>
@llvm/pr-subscribers-flang-codegen @llvm/pr-subscribers-flang-fir-hlfir Author: Tom Eccles (tblah) ChangesPreviously reduction variables were always passed by value into and out of the initialization and combiner regions of the OpenMP reduction declare operation. This worked well for reductions of primitive types (and might perform better than passing by reference). But passing by reference will be useful for array and derived type reductions (e.g. to move allocation inside of the init region). Passing reductions by reference requires different LLVM-IR generation when lowering from MLIR because some of the loads/stores/allocations will now be moved inside of the init and combiner regions. This alternate code generation is requested using a new attribute to omp.wsloop and omp.parallel. Existing lowerings from mlir are unaffected (these will continue to use the by-value argument passing. Flang will continue to pass by-value argument passing for trivial types unless a (hidden) command line argument is supplied. Non-trivial types will always use the by-ref lowering. Array reductions are not ready yet (but are coming very soon). In the meantime, this is tested by forcing existing reductions to use by-ref. Commit series for by-ref OpenMP reductions 3/3 Patch is 297.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84304.diff 37 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 185e0316870e94..d9648f3d692cc6 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -600,6 +600,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
return reductionSymbols;
};
+ mlir::UnitAttr byrefAttr;
+ if (ReductionProcessor::doReductionByRef(reductionVars))
+ byrefAttr = converter.getFirOpBuilder().getUnitAttr();
+
OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
@@ -619,7 +623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
- /*privatizers=*/nullptr);
+ /*privatizers=*/nullptr, byrefAttr);
}
bool privatize = !outerCombined;
@@ -683,7 +687,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
delayedPrivatizationInfo.privatizers.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- privatizers));
+ privatizers),
+ byrefAttr);
}
static mlir::omp::SectionOp
@@ -1568,7 +1573,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
- mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
+ mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
mlir::IntegerAttr orderedClauseOperand;
mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
std::size_t loopVarTypeSize;
@@ -1585,6 +1590,9 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
convertLoopBounds(converter, loc, lowerBound, upperBound, step,
loopVarTypeSize);
+ if (ReductionProcessor::doReductionByRef(reductionVars))
+ byrefOperand = firOpBuilder.getUnitAttr();
+
auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
loc, lowerBound, upperBound, step, linearVars, linearStepVars,
reductionVars,
@@ -1594,8 +1602,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
reductionDeclSymbols),
scheduleValClauseOperand, scheduleChunkClauseOperand,
/*schedule_modifiers=*/nullptr,
- /*simd_modifier=*/nullptr, nowaitClauseOperand, orderedClauseOperand,
- orderClauseOperand,
+ /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
+ orderedClauseOperand, orderClauseOperand,
/*inclusive=*/firOpBuilder.getUnitAttr());
// Handle attribute based clauses.
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index a8b98f3f567249..34959e95dce04a 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -14,9 +14,16 @@
#include "flang/Lower/AbstractConverter.h"
#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "llvm/Support/CommandLine.h"
+
+static llvm::cl::opt<bool> forceByrefReduction(
+ "force-byref-reduction",
+ llvm::cl::desc("Pass all reduction arguments by reference"),
+ llvm::cl::Hidden);
namespace Fortran {
namespace lower {
@@ -76,16 +83,24 @@ bool ReductionProcessor::supportedIntrinsicProcReduction(
}
std::string ReductionProcessor::getReductionName(llvm::StringRef name,
- mlir::Type ty) {
+ mlir::Type ty, bool isByRef) {
+ ty = fir::unwrapRefType(ty);
+
+ // extra string to distinguish reduction functions for variables passed by
+ // reference
+ llvm::StringRef byrefAddition{""};
+ if (isByRef)
+ byrefAddition = "_byref";
+
return (llvm::Twine(name) +
(ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
- llvm::Twine(ty.getIntOrFloatBitWidth()))
+ llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
.str();
}
std::string ReductionProcessor::getReductionName(
Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type ty) {
+ mlir::Type ty, bool isByRef) {
std::string reductionName;
switch (intrinsicOp) {
@@ -108,13 +123,14 @@ std::string ReductionProcessor::getReductionName(
break;
}
- return getReductionName(reductionName, ty);
+ return getReductionName(reductionName, ty, isByRef);
}
mlir::Value
ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
ReductionIdentifier redId,
fir::FirOpBuilder &builder) {
+ type = fir::unwrapRefType(type);
assert((fir::isa_integer(type) || fir::isa_real(type) ||
type.isa<fir::LogicalType>()) &&
"only integer, logical and real types are currently supported");
@@ -188,6 +204,7 @@ mlir::Value ReductionProcessor::createScalarCombiner(
fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
mlir::Type type, mlir::Value op1, mlir::Value op2) {
mlir::Value reductionOp;
+ type = fir::unwrapRefType(type);
switch (redId) {
case ReductionIdentifier::MAX:
reductionOp =
@@ -268,7 +285,8 @@ mlir::Value ReductionProcessor::createScalarCombiner(
mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
+ const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
+ bool isByRef) {
mlir::OpBuilder::InsertionGuard guard(builder);
mlir::ModuleOp module = builder.getModule();
@@ -278,14 +296,24 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
return decl;
mlir::OpBuilder modBuilder(module.getBodyRegion());
+ mlir::Type valTy = fir::unwrapRefType(type);
+ if (!isByRef)
+ type = valTy;
decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
type);
builder.createBlock(&decl.getInitializerRegion(),
decl.getInitializerRegion().end(), {type}, {loc});
builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+
mlir::Value init = getReductionInitValue(loc, type, redId, builder);
- builder.create<mlir::omp::YieldOp>(loc, init);
+ if (isByRef) {
+ mlir::Value alloca = builder.create<fir::AllocaOp>(loc, valTy);
+ builder.createStoreWithConvert(loc, init, alloca);
+ builder.create<mlir::omp::YieldOp>(loc, alloca);
+ } else {
+ builder.create<mlir::omp::YieldOp>(loc, init);
+ }
builder.createBlock(&decl.getReductionRegion(),
decl.getReductionRegion().end(), {type, type},
@@ -294,14 +322,41 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+ mlir::Value outAddr = op1;
+
+ op1 = builder.loadIfRef(loc, op1);
+ op2 = builder.loadIfRef(loc, op2);
mlir::Value reductionOp =
createScalarCombiner(builder, loc, redId, type, op1, op2);
- builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+ if (isByRef) {
+ builder.create<fir::StoreOp>(loc, reductionOp, outAddr);
+ builder.create<mlir::omp::YieldOp>(loc, outAddr);
+ } else {
+ builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+ }
return decl;
}
+bool ReductionProcessor::doReductionByRef(
+ const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
+ if (reductionVars.empty())
+ return false;
+ if (forceByrefReduction)
+ return true;
+
+ for (mlir::Value reductionVar : reductionVars) {
+ if (auto declare =
+ mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
+ reductionVar = declare.getMemref();
+
+ if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
+ return true;
+ }
+ return false;
+}
+
void ReductionProcessor::addReductionDecl(
mlir::Location currentLocation,
Fortran::lower::AbstractConverter &converter,
@@ -315,6 +370,24 @@ void ReductionProcessor::addReductionDecl(
const auto &redOperator{
std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+
+ // initial pass to collect all recuction vars so we can figure out if this
+ // should happen byref
+ for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+ if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ if (reductionSymbols)
+ reductionSymbols->push_back(symbol);
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+ reductionVars.push_back(symVal);
+ }
+ }
+ }
+ const bool isByRef = doReductionByRef(reductionVars);
+
if (const auto &redDefinedOp =
std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
@@ -338,23 +411,20 @@ void ReductionProcessor::addReductionDecl(
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- if (reductionSymbols)
- reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- if (redType.isa<fir::LogicalType>())
+ auto redType = symVal.getType().cast<fir::ReferenceType>();
+ if (redType.getEleTy().isa<fir::LogicalType>())
decl = createReductionDecl(
firOpBuilder,
- getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
- redType, currentLocation);
- else if (redType.isIntOrIndexOrFloat()) {
- decl = createReductionDecl(firOpBuilder,
- getReductionName(intrinsicOp, redType),
- redId, redType, currentLocation);
+ getReductionName(intrinsicOp, firOpBuilder.getI1Type(),
+ isByRef),
+ redId, redType, currentLocation, isByRef);
+ else if (redType.getEleTy().isIntOrIndexOrFloat()) {
+ decl = createReductionDecl(
+ firOpBuilder, getReductionName(intrinsicOp, redType, isByRef),
+ redId, redType, currentLocation, isByRef);
} else {
TODO(currentLocation, "Reduction of some types is not supported");
}
@@ -374,21 +444,17 @@ void ReductionProcessor::addReductionDecl(
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- if (reductionSymbols)
- reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- assert(redType.isIntOrIndexOrFloat() &&
+ auto redType = symVal.getType().cast<fir::ReferenceType>();
+ assert(redType.getEleTy().isIntOrIndexOrFloat() &&
"Unsupported reduction type");
decl = createReductionDecl(
firOpBuilder,
getReductionName(getRealName(*reductionIntrinsic).ToString(),
- redType),
- redId, redType, currentLocation);
+ redType, isByRef),
+ redId, redType, currentLocation, isByRef);
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
firOpBuilder.getContext(), decl.getSymName()));
}
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 00770fe81d1ef6..679580f2a3cac7 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -14,6 +14,7 @@
#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
@@ -71,11 +72,15 @@ class ReductionProcessor {
static const Fortran::semantics::SourceName
getRealName(const Fortran::parser::ProcedureDesignator &pd);
- static std::string getReductionName(llvm::StringRef name, mlir::Type ty);
+ static bool
+ doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);
+
+ static std::string getReductionName(llvm::StringRef name, mlir::Type ty,
+ bool isByRef);
static std::string getReductionName(
Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type ty);
+ mlir::Type ty, bool isByRef);
/// This function returns the identity value of the operator \p
/// reductionOpName. For example:
@@ -103,9 +108,11 @@ class ReductionProcessor {
/// symbol table. The declaration has a constant initializer with the neutral
/// value `initValue`, and the reduction combiner carried over from `reduce`.
/// TODO: Generalize this for non-integer types, add atomic region.
- static mlir::omp::ReductionDeclareOp createReductionDecl(
- fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- const ReductionIdentifier redId, mlir::Type type, mlir::Location loc);
+ static mlir::omp::ReductionDeclareOp
+ createReductionDecl(fir::FirOpBuilder &builder,
+ llvm::StringRef reductionOpName,
+ const ReductionIdentifier redId, mlir::Type type,
+ mlir::Location loc, bool isByRef);
/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
@@ -124,6 +131,7 @@ mlir::Value
ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
mlir::Type type, mlir::Location loc,
mlir::Value op1, mlir::Value op2) {
+ type = fir::unwrapRefType(type);
assert(type.isIntOrIndexOrFloat() &&
"only integer and float types are currently supported");
if (type.isIntOrIndex())
diff --git a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90 b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90
new file mode 100644
index 00000000000000..ca432662b77c44
--- /dev/null
+++ b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90
@@ -0,0 +1,117 @@
+! RUN: bbc -emit-fir -hlfir=false -fopenmp --force-byref-reduction -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -flang-deprecated-no-hlfir -fopenmp -mmlir --force-byref-reduction -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_F32_NAME:.*]] : !fir.ref<f32>
+!CHECK-SAME: init {
+!CHECK: ^bb0(%{{.*}}: !fir.ref<f32>):
+!CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK: %[[REF:.*]] = fir.alloca f32
+!CHECKL fir.store [[%C0_1]] to %[[REF]] : !fir.ref<f32>
+!CHECK: omp.yield(%[[REF]] : !fir.ref<f32>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<f32>, %[[ARG1:.*]]: !fir.ref<f32>):
+!CHECK: %[[LD0:.*]] = fir.load %[[ARG0]] : !fir.ref<f32>
+!CHECK: %[[LD1:.*]] = fir.load %[[ARG1]] : !fir.ref<f32>
+!CHECK: %[[RES:.*]] = arith.addf %[[LD0]], %[[LD1]] {{.*}}: f32
+!CHECK: fir.store %[[RES]] to %[[ARG0]] : !fir.ref<f32>
+!CHECK: omp.yield(%[[ARG0]] : !fir.ref<f32>)
+!CHECK: }
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_I32_NAME:.*]] : !fir.ref<i32>
+!CHECK-SAME: init {
+!CHECK: ^bb0(%{{.*}}: !fir.ref<i32>):
+!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
+!CHECK: %[[REF:.*]] = fir.alloca i32
+!CHECKL fir.store [[%C0_1]] to %[[REF]] : !fir.ref<i32>
+!CHECK: omp.yield(%[[REF]] : !fir.ref<i32>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<i32>, %[[ARG1:.*]]: !fir.ref<i32>):
+!CHECK: %[[LD0:.*]] = fir.load %[[ARG0]] : !fir.ref<i32>
+!CHECK: %[[LD1:.*]] = fir.load %[[ARG1]] : !fir.ref<i32>
+!CHECK: %[[RES:.*]] = arith.addi %[[LD0]], %[[LD1]] : i32
+!CHECK: fir.store %[[RES]] to %[[ARG0]] : !fir.ref<i32>
+!CHECK: omp.yield(%[[ARG0]] : !fir.ref<i32>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_int_add
+!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
+!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK: omp.parallel byref reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
+!CHECK: %[[I_INCR:.+]] = arith.constant 1 : i32
+!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
+!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_int_add
+ integer :: i
+ i = 0
+
+ !$omp parallel reduction(+:i)
+ i = i + 1
+ !$omp end parallel
+
+ print *, i
+end subroutine
+
+!CHECK-LABEL: func.func @_QPsimple_real_add
+!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
+!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK: omp.parallel byref reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
+!CHECK: %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
+!CHECK: %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
+!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_real_add
+ real :: r
+ r = 0.0
+
+ !$omp parallel reduction(+:r)
+ r = r + 1.5
+ !$omp end parallel
+
+ print *, r
+end subroutine
+
+!CHECK-LABEL: func.func @_QPint_real_add
+!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFint_real_addEi"}
+!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFint_real_addEr"}
+!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK: omp.parallel byref reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
+!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
+!CHECK: %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
+!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
+!CHECK: fir.store %[[RES1]] to %[[PRV1]]
+!CHECK: %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
+!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
+!CHECK: %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
+!CHECK: fir.store %[[RES0]] t...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Tom Eccles (tblah) ChangesPreviously reduction variables were always passed by value into and out of the initialization and combiner regions of the OpenMP reduction declare operation. This worked well for reductions of primitive types (and might perform better than passing by reference). But passing by reference will be useful for array and derived type reductions (e.g. to move allocation inside of the init region). Passing reductions by reference requires different LLVM-IR generation when lowering from MLIR because some of the loads/stores/allocations will now be moved inside of the init and combiner regions. This alternate code generation is requested using a new attribute to omp.wsloop and omp.parallel. Existing lowerings from mlir are unaffected (these will continue to use the by-value argument passing. Flang will continue to pass by-value argument passing for trivial types unless a (hidden) command line argument is supplied. Non-trivial types will always use the by-ref lowering. Array reductions are not ready yet (but are coming very soon). In the meantime, this is tested by forcing existing reductions to use by-ref. Commit series for by-ref OpenMP reductions 3/3 Patch is 297.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84304.diff 37 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 185e0316870e94..d9648f3d692cc6 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -600,6 +600,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
return reductionSymbols;
};
+ mlir::UnitAttr byrefAttr;
+ if (ReductionProcessor::doReductionByRef(reductionVars))
+ byrefAttr = converter.getFirOpBuilder().getUnitAttr();
+
OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
@@ -619,7 +623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
- /*privatizers=*/nullptr);
+ /*privatizers=*/nullptr, byrefAttr);
}
bool privatize = !outerCombined;
@@ -683,7 +687,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
delayedPrivatizationInfo.privatizers.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- privatizers));
+ privatizers),
+ byrefAttr);
}
static mlir::omp::SectionOp
@@ -1568,7 +1573,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
- mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
+ mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
mlir::IntegerAttr orderedClauseOperand;
mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
std::size_t loopVarTypeSize;
@@ -1585,6 +1590,9 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
convertLoopBounds(converter, loc, lowerBound, upperBound, step,
loopVarTypeSize);
+ if (ReductionProcessor::doReductionByRef(reductionVars))
+ byrefOperand = firOpBuilder.getUnitAttr();
+
auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
loc, lowerBound, upperBound, step, linearVars, linearStepVars,
reductionVars,
@@ -1594,8 +1602,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
reductionDeclSymbols),
scheduleValClauseOperand, scheduleChunkClauseOperand,
/*schedule_modifiers=*/nullptr,
- /*simd_modifier=*/nullptr, nowaitClauseOperand, orderedClauseOperand,
- orderClauseOperand,
+ /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
+ orderedClauseOperand, orderClauseOperand,
/*inclusive=*/firOpBuilder.getUnitAttr());
// Handle attribute based clauses.
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index a8b98f3f567249..34959e95dce04a 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -14,9 +14,16 @@
#include "flang/Lower/AbstractConverter.h"
#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "llvm/Support/CommandLine.h"
+
+static llvm::cl::opt<bool> forceByrefReduction(
+ "force-byref-reduction",
+ llvm::cl::desc("Pass all reduction arguments by reference"),
+ llvm::cl::Hidden);
namespace Fortran {
namespace lower {
@@ -76,16 +83,24 @@ bool ReductionProcessor::supportedIntrinsicProcReduction(
}
std::string ReductionProcessor::getReductionName(llvm::StringRef name,
- mlir::Type ty) {
+ mlir::Type ty, bool isByRef) {
+ ty = fir::unwrapRefType(ty);
+
+ // extra string to distinguish reduction functions for variables passed by
+ // reference
+ llvm::StringRef byrefAddition{""};
+ if (isByRef)
+ byrefAddition = "_byref";
+
return (llvm::Twine(name) +
(ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
- llvm::Twine(ty.getIntOrFloatBitWidth()))
+ llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
.str();
}
std::string ReductionProcessor::getReductionName(
Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type ty) {
+ mlir::Type ty, bool isByRef) {
std::string reductionName;
switch (intrinsicOp) {
@@ -108,13 +123,14 @@ std::string ReductionProcessor::getReductionName(
break;
}
- return getReductionName(reductionName, ty);
+ return getReductionName(reductionName, ty, isByRef);
}
mlir::Value
ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
ReductionIdentifier redId,
fir::FirOpBuilder &builder) {
+ type = fir::unwrapRefType(type);
assert((fir::isa_integer(type) || fir::isa_real(type) ||
type.isa<fir::LogicalType>()) &&
"only integer, logical and real types are currently supported");
@@ -188,6 +204,7 @@ mlir::Value ReductionProcessor::createScalarCombiner(
fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
mlir::Type type, mlir::Value op1, mlir::Value op2) {
mlir::Value reductionOp;
+ type = fir::unwrapRefType(type);
switch (redId) {
case ReductionIdentifier::MAX:
reductionOp =
@@ -268,7 +285,8 @@ mlir::Value ReductionProcessor::createScalarCombiner(
mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
+ const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
+ bool isByRef) {
mlir::OpBuilder::InsertionGuard guard(builder);
mlir::ModuleOp module = builder.getModule();
@@ -278,14 +296,24 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
return decl;
mlir::OpBuilder modBuilder(module.getBodyRegion());
+ mlir::Type valTy = fir::unwrapRefType(type);
+ if (!isByRef)
+ type = valTy;
decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
type);
builder.createBlock(&decl.getInitializerRegion(),
decl.getInitializerRegion().end(), {type}, {loc});
builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+
mlir::Value init = getReductionInitValue(loc, type, redId, builder);
- builder.create<mlir::omp::YieldOp>(loc, init);
+ if (isByRef) {
+ mlir::Value alloca = builder.create<fir::AllocaOp>(loc, valTy);
+ builder.createStoreWithConvert(loc, init, alloca);
+ builder.create<mlir::omp::YieldOp>(loc, alloca);
+ } else {
+ builder.create<mlir::omp::YieldOp>(loc, init);
+ }
builder.createBlock(&decl.getReductionRegion(),
decl.getReductionRegion().end(), {type, type},
@@ -294,14 +322,41 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+ mlir::Value outAddr = op1;
+
+ op1 = builder.loadIfRef(loc, op1);
+ op2 = builder.loadIfRef(loc, op2);
mlir::Value reductionOp =
createScalarCombiner(builder, loc, redId, type, op1, op2);
- builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+ if (isByRef) {
+ builder.create<fir::StoreOp>(loc, reductionOp, outAddr);
+ builder.create<mlir::omp::YieldOp>(loc, outAddr);
+ } else {
+ builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+ }
return decl;
}
+bool ReductionProcessor::doReductionByRef(
+ const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
+ if (reductionVars.empty())
+ return false;
+ if (forceByrefReduction)
+ return true;
+
+ for (mlir::Value reductionVar : reductionVars) {
+ if (auto declare =
+ mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
+ reductionVar = declare.getMemref();
+
+ if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
+ return true;
+ }
+ return false;
+}
+
void ReductionProcessor::addReductionDecl(
mlir::Location currentLocation,
Fortran::lower::AbstractConverter &converter,
@@ -315,6 +370,24 @@ void ReductionProcessor::addReductionDecl(
const auto &redOperator{
std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+
+ // initial pass to collect all recuction vars so we can figure out if this
+ // should happen byref
+ for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+ if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ if (reductionSymbols)
+ reductionSymbols->push_back(symbol);
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+ reductionVars.push_back(symVal);
+ }
+ }
+ }
+ const bool isByRef = doReductionByRef(reductionVars);
+
if (const auto &redDefinedOp =
std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
@@ -338,23 +411,20 @@ void ReductionProcessor::addReductionDecl(
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- if (reductionSymbols)
- reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- if (redType.isa<fir::LogicalType>())
+ auto redType = symVal.getType().cast<fir::ReferenceType>();
+ if (redType.getEleTy().isa<fir::LogicalType>())
decl = createReductionDecl(
firOpBuilder,
- getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
- redType, currentLocation);
- else if (redType.isIntOrIndexOrFloat()) {
- decl = createReductionDecl(firOpBuilder,
- getReductionName(intrinsicOp, redType),
- redId, redType, currentLocation);
+ getReductionName(intrinsicOp, firOpBuilder.getI1Type(),
+ isByRef),
+ redId, redType, currentLocation, isByRef);
+ else if (redType.getEleTy().isIntOrIndexOrFloat()) {
+ decl = createReductionDecl(
+ firOpBuilder, getReductionName(intrinsicOp, redType, isByRef),
+ redId, redType, currentLocation, isByRef);
} else {
TODO(currentLocation, "Reduction of some types is not supported");
}
@@ -374,21 +444,17 @@ void ReductionProcessor::addReductionDecl(
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- if (reductionSymbols)
- reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- assert(redType.isIntOrIndexOrFloat() &&
+ auto redType = symVal.getType().cast<fir::ReferenceType>();
+ assert(redType.getEleTy().isIntOrIndexOrFloat() &&
"Unsupported reduction type");
decl = createReductionDecl(
firOpBuilder,
getReductionName(getRealName(*reductionIntrinsic).ToString(),
- redType),
- redId, redType, currentLocation);
+ redType, isByRef),
+ redId, redType, currentLocation, isByRef);
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
firOpBuilder.getContext(), decl.getSymName()));
}
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 00770fe81d1ef6..679580f2a3cac7 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -14,6 +14,7 @@
#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
@@ -71,11 +72,15 @@ class ReductionProcessor {
static const Fortran::semantics::SourceName
getRealName(const Fortran::parser::ProcedureDesignator &pd);
- static std::string getReductionName(llvm::StringRef name, mlir::Type ty);
+ static bool
+ doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);
+
+ static std::string getReductionName(llvm::StringRef name, mlir::Type ty,
+ bool isByRef);
static std::string getReductionName(
Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type ty);
+ mlir::Type ty, bool isByRef);
/// This function returns the identity value of the operator \p
/// reductionOpName. For example:
@@ -103,9 +108,11 @@ class ReductionProcessor {
/// symbol table. The declaration has a constant initializer with the neutral
/// value `initValue`, and the reduction combiner carried over from `reduce`.
/// TODO: Generalize this for non-integer types, add atomic region.
- static mlir::omp::ReductionDeclareOp createReductionDecl(
- fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- const ReductionIdentifier redId, mlir::Type type, mlir::Location loc);
+ static mlir::omp::ReductionDeclareOp
+ createReductionDecl(fir::FirOpBuilder &builder,
+ llvm::StringRef reductionOpName,
+ const ReductionIdentifier redId, mlir::Type type,
+ mlir::Location loc, bool isByRef);
/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
@@ -124,6 +131,7 @@ mlir::Value
ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
mlir::Type type, mlir::Location loc,
mlir::Value op1, mlir::Value op2) {
+ type = fir::unwrapRefType(type);
assert(type.isIntOrIndexOrFloat() &&
"only integer and float types are currently supported");
if (type.isIntOrIndex())
diff --git a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90 b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90
new file mode 100644
index 00000000000000..ca432662b77c44
--- /dev/null
+++ b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90
@@ -0,0 +1,117 @@
+! RUN: bbc -emit-fir -hlfir=false -fopenmp --force-byref-reduction -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -flang-deprecated-no-hlfir -fopenmp -mmlir --force-byref-reduction -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_F32_NAME:.*]] : !fir.ref<f32>
+!CHECK-SAME: init {
+!CHECK: ^bb0(%{{.*}}: !fir.ref<f32>):
+!CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK: %[[REF:.*]] = fir.alloca f32
+!CHECKL fir.store [[%C0_1]] to %[[REF]] : !fir.ref<f32>
+!CHECK: omp.yield(%[[REF]] : !fir.ref<f32>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<f32>, %[[ARG1:.*]]: !fir.ref<f32>):
+!CHECK: %[[LD0:.*]] = fir.load %[[ARG0]] : !fir.ref<f32>
+!CHECK: %[[LD1:.*]] = fir.load %[[ARG1]] : !fir.ref<f32>
+!CHECK: %[[RES:.*]] = arith.addf %[[LD0]], %[[LD1]] {{.*}}: f32
+!CHECK: fir.store %[[RES]] to %[[ARG0]] : !fir.ref<f32>
+!CHECK: omp.yield(%[[ARG0]] : !fir.ref<f32>)
+!CHECK: }
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_I32_NAME:.*]] : !fir.ref<i32>
+!CHECK-SAME: init {
+!CHECK: ^bb0(%{{.*}}: !fir.ref<i32>):
+!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
+!CHECK: %[[REF:.*]] = fir.alloca i32
+!CHECKL fir.store [[%C0_1]] to %[[REF]] : !fir.ref<i32>
+!CHECK: omp.yield(%[[REF]] : !fir.ref<i32>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<i32>, %[[ARG1:.*]]: !fir.ref<i32>):
+!CHECK: %[[LD0:.*]] = fir.load %[[ARG0]] : !fir.ref<i32>
+!CHECK: %[[LD1:.*]] = fir.load %[[ARG1]] : !fir.ref<i32>
+!CHECK: %[[RES:.*]] = arith.addi %[[LD0]], %[[LD1]] : i32
+!CHECK: fir.store %[[RES]] to %[[ARG0]] : !fir.ref<i32>
+!CHECK: omp.yield(%[[ARG0]] : !fir.ref<i32>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_int_add
+!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
+!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK: omp.parallel byref reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
+!CHECK: %[[I_INCR:.+]] = arith.constant 1 : i32
+!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
+!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_int_add
+ integer :: i
+ i = 0
+
+ !$omp parallel reduction(+:i)
+ i = i + 1
+ !$omp end parallel
+
+ print *, i
+end subroutine
+
+!CHECK-LABEL: func.func @_QPsimple_real_add
+!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
+!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK: omp.parallel byref reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
+!CHECK: %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
+!CHECK: %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
+!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_real_add
+ real :: r
+ r = 0.0
+
+ !$omp parallel reduction(+:r)
+ r = r + 1.5
+ !$omp end parallel
+
+ print *, r
+end subroutine
+
+!CHECK-LABEL: func.func @_QPint_real_add
+!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFint_real_addEi"}
+!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFint_real_addEr"}
+!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK: omp.parallel byref reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
+!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
+!CHECK: %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
+!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
+!CHECK: fir.store %[[RES1]] to %[[PRV1]]
+!CHECK: %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
+!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
+!CHECK: %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
+!CHECK: fir.store %[[RES0]] t...
[truncated]
|
co-authored with @Leporacanthicus (github seems to have taken the tag out of the commit message but shows it in the UI) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
reductionVar = declare.getMemref(); | ||
|
||
if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) | ||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this imply that all reductions on a clause have to be by ref or by val? E.g. if we have an array reduction on the clause does that mean an integer reduction also changes to byref?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it does. I did this to keep things simpler. Currently byref vs byval is toggled over the whole wsloop or parallel region. A more sophisticated implementation could instead track this per reduction argument. I chose not to do this to keep things simple.
I suspect that in most cases, if an integer reduction and an array reduction are used together, the array reduction would take long enough that the performance loss from doing the integer reduction by reference would not be significant. I have not measured this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's probably true wrt performance, but I believe when it comes to openmp target
reductions doing reductions on basic types will affect correctness. As far as I remember you do not need to ensure manually that for example the INTEGER exists on the target device, whereas you do for an array (and would need to if the integer is passed by reference).
I think fixing that in a subsequent patch is probably fine though, as long as we add a TODO mentioning that ideally it should be considered separately per argument.
This makes sure that the generated IR doesn't change as a result of this PR. The generated IR looks wrong to me (no reduction is generated at all), but that is a matter for another patch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Previously reduction variables were always passed by value into and out of the initialization and combiner regions of the OpenMP reduction declare operation.
This worked well for reductions of primitive types (and might perform better than passing by reference). But passing by reference will be useful for array and derived type reductions (e.g. to move allocation inside of the init region).
Passing reductions by reference requires different LLVM-IR generation when lowering from MLIR because some of the loads/stores/allocations will now be moved inside of the init and combiner regions. This alternate code generation is requested using a new attribute to omp.wsloop and omp.parallel.
Existing lowerings from mlir are unaffected (these will continue to use the by-value argument passing.
Flang will continue to pass by-value argument passing for trivial types unless a (hidden) command line argument is supplied. Non-trivial types will always use the by-ref lowering.
Array reductions are not ready yet (but are coming very soon). In the meantime, this is tested by forcing existing reductions to use by-ref.
Commit series for by-ref OpenMP reductions 3/3