Skip to content

Conversation

@NouTimbaler
Copy link

This patch is a follow-up from #161213 and adds the omp.fuse loop transformation for the OpenMP dialect. Used for lowering a !$omp fuse in Flang.
Added Lowering and end2end tests.

@github-actions
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added mlir:llvm mlir flang Flang issues not falling into any other category mlir:openmp flang:fir-hlfir flang:openmp flang:semantics flang:parser clang:openmp OpenMP related changes to Clang openmp:libomp OpenMP host runtime labels Nov 20, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2025

@llvm/pr-subscribers-flang-parser
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Author: Ferran Toda (NouTimbaler)

Changes

This patch is a follow-up from #161213 and adds the omp.fuse loop transformation for the OpenMP dialect. Used for lowering a !$omp fuse in Flang.
Added Lowering and end2end tests.


Patch is 124.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168898.diff

41 Files Affected:

  • (modified) flang/include/flang/Parser/openmp-utils.h (+3)
  • (modified) flang/include/flang/Semantics/openmp-directive-sets.h (+7)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+1)
  • (modified) flang/lib/Lower/OpenMP/Clauses.cpp (+4-1)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+2-1)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+83-23)
  • (modified) flang/lib/Lower/OpenMP/Utils.cpp (+17-11)
  • (modified) flang/lib/Lower/OpenMP/Utils.h (+4-2)
  • (modified) flang/lib/Parser/openmp-parsers.cpp (+1)
  • (modified) flang/lib/Parser/openmp-utils.cpp (+17)
  • (modified) flang/lib/Semantics/canonicalize-omp.cpp (+72-45)
  • (modified) flang/lib/Semantics/check-omp-loop.cpp (+106-33)
  • (modified) flang/lib/Semantics/check-omp-structure.cpp (+5-3)
  • (modified) flang/lib/Semantics/check-omp-structure.h (+2)
  • (modified) flang/lib/Semantics/resolve-directives.cpp (+109-101)
  • (modified) flang/lib/Semantics/rewrite-parse-tree.cpp (+24-15)
  • (added) flang/test/Lower/OpenMP/fuse01.f90 (+93)
  • (added) flang/test/Lower/OpenMP/fuse02.f90 (+123)
  • (added) flang/test/Parser/OpenMP/fail-looprange.f90 (+11)
  • (added) flang/test/Parser/OpenMP/fuse-looprange.f90 (+38)
  • (added) flang/test/Parser/OpenMP/fuse01.f90 (+28)
  • (added) flang/test/Parser/OpenMP/fuse02.f90 (+97)
  • (added) flang/test/Parser/OpenMP/loop-transformation-construct04.f90 (+80)
  • (added) flang/test/Parser/OpenMP/loop-transformation-construct05.f90 (+90)
  • (added) flang/test/Semantics/OpenMP/loop-transformation-clauses01.f90 (+66)
  • (modified) flang/test/Semantics/OpenMP/loop-transformation-construct01.f90 (+2-2)
  • (added) flang/test/Semantics/OpenMP/loop-transformation-construct02.f90 (+93)
  • (added) flang/test/Semantics/OpenMP/loop-transformation-construct03.f90 (+39)
  • (added) flang/test/Semantics/OpenMP/loop-transformation-construct04.f90 (+47)
  • (modified) flang/test/Semantics/OpenMP/tile02.f90 (+1-1)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+53)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+111)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+34)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+68)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+54)
  • (added) mlir/test/Dialect/OpenMP/cli-fuse.mlir (+114)
  • (added) mlir/test/Dialect/OpenMP/invalid-fuse.mlir (+100)
  • (added) mlir/test/Target/LLVMIR/openmp-cli-fuse01.mlir (+100)
  • (added) mlir/test/Target/LLVMIR/openmp-cli-fuse02.mlir (+140)
  • (added) openmp/runtime/test/transform/fuse/do-looprange.f90 (+60)
  • (added) openmp/runtime/test/transform/fuse/do.f90 (+52)
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index 36556f8dd7f4a..7396e57144b90 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -123,6 +123,9 @@ template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
 const OpenMPDeclarativeConstruct *GetOmp(const DeclarationConstruct &x);
 const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x);
 
+const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x);
+const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x);
+
 const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
 
 template <typename T>
diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 01e8481e05721..609a7be700c28 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -275,10 +275,17 @@ static const OmpDirectiveSet loopConstructSet{
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
     Directive::OMPD_teams_loop,
+    Directive::OMPD_fuse,
     Directive::OMPD_tile,
     Directive::OMPD_unroll,
 };
 
+static const OmpDirectiveSet loopTransformationSet{
+    Directive::OMPD_tile,
+    Directive::OMPD_unroll,
+    Directive::OMPD_fuse,
+};
+
 static const OmpDirectiveSet nonPartialVarSet{
     Directive::OMPD_allocate,
     Directive::OMPD_allocators,
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 4a392381287d5..ab3a174c7ad69 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -279,6 +279,7 @@ bool ClauseProcessor::processCollapse(
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
 
   int64_t numCollapse = collectLoopRelatedInfo(converter, currentLocation, eval,
+                                               eval.getFirstNestedEvaluation(),
                                                clauses, loopResult, iv);
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   collapseResult.collapseNumLoops = firOpBuilder.getI64IntegerAttr(numCollapse);
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index b1a3c3d3c5439..f2defc62dce91 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1063,7 +1063,10 @@ Link make(const parser::OmpClause::Link &inp,
 
 LoopRange make(const parser::OmpClause::Looprange &inp,
                semantics::SemanticsContext &semaCtx) {
-  llvm_unreachable("Unimplemented: looprange");
+  auto &t0 = std::get<0>(inp.v.t);
+  auto &t1 = std::get<1>(inp.v.t);
+  return LoopRange{{/*First*/ makeExpr(t0, semaCtx),
+                    /*Count*/ makeExpr(t1, semaCtx)}};
 }
 
 Map make(const parser::OmpClause::Map &inp,
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 83c2eda0a2dc7..da9480123513f 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -347,7 +347,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
     mlir::omp::LoopRelatedClauseOps result;
     llvm::SmallVector<const semantics::Symbol *> iv;
     collectLoopRelatedInfo(converter, converter.getCurrentLocation(), eval,
-                           clauses, result, iv);
+                           eval.getFirstNestedEvaluation(), clauses, result,
+                           iv);
 
     // Update the original variable just before exiting the worksharing
     // loop. Conversion as follows:
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c6487349c4056..2d981f421a4ae 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1982,9 +1982,9 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 static void genCanonicalLoopNest(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
-    mlir::Location loc, const ConstructQueue &queue,
-    ConstructQueue::const_iterator item, size_t numLoops,
-    llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
+    lower::pft::Evaluation &nestedEval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item,
+    size_t numLoops, llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
   assert(loops.empty() && "Expecting empty list to fill");
   assert(numLoops >= 1 && "Expecting at least one loop");
 
@@ -1992,7 +1992,8 @@ static void genCanonicalLoopNest(
 
   mlir::omp::LoopRelatedClauseOps loopInfo;
   llvm::SmallVector<const semantics::Symbol *, 3> ivs;
-  collectLoopRelatedInfo(converter, loc, eval, numLoops, loopInfo, ivs);
+  collectLoopRelatedInfo(converter, loc, eval, nestedEval, numLoops, loopInfo,
+                         ivs);
   assert(ivs.size() == numLoops &&
          "Expected to parse as many loop variables as there are loops");
 
@@ -2014,7 +2015,7 @@ static void genCanonicalLoopNest(
 
   // Step 1: Loop prologues
   // Computing the trip count must happen before entering the outermost loop
-  lower::pft::Evaluation *innermostEval = &eval.getFirstNestedEvaluation();
+  lower::pft::Evaluation *innermostEval = &nestedEval;
   for ([[maybe_unused]] auto iv : ivs) {
     if (innermostEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
       // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct.
@@ -2186,7 +2187,8 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::omp::CanonicalLoopOp, 3> canonLoops;
   canonLoops.reserve(numLoops);
 
-  genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item,
+  genCanonicalLoopNest(converter, symTable, semaCtx, eval,
+                       eval.getFirstNestedEvaluation(), loc, queue, item,
                        numLoops, canonLoops);
   assert((canonLoops.size() == numLoops) &&
          "Expecting the predetermined number of loops");
@@ -2217,6 +2219,58 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter,
                             sizesClause.sizes);
 }
 
+static void genFuseOp(Fortran::lower::AbstractConverter &converter,
+                      Fortran::lower::SymMap &symTable,
+                      lower::StatementContext &stmtCtx,
+                      Fortran::semantics::SemanticsContext &semaCtx,
+                      Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+                      const ConstructQueue &queue,
+                      ConstructQueue::const_iterator item) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  int32_t first = 0;
+  int32_t count = 0;
+  auto iter = llvm::find_if(item->clauses, [](const Clause &clause) {
+    return clause.id == llvm::omp::Clause::OMPC_looprange;
+  });
+  if (iter != item->clauses.end()) {
+    const auto &looprange = std::get<clause::LoopRange>(iter->u);
+    first = evaluate::ToInt64(std::get<0>(looprange.t)).value();
+    count = evaluate::ToInt64(std::get<1>(looprange.t)).value();
+  }
+
+  llvm::SmallVector<mlir::Value> applyees;
+  for (auto &child : eval.getNestedEvaluations()) {
+    // Skip OmpEndLoopDirective
+    if (&child == &eval.getLastNestedEvaluation())
+      break;
+
+    // Emit the associated loop
+    llvm::SmallVector<mlir::omp::CanonicalLoopOp> canonLoops;
+    genCanonicalLoopNest(converter, symTable, semaCtx, eval, child, loc, queue,
+                         item, 1, canonLoops);
+
+    auto cli = llvm::getSingleElement(canonLoops).getCli();
+    applyees.push_back(cli);
+  }
+  // One generated loop + one for each loop not inside the specified looprange
+  // if present
+  llvm::SmallVector<mlir::Value> generatees;
+  int64_t numGeneratees = count == 0 ? 1 : applyees.size() - count + 1;
+  for (int i = 0; i < numGeneratees; i++) {
+    auto fusedCLI = mlir::omp::NewCliOp::create(firOpBuilder, loc);
+    generatees.push_back(fusedCLI);
+  }
+  auto op = mlir::omp::FuseOp::create(firOpBuilder, loc, generatees, applyees);
+
+  if (count != 0) {
+    mlir::IntegerAttr firstAttr = firOpBuilder.getI32IntegerAttr(first);
+    mlir::IntegerAttr countAttr = firOpBuilder.getI32IntegerAttr(count);
+    op->setAttr("first", firstAttr);
+    op->setAttr("count", countAttr);
+  }
+}
+
 static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
                         Fortran::lower::SymMap &symTable,
                         lower::StatementContext &stmtCtx,
@@ -2233,7 +2287,8 @@ static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
 
   // Emit the associated loop
   llvm::SmallVector<mlir::omp::CanonicalLoopOp, 1> canonLoops;
-  genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, 1,
+  genCanonicalLoopNest(converter, symTable, semaCtx, eval,
+                       eval.getFirstNestedEvaluation(), loc, queue, item, 1,
                        canonLoops);
 
   llvm::SmallVector<mlir::Value, 1> applyees;
@@ -3507,6 +3562,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
   case llvm::omp::Directive::OMPD_tile:
     genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
     break;
+  case llvm::omp::Directive::OMPD_fuse:
+    genFuseOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
+    break;
   case llvm::omp::Directive::OMPD_unroll:
     genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
     break;
@@ -3962,22 +4020,24 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
   mlir::Location currentLocation = converter.genLocation(beginSpec.source);
 
-  if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
-          loopConstruct.GetNestedConstruct()) {
-    llvm::omp::Directive nestedDirective =
-        parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
-    switch (nestedDirective) {
-    case llvm::omp::Directive::OMPD_tile:
-      // Skip OMPD_tile since the tile sizes will be retrieved when
-      // generating the omp.loop_nest op.
-      break;
-    default: {
-      unsigned version = semaCtx.langOptions().OpenMPVersion;
-      TODO(currentLocation,
-           "Applying a loop-associated on the loop generated by the " +
-               llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
-               " construct");
-    }
+  for (auto &construct : std::get<parser::Block>(loopConstruct.t)) {
+    if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
+            parser::omp::GetOmpLoop(construct)) {
+      llvm::omp::Directive nestedDirective =
+          parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
+      switch (nestedDirective) {
+      case llvm::omp::Directive::OMPD_tile:
+        // Skip OMPD_tile since the tile sizes will be retrieved when
+        // generating the omp.loop_nest op.
+        break;
+      default: {
+        unsigned version = semaCtx.langOptions().OpenMPVersion;
+        TODO(currentLocation,
+             "Applying a loop-associated on the loop generated by the " +
+                 llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
+                 " construct");
+      }
+      }
     }
   }
 
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 7d7a4869ab3a6..913e4d1e69500 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -812,13 +812,14 @@ void collectTileSizesFromOpenMPConstruct(
 
 int64_t collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
   int64_t numCollapse = 1;
 
   // Collect the loops to collapse.
-  lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
+  lower::pft::Evaluation *doConstructEval = &nestedEval;
   if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
     TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
   }
@@ -830,21 +831,21 @@ int64_t collectLoopRelatedInfo(
     numCollapse = collapseValue;
   }
 
-  collectLoopRelatedInfo(converter, currentLocation, eval, numCollapse, result,
-                         iv);
+  collectLoopRelatedInfo(converter, currentLocation, eval, nestedEval,
+                         numCollapse, result, iv);
   return numCollapse;
 }
 
 void collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, int64_t numCollapse,
-    mlir::omp::LoopRelatedClauseOps &result,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    int64_t numCollapse, mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
 
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
   // Collect the loops to collapse.
-  lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
+  lower::pft::Evaluation *doConstructEval = &nestedEval;
   if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
     TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
   }
@@ -852,10 +853,15 @@ void collectLoopRelatedInfo(
   // Collect sizes from tile directive if present.
   std::int64_t sizesLengthValue = 0l;
   if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
-    processTileSizesFromOpenMPConstruct(
-        ompCons, [&](const parser::OmpClause::Sizes *tclause) {
-          sizesLengthValue = tclause->v.size();
-        });
+    if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
+      const parser::OmpDirectiveSpecification &beginSpec{ompLoop->BeginDir()};
+      if (beginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
+        processTileSizesFromOpenMPConstruct(
+            ompCons, [&](const parser::OmpClause::Sizes *tclause) {
+              sizesLengthValue = tclause->v.size();
+            });
+      }
+    }
   }
 
   std::int64_t collapseValue = std::max(numCollapse, sizesLengthValue);
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 2960b663b08b2..886a5c1835f7e 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -169,13 +169,15 @@ void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,
 
 int64_t collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
 
 void collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, std::int64_t collapseValue,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    std::int64_t collapseValue,
     // const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index e2da60ed19de8..231eea8841d4b 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -2260,6 +2260,7 @@ static constexpr DirectiveSet GetLoopDirectives() {
       unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
       unsigned(Directive::OMPD_teams_distribute_simd),
       unsigned(Directive::OMPD_teams_loop),
+      unsigned(Directive::OMPD_fuse),
       unsigned(Directive::OMPD_tile),
       unsigned(Directive::OMPD_unroll),
   };
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index 2424828293c73..dfe8dbdd5ac9e 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -41,6 +41,23 @@ const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x) {
   return nullptr;
 }
 
+const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x) {
+  if (auto *construct{GetOmp(x)}) {
+    if (auto *omp{std::get_if<OpenMPLoopConstruct>(&construct->u)}) {
+      return omp;
+    }
+  }
+  return nullptr;
+}
+const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x) {
+  if (auto *y{std::get_if<ExecutableConstruct>(&x.u)}) {
+    if (auto *z{std::get_if<common::Indirection<DoConstruct>>(&y->u)}) {
+      return &z->value();
+    }
+  }
+  return nullptr;
+}
+
 const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
   // Clauses with OmpObjectList as its data member
   using MemberObjectListClauses = std::tuple<OmpClause::Copyin,
diff --git a/flang/lib/Semantics/canonicalize-omp.cpp b/flang/lib/Semantics/canonicalize-omp.cpp
index 0cec1969e0978..f7c53d6d8f4c4 100644
--- a/flang/lib/Semantics/canonicalize-omp.cpp
+++ b/flang/lib/Semantics/canonicalize-omp.cpp
@@ -9,6 +9,7 @@
 #include "canonicalize-omp.h"
 #include "flang/Parser/parse-tree-visitor.h"
 #include "flang/Parser/parse-tree.h"
+#include "flang/Semantics/openmp-directive-sets.h"
 #include "flang/Semantics/semantics.h"
 
 // After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
@@ -136,20 +137,30 @@ class CanonicalizationOfOmp {
           "A DO loop must follow the %s directive"_err_en_US,
           parser::ToUpperCaseLetters(dirName.source.ToString()));
     };
-    auto tileUnrollError = [](const parser::OmpDirectiveName &dirName,
-                               parser::Messages &messages) {
+    auto transformUnrollError = [](const parser::OmpDirectiveName &dirName,
+                                    parser::Messages &messages) {
       messages.Say(dirName.source,
-          "If a loop construct has been fully unrolled, it cannot then be tiled"_err_en_US,
+          "If a loop construct has been fully unrolled, it cannot then be further transformed"_err_en_US,
           parser::ToUpperCaseLetters(dirName.source.ToString()));
     };
+    auto missingEndFuse = [](auto &dir, auto &messages) {
+      messages.Say(dir.source,
+          "The %s construct requires the END FUSE directive"_err_en_US,
+          parser::ToUpperCaseLetters(dir.source.ToString()));
+    };
+
+    bool endFuseNeeded = beginName.v == llvm::omp::Directive::OMPD_fuse;
 
     auto &body{std::get<parser::Block>(x.t)};
 
     nextIt = it;
-    while (++nextIt != block.end()) {
+    nextIt++;
+    while (nextIt != block.end()) {
       // Ignore compiler directives.
-      if (GetConstructIf<parser::CompilerDirective>(*nextIt))
+      if (GetConstructIf<parser::CompilerDirective>(*nextIt)) {
+        nextIt++;
         continue;
+      }
 
       if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
         if (doCons->GetLoopControl()) {
@@ -160,9 +171,12 @@ class CanonicalizationOfOmp {
           if (nextIt != block.end()) {
             if (auto *endDir{
                     GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
-              std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
-                  std::move(*endDir);
-              nextIt = block.erase(nextIt);
+              auto &endDirName = endDir->DirName();
+              if (endDirName.v != llvm::omp::Directive::OMPD_fuse) {
+                std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
+                    std::move(*endDir);
+                nextIt = block.erase(nextIt);
+              }
             }
           }
         } else {
@@ -172,50 +186,45 @@ class CanonicalizationOfOmp {
         }
       } else if (auto *ompLoopCons{
                      GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
-        // We should allow UNROLL and TILE constructs to be inserted between an
-        // OpenMP Loop Construct and the DO loop itself
+        // We should allow loop transformation constructs to be inserted between
+   ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Ferran Toda (NouTimbaler)

Changes

This patch is a follow-up from #161213 and adds the omp.fuse loop transformation for the OpenMP dialect. Used for lowering a !$omp fuse in Flang.
Added Lowering and end2end tests.


Patch is 124.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168898.diff

41 Files Affected:

  • (modified) flang/include/flang/Parser/openmp-utils.h (+3)
  • (modified) flang/include/flang/Semantics/openmp-directive-sets.h (+7)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+1)
  • (modified) flang/lib/Lower/OpenMP/Clauses.cpp (+4-1)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+2-1)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+83-23)
  • (modified) flang/lib/Lower/OpenMP/Utils.cpp (+17-11)
  • (modified) flang/lib/Lower/OpenMP/Utils.h (+4-2)
  • (modified) flang/lib/Parser/openmp-parsers.cpp (+1)
  • (modified) flang/lib/Parser/openmp-utils.cpp (+17)
  • (modified) flang/lib/Semantics/canonicalize-omp.cpp (+72-45)
  • (modified) flang/lib/Semantics/check-omp-loop.cpp (+106-33)
  • (modified) flang/lib/Semantics/check-omp-structure.cpp (+5-3)
  • (modified) flang/lib/Semantics/check-omp-structure.h (+2)
  • (modified) flang/lib/Semantics/resolve-directives.cpp (+109-101)
  • (modified) flang/lib/Semantics/rewrite-parse-tree.cpp (+24-15)
  • (added) flang/test/Lower/OpenMP/fuse01.f90 (+93)
  • (added) flang/test/Lower/OpenMP/fuse02.f90 (+123)
  • (added) flang/test/Parser/OpenMP/fail-looprange.f90 (+11)
  • (added) flang/test/Parser/OpenMP/fuse-looprange.f90 (+38)
  • (added) flang/test/Parser/OpenMP/fuse01.f90 (+28)
  • (added) flang/test/Parser/OpenMP/fuse02.f90 (+97)
  • (added) flang/test/Parser/OpenMP/loop-transformation-construct04.f90 (+80)
  • (added) flang/test/Parser/OpenMP/loop-transformation-construct05.f90 (+90)
  • (added) flang/test/Semantics/OpenMP/loop-transformation-clauses01.f90 (+66)
  • (modified) flang/test/Semantics/OpenMP/loop-transformation-construct01.f90 (+2-2)
  • (added) flang/test/Semantics/OpenMP/loop-transformation-construct02.f90 (+93)
  • (added) flang/test/Semantics/OpenMP/loop-transformation-construct03.f90 (+39)
  • (added) flang/test/Semantics/OpenMP/loop-transformation-construct04.f90 (+47)
  • (modified) flang/test/Semantics/OpenMP/tile02.f90 (+1-1)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+53)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+111)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+34)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+68)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+54)
  • (added) mlir/test/Dialect/OpenMP/cli-fuse.mlir (+114)
  • (added) mlir/test/Dialect/OpenMP/invalid-fuse.mlir (+100)
  • (added) mlir/test/Target/LLVMIR/openmp-cli-fuse01.mlir (+100)
  • (added) mlir/test/Target/LLVMIR/openmp-cli-fuse02.mlir (+140)
  • (added) openmp/runtime/test/transform/fuse/do-looprange.f90 (+60)
  • (added) openmp/runtime/test/transform/fuse/do.f90 (+52)
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index 36556f8dd7f4a..7396e57144b90 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -123,6 +123,9 @@ template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
 const OpenMPDeclarativeConstruct *GetOmp(const DeclarationConstruct &x);
 const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x);
 
+const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x);
+const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x);
+
 const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
 
 template <typename T>
diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 01e8481e05721..609a7be700c28 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -275,10 +275,17 @@ static const OmpDirectiveSet loopConstructSet{
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
     Directive::OMPD_teams_loop,
+    Directive::OMPD_fuse,
     Directive::OMPD_tile,
     Directive::OMPD_unroll,
 };
 
+static const OmpDirectiveSet loopTransformationSet{
+    Directive::OMPD_tile,
+    Directive::OMPD_unroll,
+    Directive::OMPD_fuse,
+};
+
 static const OmpDirectiveSet nonPartialVarSet{
     Directive::OMPD_allocate,
     Directive::OMPD_allocators,
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 4a392381287d5..ab3a174c7ad69 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -279,6 +279,7 @@ bool ClauseProcessor::processCollapse(
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
 
   int64_t numCollapse = collectLoopRelatedInfo(converter, currentLocation, eval,
+                                               eval.getFirstNestedEvaluation(),
                                                clauses, loopResult, iv);
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   collapseResult.collapseNumLoops = firOpBuilder.getI64IntegerAttr(numCollapse);
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index b1a3c3d3c5439..f2defc62dce91 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1063,7 +1063,10 @@ Link make(const parser::OmpClause::Link &inp,
 
 LoopRange make(const parser::OmpClause::Looprange &inp,
                semantics::SemanticsContext &semaCtx) {
-  llvm_unreachable("Unimplemented: looprange");
+  auto &t0 = std::get<0>(inp.v.t);
+  auto &t1 = std::get<1>(inp.v.t);
+  return LoopRange{{/*First*/ makeExpr(t0, semaCtx),
+                    /*Count*/ makeExpr(t1, semaCtx)}};
 }
 
 Map make(const parser::OmpClause::Map &inp,
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 83c2eda0a2dc7..da9480123513f 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -347,7 +347,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
     mlir::omp::LoopRelatedClauseOps result;
     llvm::SmallVector<const semantics::Symbol *> iv;
     collectLoopRelatedInfo(converter, converter.getCurrentLocation(), eval,
-                           clauses, result, iv);
+                           eval.getFirstNestedEvaluation(), clauses, result,
+                           iv);
 
     // Update the original variable just before exiting the worksharing
     // loop. Conversion as follows:
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c6487349c4056..2d981f421a4ae 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1982,9 +1982,9 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 static void genCanonicalLoopNest(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
-    mlir::Location loc, const ConstructQueue &queue,
-    ConstructQueue::const_iterator item, size_t numLoops,
-    llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
+    lower::pft::Evaluation &nestedEval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item,
+    size_t numLoops, llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
   assert(loops.empty() && "Expecting empty list to fill");
   assert(numLoops >= 1 && "Expecting at least one loop");
 
@@ -1992,7 +1992,8 @@ static void genCanonicalLoopNest(
 
   mlir::omp::LoopRelatedClauseOps loopInfo;
   llvm::SmallVector<const semantics::Symbol *, 3> ivs;
-  collectLoopRelatedInfo(converter, loc, eval, numLoops, loopInfo, ivs);
+  collectLoopRelatedInfo(converter, loc, eval, nestedEval, numLoops, loopInfo,
+                         ivs);
   assert(ivs.size() == numLoops &&
          "Expected to parse as many loop variables as there are loops");
 
@@ -2014,7 +2015,7 @@ static void genCanonicalLoopNest(
 
   // Step 1: Loop prologues
   // Computing the trip count must happen before entering the outermost loop
-  lower::pft::Evaluation *innermostEval = &eval.getFirstNestedEvaluation();
+  lower::pft::Evaluation *innermostEval = &nestedEval;
   for ([[maybe_unused]] auto iv : ivs) {
     if (innermostEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
       // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct.
@@ -2186,7 +2187,8 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::omp::CanonicalLoopOp, 3> canonLoops;
   canonLoops.reserve(numLoops);
 
-  genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item,
+  genCanonicalLoopNest(converter, symTable, semaCtx, eval,
+                       eval.getFirstNestedEvaluation(), loc, queue, item,
                        numLoops, canonLoops);
   assert((canonLoops.size() == numLoops) &&
          "Expecting the predetermined number of loops");
@@ -2217,6 +2219,58 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter,
                             sizesClause.sizes);
 }
 
+static void genFuseOp(Fortran::lower::AbstractConverter &converter,
+                      Fortran::lower::SymMap &symTable,
+                      lower::StatementContext &stmtCtx,
+                      Fortran::semantics::SemanticsContext &semaCtx,
+                      Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+                      const ConstructQueue &queue,
+                      ConstructQueue::const_iterator item) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  int32_t first = 0;
+  int32_t count = 0;
+  auto iter = llvm::find_if(item->clauses, [](const Clause &clause) {
+    return clause.id == llvm::omp::Clause::OMPC_looprange;
+  });
+  if (iter != item->clauses.end()) {
+    const auto &looprange = std::get<clause::LoopRange>(iter->u);
+    first = evaluate::ToInt64(std::get<0>(looprange.t)).value();
+    count = evaluate::ToInt64(std::get<1>(looprange.t)).value();
+  }
+
+  llvm::SmallVector<mlir::Value> applyees;
+  for (auto &child : eval.getNestedEvaluations()) {
+    // Skip OmpEndLoopDirective
+    if (&child == &eval.getLastNestedEvaluation())
+      break;
+
+    // Emit the associated loop
+    llvm::SmallVector<mlir::omp::CanonicalLoopOp> canonLoops;
+    genCanonicalLoopNest(converter, symTable, semaCtx, eval, child, loc, queue,
+                         item, 1, canonLoops);
+
+    auto cli = llvm::getSingleElement(canonLoops).getCli();
+    applyees.push_back(cli);
+  }
+  // One generated loop + one for each loop not inside the specified looprange
+  // if present
+  llvm::SmallVector<mlir::Value> generatees;
+  int64_t numGeneratees = count == 0 ? 1 : applyees.size() - count + 1;
+  for (int i = 0; i < numGeneratees; i++) {
+    auto fusedCLI = mlir::omp::NewCliOp::create(firOpBuilder, loc);
+    generatees.push_back(fusedCLI);
+  }
+  auto op = mlir::omp::FuseOp::create(firOpBuilder, loc, generatees, applyees);
+
+  if (count != 0) {
+    mlir::IntegerAttr firstAttr = firOpBuilder.getI32IntegerAttr(first);
+    mlir::IntegerAttr countAttr = firOpBuilder.getI32IntegerAttr(count);
+    op->setAttr("first", firstAttr);
+    op->setAttr("count", countAttr);
+  }
+}
+
 static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
                         Fortran::lower::SymMap &symTable,
                         lower::StatementContext &stmtCtx,
@@ -2233,7 +2287,8 @@ static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
 
   // Emit the associated loop
   llvm::SmallVector<mlir::omp::CanonicalLoopOp, 1> canonLoops;
-  genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, 1,
+  genCanonicalLoopNest(converter, symTable, semaCtx, eval,
+                       eval.getFirstNestedEvaluation(), loc, queue, item, 1,
                        canonLoops);
 
   llvm::SmallVector<mlir::Value, 1> applyees;
@@ -3507,6 +3562,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
   case llvm::omp::Directive::OMPD_tile:
     genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
     break;
+  case llvm::omp::Directive::OMPD_fuse:
+    genFuseOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
+    break;
   case llvm::omp::Directive::OMPD_unroll:
     genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
     break;
@@ -3962,22 +4020,24 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
   mlir::Location currentLocation = converter.genLocation(beginSpec.source);
 
-  if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
-          loopConstruct.GetNestedConstruct()) {
-    llvm::omp::Directive nestedDirective =
-        parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
-    switch (nestedDirective) {
-    case llvm::omp::Directive::OMPD_tile:
-      // Skip OMPD_tile since the tile sizes will be retrieved when
-      // generating the omp.loop_nest op.
-      break;
-    default: {
-      unsigned version = semaCtx.langOptions().OpenMPVersion;
-      TODO(currentLocation,
-           "Applying a loop-associated on the loop generated by the " +
-               llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
-               " construct");
-    }
+  for (auto &construct : std::get<parser::Block>(loopConstruct.t)) {
+    if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
+            parser::omp::GetOmpLoop(construct)) {
+      llvm::omp::Directive nestedDirective =
+          parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
+      switch (nestedDirective) {
+      case llvm::omp::Directive::OMPD_tile:
+        // Skip OMPD_tile since the tile sizes will be retrieved when
+        // generating the omp.loop_nest op.
+        break;
+      default: {
+        unsigned version = semaCtx.langOptions().OpenMPVersion;
+        TODO(currentLocation,
+             "Applying a loop-associated on the loop generated by the " +
+                 llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
+                 " construct");
+      }
+      }
     }
   }
 
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 7d7a4869ab3a6..913e4d1e69500 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -812,13 +812,14 @@ void collectTileSizesFromOpenMPConstruct(
 
 int64_t collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
   int64_t numCollapse = 1;
 
   // Collect the loops to collapse.
-  lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
+  lower::pft::Evaluation *doConstructEval = &nestedEval;
   if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
     TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
   }
@@ -830,21 +831,21 @@ int64_t collectLoopRelatedInfo(
     numCollapse = collapseValue;
   }
 
-  collectLoopRelatedInfo(converter, currentLocation, eval, numCollapse, result,
-                         iv);
+  collectLoopRelatedInfo(converter, currentLocation, eval, nestedEval,
+                         numCollapse, result, iv);
   return numCollapse;
 }
 
 void collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, int64_t numCollapse,
-    mlir::omp::LoopRelatedClauseOps &result,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    int64_t numCollapse, mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
 
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
   // Collect the loops to collapse.
-  lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
+  lower::pft::Evaluation *doConstructEval = &nestedEval;
   if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
     TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
   }
@@ -852,10 +853,15 @@ void collectLoopRelatedInfo(
   // Collect sizes from tile directive if present.
   std::int64_t sizesLengthValue = 0l;
   if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
-    processTileSizesFromOpenMPConstruct(
-        ompCons, [&](const parser::OmpClause::Sizes *tclause) {
-          sizesLengthValue = tclause->v.size();
-        });
+    if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
+      const parser::OmpDirectiveSpecification &beginSpec{ompLoop->BeginDir()};
+      if (beginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
+        processTileSizesFromOpenMPConstruct(
+            ompCons, [&](const parser::OmpClause::Sizes *tclause) {
+              sizesLengthValue = tclause->v.size();
+            });
+      }
+    }
   }
 
   std::int64_t collapseValue = std::max(numCollapse, sizesLengthValue);
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 2960b663b08b2..886a5c1835f7e 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -169,13 +169,15 @@ void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,
 
 int64_t collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
 
 void collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, std::int64_t collapseValue,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    std::int64_t collapseValue,
     // const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index e2da60ed19de8..231eea8841d4b 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -2260,6 +2260,7 @@ static constexpr DirectiveSet GetLoopDirectives() {
       unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
       unsigned(Directive::OMPD_teams_distribute_simd),
       unsigned(Directive::OMPD_teams_loop),
+      unsigned(Directive::OMPD_fuse),
       unsigned(Directive::OMPD_tile),
       unsigned(Directive::OMPD_unroll),
   };
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index 2424828293c73..dfe8dbdd5ac9e 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -41,6 +41,23 @@ const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x) {
   return nullptr;
 }
 
+const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x) {
+  if (auto *construct{GetOmp(x)}) {
+    if (auto *omp{std::get_if<OpenMPLoopConstruct>(&construct->u)}) {
+      return omp;
+    }
+  }
+  return nullptr;
+}
+const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x) {
+  if (auto *y{std::get_if<ExecutableConstruct>(&x.u)}) {
+    if (auto *z{std::get_if<common::Indirection<DoConstruct>>(&y->u)}) {
+      return &z->value();
+    }
+  }
+  return nullptr;
+}
+
 const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
   // Clauses with OmpObjectList as its data member
   using MemberObjectListClauses = std::tuple<OmpClause::Copyin,
diff --git a/flang/lib/Semantics/canonicalize-omp.cpp b/flang/lib/Semantics/canonicalize-omp.cpp
index 0cec1969e0978..f7c53d6d8f4c4 100644
--- a/flang/lib/Semantics/canonicalize-omp.cpp
+++ b/flang/lib/Semantics/canonicalize-omp.cpp
@@ -9,6 +9,7 @@
 #include "canonicalize-omp.h"
 #include "flang/Parser/parse-tree-visitor.h"
 #include "flang/Parser/parse-tree.h"
+#include "flang/Semantics/openmp-directive-sets.h"
 #include "flang/Semantics/semantics.h"
 
 // After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
@@ -136,20 +137,30 @@ class CanonicalizationOfOmp {
           "A DO loop must follow the %s directive"_err_en_US,
           parser::ToUpperCaseLetters(dirName.source.ToString()));
     };
-    auto tileUnrollError = [](const parser::OmpDirectiveName &dirName,
-                               parser::Messages &messages) {
+    auto transformUnrollError = [](const parser::OmpDirectiveName &dirName,
+                                    parser::Messages &messages) {
       messages.Say(dirName.source,
-          "If a loop construct has been fully unrolled, it cannot then be tiled"_err_en_US,
+          "If a loop construct has been fully unrolled, it cannot then be further transformed"_err_en_US,
           parser::ToUpperCaseLetters(dirName.source.ToString()));
     };
+    auto missingEndFuse = [](auto &dir, auto &messages) {
+      messages.Say(dir.source,
+          "The %s construct requires the END FUSE directive"_err_en_US,
+          parser::ToUpperCaseLetters(dir.source.ToString()));
+    };
+
+    bool endFuseNeeded = beginName.v == llvm::omp::Directive::OMPD_fuse;
 
     auto &body{std::get<parser::Block>(x.t)};
 
     nextIt = it;
-    while (++nextIt != block.end()) {
+    nextIt++;
+    while (nextIt != block.end()) {
       // Ignore compiler directives.
-      if (GetConstructIf<parser::CompilerDirective>(*nextIt))
+      if (GetConstructIf<parser::CompilerDirective>(*nextIt)) {
+        nextIt++;
         continue;
+      }
 
       if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
         if (doCons->GetLoopControl()) {
@@ -160,9 +171,12 @@ class CanonicalizationOfOmp {
           if (nextIt != block.end()) {
             if (auto *endDir{
                     GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
-              std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
-                  std::move(*endDir);
-              nextIt = block.erase(nextIt);
+              auto &endDirName = endDir->DirName();
+              if (endDirName.v != llvm::omp::Directive::OMPD_fuse) {
+                std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
+                    std::move(*endDir);
+                nextIt = block.erase(nextIt);
+              }
             }
           }
         } else {
@@ -172,50 +186,45 @@ class CanonicalizationOfOmp {
         }
       } else if (auto *ompLoopCons{
                      GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
-        // We should allow UNROLL and TILE constructs to be inserted between an
-        // OpenMP Loop Construct and the DO loop itself
+        // We should allow loop transformation constructs to be inserted between
+   ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2025

@llvm/pr-subscribers-mlir-openmp

Author: Ferran Toda (NouTimbaler)

Changes

This patch is a follow-up from #161213 and adds the omp.fuse loop transformation for the OpenMP dialect. Used for lowering a !$omp fuse in Flang.
Added Lowering and end2end tests.


Patch is 124.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168898.diff

41 Files Affected:

  • (modified) flang/include/flang/Parser/openmp-utils.h (+3)
  • (modified) flang/include/flang/Semantics/openmp-directive-sets.h (+7)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+1)
  • (modified) flang/lib/Lower/OpenMP/Clauses.cpp (+4-1)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+2-1)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+83-23)
  • (modified) flang/lib/Lower/OpenMP/Utils.cpp (+17-11)
  • (modified) flang/lib/Lower/OpenMP/Utils.h (+4-2)
  • (modified) flang/lib/Parser/openmp-parsers.cpp (+1)
  • (modified) flang/lib/Parser/openmp-utils.cpp (+17)
  • (modified) flang/lib/Semantics/canonicalize-omp.cpp (+72-45)
  • (modified) flang/lib/Semantics/check-omp-loop.cpp (+106-33)
  • (modified) flang/lib/Semantics/check-omp-structure.cpp (+5-3)
  • (modified) flang/lib/Semantics/check-omp-structure.h (+2)
  • (modified) flang/lib/Semantics/resolve-directives.cpp (+109-101)
  • (modified) flang/lib/Semantics/rewrite-parse-tree.cpp (+24-15)
  • (added) flang/test/Lower/OpenMP/fuse01.f90 (+93)
  • (added) flang/test/Lower/OpenMP/fuse02.f90 (+123)
  • (added) flang/test/Parser/OpenMP/fail-looprange.f90 (+11)
  • (added) flang/test/Parser/OpenMP/fuse-looprange.f90 (+38)
  • (added) flang/test/Parser/OpenMP/fuse01.f90 (+28)
  • (added) flang/test/Parser/OpenMP/fuse02.f90 (+97)
  • (added) flang/test/Parser/OpenMP/loop-transformation-construct04.f90 (+80)
  • (added) flang/test/Parser/OpenMP/loop-transformation-construct05.f90 (+90)
  • (added) flang/test/Semantics/OpenMP/loop-transformation-clauses01.f90 (+66)
  • (modified) flang/test/Semantics/OpenMP/loop-transformation-construct01.f90 (+2-2)
  • (added) flang/test/Semantics/OpenMP/loop-transformation-construct02.f90 (+93)
  • (added) flang/test/Semantics/OpenMP/loop-transformation-construct03.f90 (+39)
  • (added) flang/test/Semantics/OpenMP/loop-transformation-construct04.f90 (+47)
  • (modified) flang/test/Semantics/OpenMP/tile02.f90 (+1-1)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+53)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+111)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+34)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+68)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+54)
  • (added) mlir/test/Dialect/OpenMP/cli-fuse.mlir (+114)
  • (added) mlir/test/Dialect/OpenMP/invalid-fuse.mlir (+100)
  • (added) mlir/test/Target/LLVMIR/openmp-cli-fuse01.mlir (+100)
  • (added) mlir/test/Target/LLVMIR/openmp-cli-fuse02.mlir (+140)
  • (added) openmp/runtime/test/transform/fuse/do-looprange.f90 (+60)
  • (added) openmp/runtime/test/transform/fuse/do.f90 (+52)
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index 36556f8dd7f4a..7396e57144b90 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -123,6 +123,9 @@ template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
 const OpenMPDeclarativeConstruct *GetOmp(const DeclarationConstruct &x);
 const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x);
 
+const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x);
+const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x);
+
 const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
 
 template <typename T>
diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 01e8481e05721..609a7be700c28 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -275,10 +275,17 @@ static const OmpDirectiveSet loopConstructSet{
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
     Directive::OMPD_teams_loop,
+    Directive::OMPD_fuse,
     Directive::OMPD_tile,
     Directive::OMPD_unroll,
 };
 
+static const OmpDirectiveSet loopTransformationSet{
+    Directive::OMPD_tile,
+    Directive::OMPD_unroll,
+    Directive::OMPD_fuse,
+};
+
 static const OmpDirectiveSet nonPartialVarSet{
     Directive::OMPD_allocate,
     Directive::OMPD_allocators,
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 4a392381287d5..ab3a174c7ad69 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -279,6 +279,7 @@ bool ClauseProcessor::processCollapse(
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
 
   int64_t numCollapse = collectLoopRelatedInfo(converter, currentLocation, eval,
+                                               eval.getFirstNestedEvaluation(),
                                                clauses, loopResult, iv);
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   collapseResult.collapseNumLoops = firOpBuilder.getI64IntegerAttr(numCollapse);
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index b1a3c3d3c5439..f2defc62dce91 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1063,7 +1063,10 @@ Link make(const parser::OmpClause::Link &inp,
 
 LoopRange make(const parser::OmpClause::Looprange &inp,
                semantics::SemanticsContext &semaCtx) {
-  llvm_unreachable("Unimplemented: looprange");
+  auto &t0 = std::get<0>(inp.v.t);
+  auto &t1 = std::get<1>(inp.v.t);
+  return LoopRange{{/*First*/ makeExpr(t0, semaCtx),
+                    /*Count*/ makeExpr(t1, semaCtx)}};
 }
 
 Map make(const parser::OmpClause::Map &inp,
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 83c2eda0a2dc7..da9480123513f 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -347,7 +347,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
     mlir::omp::LoopRelatedClauseOps result;
     llvm::SmallVector<const semantics::Symbol *> iv;
     collectLoopRelatedInfo(converter, converter.getCurrentLocation(), eval,
-                           clauses, result, iv);
+                           eval.getFirstNestedEvaluation(), clauses, result,
+                           iv);
 
     // Update the original variable just before exiting the worksharing
     // loop. Conversion as follows:
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c6487349c4056..2d981f421a4ae 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1982,9 +1982,9 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 static void genCanonicalLoopNest(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
-    mlir::Location loc, const ConstructQueue &queue,
-    ConstructQueue::const_iterator item, size_t numLoops,
-    llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
+    lower::pft::Evaluation &nestedEval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item,
+    size_t numLoops, llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
   assert(loops.empty() && "Expecting empty list to fill");
   assert(numLoops >= 1 && "Expecting at least one loop");
 
@@ -1992,7 +1992,8 @@ static void genCanonicalLoopNest(
 
   mlir::omp::LoopRelatedClauseOps loopInfo;
   llvm::SmallVector<const semantics::Symbol *, 3> ivs;
-  collectLoopRelatedInfo(converter, loc, eval, numLoops, loopInfo, ivs);
+  collectLoopRelatedInfo(converter, loc, eval, nestedEval, numLoops, loopInfo,
+                         ivs);
   assert(ivs.size() == numLoops &&
          "Expected to parse as many loop variables as there are loops");
 
@@ -2014,7 +2015,7 @@ static void genCanonicalLoopNest(
 
   // Step 1: Loop prologues
   // Computing the trip count must happen before entering the outermost loop
-  lower::pft::Evaluation *innermostEval = &eval.getFirstNestedEvaluation();
+  lower::pft::Evaluation *innermostEval = &nestedEval;
   for ([[maybe_unused]] auto iv : ivs) {
     if (innermostEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
       // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct.
@@ -2186,7 +2187,8 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::omp::CanonicalLoopOp, 3> canonLoops;
   canonLoops.reserve(numLoops);
 
-  genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item,
+  genCanonicalLoopNest(converter, symTable, semaCtx, eval,
+                       eval.getFirstNestedEvaluation(), loc, queue, item,
                        numLoops, canonLoops);
   assert((canonLoops.size() == numLoops) &&
          "Expecting the predetermined number of loops");
@@ -2217,6 +2219,58 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter,
                             sizesClause.sizes);
 }
 
+static void genFuseOp(Fortran::lower::AbstractConverter &converter,
+                      Fortran::lower::SymMap &symTable,
+                      lower::StatementContext &stmtCtx,
+                      Fortran::semantics::SemanticsContext &semaCtx,
+                      Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+                      const ConstructQueue &queue,
+                      ConstructQueue::const_iterator item) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  int32_t first = 0;
+  int32_t count = 0;
+  auto iter = llvm::find_if(item->clauses, [](const Clause &clause) {
+    return clause.id == llvm::omp::Clause::OMPC_looprange;
+  });
+  if (iter != item->clauses.end()) {
+    const auto &looprange = std::get<clause::LoopRange>(iter->u);
+    first = evaluate::ToInt64(std::get<0>(looprange.t)).value();
+    count = evaluate::ToInt64(std::get<1>(looprange.t)).value();
+  }
+
+  llvm::SmallVector<mlir::Value> applyees;
+  for (auto &child : eval.getNestedEvaluations()) {
+    // Skip OmpEndLoopDirective
+    if (&child == &eval.getLastNestedEvaluation())
+      break;
+
+    // Emit the associated loop
+    llvm::SmallVector<mlir::omp::CanonicalLoopOp> canonLoops;
+    genCanonicalLoopNest(converter, symTable, semaCtx, eval, child, loc, queue,
+                         item, 1, canonLoops);
+
+    auto cli = llvm::getSingleElement(canonLoops).getCli();
+    applyees.push_back(cli);
+  }
+  // One generated loop + one for each loop not inside the specified looprange
+  // if present
+  llvm::SmallVector<mlir::Value> generatees;
+  int64_t numGeneratees = count == 0 ? 1 : applyees.size() - count + 1;
+  for (int i = 0; i < numGeneratees; i++) {
+    auto fusedCLI = mlir::omp::NewCliOp::create(firOpBuilder, loc);
+    generatees.push_back(fusedCLI);
+  }
+  auto op = mlir::omp::FuseOp::create(firOpBuilder, loc, generatees, applyees);
+
+  if (count != 0) {
+    mlir::IntegerAttr firstAttr = firOpBuilder.getI32IntegerAttr(first);
+    mlir::IntegerAttr countAttr = firOpBuilder.getI32IntegerAttr(count);
+    op->setAttr("first", firstAttr);
+    op->setAttr("count", countAttr);
+  }
+}
+
 static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
                         Fortran::lower::SymMap &symTable,
                         lower::StatementContext &stmtCtx,
@@ -2233,7 +2287,8 @@ static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
 
   // Emit the associated loop
   llvm::SmallVector<mlir::omp::CanonicalLoopOp, 1> canonLoops;
-  genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, 1,
+  genCanonicalLoopNest(converter, symTable, semaCtx, eval,
+                       eval.getFirstNestedEvaluation(), loc, queue, item, 1,
                        canonLoops);
 
   llvm::SmallVector<mlir::Value, 1> applyees;
@@ -3507,6 +3562,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
   case llvm::omp::Directive::OMPD_tile:
     genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
     break;
+  case llvm::omp::Directive::OMPD_fuse:
+    genFuseOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
+    break;
   case llvm::omp::Directive::OMPD_unroll:
     genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
     break;
@@ -3962,22 +4020,24 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
   mlir::Location currentLocation = converter.genLocation(beginSpec.source);
 
-  if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
-          loopConstruct.GetNestedConstruct()) {
-    llvm::omp::Directive nestedDirective =
-        parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
-    switch (nestedDirective) {
-    case llvm::omp::Directive::OMPD_tile:
-      // Skip OMPD_tile since the tile sizes will be retrieved when
-      // generating the omp.loop_nest op.
-      break;
-    default: {
-      unsigned version = semaCtx.langOptions().OpenMPVersion;
-      TODO(currentLocation,
-           "Applying a loop-associated on the loop generated by the " +
-               llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
-               " construct");
-    }
+  for (auto &construct : std::get<parser::Block>(loopConstruct.t)) {
+    if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
+            parser::omp::GetOmpLoop(construct)) {
+      llvm::omp::Directive nestedDirective =
+          parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
+      switch (nestedDirective) {
+      case llvm::omp::Directive::OMPD_tile:
+        // Skip OMPD_tile since the tile sizes will be retrieved when
+        // generating the omp.loop_nest op.
+        break;
+      default: {
+        unsigned version = semaCtx.langOptions().OpenMPVersion;
+        TODO(currentLocation,
+             "Applying a loop-associated on the loop generated by the " +
+                 llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
+                 " construct");
+      }
+      }
     }
   }
 
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 7d7a4869ab3a6..913e4d1e69500 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -812,13 +812,14 @@ void collectTileSizesFromOpenMPConstruct(
 
 int64_t collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
   int64_t numCollapse = 1;
 
   // Collect the loops to collapse.
-  lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
+  lower::pft::Evaluation *doConstructEval = &nestedEval;
   if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
     TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
   }
@@ -830,21 +831,21 @@ int64_t collectLoopRelatedInfo(
     numCollapse = collapseValue;
   }
 
-  collectLoopRelatedInfo(converter, currentLocation, eval, numCollapse, result,
-                         iv);
+  collectLoopRelatedInfo(converter, currentLocation, eval, nestedEval,
+                         numCollapse, result, iv);
   return numCollapse;
 }
 
 void collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, int64_t numCollapse,
-    mlir::omp::LoopRelatedClauseOps &result,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    int64_t numCollapse, mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
 
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
   // Collect the loops to collapse.
-  lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
+  lower::pft::Evaluation *doConstructEval = &nestedEval;
   if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
     TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
   }
@@ -852,10 +853,15 @@ void collectLoopRelatedInfo(
   // Collect sizes from tile directive if present.
   std::int64_t sizesLengthValue = 0l;
   if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
-    processTileSizesFromOpenMPConstruct(
-        ompCons, [&](const parser::OmpClause::Sizes *tclause) {
-          sizesLengthValue = tclause->v.size();
-        });
+    if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
+      const parser::OmpDirectiveSpecification &beginSpec{ompLoop->BeginDir()};
+      if (beginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
+        processTileSizesFromOpenMPConstruct(
+            ompCons, [&](const parser::OmpClause::Sizes *tclause) {
+              sizesLengthValue = tclause->v.size();
+            });
+      }
+    }
   }
 
   std::int64_t collapseValue = std::max(numCollapse, sizesLengthValue);
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 2960b663b08b2..886a5c1835f7e 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -169,13 +169,15 @@ void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,
 
 int64_t collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
 
 void collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, std::int64_t collapseValue,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    std::int64_t collapseValue,
     // const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index e2da60ed19de8..231eea8841d4b 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -2260,6 +2260,7 @@ static constexpr DirectiveSet GetLoopDirectives() {
       unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
       unsigned(Directive::OMPD_teams_distribute_simd),
       unsigned(Directive::OMPD_teams_loop),
+      unsigned(Directive::OMPD_fuse),
       unsigned(Directive::OMPD_tile),
       unsigned(Directive::OMPD_unroll),
   };
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index 2424828293c73..dfe8dbdd5ac9e 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -41,6 +41,23 @@ const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x) {
   return nullptr;
 }
 
+const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x) {
+  if (auto *construct{GetOmp(x)}) {
+    if (auto *omp{std::get_if<OpenMPLoopConstruct>(&construct->u)}) {
+      return omp;
+    }
+  }
+  return nullptr;
+}
+const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x) {
+  if (auto *y{std::get_if<ExecutableConstruct>(&x.u)}) {
+    if (auto *z{std::get_if<common::Indirection<DoConstruct>>(&y->u)}) {
+      return &z->value();
+    }
+  }
+  return nullptr;
+}
+
 const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
   // Clauses with OmpObjectList as its data member
   using MemberObjectListClauses = std::tuple<OmpClause::Copyin,
diff --git a/flang/lib/Semantics/canonicalize-omp.cpp b/flang/lib/Semantics/canonicalize-omp.cpp
index 0cec1969e0978..f7c53d6d8f4c4 100644
--- a/flang/lib/Semantics/canonicalize-omp.cpp
+++ b/flang/lib/Semantics/canonicalize-omp.cpp
@@ -9,6 +9,7 @@
 #include "canonicalize-omp.h"
 #include "flang/Parser/parse-tree-visitor.h"
 #include "flang/Parser/parse-tree.h"
+#include "flang/Semantics/openmp-directive-sets.h"
 #include "flang/Semantics/semantics.h"
 
 // After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
@@ -136,20 +137,30 @@ class CanonicalizationOfOmp {
           "A DO loop must follow the %s directive"_err_en_US,
           parser::ToUpperCaseLetters(dirName.source.ToString()));
     };
-    auto tileUnrollError = [](const parser::OmpDirectiveName &dirName,
-                               parser::Messages &messages) {
+    auto transformUnrollError = [](const parser::OmpDirectiveName &dirName,
+                                    parser::Messages &messages) {
       messages.Say(dirName.source,
-          "If a loop construct has been fully unrolled, it cannot then be tiled"_err_en_US,
+          "If a loop construct has been fully unrolled, it cannot then be further transformed"_err_en_US,
           parser::ToUpperCaseLetters(dirName.source.ToString()));
     };
+    auto missingEndFuse = [](auto &dir, auto &messages) {
+      messages.Say(dir.source,
+          "The %s construct requires the END FUSE directive"_err_en_US,
+          parser::ToUpperCaseLetters(dir.source.ToString()));
+    };
+
+    bool endFuseNeeded = beginName.v == llvm::omp::Directive::OMPD_fuse;
 
     auto &body{std::get<parser::Block>(x.t)};
 
     nextIt = it;
-    while (++nextIt != block.end()) {
+    nextIt++;
+    while (nextIt != block.end()) {
       // Ignore compiler directives.
-      if (GetConstructIf<parser::CompilerDirective>(*nextIt))
+      if (GetConstructIf<parser::CompilerDirective>(*nextIt)) {
+        nextIt++;
         continue;
+      }
 
       if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
         if (doCons->GetLoopControl()) {
@@ -160,9 +171,12 @@ class CanonicalizationOfOmp {
           if (nextIt != block.end()) {
             if (auto *endDir{
                     GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
-              std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
-                  std::move(*endDir);
-              nextIt = block.erase(nextIt);
+              auto &endDirName = endDir->DirName();
+              if (endDirName.v != llvm::omp::Directive::OMPD_fuse) {
+                std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
+                    std::move(*endDir);
+                nextIt = block.erase(nextIt);
+              }
             }
           }
         } else {
@@ -172,50 +186,45 @@ class CanonicalizationOfOmp {
         }
       } else if (auto *ompLoopCons{
                      GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
-        // We should allow UNROLL and TILE constructs to be inserted between an
-        // OpenMP Loop Construct and the DO loop itself
+        // We should allow loop transformation constructs to be inserted between
+   ...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang:openmp OpenMP related changes to Clang flang:fir-hlfir flang:openmp flang:parser flang:semantics flang Flang issues not falling into any other category mlir:llvm mlir:openmp mlir openmp:libomp OpenMP host runtime

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants