-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][MLIR][OpenMP] make reduction by-ref toggled per variable #92244
Conversation
@llvm/pr-subscribers-flang-openmp @llvm/pr-subscribers-mlir-openmp Author: Tom Eccles (tblah) ChangesFixes #88935 Toggling reduction by-ref broke when multiple reduction clauses were used. Decisions made for the by-ref status for later clauses could then invalidate decisions for earlier clauses. For example,
The first clause would choose by value reduction and generate by-value reduction regions, but then after this the second clause would force by-ref to support the array argument. But by the time the second clause is processed, the first clause has already had the wrong kind of reduction regions generated. This is solved by toggling whether a variable should be reduced by reference per variable. In the above example, this allows only Patch is 116.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92244.diff 47 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 0ea87314d571f..a57b96e365999 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -933,20 +933,18 @@ bool ClauseProcessor::processReduction(
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &clause,
const Fortran::parser::CharBlock &) {
- // Use local lists of reductions to prevent variables from other
- // already-processed reduction clauses from impacting this reduction.
- // For example, the whole `reductionVars` array is queried to decide
- // whether to do the reduction byref.
llvm::SmallVector<mlir::Value> reductionVars;
+ llvm::SmallVector<bool> reduceVarByRef;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
- rp.addDeclareReduction(currentLocation, converter, clause,
- reductionVars, reductionDeclSymbols,
- outReductionSyms ? &reductionSyms : nullptr);
+ rp.addDeclareReduction(
+ currentLocation, converter, clause, reductionVars, reduceVarByRef,
+ reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
+ llvm::copy(reduceVarByRef, std::back_inserter(result.reduceVarByRef));
llvm::copy(reductionDeclSymbols,
std::back_inserter(result.reductionDeclSymbols));
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f21acdd64d7c3..aaf0c6501ab20 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1007,8 +1007,6 @@ static void genParallelClauses(
if (processReduction) {
cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
- if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
- clauseOps.reductionByRefAttr = converter.getFirOpBuilder().getUnitAttr();
}
}
@@ -1200,7 +1198,6 @@ static void genWsloopClauses(
mlir::Location loc, mlir::omp::WsloopClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processNowait(clauseOps);
cp.processOrdered(clauseOps);
@@ -1208,9 +1205,6 @@ static void genWsloopClauses(
cp.processSchedule(stmtCtx, clauseOps);
// TODO Support delayed privatization.
- if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
- clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
-
cp.processTODO<clause::Allocate, clause::Linear, clause::Order>(
loc, llvm::omp::Directive::OMPD_do);
}
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index b3f08eb81c799..689f3adc0a429 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -657,25 +657,17 @@ mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
return decl;
}
-// TODO: By-ref vs by-val reductions are currently toggled for the whole
-// operation (possibly effecting multiple reduction variables).
-// This could cause a problem with openmp target reductions because
-// by-ref trivial types may not be supported.
-bool ReductionProcessor::doReductionByRef(
- const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
- if (reductionVars.empty())
- return false;
+static bool doReductionByRef(mlir::Value reductionVar) {
if (forceByrefReduction)
return true;
- for (mlir::Value reductionVar : reductionVars) {
- if (auto declare =
- mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
- reductionVar = declare.getMemref();
+ if (auto declare =
+ mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
+ reductionVar = declare.getMemref();
+
+ if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
+ return true;
- if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
- return true;
- }
return false;
}
@@ -684,6 +676,7 @@ void ReductionProcessor::addDeclareReduction(
Fortran::lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*reductionSymbols) {
@@ -764,8 +757,8 @@ void ReductionProcessor::addDeclareReduction(
"reduction input var is a reference");
reductionVars.push_back(symVal);
+ reduceVarByRef.push_back(doReductionByRef(symVal));
}
- const bool isByRef = doReductionByRef(reductionVars);
if (const auto &redDefinedOp =
std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
@@ -787,7 +780,7 @@ void ReductionProcessor::addDeclareReduction(
break;
}
- for (mlir::Value symVal : reductionVars) {
+ for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
const auto &kindMap = firOpBuilder.getKindMap();
if (mlir::isa<fir::LogicalType>(redType.getEleTy()))
@@ -811,7 +804,7 @@ void ReductionProcessor::addDeclareReduction(
*reductionIntrinsic)) {
ReductionProcessor::ReductionIdentifier redId =
ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (mlir::Value symVal : reductionVars) {
+ for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
if (!redType.getEleTy().isIntOrIndexOrFloat())
TODO(currentLocation,
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 8b116a4c52041..95d77c8154415 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -73,9 +73,6 @@ class ReductionProcessor {
static const Fortran::semantics::SourceName
getRealName(const omp::clause::ProcedureDesignator &pd);
- static bool
- doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);
-
static std::string getReductionName(llvm::StringRef name,
const fir::KindMapping &kindMap,
mlir::Type ty, bool isByRef);
@@ -128,6 +125,7 @@ class ReductionProcessor {
Fortran::lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*reductionSymbols = nullptr);
diff --git a/flang/test/Fir/omp-reduction-embox-codegen.fir b/flang/test/Fir/omp-reduction-embox-codegen.fir
index 7602012ebc5c9..1645e1a407ad4 100644
--- a/flang/test/Fir/omp-reduction-embox-codegen.fir
+++ b/flang/test/Fir/omp-reduction-embox-codegen.fir
@@ -25,7 +25,7 @@ omp.declare_reduction @test_reduction : !fir.ref<!fir.box<i32>> init {
func.func @_QQmain() attributes {fir.bindc_name = "reduce"} {
%4 = fir.alloca !fir.box<i32>
- omp.parallel byref reduction(@test_reduction %4 -> %arg0 : !fir.ref<!fir.box<i32>>) {
+ omp.parallel reduction(byref @test_reduction %4 -> %arg0 : !fir.ref<!fir.box<i32>>) {
omp.terminator
}
return
diff --git a/flang/test/Lower/OpenMP/default-clause-byref.f90 b/flang/test/Lower/OpenMP/default-clause-byref.f90
index 7cc2bc2e0c710..7893c4d7d5732 100644
--- a/flang/test/Lower/OpenMP/default-clause-byref.f90
+++ b/flang/test/Lower/OpenMP/default-clause-byref.f90
@@ -351,7 +351,7 @@ subroutine skipped_default_clause_checks()
type(it)::iii
!CHECK: omp.parallel {
-!CHECK: omp.wsloop byref reduction(@min_byref_i32 %[[VAL_Z_DECLARE]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK: omp.wsloop reduction(byref @min_byref_i32 %[[VAL_Z_DECLARE]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
!CHECK-NEXT: omp.loop_nest (%[[ARG:.*]]) {{.*}} {
!CHECK: omp.yield
!CHECK: }
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
index 49d1142ea4b6a..72e91680a4310 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
@@ -26,5 +26,5 @@ subroutine red_and_delayed_private
! CHECK-LABEL: _QPred_and_delayed_private
! CHECK: omp.parallel
-! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90 b/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90
index 2a1d26407b27e..7347d9324feac 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90
@@ -40,7 +40,7 @@
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK: omp.parallel byref reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK: omp.parallel reduction(byref @[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
@@ -65,7 +65,7 @@ subroutine simple_int_add
!CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
-!CHECK: omp.parallel byref reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK: omp.parallel reduction(byref @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<f32>
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
@@ -94,7 +94,7 @@ subroutine simple_real_add
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK: omp.parallel byref reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
+!CHECK: omp.parallel reduction(byref @[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, byref @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90 b/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90
index 28216ef91c3a3..f6d3b0b73f738 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90
@@ -95,7 +95,7 @@ program reduce
! CHECK: %[[VAL_14:.*]] = arith.constant 0 : i32
! CHECK: %[[VAL_15:.*]] = arith.constant 10 : i32
! CHECK: %[[VAL_16:.*]] = arith.constant 1 : i32
-! CHECK: omp.wsloop byref reduction(@add_reduction_byref_box_heap_Uxi32 %[[VAL_3]]#0 -> %[[VAL_17:.*]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) {
+! CHECK: omp.wsloop reduction(byref @add_reduction_byref_box_heap_Uxi32 %[[VAL_3]]#0 -> %[[VAL_17:.*]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) {
! CHECK-NEXT: omp.loop_nest (%[[VAL_18:.*]]) : i32 = (%[[VAL_14]]) to (%[[VAL_15]]) inclusive step (%[[VAL_16]]) {
! CHECK: %[[VAL_19:.*]]:2 = hlfir.declare %[[VAL_17]] {fortran_attrs = {{.*}}<allocatable>, uniq_name = "_QFEr"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
! CHECK: fir.store %[[VAL_18]] to %[[VAL_13]]#1 : !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90
index 8202e6d897157..b44fe4c1f4cc2 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90
@@ -77,7 +77,7 @@ program reduce
! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_5]]) {uniq_name = "_QFEi"} : (!fir.ref<!fir.array<3x2xi32>>, !fir.shapeshift<2>) -> (!fir.box<!fir.array<3x2xi32>>, !fir.ref<!fir.array<3x2xi32>>)
! CHECK: %[[VAL_7:.*]] = fir.alloca !fir.box<!fir.array<3x2xi32>>
! CHECK: fir.store %[[VAL_6]]#0 to %[[VAL_7]] : !fir.ref<!fir.box<!fir.array<3x2xi32>>>
-! CHECK: omp.parallel byref reduction(@add_reduction_byref_box_3x2xi32 %[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref<!fir.box<!fir.array<3x2xi32>>>) {
+! CHECK: omp.parallel reduction(byref @add_reduction_byref_box_3x2xi32 %[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref<!fir.box<!fir.array<3x2xi32>>>) {
! CHECK: %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3x2xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3x2xi32>>>, !fir.ref<!fir.box<!fir.array<3x2xi32>>>)
! CHECK: %[[VAL_10:.*]] = arith.constant 3 : i32
! CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_9]]#0 : !fir.ref<!fir.box<!fir.array<3x2xi32>>>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array.f90
index 34f4ee0a9eb3a..60b21c9b1ebbe 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array.f90
@@ -70,7 +70,7 @@ program reduce
! CHECK: %[[VAL_4:.*]] = fir.embox %[[VAL_3]]#0(%[[VAL_2]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
! CHECK: %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
! CHECK: fir.store %[[VAL_4]] to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
-! CHECK: omp.parallel byref reduction(@add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
+! CHECK: omp.parallel reduction(byref @add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3xi32>>>, !fir.ref<!fir.box<!fir.array<3xi32>>>)
! CHECK: %[[VAL_8:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_9:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array2.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
index aa14092554eda..5d4c86d1d76e8 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
@@ -69,7 +69,7 @@ program reduce
! CHECK: %[[VAL_4:.*]] = fir.embox %[[VAL_3]]#0(%[[VAL_2]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
! CHECK: %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
! CHECK: fir.store %[[VAL_4]] to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
-! CHECK: omp.parallel byref reduction(@add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
+! CHECK: omp.parallel reduction(byref @add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3xi32>>>, !fir.ref<!fir.box<!fir.array<3xi32>>>)
! CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
! CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-byref.f90 b/flang/test/Lower/OpenMP/parallel-reduction-byref.f90
index fdcdf0c0b8d95..5685e2c584ace 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-byref.f90
@@ -21,7 +21,7 @@
!CHECK: %[[RED_ACCUM_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[RED_ACCUM_REF]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[C0:[_a-z0-9]+]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[C0]] to %[[RED_ACCUM_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK: omp.parallel byref reduction(@[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref<i32>) {
+!CHECK: omp.parallel reduction(byref @[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref<i32>) {
!CHECK: %[[PRIVATE_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[PRIVATE_RED]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[C1:[_a-z0-9]+]] = arith.constant 1 : i32
!CHECK: hlfir.assign %[[C1]] to %[[PRIVATE_DECL]]#0 : i32, !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction3.f90 b/flang/test/Lower/OpenMP/parallel-reduction3.f90
index 17d805c0d142b..47b743a558b49 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction3.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction3.f90
@@ -74,7 +74,7 @@
! CHECK: %[[VAL_18:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_19:.*]] = fir.alloca !fir.box<!fir.array<?xi32>>
! CHECK: fir.store %[[VAL_12]]#0 to %[[VAL_19]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
-! CHECK: omp.wsloop byref reduction(@add_reduction_byref_box_Uxi32 %[[VAL_19]] -> %[[VAL_20:.*]] : !fir.ref<!fir.box<!fir.array<?xi32>>>) {
+! CHECK: omp.wsloop reduction(byref @add_reduction_byref_box_Uxi32 %[[VAL_19]] -> %[[VAL_20:.*]] : !fir.ref<!fir.box<!fir.array<?xi32>>>) {
! CHECK-NEXT: omp.loop_nest (%[[VAL_21:.*]]) : i32 = (%[[VAL_16]]) to (%[[VAL_17]]) inclusive step (%[[VAL_18]]) {
! CHECK: %[[VAL_22:.*]]:2 = hlfir.declare %[[VAL_20]] {uniq_name = "_QFsEc"} : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> (!fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.array<?xi32>>>)
! CHECK: fir.store %[[VAL_21]] to %[[VAL_15]]#1 : !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90 b/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90
index 66c80c31917ba..32caac39778de 100644
--- a/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90
@@ -4,7 +4,7 @@
! RUN: flang-new -fc1 -fopenmp -mmlir --force-byref-reduction -emit-hlfir %s -o - | FileCheck %s
! CHECK: omp.parallel {
-! CHECK: omp.wsloop byref reduction(@add_reduction_byref_i32
+! CHECK: omp.wsloop reduction(byref @add...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Tom Eccles (tblah) ChangesFixes #88935 Toggling reduction by-ref broke when multiple reduction clauses were used. Decisions made for the by-ref status for later clauses could then invalidate decisions for earlier clauses. For example,
The first clause would choose by value reduction and generate by-value reduction regions, but then after this the second clause would force by-ref to support the array argument. But by the time the second clause is processed, the first clause has already had the wrong kind of reduction regions generated. This is solved by toggling whether a variable should be reduced by reference per variable. In the above example, this allows only Patch is 116.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92244.diff 47 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 0ea87314d571f..a57b96e365999 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -933,20 +933,18 @@ bool ClauseProcessor::processReduction(
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &clause,
const Fortran::parser::CharBlock &) {
- // Use local lists of reductions to prevent variables from other
- // already-processed reduction clauses from impacting this reduction.
- // For example, the whole `reductionVars` array is queried to decide
- // whether to do the reduction byref.
llvm::SmallVector<mlir::Value> reductionVars;
+ llvm::SmallVector<bool> reduceVarByRef;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
- rp.addDeclareReduction(currentLocation, converter, clause,
- reductionVars, reductionDeclSymbols,
- outReductionSyms ? &reductionSyms : nullptr);
+ rp.addDeclareReduction(
+ currentLocation, converter, clause, reductionVars, reduceVarByRef,
+ reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
+ llvm::copy(reduceVarByRef, std::back_inserter(result.reduceVarByRef));
llvm::copy(reductionDeclSymbols,
std::back_inserter(result.reductionDeclSymbols));
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f21acdd64d7c3..aaf0c6501ab20 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1007,8 +1007,6 @@ static void genParallelClauses(
if (processReduction) {
cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
- if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
- clauseOps.reductionByRefAttr = converter.getFirOpBuilder().getUnitAttr();
}
}
@@ -1200,7 +1198,6 @@ static void genWsloopClauses(
mlir::Location loc, mlir::omp::WsloopClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processNowait(clauseOps);
cp.processOrdered(clauseOps);
@@ -1208,9 +1205,6 @@ static void genWsloopClauses(
cp.processSchedule(stmtCtx, clauseOps);
// TODO Support delayed privatization.
- if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
- clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
-
cp.processTODO<clause::Allocate, clause::Linear, clause::Order>(
loc, llvm::omp::Directive::OMPD_do);
}
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index b3f08eb81c799..689f3adc0a429 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -657,25 +657,17 @@ mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
return decl;
}
-// TODO: By-ref vs by-val reductions are currently toggled for the whole
-// operation (possibly effecting multiple reduction variables).
-// This could cause a problem with openmp target reductions because
-// by-ref trivial types may not be supported.
-bool ReductionProcessor::doReductionByRef(
- const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
- if (reductionVars.empty())
- return false;
+static bool doReductionByRef(mlir::Value reductionVar) {
if (forceByrefReduction)
return true;
- for (mlir::Value reductionVar : reductionVars) {
- if (auto declare =
- mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
- reductionVar = declare.getMemref();
+ if (auto declare =
+ mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
+ reductionVar = declare.getMemref();
+
+ if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
+ return true;
- if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
- return true;
- }
return false;
}
@@ -684,6 +676,7 @@ void ReductionProcessor::addDeclareReduction(
Fortran::lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*reductionSymbols) {
@@ -764,8 +757,8 @@ void ReductionProcessor::addDeclareReduction(
"reduction input var is a reference");
reductionVars.push_back(symVal);
+ reduceVarByRef.push_back(doReductionByRef(symVal));
}
- const bool isByRef = doReductionByRef(reductionVars);
if (const auto &redDefinedOp =
std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
@@ -787,7 +780,7 @@ void ReductionProcessor::addDeclareReduction(
break;
}
- for (mlir::Value symVal : reductionVars) {
+ for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
const auto &kindMap = firOpBuilder.getKindMap();
if (mlir::isa<fir::LogicalType>(redType.getEleTy()))
@@ -811,7 +804,7 @@ void ReductionProcessor::addDeclareReduction(
*reductionIntrinsic)) {
ReductionProcessor::ReductionIdentifier redId =
ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (mlir::Value symVal : reductionVars) {
+ for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
if (!redType.getEleTy().isIntOrIndexOrFloat())
TODO(currentLocation,
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 8b116a4c52041..95d77c8154415 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -73,9 +73,6 @@ class ReductionProcessor {
static const Fortran::semantics::SourceName
getRealName(const omp::clause::ProcedureDesignator &pd);
- static bool
- doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);
-
static std::string getReductionName(llvm::StringRef name,
const fir::KindMapping &kindMap,
mlir::Type ty, bool isByRef);
@@ -128,6 +125,7 @@ class ReductionProcessor {
Fortran::lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*reductionSymbols = nullptr);
diff --git a/flang/test/Fir/omp-reduction-embox-codegen.fir b/flang/test/Fir/omp-reduction-embox-codegen.fir
index 7602012ebc5c9..1645e1a407ad4 100644
--- a/flang/test/Fir/omp-reduction-embox-codegen.fir
+++ b/flang/test/Fir/omp-reduction-embox-codegen.fir
@@ -25,7 +25,7 @@ omp.declare_reduction @test_reduction : !fir.ref<!fir.box<i32>> init {
func.func @_QQmain() attributes {fir.bindc_name = "reduce"} {
%4 = fir.alloca !fir.box<i32>
- omp.parallel byref reduction(@test_reduction %4 -> %arg0 : !fir.ref<!fir.box<i32>>) {
+ omp.parallel reduction(byref @test_reduction %4 -> %arg0 : !fir.ref<!fir.box<i32>>) {
omp.terminator
}
return
diff --git a/flang/test/Lower/OpenMP/default-clause-byref.f90 b/flang/test/Lower/OpenMP/default-clause-byref.f90
index 7cc2bc2e0c710..7893c4d7d5732 100644
--- a/flang/test/Lower/OpenMP/default-clause-byref.f90
+++ b/flang/test/Lower/OpenMP/default-clause-byref.f90
@@ -351,7 +351,7 @@ subroutine skipped_default_clause_checks()
type(it)::iii
!CHECK: omp.parallel {
-!CHECK: omp.wsloop byref reduction(@min_byref_i32 %[[VAL_Z_DECLARE]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK: omp.wsloop reduction(byref @min_byref_i32 %[[VAL_Z_DECLARE]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
!CHECK-NEXT: omp.loop_nest (%[[ARG:.*]]) {{.*}} {
!CHECK: omp.yield
!CHECK: }
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
index 49d1142ea4b6a..72e91680a4310 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
@@ -26,5 +26,5 @@ subroutine red_and_delayed_private
! CHECK-LABEL: _QPred_and_delayed_private
! CHECK: omp.parallel
-! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90 b/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90
index 2a1d26407b27e..7347d9324feac 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90
@@ -40,7 +40,7 @@
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK: omp.parallel byref reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK: omp.parallel reduction(byref @[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
@@ -65,7 +65,7 @@ subroutine simple_int_add
!CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
-!CHECK: omp.parallel byref reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK: omp.parallel reduction(byref @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<f32>
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
@@ -94,7 +94,7 @@ subroutine simple_real_add
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK: omp.parallel byref reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
+!CHECK: omp.parallel reduction(byref @[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, byref @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90 b/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90
index 28216ef91c3a3..f6d3b0b73f738 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90
@@ -95,7 +95,7 @@ program reduce
! CHECK: %[[VAL_14:.*]] = arith.constant 0 : i32
! CHECK: %[[VAL_15:.*]] = arith.constant 10 : i32
! CHECK: %[[VAL_16:.*]] = arith.constant 1 : i32
-! CHECK: omp.wsloop byref reduction(@add_reduction_byref_box_heap_Uxi32 %[[VAL_3]]#0 -> %[[VAL_17:.*]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) {
+! CHECK: omp.wsloop reduction(byref @add_reduction_byref_box_heap_Uxi32 %[[VAL_3]]#0 -> %[[VAL_17:.*]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) {
! CHECK-NEXT: omp.loop_nest (%[[VAL_18:.*]]) : i32 = (%[[VAL_14]]) to (%[[VAL_15]]) inclusive step (%[[VAL_16]]) {
! CHECK: %[[VAL_19:.*]]:2 = hlfir.declare %[[VAL_17]] {fortran_attrs = {{.*}}<allocatable>, uniq_name = "_QFEr"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
! CHECK: fir.store %[[VAL_18]] to %[[VAL_13]]#1 : !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90
index 8202e6d897157..b44fe4c1f4cc2 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90
@@ -77,7 +77,7 @@ program reduce
! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_5]]) {uniq_name = "_QFEi"} : (!fir.ref<!fir.array<3x2xi32>>, !fir.shapeshift<2>) -> (!fir.box<!fir.array<3x2xi32>>, !fir.ref<!fir.array<3x2xi32>>)
! CHECK: %[[VAL_7:.*]] = fir.alloca !fir.box<!fir.array<3x2xi32>>
! CHECK: fir.store %[[VAL_6]]#0 to %[[VAL_7]] : !fir.ref<!fir.box<!fir.array<3x2xi32>>>
-! CHECK: omp.parallel byref reduction(@add_reduction_byref_box_3x2xi32 %[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref<!fir.box<!fir.array<3x2xi32>>>) {
+! CHECK: omp.parallel reduction(byref @add_reduction_byref_box_3x2xi32 %[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref<!fir.box<!fir.array<3x2xi32>>>) {
! CHECK: %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3x2xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3x2xi32>>>, !fir.ref<!fir.box<!fir.array<3x2xi32>>>)
! CHECK: %[[VAL_10:.*]] = arith.constant 3 : i32
! CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_9]]#0 : !fir.ref<!fir.box<!fir.array<3x2xi32>>>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array.f90
index 34f4ee0a9eb3a..60b21c9b1ebbe 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array.f90
@@ -70,7 +70,7 @@ program reduce
! CHECK: %[[VAL_4:.*]] = fir.embox %[[VAL_3]]#0(%[[VAL_2]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
! CHECK: %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
! CHECK: fir.store %[[VAL_4]] to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
-! CHECK: omp.parallel byref reduction(@add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
+! CHECK: omp.parallel reduction(byref @add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3xi32>>>, !fir.ref<!fir.box<!fir.array<3xi32>>>)
! CHECK: %[[VAL_8:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_9:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array2.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
index aa14092554eda..5d4c86d1d76e8 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
@@ -69,7 +69,7 @@ program reduce
! CHECK: %[[VAL_4:.*]] = fir.embox %[[VAL_3]]#0(%[[VAL_2]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
! CHECK: %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
! CHECK: fir.store %[[VAL_4]] to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
-! CHECK: omp.parallel byref reduction(@add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
+! CHECK: omp.parallel reduction(byref @add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3xi32>>>, !fir.ref<!fir.box<!fir.array<3xi32>>>)
! CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
! CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-byref.f90 b/flang/test/Lower/OpenMP/parallel-reduction-byref.f90
index fdcdf0c0b8d95..5685e2c584ace 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-byref.f90
@@ -21,7 +21,7 @@
!CHECK: %[[RED_ACCUM_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[RED_ACCUM_REF]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[C0:[_a-z0-9]+]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[C0]] to %[[RED_ACCUM_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK: omp.parallel byref reduction(@[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref<i32>) {
+!CHECK: omp.parallel reduction(byref @[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref<i32>) {
!CHECK: %[[PRIVATE_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[PRIVATE_RED]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[C1:[_a-z0-9]+]] = arith.constant 1 : i32
!CHECK: hlfir.assign %[[C1]] to %[[PRIVATE_DECL]]#0 : i32, !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction3.f90 b/flang/test/Lower/OpenMP/parallel-reduction3.f90
index 17d805c0d142b..47b743a558b49 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction3.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction3.f90
@@ -74,7 +74,7 @@
! CHECK: %[[VAL_18:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_19:.*]] = fir.alloca !fir.box<!fir.array<?xi32>>
! CHECK: fir.store %[[VAL_12]]#0 to %[[VAL_19]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
-! CHECK: omp.wsloop byref reduction(@add_reduction_byref_box_Uxi32 %[[VAL_19]] -> %[[VAL_20:.*]] : !fir.ref<!fir.box<!fir.array<?xi32>>>) {
+! CHECK: omp.wsloop reduction(byref @add_reduction_byref_box_Uxi32 %[[VAL_19]] -> %[[VAL_20:.*]] : !fir.ref<!fir.box<!fir.array<?xi32>>>) {
! CHECK-NEXT: omp.loop_nest (%[[VAL_21:.*]]) : i32 = (%[[VAL_16]]) to (%[[VAL_17]]) inclusive step (%[[VAL_18]]) {
! CHECK: %[[VAL_22:.*]]:2 = hlfir.declare %[[VAL_20]] {uniq_name = "_QFsEc"} : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> (!fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.array<?xi32>>>)
! CHECK: fir.store %[[VAL_21]] to %[[VAL_15]]#1 : !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90 b/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90
index 66c80c31917ba..32caac39778de 100644
--- a/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90
@@ -4,7 +4,7 @@
! RUN: flang-new -fc1 -fopenmp -mmlir --force-byref-reduction -emit-hlfir %s -o - | FileCheck %s
! CHECK: omp.parallel {
-! CHECK: omp.wsloop byref reduction(@add_reduction_byref_i32
+! CHECK: omp.wsloop reduction(byref @add...
[truncated]
|
4880fd3
to
57abf53
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly OK to me. Have some minor comments inline.
Would adding the byref
attribute to the Reduction Declaration Operation work?
threads complete. | ||
accumulator variables in `reduction_vars`, symbols referring to reduction | ||
declarations in the `reductions` attribute, and whether the reduction | ||
variable should be passed into the redution region by value or by reference |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
variable should be passed into the redution region by value or by reference | |
variable should be passed into the reduction region by value or by reference |
auto reductionVarsByRef = getReductionVarsByref(); | ||
if (reductionVarsByRef && | ||
reductionVarsByRef->size() != getReductionVars().size()) | ||
return emitOpError() | ||
<< "expected as many reduction variable by reference attributes " | ||
"as reduction variables"; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be in verifyReductionVarList
?
Is there a test for this error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't think of a way to test for this error because we can't write text-form IR like this. I would have to create the operation with C++ like in a unit test, but I didn't know how to catch an error from a unit test.
llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionVarsByref()); | ||
assert(isByRef.size() == wsloopOp.getNumReductionVars()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this guaranteed by the verifier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I intended this more like a code comment. Would it be clearer to remove?
@@ -467,6 +468,10 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { | |||
wsloopOp.setReductionsAttr( | |||
ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); | |||
wsloopOp.getReductionVarsMutable().append(reductionVariables); | |||
llvm::SmallVector<bool> byRefVec; | |||
byRefVec.resize(reductionVariables.size(), false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Add a comment why this is all false
.
auto privateByRef = DenseBoolArrayAttr::get(parser.getContext(), {}); | ||
if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands, | ||
privateVarsTypes, privatizerSymbols, | ||
regionPrivateArgs))) | ||
privateVarsTypes, privateByRef, | ||
privatizerSymbols, regionPrivateArgs))) | ||
return failure(); | ||
if (llvm::any_of(privateByRef.asArrayRef(), | ||
[](bool byref) { return byref; })) { | ||
parser.emitError(parser.getCurrentLocation(), | ||
"private clause cannot have byref attributes"); | ||
return failure(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change necessary?
Do we have a test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is necessary because there is no toggle for byref on variable privatization. One alternative would be to silently ignore it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that required because parseClauseWithRegionArgs
is common for reduction and private?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
auto reductionVarsByRef = getReductionVarsByref(); | ||
if (reductionVarsByRef && | ||
reductionVarsByRef->size() != getReductionVars().size()) | ||
return emitOpError() | ||
<< "expected as many reduction variable by reference attributes " | ||
"as reduction variables"; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be in verifyReductionVarList?
Is there a test for this error?
Thanks for taking a look. Yes I think adding it to the ReductionDeclareOp would also work |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but wait for approval from Kiran
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, this looks good to me as well!
Fixes llvm#88935 Toggling reduction by-ref broke when multiple reduction clauses were used. Decisions made for the by-ref status for later clauses could then invalidate decisions for earlier clauses. For example, ``` reduction(+:scalar,scalar2) reduction(+:array) ``` The first clause would choose by value reduction and generate by-value reduction regions, but then after this the second clause would force by-ref to support the array argument. But by the time the second clause is processed, the first clause has already had the wrong kind of reduction regions generated. This is solved by toggling whether a variable should be reduced by reference per variable. In the above example, this allows only `array` to be reduced by ref.
e575e3c
to
ce7af67
Compare
Fixes #88935
Toggling reduction by-ref broke when multiple reduction clauses were used. Decisions made for the by-ref status for later clauses could then invalidate decisions for earlier clauses. For example,
The first clause would choose by value reduction and generate by-value reduction regions, but then after this the second clause would force by-ref to support the array argument. But by the time the second clause is processed, the first clause has already had the wrong kind of reduction regions generated.
This is solved by toggling whether a variable should be reduced by reference per variable. In the above example, this allows only
array
to be reduced by ref.