Skip to content

Conversation

SusanTan
Copy link
Contributor

@SusanTan SusanTan commented Oct 8, 2025

Currently the force clause collapse (force:num_level) is NYI. Added support to sink any prologue and epilogue code to the inner most level as specified.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir openacc labels Oct 8, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Susan Tan (ス-ザン タン) (SusanTan)

Changes

Currently the force clause collapse (force:num_level) is NYI. Added support to sink any prologue and epilogue code to the inner most level as specified.


Full diff: https://github.com/llvm/llvm-project/pull/162534.diff

3 Files Affected:

  • (modified) flang/lib/Lower/Bridge.cpp (+78-3)
  • (modified) flang/lib/Lower/OpenACC.cpp (+18-8)
  • (added) flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 (+41)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 780d56f085f69..3d331cdad3d43 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3190,15 +3190,39 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         std::get_if<Fortran::parser::OpenACCCombinedConstruct>(&acc.u);
 
     Fortran::lower::pft::Evaluation *curEval = &getEval();
+    // Determine collapse depth/force and loopCount
+    bool collapseForce = false;
+    uint64_t collapseDepth = 1;
+    uint64_t loopCount = 1;
+    auto parseCollapse = [&](const Fortran::parser::AccClauseList &cl)
+                             -> std::pair<bool, uint64_t> {
+      bool force = false;
+      uint64_t depth = 1;
+      for (const Fortran::parser::AccClause &clause : cl.v) {
+        if (const auto *collapseClause =
+                std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
+          const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
+          force = std::get<bool>(arg.t);
+          const auto &intExpr =
+              std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
+          if (const auto *expr = Fortran::semantics::GetExpr(intExpr)) {
+            if (auto v = Fortran::evaluate::ToInt64(*expr))
+              depth = *v;
+          }
+          break;
+        }
+      }
+      return {force, depth};
+    };
 
     if (accLoop || accCombined) {
-      uint64_t loopCount;
       if (accLoop) {
         const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
             std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
         const Fortran::parser::AccClauseList &clauseList =
             std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
         loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
+        std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
       } else if (accCombined) {
         const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
             std::get<Fortran::parser::AccBeginCombinedDirective>(
@@ -3206,6 +3230,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         const Fortran::parser::AccClauseList &clauseList =
             std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
         loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
+        std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
       }
 
       if (curEval->lowerAsStructured()) {
@@ -3215,8 +3240,58 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       }
     }
 
-    for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
-      genFIR(e);
+    const bool isStructured = curEval && curEval->lowerAsStructured();
+    if (isStructured && collapseForce && collapseDepth > 1) {
+      // force: collect prologue/epilogue for the first collapseDepth nested loops
+      // and sink them into the innermost loop body at that depth
+      llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, epilogue;
+      Fortran::lower::pft::Evaluation *parent = &getEval();
+      Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr;
+      for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) {
+        epilogue.clear();
+        auto &kids = parent->getNestedEvaluations();
+        // Collect all non-loop statements before the next inner loop as prologue,
+        // then mark remaining siblings as epilogue and descend into the inner loop.
+        Fortran::lower::pft::Evaluation *childLoop = nullptr;
+        for (auto it = kids.begin(); it != kids.end(); ++it) {
+          if (it->getIf<Fortran::parser::DoConstruct>()) {
+            childLoop = &*it;
+            for (auto it2 = std::next(it); it2 != kids.end(); ++it2)
+              epilogue.push_back(&*it2);
+            break;
+          }
+          prologue.push_back(&*it);
+        }
+        // Semantics guarantees collapseDepth does not exceed nest depth
+        // so childLoop must be found here.
+        assert(childLoop && "Expected inner DoConstruct for collapse");
+        parent = childLoop;
+        innermostLoopEval = childLoop;
+      }
+
+      // Track sunk evaluations (avoid double-lowering)
+      llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk;
+      for (auto *e : prologue)  sunk.insert(e);
+      for (auto *e : epilogue)  sunk.insert(e);
+
+      auto sink =
+          [&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) {
+            for (auto *e : lst)
+              genFIR(*e);
+          };
+
+      sink(prologue);
+
+      // Lower innermost loop body, skipping sunk
+      for (Fortran::lower::pft::Evaluation &e : innermostLoopEval->getNestedEvaluations())
+        if (!sunk.contains(&e)) genFIR(e);
+
+      sink(epilogue);
+    } else {
+      // Normal lowering
+      for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
+        genFIR(e);
+    }
     localSymbols.popScope();
     builder->restoreInsertionPoint(insertPt);
 
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 4a9e49435a907..e24e784895fe8 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2144,8 +2144,23 @@ static void processDoLoopBounds(
         locs.push_back(converter.genLocation(
             Fortran::parser::FindSourceLocation(outerDoConstruct)));
       } else {
-        auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
-        assert(doCons && "expect do construct");
+        // Safely locate the next inner DoConstruct within this eval.
+        const Fortran::parser::DoConstruct *doCons = nullptr;
+        if (crtEval && crtEval->hasNestedEvaluations()) {
+          for (Fortran::lower::pft::Evaluation &child :
+               crtEval->getNestedEvaluations()) {
+            if (auto *cand = child.getIf<Fortran::parser::DoConstruct>()) {
+              doCons = cand;
+              // Prepare to descend for the next iteration
+              crtEval = &child;
+              break;
+            }
+          }
+        }
+        if (!doCons) {
+          // No deeper loop; stop collecting collapsed bounds.
+          break;
+        }
         loopControl = &*doCons->GetLoopControl();
         locs.push_back(converter.genLocation(
             Fortran::parser::FindSourceLocation(*doCons)));
@@ -2172,8 +2187,7 @@ static void processDoLoopBounds(
 
       inclusiveBounds.push_back(true);
 
-      if (i < loopsToProcess - 1)
-        crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+      // crtEval already updated when descending; no blind increment here.
     }
   }
 }
@@ -2406,10 +2420,6 @@ static mlir::acc::LoopOp createLoopOp(
                    std::get_if<Fortran::parser::AccClause::Collapse>(
                        &clause.u)) {
       const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
-      const auto &force = std::get<bool>(arg.t);
-      if (force)
-        TODO(clauseLocation, "OpenACC collapse force modifier");
-
       const auto &intExpr =
           std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
       const auto *expr = Fortran::semantics::GetExpr(intExpr);
diff --git a/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
new file mode 100644
index 0000000000000..ca932c1b159ba
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
@@ -0,0 +1,41 @@
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+! Verify collapse(force:2) sinks prologue (between loops) and epilogue (after inner loop)
+! into the acc.loop region body.
+
+subroutine collapse_force_sink(n, m)
+  integer, intent(in) :: n, m
+  real, dimension(n,m) :: a
+  real, dimension(n) :: bb, cc
+  integer :: i, j
+
+  !$acc parallel loop collapse(force:2)
+  do i = 1, n
+    bb(i) = 4.2     ! prologue (between loops)
+    do j = 1, m
+      a(i,j) = a(i,j) + 2.0
+    end do
+    cc(i) = 7.3     ! epilogue (after inner loop)
+  end do
+  !$acc end parallel loop
+end subroutine
+
+! CHECK: func.func @_QPcollapse_force_sink(
+! CHECK: acc.parallel
+! Ensure outer acc.loop is combined(parallel)
+! CHECK: acc.loop combined(parallel)
+! Prologue: constant 4.2 and an assign before inner loop
+! CHECK: arith.constant 4.200000e+00
+! CHECK: hlfir.assign
+! Inner loop and its body include 2.0 add and an assign
+! CHECK: acc.loop
+! CHECK: arith.constant 2.000000e+00
+! CHECK: arith.addf
+! CHECK: hlfir.assign
+! Epilogue: constant 7.3 and an assign after inner loop
+! CHECK: arith.constant 7.300000e+00
+! CHECK: hlfir.assign
+! And the outer acc.loop has collapse = [2]
+! CHECK: } attributes {collapse = [2]
+
+

@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2025

@llvm/pr-subscribers-openacc

Author: Susan Tan (ス-ザン タン) (SusanTan)

Changes

Currently the force clause collapse (force:num_level) is NYI. Added support to sink any prologue and epilogue code to the inner most level as specified.


Full diff: https://github.com/llvm/llvm-project/pull/162534.diff

3 Files Affected:

  • (modified) flang/lib/Lower/Bridge.cpp (+78-3)
  • (modified) flang/lib/Lower/OpenACC.cpp (+18-8)
  • (added) flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 (+41)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 780d56f085f69..3d331cdad3d43 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3190,15 +3190,39 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         std::get_if<Fortran::parser::OpenACCCombinedConstruct>(&acc.u);
 
     Fortran::lower::pft::Evaluation *curEval = &getEval();
+    // Determine collapse depth/force and loopCount
+    bool collapseForce = false;
+    uint64_t collapseDepth = 1;
+    uint64_t loopCount = 1;
+    auto parseCollapse = [&](const Fortran::parser::AccClauseList &cl)
+                             -> std::pair<bool, uint64_t> {
+      bool force = false;
+      uint64_t depth = 1;
+      for (const Fortran::parser::AccClause &clause : cl.v) {
+        if (const auto *collapseClause =
+                std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
+          const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
+          force = std::get<bool>(arg.t);
+          const auto &intExpr =
+              std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
+          if (const auto *expr = Fortran::semantics::GetExpr(intExpr)) {
+            if (auto v = Fortran::evaluate::ToInt64(*expr))
+              depth = *v;
+          }
+          break;
+        }
+      }
+      return {force, depth};
+    };
 
     if (accLoop || accCombined) {
-      uint64_t loopCount;
       if (accLoop) {
         const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
             std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
         const Fortran::parser::AccClauseList &clauseList =
             std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
         loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
+        std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
       } else if (accCombined) {
         const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
             std::get<Fortran::parser::AccBeginCombinedDirective>(
@@ -3206,6 +3230,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         const Fortran::parser::AccClauseList &clauseList =
             std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
         loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
+        std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
       }
 
       if (curEval->lowerAsStructured()) {
@@ -3215,8 +3240,58 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       }
     }
 
-    for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
-      genFIR(e);
+    const bool isStructured = curEval && curEval->lowerAsStructured();
+    if (isStructured && collapseForce && collapseDepth > 1) {
+      // force: collect prologue/epilogue for the first collapseDepth nested loops
+      // and sink them into the innermost loop body at that depth
+      llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, epilogue;
+      Fortran::lower::pft::Evaluation *parent = &getEval();
+      Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr;
+      for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) {
+        epilogue.clear();
+        auto &kids = parent->getNestedEvaluations();
+        // Collect all non-loop statements before the next inner loop as prologue,
+        // then mark remaining siblings as epilogue and descend into the inner loop.
+        Fortran::lower::pft::Evaluation *childLoop = nullptr;
+        for (auto it = kids.begin(); it != kids.end(); ++it) {
+          if (it->getIf<Fortran::parser::DoConstruct>()) {
+            childLoop = &*it;
+            for (auto it2 = std::next(it); it2 != kids.end(); ++it2)
+              epilogue.push_back(&*it2);
+            break;
+          }
+          prologue.push_back(&*it);
+        }
+        // Semantics guarantees collapseDepth does not exceed nest depth
+        // so childLoop must be found here.
+        assert(childLoop && "Expected inner DoConstruct for collapse");
+        parent = childLoop;
+        innermostLoopEval = childLoop;
+      }
+
+      // Track sunk evaluations (avoid double-lowering)
+      llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk;
+      for (auto *e : prologue)  sunk.insert(e);
+      for (auto *e : epilogue)  sunk.insert(e);
+
+      auto sink =
+          [&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) {
+            for (auto *e : lst)
+              genFIR(*e);
+          };
+
+      sink(prologue);
+
+      // Lower innermost loop body, skipping sunk
+      for (Fortran::lower::pft::Evaluation &e : innermostLoopEval->getNestedEvaluations())
+        if (!sunk.contains(&e)) genFIR(e);
+
+      sink(epilogue);
+    } else {
+      // Normal lowering
+      for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
+        genFIR(e);
+    }
     localSymbols.popScope();
     builder->restoreInsertionPoint(insertPt);
 
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 4a9e49435a907..e24e784895fe8 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2144,8 +2144,23 @@ static void processDoLoopBounds(
         locs.push_back(converter.genLocation(
             Fortran::parser::FindSourceLocation(outerDoConstruct)));
       } else {
-        auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
-        assert(doCons && "expect do construct");
+        // Safely locate the next inner DoConstruct within this eval.
+        const Fortran::parser::DoConstruct *doCons = nullptr;
+        if (crtEval && crtEval->hasNestedEvaluations()) {
+          for (Fortran::lower::pft::Evaluation &child :
+               crtEval->getNestedEvaluations()) {
+            if (auto *cand = child.getIf<Fortran::parser::DoConstruct>()) {
+              doCons = cand;
+              // Prepare to descend for the next iteration
+              crtEval = &child;
+              break;
+            }
+          }
+        }
+        if (!doCons) {
+          // No deeper loop; stop collecting collapsed bounds.
+          break;
+        }
         loopControl = &*doCons->GetLoopControl();
         locs.push_back(converter.genLocation(
             Fortran::parser::FindSourceLocation(*doCons)));
@@ -2172,8 +2187,7 @@ static void processDoLoopBounds(
 
       inclusiveBounds.push_back(true);
 
-      if (i < loopsToProcess - 1)
-        crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+      // crtEval already updated when descending; no blind increment here.
     }
   }
 }
@@ -2406,10 +2420,6 @@ static mlir::acc::LoopOp createLoopOp(
                    std::get_if<Fortran::parser::AccClause::Collapse>(
                        &clause.u)) {
       const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
-      const auto &force = std::get<bool>(arg.t);
-      if (force)
-        TODO(clauseLocation, "OpenACC collapse force modifier");
-
       const auto &intExpr =
           std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
       const auto *expr = Fortran::semantics::GetExpr(intExpr);
diff --git a/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
new file mode 100644
index 0000000000000..ca932c1b159ba
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
@@ -0,0 +1,41 @@
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+! Verify collapse(force:2) sinks prologue (between loops) and epilogue (after inner loop)
+! into the acc.loop region body.
+
+subroutine collapse_force_sink(n, m)
+  integer, intent(in) :: n, m
+  real, dimension(n,m) :: a
+  real, dimension(n) :: bb, cc
+  integer :: i, j
+
+  !$acc parallel loop collapse(force:2)
+  do i = 1, n
+    bb(i) = 4.2     ! prologue (between loops)
+    do j = 1, m
+      a(i,j) = a(i,j) + 2.0
+    end do
+    cc(i) = 7.3     ! epilogue (after inner loop)
+  end do
+  !$acc end parallel loop
+end subroutine
+
+! CHECK: func.func @_QPcollapse_force_sink(
+! CHECK: acc.parallel
+! Ensure outer acc.loop is combined(parallel)
+! CHECK: acc.loop combined(parallel)
+! Prologue: constant 4.2 and an assign before inner loop
+! CHECK: arith.constant 4.200000e+00
+! CHECK: hlfir.assign
+! Inner loop and its body include 2.0 add and an assign
+! CHECK: acc.loop
+! CHECK: arith.constant 2.000000e+00
+! CHECK: arith.addf
+! CHECK: hlfir.assign
+! Epilogue: constant 7.3 and an assign after inner loop
+! CHECK: arith.constant 7.300000e+00
+! CHECK: hlfir.assign
+! And the outer acc.loop has collapse = [2]
+! CHECK: } attributes {collapse = [2]
+
+

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! nice work!

Copy link

github-actions bot commented Oct 8, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

force = std::get<bool>(arg.t);
const auto &intExpr =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
if (const auto *expr = Fortran::semantics::GetExpr(intExpr)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having collapse clause parsing in a single place would be ideal. We currently have getLoopCountForCollapseAndTile in OpenACC.cpp. Any chance to make something similar - like getLoopCountForCollapse - and then use it both here and in getLoopCountForCollapseAndTile.

@clementval
Copy link
Contributor

Nit: can you change your title prefix to [flang][openacc]

@SusanTan SusanTan changed the title [ACC] Add support for force clause for loop collapse [flang][openacc] Add support for force clause for loop collapse Oct 8, 2025
Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants