Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][MLIR][OpenMP] make reduction by-ref toggled per variable #92244

Merged
merged 6 commits into from
May 16, 2024

Conversation

tblah
Copy link
Contributor

@tblah tblah commented May 15, 2024

Fixes #88935

Toggling reduction by-ref broke when multiple reduction clauses were used. Decisions made for the by-ref status for later clauses could then invalidate decisions for earlier clauses. For example,

reduction(+:scalar,scalar2) reduction(+:array)

The first clause would choose by value reduction and generate by-value reduction regions, but then after this the second clause would force by-ref to support the array argument. But by the time the second clause is processed, the first clause has already had the wrong kind of reduction regions generated.

This is solved by toggling whether a variable should be reduced by reference per variable. In the above example, this allows only array to be reduced by ref.

@llvmbot
Copy link

llvmbot commented May 15, 2024

@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-mlir-openmp

Author: Tom Eccles (tblah)

Changes

Fixes #88935

Toggling reduction by-ref broke when multiple reduction clauses were used. Decisions made for the by-ref status for later clauses could then invalidate decisions for earlier clauses. For example,

reduction(+:scalar,scalar2) reduction(+:array)

The first clause would choose by value reduction and generate by-value reduction regions, but then after this the second clause would force by-ref to support the array argument. But by the time the second clause is processed, the first clause has already had the wrong kind of reduction regions generated.

This is solved by toggling whether a variable should be reduced by reference per variable. In the above example, this allows only array to be reduced by ref.


Patch is 116.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92244.diff

47 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+5-7)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (-6)
  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.cpp (+11-18)
  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.h (+1-3)
  • (modified) flang/test/Fir/omp-reduction-embox-codegen.fir (+1-1)
  • (modified) flang/test/Lower/OpenMP/default-clause-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-array.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-array2.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction3.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-byref.f90 (+7-7)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-allocatable.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array2.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-iand-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-ieor-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-ior-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-and-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-or-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-2-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-min-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-mul-byref.f90 (+7-7)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-multiple-clauses.f90 (+164)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+1-1)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+10-8)
  • (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+9-3)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+27-32)
  • (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+5)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+84-36)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+26-17)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+38)
  • (modified) mlir/test/Target/LLVMIR/openmp-parallel-reduction-cleanup.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-reduction-byref.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-wsloop-reduction-cleanup.mlir (+1-1)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 0ea87314d571f..a57b96e365999 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -933,20 +933,18 @@ bool ClauseProcessor::processReduction(
   return findRepeatableClause<omp::clause::Reduction>(
       [&](const omp::clause::Reduction &clause,
           const Fortran::parser::CharBlock &) {
-        // Use local lists of reductions to prevent variables from other
-        // already-processed reduction clauses from impacting this reduction.
-        // For example, the whole `reductionVars` array is queried to decide
-        // whether to do the reduction byref.
         llvm::SmallVector<mlir::Value> reductionVars;
+        llvm::SmallVector<bool> reduceVarByRef;
         llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
         llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
         ReductionProcessor rp;
-        rp.addDeclareReduction(currentLocation, converter, clause,
-                               reductionVars, reductionDeclSymbols,
-                               outReductionSyms ? &reductionSyms : nullptr);
+        rp.addDeclareReduction(
+            currentLocation, converter, clause, reductionVars, reduceVarByRef,
+            reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr);
 
         // Copy local lists into the output.
         llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
+        llvm::copy(reduceVarByRef, std::back_inserter(result.reduceVarByRef));
         llvm::copy(reductionDeclSymbols,
                    std::back_inserter(result.reductionDeclSymbols));
 
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f21acdd64d7c3..aaf0c6501ab20 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1007,8 +1007,6 @@ static void genParallelClauses(
 
   if (processReduction) {
     cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
-    if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
-      clauseOps.reductionByRefAttr = converter.getFirOpBuilder().getUnitAttr();
   }
 }
 
@@ -1200,7 +1198,6 @@ static void genWsloopClauses(
     mlir::Location loc, mlir::omp::WsloopClauseOps &clauseOps,
     llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processNowait(clauseOps);
   cp.processOrdered(clauseOps);
@@ -1208,9 +1205,6 @@ static void genWsloopClauses(
   cp.processSchedule(stmtCtx, clauseOps);
   // TODO Support delayed privatization.
 
-  if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
-    clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
-
   cp.processTODO<clause::Allocate, clause::Linear, clause::Order>(
       loc, llvm::omp::Directive::OMPD_do);
 }
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index b3f08eb81c799..689f3adc0a429 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -657,25 +657,17 @@ mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
   return decl;
 }
 
-// TODO: By-ref vs by-val reductions are currently toggled for the whole
-//       operation (possibly effecting multiple reduction variables).
-//       This could cause a problem with openmp target reductions because
-//       by-ref trivial types may not be supported.
-bool ReductionProcessor::doReductionByRef(
-    const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
-  if (reductionVars.empty())
-    return false;
+static bool doReductionByRef(mlir::Value reductionVar) {
   if (forceByrefReduction)
     return true;
 
-  for (mlir::Value reductionVar : reductionVars) {
-    if (auto declare =
-            mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
-      reductionVar = declare.getMemref();
+  if (auto declare =
+          mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
+    reductionVar = declare.getMemref();
+
+  if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
+    return true;
 
-    if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
-      return true;
-  }
   return false;
 }
 
@@ -684,6 +676,7 @@ void ReductionProcessor::addDeclareReduction(
     Fortran::lower::AbstractConverter &converter,
     const omp::clause::Reduction &reduction,
     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+    llvm::SmallVectorImpl<bool> &reduceVarByRef,
     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
         *reductionSymbols) {
@@ -764,8 +757,8 @@ void ReductionProcessor::addDeclareReduction(
            "reduction input var is a reference");
 
     reductionVars.push_back(symVal);
+    reduceVarByRef.push_back(doReductionByRef(symVal));
   }
-  const bool isByRef = doReductionByRef(reductionVars);
 
   if (const auto &redDefinedOp =
           std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
@@ -787,7 +780,7 @@ void ReductionProcessor::addDeclareReduction(
       break;
     }
 
-    for (mlir::Value symVal : reductionVars) {
+    for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
       auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
       const auto &kindMap = firOpBuilder.getKindMap();
       if (mlir::isa<fir::LogicalType>(redType.getEleTy()))
@@ -811,7 +804,7 @@ void ReductionProcessor::addDeclareReduction(
             *reductionIntrinsic)) {
       ReductionProcessor::ReductionIdentifier redId =
           ReductionProcessor::getReductionType(*reductionIntrinsic);
-      for (mlir::Value symVal : reductionVars) {
+      for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
         auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
         if (!redType.getEleTy().isIntOrIndexOrFloat())
           TODO(currentLocation,
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 8b116a4c52041..95d77c8154415 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -73,9 +73,6 @@ class ReductionProcessor {
   static const Fortran::semantics::SourceName
   getRealName(const omp::clause::ProcedureDesignator &pd);
 
-  static bool
-  doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);
-
   static std::string getReductionName(llvm::StringRef name,
                                       const fir::KindMapping &kindMap,
                                       mlir::Type ty, bool isByRef);
@@ -128,6 +125,7 @@ class ReductionProcessor {
       Fortran::lower::AbstractConverter &converter,
       const omp::clause::Reduction &reduction,
       llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+      llvm::SmallVectorImpl<bool> &reduceVarByRef,
       llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
       llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
           *reductionSymbols = nullptr);
diff --git a/flang/test/Fir/omp-reduction-embox-codegen.fir b/flang/test/Fir/omp-reduction-embox-codegen.fir
index 7602012ebc5c9..1645e1a407ad4 100644
--- a/flang/test/Fir/omp-reduction-embox-codegen.fir
+++ b/flang/test/Fir/omp-reduction-embox-codegen.fir
@@ -25,7 +25,7 @@ omp.declare_reduction @test_reduction : !fir.ref<!fir.box<i32>> init {
 
 func.func @_QQmain() attributes {fir.bindc_name = "reduce"} {
   %4 = fir.alloca !fir.box<i32>
-  omp.parallel byref reduction(@test_reduction %4 -> %arg0 : !fir.ref<!fir.box<i32>>) {
+  omp.parallel reduction(byref @test_reduction %4 -> %arg0 : !fir.ref<!fir.box<i32>>) {
     omp.terminator
   }
   return
diff --git a/flang/test/Lower/OpenMP/default-clause-byref.f90 b/flang/test/Lower/OpenMP/default-clause-byref.f90
index 7cc2bc2e0c710..7893c4d7d5732 100644
--- a/flang/test/Lower/OpenMP/default-clause-byref.f90
+++ b/flang/test/Lower/OpenMP/default-clause-byref.f90
@@ -351,7 +351,7 @@ subroutine skipped_default_clause_checks()
        type(it)::iii
 
 !CHECK: omp.parallel {
-!CHECK: omp.wsloop byref reduction(@min_byref_i32 %[[VAL_Z_DECLARE]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK: omp.wsloop reduction(byref @min_byref_i32 %[[VAL_Z_DECLARE]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
 !CHECK-NEXT: omp.loop_nest (%[[ARG:.*]]) {{.*}} {
 !CHECK: omp.yield
 !CHECK: }
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
index 49d1142ea4b6a..72e91680a4310 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
@@ -26,5 +26,5 @@ subroutine red_and_delayed_private
 
 ! CHECK-LABEL: _QPred_and_delayed_private
 ! CHECK: omp.parallel
-! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
 ! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90 b/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90
index 2a1d26407b27e..7347d9324feac 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90
@@ -40,7 +40,7 @@
 !CHECK:  %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
 !CHECK:  hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK:  omp.parallel byref reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK:  omp.parallel reduction(byref @[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
 !CHECK:    %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:    %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<i32>
 !CHECK:    %[[I_INCR:.*]] = arith.constant 1 : i32
@@ -65,7 +65,7 @@ subroutine simple_int_add
 !CHECK:  %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 !CHECK:  %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
 !CHECK:  hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
-!CHECK:  omp.parallel byref reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK:  omp.parallel reduction(byref @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
 !CHECK:    %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 !CHECK:    %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<f32>
 !CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
@@ -94,7 +94,7 @@ subroutine simple_real_add
 !CHECK:  hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
 !CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
 !CHECK:  hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK:  omp.parallel byref reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
+!CHECK:  omp.parallel reduction(byref @[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, byref @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
 !CHECK:    %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:    %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 !CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90 b/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90
index 28216ef91c3a3..f6d3b0b73f738 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90
@@ -95,7 +95,7 @@ program reduce
 ! CHECK:             %[[VAL_14:.*]] = arith.constant 0 : i32
 ! CHECK:             %[[VAL_15:.*]] = arith.constant 10 : i32
 ! CHECK:             %[[VAL_16:.*]] = arith.constant 1 : i32
-! CHECK:             omp.wsloop byref reduction(@add_reduction_byref_box_heap_Uxi32 %[[VAL_3]]#0 -> %[[VAL_17:.*]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) {
+! CHECK:             omp.wsloop reduction(byref @add_reduction_byref_box_heap_Uxi32 %[[VAL_3]]#0 -> %[[VAL_17:.*]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) {
 ! CHECK-NEXT:          omp.loop_nest (%[[VAL_18:.*]]) : i32 = (%[[VAL_14]]) to (%[[VAL_15]]) inclusive step (%[[VAL_16]]) {
 ! CHECK:                 %[[VAL_19:.*]]:2 = hlfir.declare %[[VAL_17]] {fortran_attrs = {{.*}}<allocatable>, uniq_name = "_QFEr"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 ! CHECK:                 fir.store %[[VAL_18]] to %[[VAL_13]]#1 : !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90
index 8202e6d897157..b44fe4c1f4cc2 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90
@@ -77,7 +77,7 @@ program reduce
 ! CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_5]]) {uniq_name = "_QFEi"} : (!fir.ref<!fir.array<3x2xi32>>, !fir.shapeshift<2>) -> (!fir.box<!fir.array<3x2xi32>>, !fir.ref<!fir.array<3x2xi32>>)
 ! CHECK:           %[[VAL_7:.*]] = fir.alloca !fir.box<!fir.array<3x2xi32>>
 ! CHECK:           fir.store %[[VAL_6]]#0 to %[[VAL_7]] : !fir.ref<!fir.box<!fir.array<3x2xi32>>>
-! CHECK:           omp.parallel byref reduction(@add_reduction_byref_box_3x2xi32 %[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref<!fir.box<!fir.array<3x2xi32>>>) {
+! CHECK:           omp.parallel reduction(byref @add_reduction_byref_box_3x2xi32 %[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref<!fir.box<!fir.array<3x2xi32>>>) {
 ! CHECK:             %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3x2xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3x2xi32>>>, !fir.ref<!fir.box<!fir.array<3x2xi32>>>)
 ! CHECK:             %[[VAL_10:.*]] = arith.constant 3 : i32
 ! CHECK:             %[[VAL_11:.*]] = fir.load %[[VAL_9]]#0 : !fir.ref<!fir.box<!fir.array<3x2xi32>>>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array.f90
index 34f4ee0a9eb3a..60b21c9b1ebbe 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array.f90
@@ -70,7 +70,7 @@ program reduce
 ! CHECK:           %[[VAL_4:.*]] = fir.embox %[[VAL_3]]#0(%[[VAL_2]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
 ! CHECK:           %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
 ! CHECK:           fir.store %[[VAL_4]] to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
-! CHECK:           omp.parallel byref reduction(@add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
+! CHECK:           omp.parallel reduction(byref @add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
 ! CHECK:             %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3xi32>>>, !fir.ref<!fir.box<!fir.array<3xi32>>>)
 ! CHECK:             %[[VAL_8:.*]] = arith.constant 1 : i32
 ! CHECK:             %[[VAL_9:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array2.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
index aa14092554eda..5d4c86d1d76e8 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
@@ -69,7 +69,7 @@ program reduce
 ! CHECK:           %[[VAL_4:.*]] = fir.embox %[[VAL_3]]#0(%[[VAL_2]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
 ! CHECK:           %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
 ! CHECK:           fir.store %[[VAL_4]] to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
-! CHECK:           omp.parallel byref reduction(@add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
+! CHECK:           omp.parallel reduction(byref @add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
 ! CHECK:             %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3xi32>>>, !fir.ref<!fir.box<!fir.array<3xi32>>>)
 ! CHECK:             %[[VAL_8:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
 ! CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-byref.f90 b/flang/test/Lower/OpenMP/parallel-reduction-byref.f90
index fdcdf0c0b8d95..5685e2c584ace 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-byref.f90
@@ -21,7 +21,7 @@
 !CHECK:    %[[RED_ACCUM_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[RED_ACCUM_REF]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:    %[[C0:[_a-z0-9]+]] = arith.constant 0 : i32
 !CHECK:    hlfir.assign %[[C0]] to %[[RED_ACCUM_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK:    omp.parallel byref reduction(@[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref<i32>) {
+!CHECK:    omp.parallel reduction(byref @[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref<i32>) {
 !CHECK:      %[[PRIVATE_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[PRIVATE_RED]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:      %[[C1:[_a-z0-9]+]] = arith.constant 1 : i32
 !CHECK:      hlfir.assign %[[C1]] to %[[PRIVATE_DECL]]#0 : i32, !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction3.f90 b/flang/test/Lower/OpenMP/parallel-reduction3.f90
index 17d805c0d142b..47b743a558b49 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction3.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction3.f90
@@ -74,7 +74,7 @@
 ! CHECK:             %[[VAL_18:.*]] = arith.constant 1 : i32
 ! CHECK:             %[[VAL_19:.*]] = fir.alloca !fir.box<!fir.array<?xi32>>
 ! CHECK:             fir.store %[[VAL_12]]#0 to %[[VAL_19]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
-! CHECK:             omp.wsloop byref reduction(@add_reduction_byref_box_Uxi32 %[[VAL_19]] -> %[[VAL_20:.*]] : !fir.ref<!fir.box<!fir.array<?xi32>>>) {
+! CHECK:             omp.wsloop reduction(byref @add_reduction_byref_box_Uxi32 %[[VAL_19]] -> %[[VAL_20:.*]] : !fir.ref<!fir.box<!fir.array<?xi32>>>) {
 ! CHECK-NEXT:          omp.loop_nest (%[[VAL_21:.*]]) : i32 = (%[[VAL_16]]) to (%[[VAL_17]]) inclusive step (%[[VAL_18]]) {
 ! CHECK:                 %[[VAL_22:.*]]:2 = hlfir.declare %[[VAL_20]] {uniq_name = "_QFsEc"} : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> (!fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.array<?xi32>>>)
 ! CHECK:                 fir.store %[[VAL_21]] to %[[VAL_15]]#1 : !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90 b/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90
index 66c80c31917ba..32caac39778de 100644
--- a/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90
@@ -4,7 +4,7 @@
 ! RUN: flang-new -fc1 -fopenmp -mmlir --force-byref-reduction -emit-hlfir %s -o - | FileCheck %s
 
 ! CHECK: omp.parallel {
-! CHECK: omp.wsloop byref reduction(@add_reduction_byref_i32
+! CHECK: omp.wsloop reduction(byref @add...
[truncated]

@llvmbot
Copy link

llvmbot commented May 15, 2024

@llvm/pr-subscribers-mlir

Author: Tom Eccles (tblah)

Changes

Fixes #88935

Toggling reduction by-ref broke when multiple reduction clauses were used. Decisions made for the by-ref status for later clauses could then invalidate decisions for earlier clauses. For example,

reduction(+:scalar,scalar2) reduction(+:array)

The first clause would choose by value reduction and generate by-value reduction regions, but then after this the second clause would force by-ref to support the array argument. But by the time the second clause is processed, the first clause has already had the wrong kind of reduction regions generated.

This is solved by toggling whether a variable should be reduced by reference per variable. In the above example, this allows only array to be reduced by ref.


Patch is 116.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92244.diff

47 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+5-7)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (-6)
  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.cpp (+11-18)
  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.h (+1-3)
  • (modified) flang/test/Fir/omp-reduction-embox-codegen.fir (+1-1)
  • (modified) flang/test/Lower/OpenMP/default-clause-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-array.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-array2.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-reduction3.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-byref.f90 (+7-7)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-allocatable.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array-assumed-shape.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-array2.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-iand-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-ieor-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-ior-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-and-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-or-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-2-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir-byref.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-min-byref.f90 (+3-3)
  • (modified) flang/test/Lower/OpenMP/wsloop-reduction-mul-byref.f90 (+7-7)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-multiple-clauses.f90 (+164)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+1-1)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+10-8)
  • (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+9-3)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+27-32)
  • (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+5)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+84-36)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+26-17)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+38)
  • (modified) mlir/test/Target/LLVMIR/openmp-parallel-reduction-cleanup.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-reduction-byref.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-wsloop-reduction-cleanup.mlir (+1-1)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 0ea87314d571f..a57b96e365999 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -933,20 +933,18 @@ bool ClauseProcessor::processReduction(
   return findRepeatableClause<omp::clause::Reduction>(
       [&](const omp::clause::Reduction &clause,
           const Fortran::parser::CharBlock &) {
-        // Use local lists of reductions to prevent variables from other
-        // already-processed reduction clauses from impacting this reduction.
-        // For example, the whole `reductionVars` array is queried to decide
-        // whether to do the reduction byref.
         llvm::SmallVector<mlir::Value> reductionVars;
+        llvm::SmallVector<bool> reduceVarByRef;
         llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
         llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
         ReductionProcessor rp;
-        rp.addDeclareReduction(currentLocation, converter, clause,
-                               reductionVars, reductionDeclSymbols,
-                               outReductionSyms ? &reductionSyms : nullptr);
+        rp.addDeclareReduction(
+            currentLocation, converter, clause, reductionVars, reduceVarByRef,
+            reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr);
 
         // Copy local lists into the output.
         llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
+        llvm::copy(reduceVarByRef, std::back_inserter(result.reduceVarByRef));
         llvm::copy(reductionDeclSymbols,
                    std::back_inserter(result.reductionDeclSymbols));
 
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f21acdd64d7c3..aaf0c6501ab20 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1007,8 +1007,6 @@ static void genParallelClauses(
 
   if (processReduction) {
     cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
-    if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
-      clauseOps.reductionByRefAttr = converter.getFirOpBuilder().getUnitAttr();
   }
 }
 
@@ -1200,7 +1198,6 @@ static void genWsloopClauses(
     mlir::Location loc, mlir::omp::WsloopClauseOps &clauseOps,
     llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processNowait(clauseOps);
   cp.processOrdered(clauseOps);
@@ -1208,9 +1205,6 @@ static void genWsloopClauses(
   cp.processSchedule(stmtCtx, clauseOps);
   // TODO Support delayed privatization.
 
-  if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
-    clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
-
   cp.processTODO<clause::Allocate, clause::Linear, clause::Order>(
       loc, llvm::omp::Directive::OMPD_do);
 }
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index b3f08eb81c799..689f3adc0a429 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -657,25 +657,17 @@ mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
   return decl;
 }
 
-// TODO: By-ref vs by-val reductions are currently toggled for the whole
-//       operation (possibly effecting multiple reduction variables).
-//       This could cause a problem with openmp target reductions because
-//       by-ref trivial types may not be supported.
-bool ReductionProcessor::doReductionByRef(
-    const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
-  if (reductionVars.empty())
-    return false;
+static bool doReductionByRef(mlir::Value reductionVar) {
   if (forceByrefReduction)
     return true;
 
-  for (mlir::Value reductionVar : reductionVars) {
-    if (auto declare =
-            mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
-      reductionVar = declare.getMemref();
+  if (auto declare =
+          mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
+    reductionVar = declare.getMemref();
+
+  if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
+    return true;
 
-    if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
-      return true;
-  }
   return false;
 }
 
@@ -684,6 +676,7 @@ void ReductionProcessor::addDeclareReduction(
     Fortran::lower::AbstractConverter &converter,
     const omp::clause::Reduction &reduction,
     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+    llvm::SmallVectorImpl<bool> &reduceVarByRef,
     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
         *reductionSymbols) {
@@ -764,8 +757,8 @@ void ReductionProcessor::addDeclareReduction(
            "reduction input var is a reference");
 
     reductionVars.push_back(symVal);
+    reduceVarByRef.push_back(doReductionByRef(symVal));
   }
-  const bool isByRef = doReductionByRef(reductionVars);
 
   if (const auto &redDefinedOp =
           std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
@@ -787,7 +780,7 @@ void ReductionProcessor::addDeclareReduction(
       break;
     }
 
-    for (mlir::Value symVal : reductionVars) {
+    for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
       auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
       const auto &kindMap = firOpBuilder.getKindMap();
       if (mlir::isa<fir::LogicalType>(redType.getEleTy()))
@@ -811,7 +804,7 @@ void ReductionProcessor::addDeclareReduction(
             *reductionIntrinsic)) {
       ReductionProcessor::ReductionIdentifier redId =
           ReductionProcessor::getReductionType(*reductionIntrinsic);
-      for (mlir::Value symVal : reductionVars) {
+      for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
         auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
         if (!redType.getEleTy().isIntOrIndexOrFloat())
           TODO(currentLocation,
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 8b116a4c52041..95d77c8154415 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -73,9 +73,6 @@ class ReductionProcessor {
   static const Fortran::semantics::SourceName
   getRealName(const omp::clause::ProcedureDesignator &pd);
 
-  static bool
-  doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);
-
   static std::string getReductionName(llvm::StringRef name,
                                       const fir::KindMapping &kindMap,
                                       mlir::Type ty, bool isByRef);
@@ -128,6 +125,7 @@ class ReductionProcessor {
       Fortran::lower::AbstractConverter &converter,
       const omp::clause::Reduction &reduction,
       llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+      llvm::SmallVectorImpl<bool> &reduceVarByRef,
       llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
       llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
           *reductionSymbols = nullptr);
diff --git a/flang/test/Fir/omp-reduction-embox-codegen.fir b/flang/test/Fir/omp-reduction-embox-codegen.fir
index 7602012ebc5c9..1645e1a407ad4 100644
--- a/flang/test/Fir/omp-reduction-embox-codegen.fir
+++ b/flang/test/Fir/omp-reduction-embox-codegen.fir
@@ -25,7 +25,7 @@ omp.declare_reduction @test_reduction : !fir.ref<!fir.box<i32>> init {
 
 func.func @_QQmain() attributes {fir.bindc_name = "reduce"} {
   %4 = fir.alloca !fir.box<i32>
-  omp.parallel byref reduction(@test_reduction %4 -> %arg0 : !fir.ref<!fir.box<i32>>) {
+  omp.parallel reduction(byref @test_reduction %4 -> %arg0 : !fir.ref<!fir.box<i32>>) {
     omp.terminator
   }
   return
diff --git a/flang/test/Lower/OpenMP/default-clause-byref.f90 b/flang/test/Lower/OpenMP/default-clause-byref.f90
index 7cc2bc2e0c710..7893c4d7d5732 100644
--- a/flang/test/Lower/OpenMP/default-clause-byref.f90
+++ b/flang/test/Lower/OpenMP/default-clause-byref.f90
@@ -351,7 +351,7 @@ subroutine skipped_default_clause_checks()
        type(it)::iii
 
 !CHECK: omp.parallel {
-!CHECK: omp.wsloop byref reduction(@min_byref_i32 %[[VAL_Z_DECLARE]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK: omp.wsloop reduction(byref @min_byref_i32 %[[VAL_Z_DECLARE]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
 !CHECK-NEXT: omp.loop_nest (%[[ARG:.*]]) {{.*}} {
 !CHECK: omp.yield
 !CHECK: }
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
index 49d1142ea4b6a..72e91680a4310 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
@@ -26,5 +26,5 @@ subroutine red_and_delayed_private
 
 ! CHECK-LABEL: _QPred_and_delayed_private
 ! CHECK: omp.parallel
-! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
 ! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90 b/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90
index 2a1d26407b27e..7347d9324feac 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90
@@ -40,7 +40,7 @@
 !CHECK:  %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
 !CHECK:  hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK:  omp.parallel byref reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK:  omp.parallel reduction(byref @[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
 !CHECK:    %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:    %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<i32>
 !CHECK:    %[[I_INCR:.*]] = arith.constant 1 : i32
@@ -65,7 +65,7 @@ subroutine simple_int_add
 !CHECK:  %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 !CHECK:  %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
 !CHECK:  hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
-!CHECK:  omp.parallel byref reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK:  omp.parallel reduction(byref @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
 !CHECK:    %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 !CHECK:    %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<f32>
 !CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
@@ -94,7 +94,7 @@ subroutine simple_real_add
 !CHECK:  hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
 !CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
 !CHECK:  hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK:  omp.parallel byref reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
+!CHECK:  omp.parallel reduction(byref @[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, byref @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
 !CHECK:    %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:    %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 !CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90 b/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90
index 28216ef91c3a3..f6d3b0b73f738 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-allocatable-array.f90
@@ -95,7 +95,7 @@ program reduce
 ! CHECK:             %[[VAL_14:.*]] = arith.constant 0 : i32
 ! CHECK:             %[[VAL_15:.*]] = arith.constant 10 : i32
 ! CHECK:             %[[VAL_16:.*]] = arith.constant 1 : i32
-! CHECK:             omp.wsloop byref reduction(@add_reduction_byref_box_heap_Uxi32 %[[VAL_3]]#0 -> %[[VAL_17:.*]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) {
+! CHECK:             omp.wsloop reduction(byref @add_reduction_byref_box_heap_Uxi32 %[[VAL_3]]#0 -> %[[VAL_17:.*]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) {
 ! CHECK-NEXT:          omp.loop_nest (%[[VAL_18:.*]]) : i32 = (%[[VAL_14]]) to (%[[VAL_15]]) inclusive step (%[[VAL_16]]) {
 ! CHECK:                 %[[VAL_19:.*]]:2 = hlfir.declare %[[VAL_17]] {fortran_attrs = {{.*}}<allocatable>, uniq_name = "_QFEr"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 ! CHECK:                 fir.store %[[VAL_18]] to %[[VAL_13]]#1 : !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90
index 8202e6d897157..b44fe4c1f4cc2 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array-lb.f90
@@ -77,7 +77,7 @@ program reduce
 ! CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_5]]) {uniq_name = "_QFEi"} : (!fir.ref<!fir.array<3x2xi32>>, !fir.shapeshift<2>) -> (!fir.box<!fir.array<3x2xi32>>, !fir.ref<!fir.array<3x2xi32>>)
 ! CHECK:           %[[VAL_7:.*]] = fir.alloca !fir.box<!fir.array<3x2xi32>>
 ! CHECK:           fir.store %[[VAL_6]]#0 to %[[VAL_7]] : !fir.ref<!fir.box<!fir.array<3x2xi32>>>
-! CHECK:           omp.parallel byref reduction(@add_reduction_byref_box_3x2xi32 %[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref<!fir.box<!fir.array<3x2xi32>>>) {
+! CHECK:           omp.parallel reduction(byref @add_reduction_byref_box_3x2xi32 %[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref<!fir.box<!fir.array<3x2xi32>>>) {
 ! CHECK:             %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_8]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3x2xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3x2xi32>>>, !fir.ref<!fir.box<!fir.array<3x2xi32>>>)
 ! CHECK:             %[[VAL_10:.*]] = arith.constant 3 : i32
 ! CHECK:             %[[VAL_11:.*]] = fir.load %[[VAL_9]]#0 : !fir.ref<!fir.box<!fir.array<3x2xi32>>>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array.f90
index 34f4ee0a9eb3a..60b21c9b1ebbe 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array.f90
@@ -70,7 +70,7 @@ program reduce
 ! CHECK:           %[[VAL_4:.*]] = fir.embox %[[VAL_3]]#0(%[[VAL_2]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
 ! CHECK:           %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
 ! CHECK:           fir.store %[[VAL_4]] to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
-! CHECK:           omp.parallel byref reduction(@add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
+! CHECK:           omp.parallel reduction(byref @add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
 ! CHECK:             %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3xi32>>>, !fir.ref<!fir.box<!fir.array<3xi32>>>)
 ! CHECK:             %[[VAL_8:.*]] = arith.constant 1 : i32
 ! CHECK:             %[[VAL_9:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-array2.f90 b/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
index aa14092554eda..5d4c86d1d76e8 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-array2.f90
@@ -69,7 +69,7 @@ program reduce
 ! CHECK:           %[[VAL_4:.*]] = fir.embox %[[VAL_3]]#0(%[[VAL_2]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
 ! CHECK:           %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
 ! CHECK:           fir.store %[[VAL_4]] to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
-! CHECK:           omp.parallel byref reduction(@add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
+! CHECK:           omp.parallel reduction(byref @add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
 ! CHECK:             %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3xi32>>>, !fir.ref<!fir.box<!fir.array<3xi32>>>)
 ! CHECK:             %[[VAL_8:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
 ! CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-byref.f90 b/flang/test/Lower/OpenMP/parallel-reduction-byref.f90
index fdcdf0c0b8d95..5685e2c584ace 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction-byref.f90
@@ -21,7 +21,7 @@
 !CHECK:    %[[RED_ACCUM_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[RED_ACCUM_REF]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:    %[[C0:[_a-z0-9]+]] = arith.constant 0 : i32
 !CHECK:    hlfir.assign %[[C0]] to %[[RED_ACCUM_DECL]]#0 : i32, !fir.ref<i32>
-!CHECK:    omp.parallel byref reduction(@[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref<i32>) {
+!CHECK:    omp.parallel reduction(byref @[[REDUCTION_DECLARE]] %[[RED_ACCUM_DECL]]#0 -> %[[PRIVATE_RED:[a-z0-9]+]] : !fir.ref<i32>) {
 !CHECK:      %[[PRIVATE_DECL:[_a-z0-9]+]]:2 = hlfir.declare %[[PRIVATE_RED]] {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 !CHECK:      %[[C1:[_a-z0-9]+]] = arith.constant 1 : i32
 !CHECK:      hlfir.assign %[[C1]] to %[[PRIVATE_DECL]]#0 : i32, !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/parallel-reduction3.f90 b/flang/test/Lower/OpenMP/parallel-reduction3.f90
index 17d805c0d142b..47b743a558b49 100644
--- a/flang/test/Lower/OpenMP/parallel-reduction3.f90
+++ b/flang/test/Lower/OpenMP/parallel-reduction3.f90
@@ -74,7 +74,7 @@
 ! CHECK:             %[[VAL_18:.*]] = arith.constant 1 : i32
 ! CHECK:             %[[VAL_19:.*]] = fir.alloca !fir.box<!fir.array<?xi32>>
 ! CHECK:             fir.store %[[VAL_12]]#0 to %[[VAL_19]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
-! CHECK:             omp.wsloop byref reduction(@add_reduction_byref_box_Uxi32 %[[VAL_19]] -> %[[VAL_20:.*]] : !fir.ref<!fir.box<!fir.array<?xi32>>>) {
+! CHECK:             omp.wsloop reduction(byref @add_reduction_byref_box_Uxi32 %[[VAL_19]] -> %[[VAL_20:.*]] : !fir.ref<!fir.box<!fir.array<?xi32>>>) {
 ! CHECK-NEXT:          omp.loop_nest (%[[VAL_21:.*]]) : i32 = (%[[VAL_16]]) to (%[[VAL_17]]) inclusive step (%[[VAL_18]]) {
 ! CHECK:                 %[[VAL_22:.*]]:2 = hlfir.declare %[[VAL_20]] {uniq_name = "_QFsEc"} : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> (!fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.array<?xi32>>>)
 ! CHECK:                 fir.store %[[VAL_21]] to %[[VAL_15]]#1 : !fir.ref<i32>
diff --git a/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90 b/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90
index 66c80c31917ba..32caac39778de 100644
--- a/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90
@@ -4,7 +4,7 @@
 ! RUN: flang-new -fc1 -fopenmp -mmlir --force-byref-reduction -emit-hlfir %s -o - | FileCheck %s
 
 ! CHECK: omp.parallel {
-! CHECK: omp.wsloop byref reduction(@add_reduction_byref_i32
+! CHECK: omp.wsloop reduction(byref @add...
[truncated]

@tblah tblah force-pushed the ecclescake/multiple-reduction-clauses branch from 4880fd3 to 57abf53 Compare May 15, 2024 11:02
Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly OK to me. Have some minor comments inline.

Would adding the byref attribute to the Reduction Declaration Operation work?

threads complete.
accumulator variables in `reduction_vars`, symbols referring to reduction
declarations in the `reductions` attribute, and whether the reduction
variable should be passed into the redution region by value or by reference
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
variable should be passed into the redution region by value or by reference
variable should be passed into the reduction region by value or by reference

Comment on lines 1523 to 1529
auto reductionVarsByRef = getReductionVarsByref();
if (reductionVarsByRef &&
reductionVarsByRef->size() != getReductionVars().size())
return emitOpError()
<< "expected as many reduction variable by reference attributes "
"as reduction variables";

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in verifyReductionVarList?

Is there a test for this error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't think of a way to test for this error because we can't write text-form IR like this. I would have to create the operation with C++ like in a unit test, but I didn't know how to catch an error from a unit test.

Comment on lines +939 to +940
llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionVarsByref());
assert(isByRef.size() == wsloopOp.getNumReductionVars());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this guaranteed by the verifier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I intended this more like a code comment. Would it be clearer to remove?

@@ -467,6 +468,10 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
wsloopOp.setReductionsAttr(
ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
wsloopOp.getReductionVarsMutable().append(reductionVariables);
llvm::SmallVector<bool> byRefVec;
byRefVec.resize(reductionVariables.size(), false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Add a comment why this is all false.

Comment on lines +514 to +524
auto privateByRef = DenseBoolArrayAttr::get(parser.getContext(), {});
if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands,
privateVarsTypes, privatizerSymbols,
regionPrivateArgs)))
privateVarsTypes, privateByRef,
privatizerSymbols, regionPrivateArgs)))
return failure();
if (llvm::any_of(privateByRef.asArrayRef(),
[](bool byref) { return byref; })) {
parser.emitError(parser.getCurrentLocation(),
"private clause cannot have byref attributes");
return failure();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary?

Do we have a test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is necessary because there is no toggle for byref on variable privatization. One alternative would be to silently ignore it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that required because parseClauseWithRegionArgs is common for reduction and private?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Comment on lines 1712 to 1718
auto reductionVarsByRef = getReductionVarsByref();
if (reductionVarsByRef &&
reductionVarsByRef->size() != getReductionVars().size())
return emitOpError()
<< "expected as many reduction variable by reference attributes "
"as reduction variables";

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in verifyReductionVarList?

Is there a test for this error?

@tblah
Copy link
Contributor Author

tblah commented May 16, 2024

Looks mostly OK to me. Have some minor comments inline.

Would adding the byref attribute to the Reduction Declaration Operation work?

Thanks for taking a look. Yes I think adding it to the ReductionDeclareOp would also work

Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Member

@DavidTruby DavidTruby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but wait for approval from Kiran

Copy link
Contributor

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, this looks good to me as well!

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp Show resolved Hide resolved
tblah added 6 commits May 16, 2024 14:19
Fixes llvm#88935

Toggling reduction by-ref broke when multiple reduction clauses were
used. Decisions made for the by-ref status for later clauses could then
invalidate decisions for earlier clauses. For example,

```
reduction(+:scalar,scalar2) reduction(+:array)
```

The first clause would choose by value reduction and generate by-value
reduction regions, but then after this the second clause would force
by-ref to support the array argument. But by the time the second clause
is processed, the first clause has already had the wrong kind of
reduction regions generated.

This is solved by toggling whether a variable should be reduced by
reference per variable. In the above example, this allows only `array`
to be reduced by ref.
@tblah tblah force-pushed the ecclescake/multiple-reduction-clauses branch from e575e3c to ce7af67 Compare May 16, 2024 14:23
@tblah tblah merged commit 74a8754 into llvm:main May 16, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:openmp OpenMP related changes to Clang flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:llvm mlir:openmp mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Flang][OpenMP] assertion failure on reduction
5 participants