diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h index 4622dbc8ccf64..69f1f5be753e6 100644 --- a/flang/include/flang/Lower/OpenACC.h +++ b/flang/include/flang/Lower/OpenACC.h @@ -122,6 +122,11 @@ void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *, /// clause. uint64_t getLoopCountForCollapseAndTile(const Fortran::parser::AccClauseList &); +/// Parse collapse clause and return {size, force}. If absent, returns +/// {1,false}. +std::pair +getCollapseSizeAndForce(const Fortran::parser::AccClauseList &); + /// Checks whether the current insertion point is inside OpenACC loop. bool isInOpenACCLoop(fir::FirOpBuilder &); diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 780d56f085f69..b41ebd7d15a00 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3190,15 +3190,20 @@ class FirConverter : public Fortran::lower::AbstractConverter { std::get_if(&acc.u); Fortran::lower::pft::Evaluation *curEval = &getEval(); + // Determine collapse depth/force and loopCount + bool collapseForce = false; + uint64_t collapseDepth = 1; + uint64_t loopCount = 1; if (accLoop || accCombined) { - uint64_t loopCount; if (accLoop) { const Fortran::parser::AccBeginLoopDirective &beginLoopDir = std::get(accLoop->t); const Fortran::parser::AccClauseList &clauseList = std::get(beginLoopDir.t); loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList); + std::tie(collapseDepth, collapseForce) = + Fortran::lower::getCollapseSizeAndForce(clauseList); } else if (accCombined) { const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir = std::get( @@ -3206,6 +3211,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { const Fortran::parser::AccClauseList &clauseList = std::get(beginCombinedDir.t); loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList); + std::tie(collapseDepth, collapseForce) = + Fortran::lower::getCollapseSizeAndForce(clauseList); } if (curEval->lowerAsStructured()) { @@ -3215,8 +3222,63 @@ class FirConverter : public Fortran::lower::AbstractConverter { } } - for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) - genFIR(e); + const bool isStructured = curEval && curEval->lowerAsStructured(); + if (isStructured && collapseForce && collapseDepth > 1) { + // force: collect prologue/epilogue for the first collapseDepth nested + // loops and sink them into the innermost loop body at that depth + llvm::SmallVector prologue, epilogue; + Fortran::lower::pft::Evaluation *parent = &getEval(); + Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr; + for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) { + epilogue.clear(); + auto &kids = parent->getNestedEvaluations(); + // Collect all non-loop statements before the next inner loop as + // prologue, then mark remaining siblings as epilogue and descend into + // the inner loop. + Fortran::lower::pft::Evaluation *childLoop = nullptr; + for (auto it = kids.begin(); it != kids.end(); ++it) { + if (it->getIf()) { + childLoop = &*it; + for (auto it2 = std::next(it); it2 != kids.end(); ++it2) + epilogue.push_back(&*it2); + break; + } + prologue.push_back(&*it); + } + // Semantics guarantees collapseDepth does not exceed nest depth + // so childLoop must be found here. + assert(childLoop && "Expected inner DoConstruct for collapse"); + parent = childLoop; + innermostLoopEval = childLoop; + } + + // Track sunk evaluations (avoid double-lowering) + llvm::SmallPtrSet sunk; + for (auto *e : prologue) + sunk.insert(e); + for (auto *e : epilogue) + sunk.insert(e); + + auto sink = + [&](llvm::SmallVector &lst) { + for (auto *e : lst) + genFIR(*e); + }; + + sink(prologue); + + // Lower innermost loop body, skipping sunk + for (Fortran::lower::pft::Evaluation &e : + innermostLoopEval->getNestedEvaluations()) + if (!sunk.contains(&e)) + genFIR(e); + + sink(epilogue); + } else { + // Normal lowering + for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) + genFIR(e); + } localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 4a9e49435a907..0aed144fc5123 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -2144,11 +2144,25 @@ static void processDoLoopBounds( locs.push_back(converter.genLocation( Fortran::parser::FindSourceLocation(outerDoConstruct))); } else { - auto *doCons = crtEval->getIf(); - assert(doCons && "expect do construct"); - loopControl = &*doCons->GetLoopControl(); + // Safely locate the next inner DoConstruct within this eval. + const Fortran::parser::DoConstruct *innerDo = nullptr; + if (crtEval && crtEval->hasNestedEvaluations()) { + for (Fortran::lower::pft::Evaluation &child : + crtEval->getNestedEvaluations()) { + if (auto *stmt = child.getIf()) { + innerDo = stmt; + // Prepare to descend for the next iteration + crtEval = &child; + break; + } + } + } + if (!innerDo) + break; // No deeper loop; stop collecting collapsed bounds. + + loopControl = &*innerDo->GetLoopControl(); locs.push_back(converter.genLocation( - Fortran::parser::FindSourceLocation(*doCons))); + Fortran::parser::FindSourceLocation(*innerDo))); } const Fortran::parser::LoopControl::Bounds *bounds = @@ -2172,8 +2186,7 @@ static void processDoLoopBounds( inclusiveBounds.push_back(true); - if (i < loopsToProcess - 1) - crtEval = &*std::next(crtEval->getNestedEvaluations().begin()); + // crtEval already updated when descending; no blind increment here. } } } @@ -2406,10 +2419,6 @@ static mlir::acc::LoopOp createLoopOp( std::get_if( &clause.u)) { const Fortran::parser::AccCollapseArg &arg = collapseClause->v; - const auto &force = std::get(arg.t); - if (force) - TODO(clauseLocation, "OpenACC collapse force modifier"); - const auto &intExpr = std::get(arg.t); const auto *expr = Fortran::semantics::GetExpr(intExpr); @@ -4860,25 +4869,34 @@ void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder, uint64_t Fortran::lower::getLoopCountForCollapseAndTile( const Fortran::parser::AccClauseList &clauseList) { - uint64_t collapseLoopCount = 1; + uint64_t collapseLoopCount = getCollapseSizeAndForce(clauseList).first; uint64_t tileLoopCount = 1; for (const Fortran::parser::AccClause &clause : clauseList.v) { - if (const auto *collapseClause = - std::get_if(&clause.u)) { - const parser::AccCollapseArg &arg = collapseClause->v; - const auto &collapseValue{std::get(arg.t)}; - collapseLoopCount = *Fortran::semantics::GetIntValue(collapseValue); - } if (const auto *tileClause = std::get_if(&clause.u)) { const parser::AccTileExprList &tileExprList = tileClause->v; - const std::list &listTileExpr = tileExprList.v; - tileLoopCount = listTileExpr.size(); + tileLoopCount = tileExprList.v.size(); + } + } + return tileLoopCount > collapseLoopCount ? tileLoopCount : collapseLoopCount; +} + +std::pair Fortran::lower::getCollapseSizeAndForce( + const Fortran::parser::AccClauseList &clauseList) { + uint64_t size = 1; + bool force = false; + for (const Fortran::parser::AccClause &clause : clauseList.v) { + if (const auto *collapseClause = + std::get_if(&clause.u)) { + const Fortran::parser::AccCollapseArg &arg = collapseClause->v; + force = std::get(arg.t); + const auto &collapseValue = + std::get(arg.t); + size = *Fortran::semantics::GetIntValue(collapseValue); + break; } } - if (tileLoopCount > collapseLoopCount) - return tileLoopCount; - return collapseLoopCount; + return {size, force}; } /// Create an ACC loop operation for a DO construct when inside ACC compute diff --git a/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 new file mode 100644 index 0000000000000..ca932c1b159ba --- /dev/null +++ b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 @@ -0,0 +1,41 @@ +! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s + +! Verify collapse(force:2) sinks prologue (between loops) and epilogue (after inner loop) +! into the acc.loop region body. + +subroutine collapse_force_sink(n, m) + integer, intent(in) :: n, m + real, dimension(n,m) :: a + real, dimension(n) :: bb, cc + integer :: i, j + + !$acc parallel loop collapse(force:2) + do i = 1, n + bb(i) = 4.2 ! prologue (between loops) + do j = 1, m + a(i,j) = a(i,j) + 2.0 + end do + cc(i) = 7.3 ! epilogue (after inner loop) + end do + !$acc end parallel loop +end subroutine + +! CHECK: func.func @_QPcollapse_force_sink( +! CHECK: acc.parallel +! Ensure outer acc.loop is combined(parallel) +! CHECK: acc.loop combined(parallel) +! Prologue: constant 4.2 and an assign before inner loop +! CHECK: arith.constant 4.200000e+00 +! CHECK: hlfir.assign +! Inner loop and its body include 2.0 add and an assign +! CHECK: acc.loop +! CHECK: arith.constant 2.000000e+00 +! CHECK: arith.addf +! CHECK: hlfir.assign +! Epilogue: constant 7.3 and an assign after inner loop +! CHECK: arith.constant 7.300000e+00 +! CHECK: hlfir.assign +! And the outer acc.loop has collapse = [2] +! CHECK: } attributes {collapse = [2] + +