From 65b02a4c45431c335021c0131264ffdba65c0524 Mon Sep 17 00:00:00 2001 From: Susan Tan Date: Tue, 7 Oct 2025 16:47:15 -0700 Subject: [PATCH 01/10] add initial implementation --- flang/lib/Lower/Bridge.cpp | 69 +++++++++++++++++++++++++++++++++++-- flang/lib/Lower/OpenACC.cpp | 4 --- 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 780d56f085f69..f75f648fbdfcc 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3190,15 +3190,39 @@ class FirConverter : public Fortran::lower::AbstractConverter { std::get_if(&acc.u); Fortran::lower::pft::Evaluation *curEval = &getEval(); + // Determine collapse depth/force and loopCount + bool collapseForce = false; + uint64_t collapseDepth = 1; + uint64_t loopCount = 1; + auto parseCollapse = [&](const Fortran::parser::AccClauseList &cl) + -> std::pair { + bool force = false; + uint64_t depth = 1; + for (const Fortran::parser::AccClause &clause : cl.v) { + if (const auto *collapseClause = + std::get_if(&clause.u)) { + const Fortran::parser::AccCollapseArg &arg = collapseClause->v; + force = std::get(arg.t); + const auto &intExpr = + std::get(arg.t); + if (const auto *expr = Fortran::semantics::GetExpr(intExpr)) { + if (auto v = Fortran::evaluate::ToInt64(*expr)) + depth = *v; + } + break; + } + } + return {force, depth}; + }; if (accLoop || accCombined) { - uint64_t loopCount; if (accLoop) { const Fortran::parser::AccBeginLoopDirective &beginLoopDir = std::get(accLoop->t); const Fortran::parser::AccClauseList &clauseList = std::get(beginLoopDir.t); loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList); + std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList); } else if (accCombined) { const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir = std::get( @@ -3206,6 +3230,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { const Fortran::parser::AccClauseList &clauseList = std::get(beginCombinedDir.t); loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList); + std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList); } if (curEval->lowerAsStructured()) { @@ -3215,8 +3240,46 @@ class FirConverter : public Fortran::lower::AbstractConverter { } } - for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) - genFIR(e); + // Collect prologue and tail (after-inner) statements if force + llvm::SmallVector prologue, tail; + if (collapseForce && loopCount > 1 && getEval().lowerAsStructured()) { + auto hasKids = [](Fortran::lower::pft::Evaluation *ev) -> bool { + return ev && ev->hasNestedEvaluations(); + }; + Fortran::lower::pft::Evaluation *parent = &getEval(); + uint64_t levelsToProcess = std::min(collapseDepth, loopCount); + for (uint64_t lvl = 0; lvl + 1 < levelsToProcess; ++lvl) { + if (!hasKids(parent)) break; + Fortran::lower::pft::Evaluation *childLoop = nullptr; + tail.clear(); + auto &kids = parent->getNestedEvaluations(); + for (auto it = kids.begin(); it != kids.end(); ++it) { + if (it->getIf()) { + childLoop = &*it; + for (auto it2 = std::next(it); it2 != kids.end(); ++it2) + tail.push_back(&*it2); + break; + } + prologue.push_back(&*it); + } + if (!childLoop) break; + parent = childLoop; + } + } + + // Prologue sink + for (auto *e : prologue) + genFIR(*e); + + // Lower the loop body as usual + if (curEval && curEval->hasNestedEvaluations()) { + for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) + genFIR(e); + } + + // Epilogue sink + for (auto *e : tail) + genFIR(*e); localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 4a9e49435a907..4653f40e77948 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -2406,10 +2406,6 @@ static mlir::acc::LoopOp createLoopOp( std::get_if( &clause.u)) { const Fortran::parser::AccCollapseArg &arg = collapseClause->v; - const auto &force = std::get(arg.t); - if (force) - TODO(clauseLocation, "OpenACC collapse force modifier"); - const auto &intExpr = std::get(arg.t); const auto *expr = Fortran::semantics::GetExpr(intExpr); From 64d01104612b5641c2b8b4685f6d3f82762456ad Mon Sep 17 00:00:00 2001 From: Susan Tan Date: Tue, 7 Oct 2025 16:58:08 -0700 Subject: [PATCH 02/10] tweak --- flang/lib/Lower/OpenACC.cpp | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 4653f40e77948..e24e784895fe8 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -2144,8 +2144,23 @@ static void processDoLoopBounds( locs.push_back(converter.genLocation( Fortran::parser::FindSourceLocation(outerDoConstruct))); } else { - auto *doCons = crtEval->getIf(); - assert(doCons && "expect do construct"); + // Safely locate the next inner DoConstruct within this eval. + const Fortran::parser::DoConstruct *doCons = nullptr; + if (crtEval && crtEval->hasNestedEvaluations()) { + for (Fortran::lower::pft::Evaluation &child : + crtEval->getNestedEvaluations()) { + if (auto *cand = child.getIf()) { + doCons = cand; + // Prepare to descend for the next iteration + crtEval = &child; + break; + } + } + } + if (!doCons) { + // No deeper loop; stop collecting collapsed bounds. + break; + } loopControl = &*doCons->GetLoopControl(); locs.push_back(converter.genLocation( Fortran::parser::FindSourceLocation(*doCons))); @@ -2172,8 +2187,7 @@ static void processDoLoopBounds( inclusiveBounds.push_back(true); - if (i < loopsToProcess - 1) - crtEval = &*std::next(crtEval->getNestedEvaluations().begin()); + // crtEval already updated when descending; no blind increment here. } } } From c7e7321e65fcf9a6138cf526a4c40d5f303f636b Mon Sep 17 00:00:00 2001 From: Susan Tan Date: Wed, 8 Oct 2025 08:49:42 -0700 Subject: [PATCH 03/10] tweak --- flang/lib/Lower/Bridge.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index f75f648fbdfcc..b406de9a739ff 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3267,14 +3267,26 @@ class FirConverter : public Fortran::lower::AbstractConverter { } } + // Track sunk evaluations to avoid double-lowering + llvm::SmallPtrSet sunk; + for (auto *e : prologue) sunk.insert(e); + for (auto *e : tail) sunk.insert(e); + // Prologue sink for (auto *e : prologue) genFIR(*e); - // Lower the loop body as usual + // Lower the loop body as usual, skipping already-sunk evals if (curEval && curEval->hasNestedEvaluations()) { - for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) + for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) { + if (sunk.contains(&e)) continue; + genFIR(e); + } + } else if (getEval().hasNestedEvaluations()) { + for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) { + if (sunk.contains(&e)) continue; genFIR(e); + } } // Epilogue sink From 48733f4d7258267ea2cd8b99bfe764f6fd559feb Mon Sep 17 00:00:00 2001 From: Susan Tan Date: Wed, 8 Oct 2025 11:27:28 -0700 Subject: [PATCH 04/10] code cleanup --- flang/lib/Lower/Bridge.cpp | 70 +++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index b406de9a739ff..32eb382e2c34f 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3240,58 +3240,58 @@ class FirConverter : public Fortran::lower::AbstractConverter { } } - // Collect prologue and tail (after-inner) statements if force - llvm::SmallVector prologue, tail; - if (collapseForce && loopCount > 1 && getEval().lowerAsStructured()) { - auto hasKids = [](Fortran::lower::pft::Evaluation *ev) -> bool { - return ev && ev->hasNestedEvaluations(); - }; + const bool isStructured = curEval && curEval->lowerAsStructured(); + if (isStructured && collapseForce && collapseDepth > 1) { + // force: collect prologue/epilogue for the first collapseDepth nested loops + // and sink them into the innermost loop body at that depth + llvm::SmallVector prologue, epilogue; Fortran::lower::pft::Evaluation *parent = &getEval(); - uint64_t levelsToProcess = std::min(collapseDepth, loopCount); - for (uint64_t lvl = 0; lvl + 1 < levelsToProcess; ++lvl) { - if (!hasKids(parent)) break; - Fortran::lower::pft::Evaluation *childLoop = nullptr; - tail.clear(); + Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr; + for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) { + epilogue.clear(); auto &kids = parent->getNestedEvaluations(); + // Collect all non-loop statements before the next inner loop as prologue, + // then mark remaining siblings as epilogue and descend into the inner loop. + Fortran::lower::pft::Evaluation *childLoop = nullptr; for (auto it = kids.begin(); it != kids.end(); ++it) { if (it->getIf()) { childLoop = &*it; for (auto it2 = std::next(it); it2 != kids.end(); ++it2) - tail.push_back(&*it2); + epilogue.push_back(&*it2); break; } prologue.push_back(&*it); } - if (!childLoop) break; + // Semantics guarantees collapseDepth does not exceed nest depth + // so childLoop must be found here. + assert(childLoop && "Expected inner DoConstruct for collapse"); parent = childLoop; + innermostLoopEval = childLoop; } - } - // Track sunk evaluations to avoid double-lowering - llvm::SmallPtrSet sunk; - for (auto *e : prologue) sunk.insert(e); - for (auto *e : tail) sunk.insert(e); + // Track sunk evaluations (avoid double-lowering) + llvm::SmallPtrSet sunk; + for (auto *e : prologue) sunk.insert(e); + for (auto *e : epilogue) sunk.insert(e); - // Prologue sink - for (auto *e : prologue) - genFIR(*e); + auto emit = [&](llvm::SmallVector &lst) { + for (auto *e : lst) genFIR(*e); + }; - // Lower the loop body as usual, skipping already-sunk evals - if (curEval && curEval->hasNestedEvaluations()) { - for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) { - if (sunk.contains(&e)) continue; - genFIR(e); - } - } else if (getEval().hasNestedEvaluations()) { - for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) { - if (sunk.contains(&e)) continue; + // Sink prologue + emit(prologue); + + // Lower innermost loop body, skipping sunk + for (Fortran::lower::pft::Evaluation &e : innermostLoopEval->getNestedEvaluations()) + if (!sunk.contains(&e)) genFIR(e); + + // Sink epilogue + emit(epilogue); + } else { + // Normal lowering + for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) genFIR(e); - } } - - // Epilogue sink - for (auto *e : tail) - genFIR(*e); localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); From 32b5f71aa9b88ff8ab7681ed23f5e939572a6cca Mon Sep 17 00:00:00 2001 From: Susan Tan Date: Wed, 8 Oct 2025 11:57:49 -0700 Subject: [PATCH 05/10] add a test --- flang/lib/Lower/Bridge.cpp | 14 +++---- .../acc-loop-collapse-force-lowering.f90 | 41 +++++++++++++++++++ 2 files changed, 48 insertions(+), 7 deletions(-) create mode 100644 flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 32eb382e2c34f..3d331cdad3d43 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3274,19 +3274,19 @@ class FirConverter : public Fortran::lower::AbstractConverter { for (auto *e : prologue) sunk.insert(e); for (auto *e : epilogue) sunk.insert(e); - auto emit = [&](llvm::SmallVector &lst) { - for (auto *e : lst) genFIR(*e); - }; + auto sink = + [&](llvm::SmallVector &lst) { + for (auto *e : lst) + genFIR(*e); + }; - // Sink prologue - emit(prologue); + sink(prologue); // Lower innermost loop body, skipping sunk for (Fortran::lower::pft::Evaluation &e : innermostLoopEval->getNestedEvaluations()) if (!sunk.contains(&e)) genFIR(e); - // Sink epilogue - emit(epilogue); + sink(epilogue); } else { // Normal lowering for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) diff --git a/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 new file mode 100644 index 0000000000000..ca932c1b159ba --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 @@ -0,0 +1,41 @@ +! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s + +! Verify collapse(force:2) sinks prologue (between loops) and epilogue (after inner loop) +! into the acc.loop region body. + +subroutine collapse_force_sink(n, m) + integer, intent(in) :: n, m + real, dimension(n,m) :: a + real, dimension(n) :: bb, cc + integer :: i, j + + !$acc parallel loop collapse(force:2) + do i = 1, n + bb(i) = 4.2 ! prologue (between loops) + do j = 1, m + a(i,j) = a(i,j) + 2.0 + end do + cc(i) = 7.3 ! epilogue (after inner loop) + end do + !$acc end parallel loop +end subroutine + +! CHECK: func.func @_QPcollapse_force_sink( +! CHECK: acc.parallel +! Ensure outer acc.loop is combined(parallel) +! CHECK: acc.loop combined(parallel) +! Prologue: constant 4.2 and an assign before inner loop +! CHECK: arith.constant 4.200000e+00 +! CHECK: hlfir.assign +! Inner loop and its body include 2.0 add and an assign +! CHECK: acc.loop +! CHECK: arith.constant 2.000000e+00 +! CHECK: arith.addf +! CHECK: hlfir.assign +! Epilogue: constant 7.3 and an assign after inner loop +! CHECK: arith.constant 7.300000e+00 +! CHECK: hlfir.assign +! And the outer acc.loop has collapse = [2] +! CHECK: } attributes {collapse = [2] + + From fa52cbb69007c2ecf4d1818289fb0cb93a7b1e94 Mon Sep 17 00:00:00 2001 From: Susan Tan Date: Wed, 8 Oct 2025 12:18:17 -0700 Subject: [PATCH 06/10] cleanup code --- flang/lib/Lower/OpenACC.cpp | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index e24e784895fe8..c376609ee1b5b 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -2145,25 +2145,23 @@ static void processDoLoopBounds( Fortran::parser::FindSourceLocation(outerDoConstruct))); } else { // Safely locate the next inner DoConstruct within this eval. - const Fortran::parser::DoConstruct *doCons = nullptr; + const Fortran::parser::DoConstruct *innerDo = nullptr; if (crtEval && crtEval->hasNestedEvaluations()) { - for (Fortran::lower::pft::Evaluation &child : - crtEval->getNestedEvaluations()) { - if (auto *cand = child.getIf()) { - doCons = cand; + for (Fortran::lower::pft::Evaluation &child : crtEval->getNestedEvaluations()) { + if (auto *stmt = child.getIf()) { + innerDo = stmt; // Prepare to descend for the next iteration crtEval = &child; break; } } } - if (!doCons) { - // No deeper loop; stop collecting collapsed bounds. - break; - } - loopControl = &*doCons->GetLoopControl(); + if (!innerDo) + break; // No deeper loop; stop collecting collapsed bounds. + + loopControl = &*innerDo->GetLoopControl(); locs.push_back(converter.genLocation( - Fortran::parser::FindSourceLocation(*doCons))); + Fortran::parser::FindSourceLocation(*innerDo))); } const Fortran::parser::LoopControl::Bounds *bounds = From e2141ff297091d0f3da86dd70425c5f6d542f546 Mon Sep 17 00:00:00 2001 From: Susan Tan Date: Wed, 8 Oct 2025 12:38:53 -0700 Subject: [PATCH 07/10] refactor code to parse in OpenACC.cpp --- flang/include/flang/Lower/OpenACC.h | 3 +++ flang/lib/Lower/Bridge.cpp | 6 ++++-- flang/lib/Lower/OpenACC.cpp | 14 ++++++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h index 4622dbc8ccf64..f6ec3658eff30 100644 --- a/flang/include/flang/Lower/OpenACC.h +++ b/flang/include/flang/Lower/OpenACC.h @@ -122,6 +122,9 @@ void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *, /// clause. uint64_t getLoopCountForCollapseAndTile(const Fortran::parser::AccClauseList &); +/// Returns only the collapse(N) depth (defaults to 1 when absent). +uint64_t getLoopCountForCollapse(const Fortran::parser::AccClauseList &); + /// Checks whether the current insertion point is inside OpenACC loop. bool isInOpenACCLoop(fir::FirOpBuilder &); diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 3d331cdad3d43..8482bba4ecbf8 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3222,7 +3222,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { const Fortran::parser::AccClauseList &clauseList = std::get(beginLoopDir.t); loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList); - std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList); + collapseDepth = Fortran::lower::getLoopCountForCollapse(clauseList); + std::tie(collapseForce, std::ignore) = parseCollapse(clauseList); } else if (accCombined) { const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir = std::get( @@ -3230,7 +3231,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { const Fortran::parser::AccClauseList &clauseList = std::get(beginCombinedDir.t); loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList); - std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList); + collapseDepth = Fortran::lower::getLoopCountForCollapse(clauseList); + std::tie(collapseForce, std::ignore) = parseCollapse(clauseList); } if (curEval->lowerAsStructured()) { diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index c376609ee1b5b..90edc102e13a0 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -4889,6 +4889,20 @@ uint64_t Fortran::lower::getLoopCountForCollapseAndTile( return collapseLoopCount; } +uint64_t Fortran::lower::getLoopCountForCollapse( + const Fortran::parser::AccClauseList &clauseList) { + for (const Fortran::parser::AccClause &clause : clauseList.v) { + if (const auto *collapseClause = + std::get_if(&clause.u)) { + const Fortran::parser::AccCollapseArg &arg = collapseClause->v; + const auto &collapseValue = + std::get(arg.t); + return *Fortran::semantics::GetIntValue(collapseValue); + } + } + return 1; +} + /// Create an ACC loop operation for a DO construct when inside ACC compute /// constructs This serves as a bridge between regular DO construct handling and /// ACC loop creation From c27d440f5510bb441aaa23ec96b4218142d53a73 Mon Sep 17 00:00:00 2001 From: Susan Tan Date: Wed, 8 Oct 2025 12:54:15 -0700 Subject: [PATCH 08/10] use collapseLoopCount in calculating tile loopcount --- flang/lib/Lower/OpenACC.cpp | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 90edc102e13a0..fb8c882cb30fa 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -4868,25 +4868,16 @@ void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder, uint64_t Fortran::lower::getLoopCountForCollapseAndTile( const Fortran::parser::AccClauseList &clauseList) { - uint64_t collapseLoopCount = 1; + uint64_t collapseLoopCount = getLoopCountForCollapse(clauseList); uint64_t tileLoopCount = 1; for (const Fortran::parser::AccClause &clause : clauseList.v) { - if (const auto *collapseClause = - std::get_if(&clause.u)) { - const parser::AccCollapseArg &arg = collapseClause->v; - const auto &collapseValue{std::get(arg.t)}; - collapseLoopCount = *Fortran::semantics::GetIntValue(collapseValue); - } if (const auto *tileClause = std::get_if(&clause.u)) { const parser::AccTileExprList &tileExprList = tileClause->v; - const std::list &listTileExpr = tileExprList.v; - tileLoopCount = listTileExpr.size(); + tileLoopCount = tileExprList.v.size(); } } - if (tileLoopCount > collapseLoopCount) - return tileLoopCount; - return collapseLoopCount; + return tileLoopCount > collapseLoopCount ? tileLoopCount : collapseLoopCount; } uint64_t Fortran::lower::getLoopCountForCollapse( From 8c11d6fb3f73c02ceb5c7c035071e88a09ce4131 Mon Sep 17 00:00:00 2001 From: Susan Tan Date: Wed, 8 Oct 2025 16:06:17 -0700 Subject: [PATCH 09/10] format --- flang/lib/Lower/Bridge.cpp | 23 ++++++++++++++--------- flang/lib/Lower/OpenACC.cpp | 3 ++- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 8482bba4ecbf8..e3f59d475b24c 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3195,7 +3195,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { uint64_t collapseDepth = 1; uint64_t loopCount = 1; auto parseCollapse = [&](const Fortran::parser::AccClauseList &cl) - -> std::pair { + -> std::pair { bool force = false; uint64_t depth = 1; for (const Fortran::parser::AccClause &clause : cl.v) { @@ -3244,16 +3244,17 @@ class FirConverter : public Fortran::lower::AbstractConverter { const bool isStructured = curEval && curEval->lowerAsStructured(); if (isStructured && collapseForce && collapseDepth > 1) { - // force: collect prologue/epilogue for the first collapseDepth nested loops - // and sink them into the innermost loop body at that depth + // force: collect prologue/epilogue for the first collapseDepth nested + // loops and sink them into the innermost loop body at that depth llvm::SmallVector prologue, epilogue; Fortran::lower::pft::Evaluation *parent = &getEval(); Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr; for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) { epilogue.clear(); auto &kids = parent->getNestedEvaluations(); - // Collect all non-loop statements before the next inner loop as prologue, - // then mark remaining siblings as epilogue and descend into the inner loop. + // Collect all non-loop statements before the next inner loop as + // prologue, then mark remaining siblings as epilogue and descend into + // the inner loop. Fortran::lower::pft::Evaluation *childLoop = nullptr; for (auto it = kids.begin(); it != kids.end(); ++it) { if (it->getIf()) { @@ -3273,8 +3274,10 @@ class FirConverter : public Fortran::lower::AbstractConverter { // Track sunk evaluations (avoid double-lowering) llvm::SmallPtrSet sunk; - for (auto *e : prologue) sunk.insert(e); - for (auto *e : epilogue) sunk.insert(e); + for (auto *e : prologue) + sunk.insert(e); + for (auto *e : epilogue) + sunk.insert(e); auto sink = [&](llvm::SmallVector &lst) { @@ -3285,8 +3288,10 @@ class FirConverter : public Fortran::lower::AbstractConverter { sink(prologue); // Lower innermost loop body, skipping sunk - for (Fortran::lower::pft::Evaluation &e : innermostLoopEval->getNestedEvaluations()) - if (!sunk.contains(&e)) genFIR(e); + for (Fortran::lower::pft::Evaluation &e : + innermostLoopEval->getNestedEvaluations()) + if (!sunk.contains(&e)) + genFIR(e); sink(epilogue); } else { diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index fb8c882cb30fa..c38fd3a78c393 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -2147,7 +2147,8 @@ static void processDoLoopBounds( // Safely locate the next inner DoConstruct within this eval. const Fortran::parser::DoConstruct *innerDo = nullptr; if (crtEval && crtEval->hasNestedEvaluations()) { - for (Fortran::lower::pft::Evaluation &child : crtEval->getNestedEvaluations()) { + for (Fortran::lower::pft::Evaluation &child : + crtEval->getNestedEvaluations()) { if (auto *stmt = child.getIf()) { innerDo = stmt; // Prepare to descend for the next iteration From 7cca8e4ee98c7f8f6c2df67784e40a0ab6c55dc5 Mon Sep 17 00:00:00 2001 From: Susan Tan Date: Wed, 8 Oct 2025 16:24:13 -0700 Subject: [PATCH 10/10] refactor --- flang/include/flang/Lower/OpenACC.h | 6 ++++-- flang/lib/Lower/Bridge.cpp | 28 ++++------------------------ flang/lib/Lower/OpenACC.cpp | 12 ++++++++---- 3 files changed, 16 insertions(+), 30 deletions(-) diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h index f6ec3658eff30..69f1f5be753e6 100644 --- a/flang/include/flang/Lower/OpenACC.h +++ b/flang/include/flang/Lower/OpenACC.h @@ -122,8 +122,10 @@ void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *, /// clause. uint64_t getLoopCountForCollapseAndTile(const Fortran::parser::AccClauseList &); -/// Returns only the collapse(N) depth (defaults to 1 when absent). -uint64_t getLoopCountForCollapse(const Fortran::parser::AccClauseList &); +/// Parse collapse clause and return {size, force}. If absent, returns +/// {1,false}. +std::pair +getCollapseSizeAndForce(const Fortran::parser::AccClauseList &); /// Checks whether the current insertion point is inside OpenACC loop. bool isInOpenACCLoop(fir::FirOpBuilder &); diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index e3f59d475b24c..b41ebd7d15a00 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3194,26 +3194,6 @@ class FirConverter : public Fortran::lower::AbstractConverter { bool collapseForce = false; uint64_t collapseDepth = 1; uint64_t loopCount = 1; - auto parseCollapse = [&](const Fortran::parser::AccClauseList &cl) - -> std::pair { - bool force = false; - uint64_t depth = 1; - for (const Fortran::parser::AccClause &clause : cl.v) { - if (const auto *collapseClause = - std::get_if(&clause.u)) { - const Fortran::parser::AccCollapseArg &arg = collapseClause->v; - force = std::get(arg.t); - const auto &intExpr = - std::get(arg.t); - if (const auto *expr = Fortran::semantics::GetExpr(intExpr)) { - if (auto v = Fortran::evaluate::ToInt64(*expr)) - depth = *v; - } - break; - } - } - return {force, depth}; - }; if (accLoop || accCombined) { if (accLoop) { @@ -3222,8 +3202,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { const Fortran::parser::AccClauseList &clauseList = std::get(beginLoopDir.t); loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList); - collapseDepth = Fortran::lower::getLoopCountForCollapse(clauseList); - std::tie(collapseForce, std::ignore) = parseCollapse(clauseList); + std::tie(collapseDepth, collapseForce) = + Fortran::lower::getCollapseSizeAndForce(clauseList); } else if (accCombined) { const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir = std::get( @@ -3231,8 +3211,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { const Fortran::parser::AccClauseList &clauseList = std::get(beginCombinedDir.t); loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList); - collapseDepth = Fortran::lower::getLoopCountForCollapse(clauseList); - std::tie(collapseForce, std::ignore) = parseCollapse(clauseList); + std::tie(collapseDepth, collapseForce) = + Fortran::lower::getCollapseSizeAndForce(clauseList); } if (curEval->lowerAsStructured()) { diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index c38fd3a78c393..0aed144fc5123 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -4869,7 +4869,7 @@ void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder, uint64_t Fortran::lower::getLoopCountForCollapseAndTile( const Fortran::parser::AccClauseList &clauseList) { - uint64_t collapseLoopCount = getLoopCountForCollapse(clauseList); + uint64_t collapseLoopCount = getCollapseSizeAndForce(clauseList).first; uint64_t tileLoopCount = 1; for (const Fortran::parser::AccClause &clause : clauseList.v) { if (const auto *tileClause = @@ -4881,18 +4881,22 @@ uint64_t Fortran::lower::getLoopCountForCollapseAndTile( return tileLoopCount > collapseLoopCount ? tileLoopCount : collapseLoopCount; } -uint64_t Fortran::lower::getLoopCountForCollapse( +std::pair Fortran::lower::getCollapseSizeAndForce( const Fortran::parser::AccClauseList &clauseList) { + uint64_t size = 1; + bool force = false; for (const Fortran::parser::AccClause &clause : clauseList.v) { if (const auto *collapseClause = std::get_if(&clause.u)) { const Fortran::parser::AccCollapseArg &arg = collapseClause->v; + force = std::get(arg.t); const auto &collapseValue = std::get(arg.t); - return *Fortran::semantics::GetIntValue(collapseValue); + size = *Fortran::semantics::GetIntValue(collapseValue); + break; } } - return 1; + return {size, force}; } /// Create an ACC loop operation for a DO construct when inside ACC compute