Skip to content

Conversation

@jsjodin
Copy link
Contributor

@jsjodin jsjodin commented Nov 17, 2025

This patch add support for lowering of custom reductions to MLIR. It also enhances the capability of the pass to automatically mark functions as "declare target" by traversing custom reduction initializers and combiners.

@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2025

@llvm/pr-subscribers-offload
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Author: Jan Leyonberg (jsjodin)

Changes

This patch add support for lowering of custom reductions to MLIR. It also enhances the capability of the pass to automatically mark functions as "declare target" by traversing custom reduction initializers and combiners.


Patch is 49.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168417.diff

14 Files Affected:

  • (modified) flang/include/flang/Lower/Support/ReductionProcessor.h (+18)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+60)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+4)
  • (modified) flang/lib/Lower/OpenMP/Clauses.cpp (+16-1)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+130-3)
  • (modified) flang/lib/Lower/Support/ReductionProcessor.cpp (+65-25)
  • (modified) flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp (+102-37)
  • (removed) flang/test/Lower/OpenMP/Todo/omp-declare-reduction-initsub.f90 (-28)
  • (removed) flang/test/Lower/OpenMP/Todo/omp-declare-reduction.f90 (-10)
  • (added) flang/test/Lower/OpenMP/declare-target-deferred-marking-reductions.f90 (+37)
  • (added) flang/test/Lower/OpenMP/omp-declare-reduction-derivedtype.f90 (+112)
  • (added) flang/test/Lower/OpenMP/omp-declare-reduction-initsub.f90 (+59)
  • (added) flang/test/Lower/OpenMP/omp-declare-reduction.f90 (+33)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+13-3)
diff --git a/flang/include/flang/Lower/Support/ReductionProcessor.h b/flang/include/flang/Lower/Support/ReductionProcessor.h
index 66f26b3b55630..bd0447360f089 100644
--- a/flang/include/flang/Lower/Support/ReductionProcessor.h
+++ b/flang/include/flang/Lower/Support/ReductionProcessor.h
@@ -40,6 +40,13 @@ namespace omp {
 
 class ReductionProcessor {
 public:
+  using GenInitValueCBTy =
+      std::function<mlir::Value(fir::FirOpBuilder &builder, mlir::Location loc,
+                                mlir::Type type, mlir::Value ompOrig)>;
+  using GenCombinerCBTy = std::function<void(
+      fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
+      mlir::Value op1, mlir::Value op2, bool isByRef)>;
+
   // TODO: Move this enumeration to the OpenMP dialect
   enum ReductionIdentifier {
     ID,
@@ -58,6 +65,9 @@ class ReductionProcessor {
     IEOR
   };
 
+  static bool doReductionByRef(mlir::Type reductionType);
+  static bool doReductionByRef(mlir::Value reductionVar);
+
   static ReductionIdentifier
   getReductionType(const omp::clause::ProcedureDesignator &pd);
 
@@ -109,6 +119,14 @@ class ReductionProcessor {
                                           ReductionIdentifier redId,
                                           mlir::Type type, mlir::Value op1,
                                           mlir::Value op2);
+  /// Creates an OpenMP reduction declaration and inserts it into the provided
+  /// symbol table. The init and combiner regions are generated by the callback
+  /// functions genCombinerCB and genInitValueCB.
+  template <typename DeclareRedType>
+  static DeclareRedType createDeclareReductionHelper(
+      AbstractConverter &converter, llvm::StringRef reductionOpName,
+      mlir::Type type, mlir::Location loc, bool isByRef,
+      GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB);
 
   /// Creates an OpenMP reduction declaration and inserts it into the provided
   /// symbol table. The declaration has a constant initializer with the neutral
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index e018a2d937435..fadfb29b07a28 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -13,6 +13,7 @@
 #include "ClauseProcessor.h"
 #include "Utils.h"
 
+#include "flang/Lower/ConvertCall.h"
 #include "flang/Lower/ConvertExprToHLFIR.h"
 #include "flang/Lower/OpenMP/Clauses.h"
 #include "flang/Lower/PFTBuilder.h"
@@ -402,6 +403,65 @@ bool ClauseProcessor::processInclusive(
   return false;
 }
 
+bool ClauseProcessor::processInitializer(
+    lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
+    ReductionProcessor::GenInitValueCBTy &genInitValueCB) const {
+  if (auto *clause = findUniqueClause<omp::clause::Initializer>()) {
+    genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc,
+                                 mlir::Type type, mlir::Value ompOrig) {
+      lower::SymMapScope scope(symMap);
+      const parser::OmpInitializerExpression &iexpr = inp.v.v;
+      const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
+      const std::list<parser::OmpStylizedDeclaration> &declList =
+          std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t);
+      mlir::Value ompPrivVar;
+      for (const parser::OmpStylizedDeclaration &decl : declList) {
+        auto &name = std::get<parser::ObjectName>(decl.var.t);
+        assert(name.symbol && "Name does not have a symbol");
+        mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
+        fir::StoreOp::create(builder, loc, ompOrig, addr);
+        fir::FortranVariableFlagsEnum extraFlags = {};
+        fir::FortranVariableFlagsAttr attributes =
+            Fortran::lower::translateSymbolAttributes(builder.getContext(),
+                                                      *name.symbol, extraFlags);
+        auto declareOp = hlfir::DeclareOp::create(
+            builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
+            0, attributes);
+        if (name.ToString() == "omp_priv")
+          ompPrivVar = declareOp.getResult(0);
+        symMap.addVariableDefinition(*name.symbol, declareOp);
+      }
+      // Lower the expression/function call
+      lower::StatementContext stmtCtx;
+      mlir::Value result = common::visit(
+          common::visitors{
+              [&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
+                convertCallToHLFIR(loc, converter, procRef, std::nullopt,
+                                   symMap, stmtCtx);
+                auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar);
+                return privVal;
+              },
+              [&](const auto &expr) -> mlir::Value {
+                mlir::Value exprResult = fir::getBase(convertExprToValue(
+                    loc, converter, clause->v, symMap, stmtCtx));
+                // Conversion can either give a value or a refrence to a value,
+                // we need to return the reduction type, so an optional load may
+                // be generated.
+                if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
+                        exprResult.getType()))
+                  if (ompPrivVar.getType() == refType)
+                    exprResult = fir::LoadOp::create(builder, loc, exprResult);
+                return exprResult;
+              }},
+          clause->v.u);
+      stmtCtx.finalizeAndPop();
+      return result;
+    };
+    return true;
+  }
+  return false;
+}
+
 bool ClauseProcessor::processMergeable(
     mlir::omp::MergeableClauseOps &result) const {
   return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index d524b4ddc8ac4..529b871330052 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -18,6 +18,7 @@
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/OpenMP/Clauses.h"
+#include "flang/Lower/Support/ReductionProcessor.h"
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Parser/dump-parse-tree.h"
 #include "flang/Parser/parse-tree.h"
@@ -88,6 +89,9 @@ class ClauseProcessor {
   bool processHint(mlir::omp::HintClauseOps &result) const;
   bool processInclusive(mlir::Location currentLocation,
                         mlir::omp::InclusiveClauseOps &result) const;
+  bool processInitializer(
+      lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
+      ReductionProcessor::GenInitValueCBTy &genInitValueCB) const;
   bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
   bool processNogroup(mlir::omp::NogroupClauseOps &result) const;
   bool processNowait(mlir::omp::NowaitClauseOps &result) const;
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index b1a3c3d3c5439..cf8d9a7ee6596 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -981,7 +981,22 @@ Init make(const parser::OmpClause::Init &inp,
 
 Initializer make(const parser::OmpClause::Initializer &inp,
                  semantics::SemanticsContext &semaCtx) {
-  llvm_unreachable("Empty: initializer");
+  const parser::OmpInitializerExpression &iexpr = inp.v.v;
+  const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
+  const parser::OmpStylizedInstance::Instance &instance =
+      std::get<parser::OmpStylizedInstance::Instance>(styleInstance.t);
+  if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
+    auto &expr = std::get<parser::Expr>(as->t);
+    return Initializer{makeExpr(expr, semaCtx)};
+  } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
+    if (call->typedCall) {
+      const auto &procRef = *call->typedCall;
+      semantics::SomeExpr evalProcRef{procRef};
+      return Initializer{evalProcRef};
+    }
+  } else {
+    llvm_unreachable("Unexpected initializer");
+  }
 }
 
 InReduction make(const parser::OmpClause::InReduction &inp,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f822fe3c8dd71..c4b174db8ac22 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -18,12 +18,15 @@
 #include "Decomposer.h"
 #include "Utils.h"
 #include "flang/Common/idioms.h"
+#include "flang/Evaluate/type.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertExpr.h"
+#include "flang/Lower/ConvertExprToHLFIR.h"
 #include "flang/Lower/ConvertVariable.h"
 #include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/OpenMP/Clauses.h"
 #include "flang/Lower/StatementContext.h"
+#include "flang/Lower/Support/ReductionProcessor.h"
 #include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
@@ -2847,7 +2850,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   // TODO: Add private syms and vars.
   args.reduction.syms = reductionSyms;
   args.reduction.vars = clauseOps.reductionVars;
-
   return genOpWithBody<mlir::omp::TeamsOp>(
       OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
                         llvm::omp::Directive::OMPD_teams)
@@ -3563,12 +3565,137 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
     TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
 }
 
+static bool
+processReductionCombiner(lower::AbstractConverter &converter,
+                         lower::SymMap &symTable,
+                         semantics::SemanticsContext &semaCtx,
+                         const parser::OmpReductionSpecifier &specifier,
+                         ReductionProcessor::GenCombinerCBTy &genCombinerCB) {
+  const auto &combinerExpression =
+      std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
+          .value();
+  const parser::OmpStylizedInstance &combinerInstance =
+      combinerExpression.v.front();
+  const parser::OmpStylizedInstance::Instance &instance =
+    std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
+  if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
+    auto &expr = std::get<parser::Expr>(as->t);
+    genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
+                        mlir::Type type, mlir::Value lhs, mlir::Value rhs,
+                        bool isByRef) {
+      const auto &evalExpr = makeExpr(expr, semaCtx);
+      lower::SymMapScope scope(symTable);
+      const std::list<parser::OmpStylizedDeclaration> &declList =
+        std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
+      for (const parser::OmpStylizedDeclaration &decl : declList) {
+        auto &name = std::get<parser::ObjectName>(decl.var.t);
+        mlir::Value addr = lhs;
+        mlir::Type type = lhs.getType();
+        bool isRhs = name.ToString() == std::string("omp_in");
+        if (isRhs) {
+          addr = rhs;
+          type = rhs.getType();
+        }
+
+        assert(name.symbol && "Reduction object name does not have a symbol");
+        if (!fir::conformsWithPassByRef(type)) {
+            addr = builder.createTemporary(loc, type);
+            fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
+        }
+        fir::FortranVariableFlagsEnum extraFlags = {};
+        fir::FortranVariableFlagsAttr attributes =
+          Fortran::lower::translateSymbolAttributes(builder.getContext(),
+                                                    *name.symbol, extraFlags);
+        auto declareOp = hlfir::DeclareOp::create(
+            builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
+            0, attributes);
+        symTable.addVariableDefinition(*name.symbol, declareOp);
+      }
+
+      lower::StatementContext stmtCtx;
+      mlir::Value result = fir::getBase(
+          convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
+      if (auto refType =
+          llvm::dyn_cast<fir::ReferenceType>(result.getType()))
+        if (lhs.getType() == refType.getElementType())
+          result = fir::LoadOp::create(builder, loc, result);
+      stmtCtx.finalizeAndPop();
+      if (isByRef) {
+        fir::StoreOp::create(builder, loc, result, lhs);
+        mlir::omp::YieldOp::create(builder, loc, lhs);
+      } else {
+        mlir::omp::YieldOp::create(builder, loc, result);
+      }
+
+      return result;
+    };
+  }
+  return true;
+}
+
+// Getting the type from a symbol compared to a DeclSpec is simpler since we do
+// not need to consider derived vs intrinsic types. Semantics is guaranteed to
+// generate these symbols.
+static mlir::Type
+getReductionType(lower::AbstractConverter &converter,
+                 const parser::OmpReductionSpecifier &specifier) {
+  const auto &combinerExpression =
+      std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
+          .value();
+  const parser::OmpStylizedInstance &combinerInstance =
+      combinerExpression.v.front();
+  const std::list<parser::OmpStylizedDeclaration> &declList =
+      std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
+  const parser::OmpStylizedDeclaration &decl = declList.front();
+  const auto &name = std::get<parser::ObjectName>(decl.var.t);
+  const auto &symbol = semantics::SymbolRef(*name.symbol);
+  mlir::Type reductionType = converter.genType(symbol);
+  return reductionType;
+}
+
 static void genOMP(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
     const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
-  if (!semaCtx.langOptions().OpenMPSimd)
-    TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
+  if (!semaCtx.langOptions().OpenMPSimd) {
+    const parser::OmpArgumentList &args{
+        declareReductionConstruct.v.Arguments()};
+    const parser::OmpArgument &arg{args.v.front()};
+    const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
+
+    if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
+      TODO(converter.getCurrentLocation(),
+           "multiple types in declare target is not yet supported");
+
+    mlir::Type reductionType = getReductionType(converter, specifier);
+    ReductionProcessor::GenCombinerCBTy genCombinerCB;
+    processReductionCombiner(converter, symTable, semaCtx, specifier,
+                             genCombinerCB);
+    const parser::OmpClauseList &initializer =
+        declareReductionConstruct.v.Clauses();
+    if (initializer.v.size() > 0) {
+      List<Clause> clauses = makeClauses(initializer, semaCtx);
+      ReductionProcessor::GenInitValueCBTy genInitValueCB;
+      ClauseProcessor cp(converter, semaCtx, clauses);
+      const parser::OmpClause::Initializer &iclause{
+          std::get<parser::OmpClause::Initializer>(initializer.v.front().u)};
+      cp.processInitializer(symTable, iclause, genInitValueCB);
+      const auto &identifier =
+          std::get<parser::OmpReductionIdentifier>(specifier.t);
+      const auto &designator =
+          std::get<parser::ProcedureDesignator>(identifier.u);
+      const auto &reductionName = std::get<parser::Name>(designator.u);
+      bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
+      ReductionProcessor::createDeclareReductionHelper<
+          mlir::omp::DeclareReductionOp>(
+          converter, reductionName.ToString(), reductionType,
+          converter.getCurrentLocation(), isByRef, genCombinerCB,
+          genInitValueCB);
+    } else {
+      TODO(converter.getCurrentLocation(),
+           "declare target without an initializer clause is not yet supported");
+    }
+  }
 }
 
 static void
diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp
index 605a5b6b20b94..283e5ea73c319 100644
--- a/flang/lib/Lower/Support/ReductionProcessor.cpp
+++ b/flang/lib/Lower/Support/ReductionProcessor.cpp
@@ -462,7 +462,7 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
                         bool isByRef) {
   ty = fir::unwrapRefType(ty);
 
-  if (fir::isa_trivial(ty)) {
+  if (fir::isa_trivial(ty) || fir::isa_derived(ty)) {
     mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
     mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
 
@@ -501,7 +501,7 @@ static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
 template <typename OpType>
 static void createReductionAllocAndInitRegions(
     AbstractConverter &converter, mlir::Location loc, OpType &reductionDecl,
-    const ReductionProcessor::ReductionIdentifier redId, mlir::Type type,
+    ReductionProcessor::GenInitValueCBTy genInitValueCB, mlir::Type type,
     bool isByRef) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   auto yield = [&](mlir::Value ret) { genYield<OpType>(builder, loc, ret); };
@@ -523,9 +523,8 @@ static void createReductionAllocAndInitRegions(
 
   mlir::Type ty = fir::unwrapRefType(type);
   builder.setInsertionPointToEnd(initBlock);
-  mlir::Value initValue = ReductionProcessor::getReductionInitValue(
-      loc, unwrapSeqOrBoxedType(ty), redId, builder);
-
+  mlir::Value initValue =
+      genInitValueCB(builder, loc, ty, initBlock->getArgument(0));
   if (isByRef) {
     populateByRefInitAndCleanupRegions(
         converter, loc, type, initValue, initBlock,
@@ -536,7 +535,7 @@ static void createReductionAllocAndInitRegions(
         /*isDoConcurrent*/ std::is_same_v<OpType, fir::DeclareReductionOp>);
   }
 
-  if (fir::isa_trivial(ty)) {
+  if (fir::isa_trivial(ty) || fir::isa_derived(ty)) {
     if (isByRef) {
       // alloc region
       builder.setInsertionPointToEnd(allocBlock);
@@ -556,18 +555,18 @@ static void createReductionAllocAndInitRegions(
   yield(boxAlloca);
 }
 
-template <typename OpType>
-OpType ReductionProcessor::createDeclareReduction(
+template <typename DeclareRedType>
+DeclareRedType ReductionProcessor::createDeclareReductionHelper(
     AbstractConverter &converter, llvm::StringRef reductionOpName,
-    const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
-    bool isByRef) {
+    mlir::Type type, mlir::Location loc, bool isByRef,
+    GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   mlir::OpBuilder::InsertionGuard guard(builder);
   mlir::ModuleOp module = builder.getModule();
 
   assert(!reductionOpName.empty());
 
-  auto decl = module.lookupSymbol<OpType>(reductionOpName);
+  auto decl = module.lookupSymbol<DeclareRedType>(reductionOpName);
   if (decl)
     return decl;
 
@@ -576,23 +575,54 @@ OpType ReductionProcessor::createDeclareReduction(
   if (!isByRef)
     type = valTy;
 
-  decl = OpType::create(modBuilder, loc, reductionOpName, type);
-  createReductionAllocAndInitRegions(converter, loc, decl, redId, type,
+  decl = DeclareRedType::create(modBuilder, loc, reductionOpName, type);
+  createReductionAllocAndInitRegions(converter, loc, decl, genInitValueCB, type,
                                      isByRef);
-
   builder.createBlock(&decl.getReductionRegion(),
                       decl.getReductionRegion().end(), {type, type},
                       {loc, loc});
-
   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-  genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
-
+  genCombinerCB(builder, loc, type, op1, op2, isByRef);
   return decl;
 }
 
-static bool doReductionByRef(mlir::Value reductionVar) {
+template <typename OpType>
+OpType ReductionProcessor::createDeclareReduction(
+    AbstractConverter &converter, llvm::StringRef reductionOpName,
+    const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
+    bool isByRef) {
+  auto genInitValueCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
+                            mlir::Type type, mlir::Value val) {
+    mlir::Type ty = fir::unwrapRefType(type);
+    mlir::Value initValue = ReductionProcessor::getReductionInitValue(
+        loc, unwrapSeqOrBoxedType(ty), redId, builder);...
[truncated]

@github-actions
Copy link

github-actions bot commented Nov 17, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@github-actions
Copy link

github-actions bot commented Nov 17, 2025

🐧 Linux x64 Test Results

  • 7362 tests passed
  • 595 tests skipped

Copy link
Contributor

@agozillon agozillon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR LGTM, I left a couple of nits but they're mainly questions, only do them if they make sense :-) Otherwise, if it's possible yet, would it be possible to add some runtime tests to the offload/fortran directory? Just to make sure everything remains working as intended from end to end!


bool ReductionProcessor::doReductionByRef(mlir::Value reductionVar) {
if (forceByrefReduction)
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIt: Might not be possible or make sense, but since we call doRecutionByRef that checks the same thing, is it possible to merge it into just one check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about skipping the early return but decided to keep it. It seemed to make it more obvious that it always returns true if the flag it set.

Copy link
Member

@ergawy ergawy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jan, took an initial look. Just have a few comments.

@jsjodin jsjodin requested a review from ergawy November 18, 2025 20:22
Copy link
Member

@ergawy ergawy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks Jan!

@jsjodin jsjodin force-pushed the jleyonberg/custom-reductions-up branch from 50c7cf9 to 1d14774 Compare November 20, 2025 15:38
Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for contributing this. The overall design looks good to me but I think we need to do more to correctly init and de-init derived types.

auto &name = std::get<parser::ObjectName>(decl.var.t);
assert(name.symbol && "Name does not have a symbol");
mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
fir::StoreOp::create(builder, loc, ompOrig, addr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the guarantee that ompOrig is a primitive type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't have to be primitive. It can be a derived type, as long as all the data can be contained in one unit. I would like to restrict the current PR to these simpler types. I can add a check that the reduction type to not include boxed types/references and put a TODO if that acceptable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay yeah that works for me. Thanks!

!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: [[MAXTYPE]]):
!CHECK: %[[OMP_PRIV:.*]] = fir.alloca [[MAXTYPE]]
!CHECK: %[[OMP_ORIG:.*]] = fir.alloca [[MAXTYPE]]
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<[[MAXTYPE]]>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised that a simple store works for non-scalar types.

But I think we need more than this. Derived types can have initialisers and destructors (as can members of the derived type...). I suggest you use the init region generation for privatization (which correctly handles this) and then add the user initialization code after it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for the general case this is not enough, but for the simpler types addressed in this PR it works. We are not able to generate code for the more complicated types for target offloading yet since they result in library calls that cause link errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand this PR, this will also apply to CPU reductions (thank you for implementing this!). Please could you add a test to make sure a todo message is generated for these types that need library calls instead of a misscompile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the TODO and a test to limit the reduction type.

@jsjodin jsjodin requested a review from tblah November 20, 2025 17:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants