Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][OpenMP][OMPIRBuilder][mlir] Optionally pass reduction vars by ref #84304

Merged
merged 5 commits into from
Mar 13, 2024

Conversation

tblah
Copy link
Contributor

@tblah tblah commented Mar 7, 2024

Previously reduction variables were always passed by value into and out of the initialization and combiner regions of the OpenMP reduction declare operation.

This worked well for reductions of primitive types (and might perform better than passing by reference). But passing by reference will be useful for array and derived type reductions (e.g. to move allocation inside of the init region).

Passing reductions by reference requires different LLVM-IR generation when lowering from MLIR because some of the loads/stores/allocations will now be moved inside of the init and combiner regions. This alternate code generation is requested using a new attribute to omp.wsloop and omp.parallel.

Existing lowerings from mlir are unaffected (these will continue to use the by-value argument passing.

Flang will continue to pass by-value argument passing for trivial types unless a (hidden) command line argument is supplied. Non-trivial types will always use the by-ref lowering.

Array reductions are not ready yet (but are coming very soon). In the meantime, this is tested by forcing existing reductions to use by-ref.

Commit series for by-ref OpenMP reductions 3/3

tblah and others added 3 commits March 7, 2024 10:53
This will be useful for OpenMP too.

I changed the definition slightly to use `fir::isa_ref_type` (which also
includes llvm pointers) because I think it reads better using the common
type helpers. There shouldn't be any llvm pointers in lowering so this
isn't a functional change.

Commit series for by-ref openmp reductions: 1/3
TBAA builder assumed that all loads/stores are inside of functions and
hit an assertion once it found loads and stores inside of an
omp::ReductionDeclareOp.

For now just don't add TBAA tags to those loads and stores. They would
end up in a different TBAA tree to the host function after
OpenMPIRBuilder inlines them anyway so there isn't an easy way of making
this work.

Commit series for by-ref OpenMP reductions: 2/3
… ref

Previously reduction variables were always passed by value into and out
of the initialization and combiner regions of the OpenMP reduction
declare operation.

This worked well for reductions of primitive types (and might perform
better than passing by reference). But passing by reference will be
useful for array and derived type reductions (e.g. to move allocation
inside of the init region).

Passing reductions by reference requires different LLVM-IR generation
when lowering from MLIR because some of the loads/stores/allocations will
now be moved inside of the init and combiner regions. This alternate
code generation is requested using a new attribute to omp.wsloop and
omp.parallel.

Existing lowerings from mlir are unaffected (these will continue to use
the by-value argument passing.

Flang will continue to pass by-value argument passing for trivial types
unless a (hidden) command line argument is supplied. Non-trivial types
will always use the by-ref lowering.

Array reductions are not ready yet (but are coming very soon). In the
meantime, this is tested by forcing existing reductions to use by-ref.

Commit series for by-ref OpenMP reductions 3/3

Co-authored-by: Mats Petersson <mats.petersson@arm.com>
@llvmbot llvmbot added mlir:llvm mlir flang Flang issues not falling into any other category mlir:openmp flang:fir-hlfir flang:openmp clang:openmp OpenMP related changes to Clang labels Mar 7, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 7, 2024

@llvm/pr-subscribers-flang-codegen
@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-flang-fir-hlfir

Author: Tom Eccles (tblah)

Changes

Previously reduction variables were always passed by value into and out of the initialization and combiner regions of the OpenMP reduction declare operation.

This worked well for reductions of primitive types (and might perform better than passing by reference). But passing by reference will be useful for array and derived type reductions (e.g. to move allocation inside of the init region).

Passing reductions by reference requires different LLVM-IR generation when lowering from MLIR because some of the loads/stores/allocations will now be moved inside of the init and combiner regions. This alternate code generation is requested using a new attribute to omp.wsloop and omp.parallel.

Existing lowerings from mlir are unaffected (these will continue to use the by-value argument passing.

Flang will continue to pass by-value argument passing for trivial types unless a (hidden) command line argument is supplied. Non-trivial types will always use the by-ref lowering.

Array reductions are not ready yet (but are coming very soon). In the meantime, this is tested by forcing existing reductions to use by-ref.

Commit series for by-ref OpenMP reductions 3/3


Patch is 297.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84304.diff

37 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+13-5)
  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.cpp (+93-27)
  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.h (+13-5)
  • (added) flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90 (+117)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-add-byref.f90 (+392)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-iand-byref.f90 (+46)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ieor-byref.f90 (+45)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ior-byref.f90 (+45)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-eqv-byref.f90 (+187)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-neqv-byref.f90 (+189)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-max-byref.f90 (+90)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-min-byref.f90 (+91)
  • (added) flang/test/Lower/OpenMP/default-clause-byref.f90 (+385)
  • (added) flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 (+30)
  • (added) flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90 (+125)
  • (added) flang/test/Lower/OpenMP/parallel-reduction-byref.f90 (+44)
  • (added) flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90 (+16)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-add-byref.f90 (+433)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir-byref.f90 (+58)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-iand-byref.f90 (+64)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-ieor-byref.f90 (+55)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-ior-byref.f90 (+64)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-logical-and-byref.f90 (+206)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv-byref.f90 (+202)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv-byref.f90 (+207)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-logical-or-byref.f90 (+204)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-max-2-byref.f90 (+20)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-max-byref.f90 (+152)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir-byref.f90 (+62)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-min-byref.f90 (+154)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-mul-byref.f90 (+414)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+3-1)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+22-8)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+11-1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+3-2)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+72-27)
  • (added) mlir/test/Target/LLVMIR/openmp-reduction-byref.mlir (+66)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 185e0316870e94..d9648f3d692cc6 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -600,6 +600,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
     return reductionSymbols;
   };
 
+  mlir::UnitAttr byrefAttr;
+  if (ReductionProcessor::doReductionByRef(reductionVars))
+    byrefAttr = converter.getFirOpBuilder().getUnitAttr();
+
   OpWithBodyGenInfo genInfo =
       OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
           .setGenNested(genNested)
@@ -619,7 +623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
             : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
                                    reductionDeclSymbols),
         procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
-        /*privatizers=*/nullptr);
+        /*privatizers=*/nullptr, byrefAttr);
   }
 
   bool privatize = !outerCombined;
@@ -683,7 +687,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
       delayedPrivatizationInfo.privatizers.empty()
           ? nullptr
           : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
-                                 privatizers));
+                                 privatizers),
+      byrefAttr);
 }
 
 static mlir::omp::SectionOp
@@ -1568,7 +1573,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
   mlir::omp::ClauseOrderKindAttr orderClauseOperand;
   mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
-  mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
+  mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
   mlir::IntegerAttr orderedClauseOperand;
   mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
   std::size_t loopVarTypeSize;
@@ -1585,6 +1590,9 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
   convertLoopBounds(converter, loc, lowerBound, upperBound, step,
                     loopVarTypeSize);
 
+  if (ReductionProcessor::doReductionByRef(reductionVars))
+    byrefOperand = firOpBuilder.getUnitAttr();
+
   auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
       loc, lowerBound, upperBound, step, linearVars, linearStepVars,
       reductionVars,
@@ -1594,8 +1602,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
                                  reductionDeclSymbols),
       scheduleValClauseOperand, scheduleChunkClauseOperand,
       /*schedule_modifiers=*/nullptr,
-      /*simd_modifier=*/nullptr, nowaitClauseOperand, orderedClauseOperand,
-      orderClauseOperand,
+      /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
+      orderedClauseOperand, orderClauseOperand,
       /*inclusive=*/firOpBuilder.getUnitAttr());
 
   // Handle attribute based clauses.
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index a8b98f3f567249..34959e95dce04a 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -14,9 +14,16 @@
 
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Parser/tools.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "llvm/Support/CommandLine.h"
+
+static llvm::cl::opt<bool> forceByrefReduction(
+    "force-byref-reduction",
+    llvm::cl::desc("Pass all reduction arguments by reference"),
+    llvm::cl::Hidden);
 
 namespace Fortran {
 namespace lower {
@@ -76,16 +83,24 @@ bool ReductionProcessor::supportedIntrinsicProcReduction(
 }
 
 std::string ReductionProcessor::getReductionName(llvm::StringRef name,
-                                                 mlir::Type ty) {
+                                                 mlir::Type ty, bool isByRef) {
+  ty = fir::unwrapRefType(ty);
+
+  // extra string to distinguish reduction functions for variables passed by
+  // reference
+  llvm::StringRef byrefAddition{""};
+  if (isByRef)
+    byrefAddition = "_byref";
+
   return (llvm::Twine(name) +
           (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
-          llvm::Twine(ty.getIntOrFloatBitWidth()))
+          llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
       .str();
 }
 
 std::string ReductionProcessor::getReductionName(
     Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-    mlir::Type ty) {
+    mlir::Type ty, bool isByRef) {
   std::string reductionName;
 
   switch (intrinsicOp) {
@@ -108,13 +123,14 @@ std::string ReductionProcessor::getReductionName(
     break;
   }
 
-  return getReductionName(reductionName, ty);
+  return getReductionName(reductionName, ty, isByRef);
 }
 
 mlir::Value
 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
                                           ReductionIdentifier redId,
                                           fir::FirOpBuilder &builder) {
+  type = fir::unwrapRefType(type);
   assert((fir::isa_integer(type) || fir::isa_real(type) ||
           type.isa<fir::LogicalType>()) &&
          "only integer, logical and real types are currently supported");
@@ -188,6 +204,7 @@ mlir::Value ReductionProcessor::createScalarCombiner(
     fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
     mlir::Type type, mlir::Value op1, mlir::Value op2) {
   mlir::Value reductionOp;
+  type = fir::unwrapRefType(type);
   switch (redId) {
   case ReductionIdentifier::MAX:
     reductionOp =
@@ -268,7 +285,8 @@ mlir::Value ReductionProcessor::createScalarCombiner(
 
 mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
     fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-    const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
+    const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
+    bool isByRef) {
   mlir::OpBuilder::InsertionGuard guard(builder);
   mlir::ModuleOp module = builder.getModule();
 
@@ -278,14 +296,24 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
     return decl;
 
   mlir::OpBuilder modBuilder(module.getBodyRegion());
+  mlir::Type valTy = fir::unwrapRefType(type);
+  if (!isByRef)
+    type = valTy;
 
   decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
                                                           type);
   builder.createBlock(&decl.getInitializerRegion(),
                       decl.getInitializerRegion().end(), {type}, {loc});
   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+
   mlir::Value init = getReductionInitValue(loc, type, redId, builder);
-  builder.create<mlir::omp::YieldOp>(loc, init);
+  if (isByRef) {
+    mlir::Value alloca = builder.create<fir::AllocaOp>(loc, valTy);
+    builder.createStoreWithConvert(loc, init, alloca);
+    builder.create<mlir::omp::YieldOp>(loc, alloca);
+  } else {
+    builder.create<mlir::omp::YieldOp>(loc, init);
+  }
 
   builder.createBlock(&decl.getReductionRegion(),
                       decl.getReductionRegion().end(), {type, type},
@@ -294,14 +322,41 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+  mlir::Value outAddr = op1;
+
+  op1 = builder.loadIfRef(loc, op1);
+  op2 = builder.loadIfRef(loc, op2);
 
   mlir::Value reductionOp =
       createScalarCombiner(builder, loc, redId, type, op1, op2);
-  builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+  if (isByRef) {
+    builder.create<fir::StoreOp>(loc, reductionOp, outAddr);
+    builder.create<mlir::omp::YieldOp>(loc, outAddr);
+  } else {
+    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+  }
 
   return decl;
 }
 
+bool ReductionProcessor::doReductionByRef(
+    const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
+  if (reductionVars.empty())
+    return false;
+  if (forceByrefReduction)
+    return true;
+
+  for (mlir::Value reductionVar : reductionVars) {
+    if (auto declare =
+            mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
+      reductionVar = declare.getMemref();
+
+    if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
+      return true;
+  }
+  return false;
+}
+
 void ReductionProcessor::addReductionDecl(
     mlir::Location currentLocation,
     Fortran::lower::AbstractConverter &converter,
@@ -315,6 +370,24 @@ void ReductionProcessor::addReductionDecl(
   const auto &redOperator{
       std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
   const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+
+  // initial pass to collect all recuction vars so we can figure out if this
+  // should happen byref
+  for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+    if (const auto *name{
+            Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+      if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+        if (reductionSymbols)
+          reductionSymbols->push_back(symbol);
+        mlir::Value symVal = converter.getSymbolAddress(*symbol);
+        if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+          symVal = declOp.getBase();
+        reductionVars.push_back(symVal);
+      }
+    }
+  }
+  const bool isByRef = doReductionByRef(reductionVars);
+
   if (const auto &redDefinedOp =
           std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
     const auto &intrinsicOp{
@@ -338,23 +411,20 @@ void ReductionProcessor::addReductionDecl(
       if (const auto *name{
               Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
         if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-          if (reductionSymbols)
-            reductionSymbols->push_back(symbol);
           mlir::Value symVal = converter.getSymbolAddress(*symbol);
           if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
             symVal = declOp.getBase();
-          mlir::Type redType =
-              symVal.getType().cast<fir::ReferenceType>().getEleTy();
-          reductionVars.push_back(symVal);
-          if (redType.isa<fir::LogicalType>())
+          auto redType = symVal.getType().cast<fir::ReferenceType>();
+          if (redType.getEleTy().isa<fir::LogicalType>())
             decl = createReductionDecl(
                 firOpBuilder,
-                getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
-                redType, currentLocation);
-          else if (redType.isIntOrIndexOrFloat()) {
-            decl = createReductionDecl(firOpBuilder,
-                                       getReductionName(intrinsicOp, redType),
-                                       redId, redType, currentLocation);
+                getReductionName(intrinsicOp, firOpBuilder.getI1Type(),
+                                 isByRef),
+                redId, redType, currentLocation, isByRef);
+          else if (redType.getEleTy().isIntOrIndexOrFloat()) {
+            decl = createReductionDecl(
+                firOpBuilder, getReductionName(intrinsicOp, redType, isByRef),
+                redId, redType, currentLocation, isByRef);
           } else {
             TODO(currentLocation, "Reduction of some types is not supported");
           }
@@ -374,21 +444,17 @@ void ReductionProcessor::addReductionDecl(
         if (const auto *name{
                 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
           if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-            if (reductionSymbols)
-              reductionSymbols->push_back(symbol);
             mlir::Value symVal = converter.getSymbolAddress(*symbol);
             if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
               symVal = declOp.getBase();
-            mlir::Type redType =
-                symVal.getType().cast<fir::ReferenceType>().getEleTy();
-            reductionVars.push_back(symVal);
-            assert(redType.isIntOrIndexOrFloat() &&
+            auto redType = symVal.getType().cast<fir::ReferenceType>();
+            assert(redType.getEleTy().isIntOrIndexOrFloat() &&
                    "Unsupported reduction type");
             decl = createReductionDecl(
                 firOpBuilder,
                 getReductionName(getRealName(*reductionIntrinsic).ToString(),
-                                 redType),
-                redId, redType, currentLocation);
+                                 redType, isByRef),
+                redId, redType, currentLocation, isByRef);
             reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
                 firOpBuilder.getContext(), decl.getSymName()));
           }
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 00770fe81d1ef6..679580f2a3cac7 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -14,6 +14,7 @@
 #define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
 
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/type.h"
@@ -71,11 +72,15 @@ class ReductionProcessor {
   static const Fortran::semantics::SourceName
   getRealName(const Fortran::parser::ProcedureDesignator &pd);
 
-  static std::string getReductionName(llvm::StringRef name, mlir::Type ty);
+  static bool
+  doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);
+
+  static std::string getReductionName(llvm::StringRef name, mlir::Type ty,
+                                      bool isByRef);
 
   static std::string getReductionName(
       Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Type ty);
+      mlir::Type ty, bool isByRef);
 
   /// This function returns the identity value of the operator \p
   /// reductionOpName. For example:
@@ -103,9 +108,11 @@ class ReductionProcessor {
   /// symbol table. The declaration has a constant initializer with the neutral
   /// value `initValue`, and the reduction combiner carried over from `reduce`.
   /// TODO: Generalize this for non-integer types, add atomic region.
-  static mlir::omp::ReductionDeclareOp createReductionDecl(
-      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-      const ReductionIdentifier redId, mlir::Type type, mlir::Location loc);
+  static mlir::omp::ReductionDeclareOp
+  createReductionDecl(fir::FirOpBuilder &builder,
+                      llvm::StringRef reductionOpName,
+                      const ReductionIdentifier redId, mlir::Type type,
+                      mlir::Location loc, bool isByRef);
 
   /// Creates a reduction declaration and associates it with an OpenMP block
   /// directive.
@@ -124,6 +131,7 @@ mlir::Value
 ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
                                           mlir::Type type, mlir::Location loc,
                                           mlir::Value op1, mlir::Value op2) {
+  type = fir::unwrapRefType(type);
   assert(type.isIntOrIndexOrFloat() &&
          "only integer and float types are currently supported");
   if (type.isIntOrIndex())
diff --git a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90 b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90
new file mode 100644
index 00000000000000..ca432662b77c44
--- /dev/null
+++ b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90
@@ -0,0 +1,117 @@
+! RUN: bbc -emit-fir -hlfir=false -fopenmp --force-byref-reduction -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -flang-deprecated-no-hlfir -fopenmp -mmlir --force-byref-reduction -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_F32_NAME:.*]] : !fir.ref<f32>
+!CHECK-SAME: init {
+!CHECK: ^bb0(%{{.*}}: !fir.ref<f32>):
+!CHECK:  %[[C0_1:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK:  %[[REF:.*]] = fir.alloca f32
+!CHECKL  fir.store [[%C0_1]] to %[[REF]] : !fir.ref<f32>
+!CHECK:  omp.yield(%[[REF]] : !fir.ref<f32>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<f32>, %[[ARG1:.*]]: !fir.ref<f32>):
+!CHECK:  %[[LD0:.*]] = fir.load %[[ARG0]] : !fir.ref<f32>
+!CHECK:  %[[LD1:.*]] = fir.load %[[ARG1]] : !fir.ref<f32>
+!CHECK:  %[[RES:.*]] = arith.addf %[[LD0]], %[[LD1]] {{.*}}: f32
+!CHECK:  fir.store %[[RES]] to %[[ARG0]] : !fir.ref<f32>
+!CHECK:  omp.yield(%[[ARG0]] : !fir.ref<f32>)
+!CHECK: }
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_I32_NAME:.*]] : !fir.ref<i32>
+!CHECK-SAME: init {
+!CHECK: ^bb0(%{{.*}}: !fir.ref<i32>):
+!CHECK:  %[[C0_1:.*]] = arith.constant 0 : i32
+!CHECK:  %[[REF:.*]] = fir.alloca i32
+!CHECKL  fir.store [[%C0_1]] to %[[REF]] : !fir.ref<i32>
+!CHECK:  omp.yield(%[[REF]] : !fir.ref<i32>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<i32>, %[[ARG1:.*]]: !fir.ref<i32>):
+!CHECK:  %[[LD0:.*]] = fir.load %[[ARG0]] : !fir.ref<i32>
+!CHECK:  %[[LD1:.*]] = fir.load %[[ARG1]] : !fir.ref<i32>
+!CHECK:  %[[RES:.*]] = arith.addi %[[LD0]], %[[LD1]] : i32
+!CHECK:  fir.store %[[RES]] to %[[ARG0]] : !fir.ref<i32>
+!CHECK:  omp.yield(%[[ARG0]] : !fir.ref<i32>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_int_add
+!CHECK:  %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
+!CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK:  fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK:  omp.parallel byref reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK:    %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
+!CHECK:    %[[I_INCR:.+]] = arith.constant 1 : i32
+!CHECK:    %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
+!CHECK:    fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine simple_int_add
+    integer :: i
+    i = 0
+
+    !$omp parallel reduction(+:i)
+    i = i + 1
+    !$omp end parallel
+
+    print *, i
+end subroutine
+
+!CHECK-LABEL: func.func @_QPsimple_real_add
+!CHECK:  %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
+!CHECK:  %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK:  fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK:  omp.parallel byref reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK:    %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
+!CHECK:    %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
+!CHECK:    %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
+!CHECK:    fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine simple_real_add
+    real :: r
+    r = 0.0
+
+    !$omp parallel reduction(+:r)
+    r = r + 1.5
+    !$omp end parallel
+
+    print *, r
+end subroutine
+
+!CHECK-LABEL: func.func @_QPint_real_add
+!CHECK:  %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFint_real_addEi"}
+!CHECK:  %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFint_real_addEr"}
+!CHECK:  %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK:  fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK:  fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK:  omp.parallel byref reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
+!CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
+!CHECK:    %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
+!CHECK:    %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
+!CHECK:    fir.store %[[RES1]] to %[[PRV1]]
+!CHECK:    %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
+!CHECK:    %[[I_INCR:.*]] = arith.constant 3 : i32
+!CHECK:    %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
+!CHECK:    fir.store %[[RES0]] t...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 7, 2024

@llvm/pr-subscribers-mlir

Author: Tom Eccles (tblah)

Changes

Previously reduction variables were always passed by value into and out of the initialization and combiner regions of the OpenMP reduction declare operation.

This worked well for reductions of primitive types (and might perform better than passing by reference). But passing by reference will be useful for array and derived type reductions (e.g. to move allocation inside of the init region).

Passing reductions by reference requires different LLVM-IR generation when lowering from MLIR because some of the loads/stores/allocations will now be moved inside of the init and combiner regions. This alternate code generation is requested using a new attribute to omp.wsloop and omp.parallel.

Existing lowerings from mlir are unaffected (these will continue to use the by-value argument passing.

Flang will continue to pass by-value argument passing for trivial types unless a (hidden) command line argument is supplied. Non-trivial types will always use the by-ref lowering.

Array reductions are not ready yet (but are coming very soon). In the meantime, this is tested by forcing existing reductions to use by-ref.

Commit series for by-ref OpenMP reductions 3/3


Patch is 297.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84304.diff

37 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+13-5)
  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.cpp (+93-27)
  • (modified) flang/lib/Lower/OpenMP/ReductionProcessor.h (+13-5)
  • (added) flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90 (+117)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-add-byref.f90 (+392)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-iand-byref.f90 (+46)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ieor-byref.f90 (+45)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ior-byref.f90 (+45)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-eqv-byref.f90 (+187)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-neqv-byref.f90 (+189)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-max-byref.f90 (+90)
  • (added) flang/test/Lower/OpenMP/FIR/wsloop-reduction-min-byref.f90 (+91)
  • (added) flang/test/Lower/OpenMP/default-clause-byref.f90 (+385)
  • (added) flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 (+30)
  • (added) flang/test/Lower/OpenMP/parallel-reduction-add-byref.f90 (+125)
  • (added) flang/test/Lower/OpenMP/parallel-reduction-byref.f90 (+44)
  • (added) flang/test/Lower/OpenMP/parallel-wsloop-reduction-byref.f90 (+16)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-add-byref.f90 (+433)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir-byref.f90 (+58)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-iand-byref.f90 (+64)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-ieor-byref.f90 (+55)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-ior-byref.f90 (+64)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-logical-and-byref.f90 (+206)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv-byref.f90 (+202)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv-byref.f90 (+207)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-logical-or-byref.f90 (+204)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-max-2-byref.f90 (+20)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-max-byref.f90 (+152)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir-byref.f90 (+62)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-min-byref.f90 (+154)
  • (added) flang/test/Lower/OpenMP/wsloop-reduction-mul-byref.f90 (+414)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+3-1)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+22-8)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+11-1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+3-2)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+72-27)
  • (added) mlir/test/Target/LLVMIR/openmp-reduction-byref.mlir (+66)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 185e0316870e94..d9648f3d692cc6 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -600,6 +600,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
     return reductionSymbols;
   };
 
+  mlir::UnitAttr byrefAttr;
+  if (ReductionProcessor::doReductionByRef(reductionVars))
+    byrefAttr = converter.getFirOpBuilder().getUnitAttr();
+
   OpWithBodyGenInfo genInfo =
       OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
           .setGenNested(genNested)
@@ -619,7 +623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
             : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
                                    reductionDeclSymbols),
         procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
-        /*privatizers=*/nullptr);
+        /*privatizers=*/nullptr, byrefAttr);
   }
 
   bool privatize = !outerCombined;
@@ -683,7 +687,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
       delayedPrivatizationInfo.privatizers.empty()
           ? nullptr
           : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
-                                 privatizers));
+                                 privatizers),
+      byrefAttr);
 }
 
 static mlir::omp::SectionOp
@@ -1568,7 +1573,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
   mlir::omp::ClauseOrderKindAttr orderClauseOperand;
   mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
-  mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
+  mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
   mlir::IntegerAttr orderedClauseOperand;
   mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
   std::size_t loopVarTypeSize;
@@ -1585,6 +1590,9 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
   convertLoopBounds(converter, loc, lowerBound, upperBound, step,
                     loopVarTypeSize);
 
+  if (ReductionProcessor::doReductionByRef(reductionVars))
+    byrefOperand = firOpBuilder.getUnitAttr();
+
   auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
       loc, lowerBound, upperBound, step, linearVars, linearStepVars,
       reductionVars,
@@ -1594,8 +1602,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
                                  reductionDeclSymbols),
       scheduleValClauseOperand, scheduleChunkClauseOperand,
       /*schedule_modifiers=*/nullptr,
-      /*simd_modifier=*/nullptr, nowaitClauseOperand, orderedClauseOperand,
-      orderClauseOperand,
+      /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
+      orderedClauseOperand, orderClauseOperand,
       /*inclusive=*/firOpBuilder.getUnitAttr());
 
   // Handle attribute based clauses.
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index a8b98f3f567249..34959e95dce04a 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -14,9 +14,16 @@
 
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Parser/tools.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "llvm/Support/CommandLine.h"
+
+static llvm::cl::opt<bool> forceByrefReduction(
+    "force-byref-reduction",
+    llvm::cl::desc("Pass all reduction arguments by reference"),
+    llvm::cl::Hidden);
 
 namespace Fortran {
 namespace lower {
@@ -76,16 +83,24 @@ bool ReductionProcessor::supportedIntrinsicProcReduction(
 }
 
 std::string ReductionProcessor::getReductionName(llvm::StringRef name,
-                                                 mlir::Type ty) {
+                                                 mlir::Type ty, bool isByRef) {
+  ty = fir::unwrapRefType(ty);
+
+  // extra string to distinguish reduction functions for variables passed by
+  // reference
+  llvm::StringRef byrefAddition{""};
+  if (isByRef)
+    byrefAddition = "_byref";
+
   return (llvm::Twine(name) +
           (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
-          llvm::Twine(ty.getIntOrFloatBitWidth()))
+          llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
       .str();
 }
 
 std::string ReductionProcessor::getReductionName(
     Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-    mlir::Type ty) {
+    mlir::Type ty, bool isByRef) {
   std::string reductionName;
 
   switch (intrinsicOp) {
@@ -108,13 +123,14 @@ std::string ReductionProcessor::getReductionName(
     break;
   }
 
-  return getReductionName(reductionName, ty);
+  return getReductionName(reductionName, ty, isByRef);
 }
 
 mlir::Value
 ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
                                           ReductionIdentifier redId,
                                           fir::FirOpBuilder &builder) {
+  type = fir::unwrapRefType(type);
   assert((fir::isa_integer(type) || fir::isa_real(type) ||
           type.isa<fir::LogicalType>()) &&
          "only integer, logical and real types are currently supported");
@@ -188,6 +204,7 @@ mlir::Value ReductionProcessor::createScalarCombiner(
     fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
     mlir::Type type, mlir::Value op1, mlir::Value op2) {
   mlir::Value reductionOp;
+  type = fir::unwrapRefType(type);
   switch (redId) {
   case ReductionIdentifier::MAX:
     reductionOp =
@@ -268,7 +285,8 @@ mlir::Value ReductionProcessor::createScalarCombiner(
 
 mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
     fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-    const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
+    const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
+    bool isByRef) {
   mlir::OpBuilder::InsertionGuard guard(builder);
   mlir::ModuleOp module = builder.getModule();
 
@@ -278,14 +296,24 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
     return decl;
 
   mlir::OpBuilder modBuilder(module.getBodyRegion());
+  mlir::Type valTy = fir::unwrapRefType(type);
+  if (!isByRef)
+    type = valTy;
 
   decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
                                                           type);
   builder.createBlock(&decl.getInitializerRegion(),
                       decl.getInitializerRegion().end(), {type}, {loc});
   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+
   mlir::Value init = getReductionInitValue(loc, type, redId, builder);
-  builder.create<mlir::omp::YieldOp>(loc, init);
+  if (isByRef) {
+    mlir::Value alloca = builder.create<fir::AllocaOp>(loc, valTy);
+    builder.createStoreWithConvert(loc, init, alloca);
+    builder.create<mlir::omp::YieldOp>(loc, alloca);
+  } else {
+    builder.create<mlir::omp::YieldOp>(loc, init);
+  }
 
   builder.createBlock(&decl.getReductionRegion(),
                       decl.getReductionRegion().end(), {type, type},
@@ -294,14 +322,41 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+  mlir::Value outAddr = op1;
+
+  op1 = builder.loadIfRef(loc, op1);
+  op2 = builder.loadIfRef(loc, op2);
 
   mlir::Value reductionOp =
       createScalarCombiner(builder, loc, redId, type, op1, op2);
-  builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+  if (isByRef) {
+    builder.create<fir::StoreOp>(loc, reductionOp, outAddr);
+    builder.create<mlir::omp::YieldOp>(loc, outAddr);
+  } else {
+    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+  }
 
   return decl;
 }
 
+bool ReductionProcessor::doReductionByRef(
+    const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
+  if (reductionVars.empty())
+    return false;
+  if (forceByrefReduction)
+    return true;
+
+  for (mlir::Value reductionVar : reductionVars) {
+    if (auto declare =
+            mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
+      reductionVar = declare.getMemref();
+
+    if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
+      return true;
+  }
+  return false;
+}
+
 void ReductionProcessor::addReductionDecl(
     mlir::Location currentLocation,
     Fortran::lower::AbstractConverter &converter,
@@ -315,6 +370,24 @@ void ReductionProcessor::addReductionDecl(
   const auto &redOperator{
       std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
   const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+
+  // initial pass to collect all recuction vars so we can figure out if this
+  // should happen byref
+  for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+    if (const auto *name{
+            Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+      if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+        if (reductionSymbols)
+          reductionSymbols->push_back(symbol);
+        mlir::Value symVal = converter.getSymbolAddress(*symbol);
+        if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+          symVal = declOp.getBase();
+        reductionVars.push_back(symVal);
+      }
+    }
+  }
+  const bool isByRef = doReductionByRef(reductionVars);
+
   if (const auto &redDefinedOp =
           std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
     const auto &intrinsicOp{
@@ -338,23 +411,20 @@ void ReductionProcessor::addReductionDecl(
       if (const auto *name{
               Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
         if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-          if (reductionSymbols)
-            reductionSymbols->push_back(symbol);
           mlir::Value symVal = converter.getSymbolAddress(*symbol);
           if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
             symVal = declOp.getBase();
-          mlir::Type redType =
-              symVal.getType().cast<fir::ReferenceType>().getEleTy();
-          reductionVars.push_back(symVal);
-          if (redType.isa<fir::LogicalType>())
+          auto redType = symVal.getType().cast<fir::ReferenceType>();
+          if (redType.getEleTy().isa<fir::LogicalType>())
             decl = createReductionDecl(
                 firOpBuilder,
-                getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
-                redType, currentLocation);
-          else if (redType.isIntOrIndexOrFloat()) {
-            decl = createReductionDecl(firOpBuilder,
-                                       getReductionName(intrinsicOp, redType),
-                                       redId, redType, currentLocation);
+                getReductionName(intrinsicOp, firOpBuilder.getI1Type(),
+                                 isByRef),
+                redId, redType, currentLocation, isByRef);
+          else if (redType.getEleTy().isIntOrIndexOrFloat()) {
+            decl = createReductionDecl(
+                firOpBuilder, getReductionName(intrinsicOp, redType, isByRef),
+                redId, redType, currentLocation, isByRef);
           } else {
             TODO(currentLocation, "Reduction of some types is not supported");
           }
@@ -374,21 +444,17 @@ void ReductionProcessor::addReductionDecl(
         if (const auto *name{
                 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
           if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-            if (reductionSymbols)
-              reductionSymbols->push_back(symbol);
             mlir::Value symVal = converter.getSymbolAddress(*symbol);
             if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
               symVal = declOp.getBase();
-            mlir::Type redType =
-                symVal.getType().cast<fir::ReferenceType>().getEleTy();
-            reductionVars.push_back(symVal);
-            assert(redType.isIntOrIndexOrFloat() &&
+            auto redType = symVal.getType().cast<fir::ReferenceType>();
+            assert(redType.getEleTy().isIntOrIndexOrFloat() &&
                    "Unsupported reduction type");
             decl = createReductionDecl(
                 firOpBuilder,
                 getReductionName(getRealName(*reductionIntrinsic).ToString(),
-                                 redType),
-                redId, redType, currentLocation);
+                                 redType, isByRef),
+                redId, redType, currentLocation, isByRef);
             reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
                 firOpBuilder.getContext(), decl.getSymName()));
           }
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 00770fe81d1ef6..679580f2a3cac7 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -14,6 +14,7 @@
 #define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
 
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/type.h"
@@ -71,11 +72,15 @@ class ReductionProcessor {
   static const Fortran::semantics::SourceName
   getRealName(const Fortran::parser::ProcedureDesignator &pd);
 
-  static std::string getReductionName(llvm::StringRef name, mlir::Type ty);
+  static bool
+  doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);
+
+  static std::string getReductionName(llvm::StringRef name, mlir::Type ty,
+                                      bool isByRef);
 
   static std::string getReductionName(
       Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Type ty);
+      mlir::Type ty, bool isByRef);
 
   /// This function returns the identity value of the operator \p
   /// reductionOpName. For example:
@@ -103,9 +108,11 @@ class ReductionProcessor {
   /// symbol table. The declaration has a constant initializer with the neutral
   /// value `initValue`, and the reduction combiner carried over from `reduce`.
   /// TODO: Generalize this for non-integer types, add atomic region.
-  static mlir::omp::ReductionDeclareOp createReductionDecl(
-      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-      const ReductionIdentifier redId, mlir::Type type, mlir::Location loc);
+  static mlir::omp::ReductionDeclareOp
+  createReductionDecl(fir::FirOpBuilder &builder,
+                      llvm::StringRef reductionOpName,
+                      const ReductionIdentifier redId, mlir::Type type,
+                      mlir::Location loc, bool isByRef);
 
   /// Creates a reduction declaration and associates it with an OpenMP block
   /// directive.
@@ -124,6 +131,7 @@ mlir::Value
 ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
                                           mlir::Type type, mlir::Location loc,
                                           mlir::Value op1, mlir::Value op2) {
+  type = fir::unwrapRefType(type);
   assert(type.isIntOrIndexOrFloat() &&
          "only integer and float types are currently supported");
   if (type.isIntOrIndex())
diff --git a/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90 b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90
new file mode 100644
index 00000000000000..ca432662b77c44
--- /dev/null
+++ b/flang/test/Lower/OpenMP/FIR/parallel-reduction-add-byref.f90
@@ -0,0 +1,117 @@
+! RUN: bbc -emit-fir -hlfir=false -fopenmp --force-byref-reduction -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -flang-deprecated-no-hlfir -fopenmp -mmlir --force-byref-reduction -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_F32_NAME:.*]] : !fir.ref<f32>
+!CHECK-SAME: init {
+!CHECK: ^bb0(%{{.*}}: !fir.ref<f32>):
+!CHECK:  %[[C0_1:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK:  %[[REF:.*]] = fir.alloca f32
+!CHECKL  fir.store [[%C0_1]] to %[[REF]] : !fir.ref<f32>
+!CHECK:  omp.yield(%[[REF]] : !fir.ref<f32>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<f32>, %[[ARG1:.*]]: !fir.ref<f32>):
+!CHECK:  %[[LD0:.*]] = fir.load %[[ARG0]] : !fir.ref<f32>
+!CHECK:  %[[LD1:.*]] = fir.load %[[ARG1]] : !fir.ref<f32>
+!CHECK:  %[[RES:.*]] = arith.addf %[[LD0]], %[[LD1]] {{.*}}: f32
+!CHECK:  fir.store %[[RES]] to %[[ARG0]] : !fir.ref<f32>
+!CHECK:  omp.yield(%[[ARG0]] : !fir.ref<f32>)
+!CHECK: }
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_I32_NAME:.*]] : !fir.ref<i32>
+!CHECK-SAME: init {
+!CHECK: ^bb0(%{{.*}}: !fir.ref<i32>):
+!CHECK:  %[[C0_1:.*]] = arith.constant 0 : i32
+!CHECK:  %[[REF:.*]] = fir.alloca i32
+!CHECKL  fir.store [[%C0_1]] to %[[REF]] : !fir.ref<i32>
+!CHECK:  omp.yield(%[[REF]] : !fir.ref<i32>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<i32>, %[[ARG1:.*]]: !fir.ref<i32>):
+!CHECK:  %[[LD0:.*]] = fir.load %[[ARG0]] : !fir.ref<i32>
+!CHECK:  %[[LD1:.*]] = fir.load %[[ARG1]] : !fir.ref<i32>
+!CHECK:  %[[RES:.*]] = arith.addi %[[LD0]], %[[LD1]] : i32
+!CHECK:  fir.store %[[RES]] to %[[ARG0]] : !fir.ref<i32>
+!CHECK:  omp.yield(%[[ARG0]] : !fir.ref<i32>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_int_add
+!CHECK:  %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
+!CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK:  fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK:  omp.parallel byref reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
+!CHECK:    %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
+!CHECK:    %[[I_INCR:.+]] = arith.constant 1 : i32
+!CHECK:    %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
+!CHECK:    fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine simple_int_add
+    integer :: i
+    i = 0
+
+    !$omp parallel reduction(+:i)
+    i = i + 1
+    !$omp end parallel
+
+    print *, i
+end subroutine
+
+!CHECK-LABEL: func.func @_QPsimple_real_add
+!CHECK:  %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
+!CHECK:  %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK:  fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK:  omp.parallel byref reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
+!CHECK:    %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
+!CHECK:    %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
+!CHECK:    %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
+!CHECK:    fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine simple_real_add
+    real :: r
+    r = 0.0
+
+    !$omp parallel reduction(+:r)
+    r = r + 1.5
+    !$omp end parallel
+
+    print *, r
+end subroutine
+
+!CHECK-LABEL: func.func @_QPint_real_add
+!CHECK:  %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFint_real_addEi"}
+!CHECK:  %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFint_real_addEr"}
+!CHECK:  %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK:  fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK:  %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK:  fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK:  omp.parallel byref reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
+!CHECK:    %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
+!CHECK:    %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
+!CHECK:    %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
+!CHECK:    fir.store %[[RES1]] to %[[PRV1]]
+!CHECK:    %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
+!CHECK:    %[[I_INCR:.*]] = arith.constant 3 : i32
+!CHECK:    %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
+!CHECK:    fir.store %[[RES0]] t...
[truncated]

@tblah
Copy link
Contributor Author

tblah commented Mar 7, 2024

co-authored with @Leporacanthicus (github seems to have taken the tag out of the commit message but shows it in the UI)

Base automatically changed from users/tblah/omp_byref_2 to main March 8, 2024 10:34
Copy link
Contributor

@Leporacanthicus Leporacanthicus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

reductionVar = declare.getMemref();

if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
return true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this imply that all reductions on a clause have to be by ref or by val? E.g. if we have an array reduction on the clause does that mean an integer reduction also changes to byref?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it does. I did this to keep things simpler. Currently byref vs byval is toggled over the whole wsloop or parallel region. A more sophisticated implementation could instead track this per reduction argument. I chose not to do this to keep things simple.

I suspect that in most cases, if an integer reduction and an array reduction are used together, the array reduction would take long enough that the performance loss from doing the integer reduction by reference would not be significant. I have not measured this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's probably true wrt performance, but I believe when it comes to openmp target reductions doing reductions on basic types will affect correctness. As far as I remember you do not need to ensure manually that for example the INTEGER exists on the target device, whereas you do for an array (and would need to if the integer is passed by reference).

I think fixing that in a subsequent patch is probably fine though, as long as we add a TODO mentioning that ideally it should be considered separately per argument.

This makes sure that the generated IR doesn't change as a result of this
PR. The generated IR looks wrong to me (no reduction is generated at
all), but that is a matter for another patch.
Copy link
Member

@DavidTruby DavidTruby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tblah tblah merged commit f46f5a0 into main Mar 13, 2024
6 checks passed
@tblah tblah deleted the users/tblah/omp_byref_3 branch March 13, 2024 14:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:openmp OpenMP related changes to Clang flang:codegen flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:llvm mlir:openmp mlir openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants