Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flang/include/flang/Lower/OpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
/// clause.
uint64_t getLoopCountForCollapseAndTile(const Fortran::parser::AccClauseList &);

/// Parse collapse clause and return {size, force}. If absent, returns
/// {1,false}.
std::pair<uint64_t, bool>
getCollapseSizeAndForce(const Fortran::parser::AccClauseList &);

/// Checks whether the current insertion point is inside OpenACC loop.
bool isInOpenACCLoop(fir::FirOpBuilder &);

Expand Down
68 changes: 65 additions & 3 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3190,22 +3190,29 @@ class FirConverter : public Fortran::lower::AbstractConverter {
std::get_if<Fortran::parser::OpenACCCombinedConstruct>(&acc.u);

Fortran::lower::pft::Evaluation *curEval = &getEval();
// Determine collapse depth/force and loopCount
bool collapseForce = false;
uint64_t collapseDepth = 1;
uint64_t loopCount = 1;

if (accLoop || accCombined) {
uint64_t loopCount;
if (accLoop) {
const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
std::tie(collapseDepth, collapseForce) =
Fortran::lower::getCollapseSizeAndForce(clauseList);
} else if (accCombined) {
const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
std::get<Fortran::parser::AccBeginCombinedDirective>(
accCombined->t);
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
std::tie(collapseDepth, collapseForce) =
Fortran::lower::getCollapseSizeAndForce(clauseList);
}

if (curEval->lowerAsStructured()) {
Expand All @@ -3215,8 +3222,63 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
}

for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
genFIR(e);
const bool isStructured = curEval && curEval->lowerAsStructured();
if (isStructured && collapseForce && collapseDepth > 1) {
// force: collect prologue/epilogue for the first collapseDepth nested
// loops and sink them into the innermost loop body at that depth
llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, epilogue;
Fortran::lower::pft::Evaluation *parent = &getEval();
Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr;
for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) {
epilogue.clear();
auto &kids = parent->getNestedEvaluations();
// Collect all non-loop statements before the next inner loop as
// prologue, then mark remaining siblings as epilogue and descend into
// the inner loop.
Fortran::lower::pft::Evaluation *childLoop = nullptr;
for (auto it = kids.begin(); it != kids.end(); ++it) {
if (it->getIf<Fortran::parser::DoConstruct>()) {
childLoop = &*it;
for (auto it2 = std::next(it); it2 != kids.end(); ++it2)
epilogue.push_back(&*it2);
break;
}
prologue.push_back(&*it);
}
// Semantics guarantees collapseDepth does not exceed nest depth
// so childLoop must be found here.
assert(childLoop && "Expected inner DoConstruct for collapse");
parent = childLoop;
innermostLoopEval = childLoop;
}

// Track sunk evaluations (avoid double-lowering)
llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk;
for (auto *e : prologue)
sunk.insert(e);
for (auto *e : epilogue)
sunk.insert(e);

auto sink =
[&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) {
for (auto *e : lst)
genFIR(*e);
};

sink(prologue);

// Lower innermost loop body, skipping sunk
for (Fortran::lower::pft::Evaluation &e :
innermostLoopEval->getNestedEvaluations())
if (!sunk.contains(&e))
genFIR(e);

sink(epilogue);
} else {
// Normal lowering
for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
genFIR(e);
}
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);

Expand Down
62 changes: 40 additions & 22 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2144,11 +2144,25 @@ static void processDoLoopBounds(
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(outerDoConstruct)));
} else {
auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
assert(doCons && "expect do construct");
loopControl = &*doCons->GetLoopControl();
// Safely locate the next inner DoConstruct within this eval.
const Fortran::parser::DoConstruct *innerDo = nullptr;
if (crtEval && crtEval->hasNestedEvaluations()) {
for (Fortran::lower::pft::Evaluation &child :
crtEval->getNestedEvaluations()) {
if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) {
innerDo = stmt;
// Prepare to descend for the next iteration
crtEval = &child;
break;
}
}
}
if (!innerDo)
break; // No deeper loop; stop collecting collapsed bounds.

loopControl = &*innerDo->GetLoopControl();
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(*doCons)));
Fortran::parser::FindSourceLocation(*innerDo)));
}

const Fortran::parser::LoopControl::Bounds *bounds =
Expand All @@ -2172,8 +2186,7 @@ static void processDoLoopBounds(

inclusiveBounds.push_back(true);

if (i < loopsToProcess - 1)
crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
// crtEval already updated when descending; no blind increment here.
}
}
}
Expand Down Expand Up @@ -2406,10 +2419,6 @@ static mlir::acc::LoopOp createLoopOp(
std::get_if<Fortran::parser::AccClause::Collapse>(
&clause.u)) {
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
const auto &force = std::get<bool>(arg.t);
if (force)
TODO(clauseLocation, "OpenACC collapse force modifier");

const auto &intExpr =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
const auto *expr = Fortran::semantics::GetExpr(intExpr);
Expand Down Expand Up @@ -4860,25 +4869,34 @@ void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,

uint64_t Fortran::lower::getLoopCountForCollapseAndTile(
const Fortran::parser::AccClauseList &clauseList) {
uint64_t collapseLoopCount = 1;
uint64_t collapseLoopCount = getCollapseSizeAndForce(clauseList).first;
uint64_t tileLoopCount = 1;
for (const Fortran::parser::AccClause &clause : clauseList.v) {
if (const auto *collapseClause =
std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
const parser::AccCollapseArg &arg = collapseClause->v;
const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)};
collapseLoopCount = *Fortran::semantics::GetIntValue(collapseValue);
}
if (const auto *tileClause =
std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
const parser::AccTileExprList &tileExprList = tileClause->v;
const std::list<parser::AccTileExpr> &listTileExpr = tileExprList.v;
tileLoopCount = listTileExpr.size();
tileLoopCount = tileExprList.v.size();
}
}
return tileLoopCount > collapseLoopCount ? tileLoopCount : collapseLoopCount;
}

std::pair<uint64_t, bool> Fortran::lower::getCollapseSizeAndForce(
const Fortran::parser::AccClauseList &clauseList) {
uint64_t size = 1;
bool force = false;
for (const Fortran::parser::AccClause &clause : clauseList.v) {
if (const auto *collapseClause =
std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
force = std::get<bool>(arg.t);
const auto &collapseValue =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
size = *Fortran::semantics::GetIntValue(collapseValue);
break;
}
}
if (tileLoopCount > collapseLoopCount)
return tileLoopCount;
return collapseLoopCount;
return {size, force};
}

/// Create an ACC loop operation for a DO construct when inside ACC compute
Expand Down
41 changes: 41 additions & 0 deletions flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s

! Verify collapse(force:2) sinks prologue (between loops) and epilogue (after inner loop)
! into the acc.loop region body.

subroutine collapse_force_sink(n, m)
integer, intent(in) :: n, m
real, dimension(n,m) :: a
real, dimension(n) :: bb, cc
integer :: i, j

!$acc parallel loop collapse(force:2)
do i = 1, n
bb(i) = 4.2 ! prologue (between loops)
do j = 1, m
a(i,j) = a(i,j) + 2.0
end do
cc(i) = 7.3 ! epilogue (after inner loop)
end do
!$acc end parallel loop
end subroutine

! CHECK: func.func @_QPcollapse_force_sink(
! CHECK: acc.parallel
! Ensure outer acc.loop is combined(parallel)
! CHECK: acc.loop combined(parallel)
! Prologue: constant 4.2 and an assign before inner loop
! CHECK: arith.constant 4.200000e+00
! CHECK: hlfir.assign
! Inner loop and its body include 2.0 add and an assign
! CHECK: acc.loop
! CHECK: arith.constant 2.000000e+00
! CHECK: arith.addf
! CHECK: hlfir.assign
! Epilogue: constant 7.3 and an assign after inner loop
! CHECK: arith.constant 7.300000e+00
! CHECK: hlfir.assign
! And the outer acc.loop has collapse = [2]
! CHECK: } attributes {collapse = [2]