-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[flang][openacc] Add support for force clause for loop collapse #162534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: Susan Tan (ス-ザン タン) (SusanTan) ChangesCurrently the force clause Full diff: https://github.com/llvm/llvm-project/pull/162534.diff 3 Files Affected:
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 780d56f085f69..3d331cdad3d43 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3190,15 +3190,39 @@ class FirConverter : public Fortran::lower::AbstractConverter {
std::get_if<Fortran::parser::OpenACCCombinedConstruct>(&acc.u);
Fortran::lower::pft::Evaluation *curEval = &getEval();
+ // Determine collapse depth/force and loopCount
+ bool collapseForce = false;
+ uint64_t collapseDepth = 1;
+ uint64_t loopCount = 1;
+ auto parseCollapse = [&](const Fortran::parser::AccClauseList &cl)
+ -> std::pair<bool, uint64_t> {
+ bool force = false;
+ uint64_t depth = 1;
+ for (const Fortran::parser::AccClause &clause : cl.v) {
+ if (const auto *collapseClause =
+ std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
+ const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
+ force = std::get<bool>(arg.t);
+ const auto &intExpr =
+ std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
+ if (const auto *expr = Fortran::semantics::GetExpr(intExpr)) {
+ if (auto v = Fortran::evaluate::ToInt64(*expr))
+ depth = *v;
+ }
+ break;
+ }
+ }
+ return {force, depth};
+ };
if (accLoop || accCombined) {
- uint64_t loopCount;
if (accLoop) {
const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
+ std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
} else if (accCombined) {
const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
std::get<Fortran::parser::AccBeginCombinedDirective>(
@@ -3206,6 +3230,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
+ std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
}
if (curEval->lowerAsStructured()) {
@@ -3215,8 +3240,58 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
}
- for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
- genFIR(e);
+ const bool isStructured = curEval && curEval->lowerAsStructured();
+ if (isStructured && collapseForce && collapseDepth > 1) {
+ // force: collect prologue/epilogue for the first collapseDepth nested loops
+ // and sink them into the innermost loop body at that depth
+ llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, epilogue;
+ Fortran::lower::pft::Evaluation *parent = &getEval();
+ Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr;
+ for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) {
+ epilogue.clear();
+ auto &kids = parent->getNestedEvaluations();
+ // Collect all non-loop statements before the next inner loop as prologue,
+ // then mark remaining siblings as epilogue and descend into the inner loop.
+ Fortran::lower::pft::Evaluation *childLoop = nullptr;
+ for (auto it = kids.begin(); it != kids.end(); ++it) {
+ if (it->getIf<Fortran::parser::DoConstruct>()) {
+ childLoop = &*it;
+ for (auto it2 = std::next(it); it2 != kids.end(); ++it2)
+ epilogue.push_back(&*it2);
+ break;
+ }
+ prologue.push_back(&*it);
+ }
+ // Semantics guarantees collapseDepth does not exceed nest depth
+ // so childLoop must be found here.
+ assert(childLoop && "Expected inner DoConstruct for collapse");
+ parent = childLoop;
+ innermostLoopEval = childLoop;
+ }
+
+ // Track sunk evaluations (avoid double-lowering)
+ llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk;
+ for (auto *e : prologue) sunk.insert(e);
+ for (auto *e : epilogue) sunk.insert(e);
+
+ auto sink =
+ [&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) {
+ for (auto *e : lst)
+ genFIR(*e);
+ };
+
+ sink(prologue);
+
+ // Lower innermost loop body, skipping sunk
+ for (Fortran::lower::pft::Evaluation &e : innermostLoopEval->getNestedEvaluations())
+ if (!sunk.contains(&e)) genFIR(e);
+
+ sink(epilogue);
+ } else {
+ // Normal lowering
+ for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
+ genFIR(e);
+ }
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 4a9e49435a907..e24e784895fe8 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2144,8 +2144,23 @@ static void processDoLoopBounds(
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(outerDoConstruct)));
} else {
- auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
- assert(doCons && "expect do construct");
+ // Safely locate the next inner DoConstruct within this eval.
+ const Fortran::parser::DoConstruct *doCons = nullptr;
+ if (crtEval && crtEval->hasNestedEvaluations()) {
+ for (Fortran::lower::pft::Evaluation &child :
+ crtEval->getNestedEvaluations()) {
+ if (auto *cand = child.getIf<Fortran::parser::DoConstruct>()) {
+ doCons = cand;
+ // Prepare to descend for the next iteration
+ crtEval = &child;
+ break;
+ }
+ }
+ }
+ if (!doCons) {
+ // No deeper loop; stop collecting collapsed bounds.
+ break;
+ }
loopControl = &*doCons->GetLoopControl();
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(*doCons)));
@@ -2172,8 +2187,7 @@ static void processDoLoopBounds(
inclusiveBounds.push_back(true);
- if (i < loopsToProcess - 1)
- crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+ // crtEval already updated when descending; no blind increment here.
}
}
}
@@ -2406,10 +2420,6 @@ static mlir::acc::LoopOp createLoopOp(
std::get_if<Fortran::parser::AccClause::Collapse>(
&clause.u)) {
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
- const auto &force = std::get<bool>(arg.t);
- if (force)
- TODO(clauseLocation, "OpenACC collapse force modifier");
-
const auto &intExpr =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
const auto *expr = Fortran::semantics::GetExpr(intExpr);
diff --git a/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
new file mode 100644
index 0000000000000..ca932c1b159ba
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
@@ -0,0 +1,41 @@
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+! Verify collapse(force:2) sinks prologue (between loops) and epilogue (after inner loop)
+! into the acc.loop region body.
+
+subroutine collapse_force_sink(n, m)
+ integer, intent(in) :: n, m
+ real, dimension(n,m) :: a
+ real, dimension(n) :: bb, cc
+ integer :: i, j
+
+ !$acc parallel loop collapse(force:2)
+ do i = 1, n
+ bb(i) = 4.2 ! prologue (between loops)
+ do j = 1, m
+ a(i,j) = a(i,j) + 2.0
+ end do
+ cc(i) = 7.3 ! epilogue (after inner loop)
+ end do
+ !$acc end parallel loop
+end subroutine
+
+! CHECK: func.func @_QPcollapse_force_sink(
+! CHECK: acc.parallel
+! Ensure outer acc.loop is combined(parallel)
+! CHECK: acc.loop combined(parallel)
+! Prologue: constant 4.2 and an assign before inner loop
+! CHECK: arith.constant 4.200000e+00
+! CHECK: hlfir.assign
+! Inner loop and its body include 2.0 add and an assign
+! CHECK: acc.loop
+! CHECK: arith.constant 2.000000e+00
+! CHECK: arith.addf
+! CHECK: hlfir.assign
+! Epilogue: constant 7.3 and an assign after inner loop
+! CHECK: arith.constant 7.300000e+00
+! CHECK: hlfir.assign
+! And the outer acc.loop has collapse = [2]
+! CHECK: } attributes {collapse = [2]
+
+
|
@llvm/pr-subscribers-openacc Author: Susan Tan (ス-ザン タン) (SusanTan) ChangesCurrently the force clause Full diff: https://github.com/llvm/llvm-project/pull/162534.diff 3 Files Affected:
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 780d56f085f69..3d331cdad3d43 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3190,15 +3190,39 @@ class FirConverter : public Fortran::lower::AbstractConverter {
std::get_if<Fortran::parser::OpenACCCombinedConstruct>(&acc.u);
Fortran::lower::pft::Evaluation *curEval = &getEval();
+ // Determine collapse depth/force and loopCount
+ bool collapseForce = false;
+ uint64_t collapseDepth = 1;
+ uint64_t loopCount = 1;
+ auto parseCollapse = [&](const Fortran::parser::AccClauseList &cl)
+ -> std::pair<bool, uint64_t> {
+ bool force = false;
+ uint64_t depth = 1;
+ for (const Fortran::parser::AccClause &clause : cl.v) {
+ if (const auto *collapseClause =
+ std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
+ const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
+ force = std::get<bool>(arg.t);
+ const auto &intExpr =
+ std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
+ if (const auto *expr = Fortran::semantics::GetExpr(intExpr)) {
+ if (auto v = Fortran::evaluate::ToInt64(*expr))
+ depth = *v;
+ }
+ break;
+ }
+ }
+ return {force, depth};
+ };
if (accLoop || accCombined) {
- uint64_t loopCount;
if (accLoop) {
const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
+ std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
} else if (accCombined) {
const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
std::get<Fortran::parser::AccBeginCombinedDirective>(
@@ -3206,6 +3230,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
+ std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
}
if (curEval->lowerAsStructured()) {
@@ -3215,8 +3240,58 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
}
- for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
- genFIR(e);
+ const bool isStructured = curEval && curEval->lowerAsStructured();
+ if (isStructured && collapseForce && collapseDepth > 1) {
+ // force: collect prologue/epilogue for the first collapseDepth nested loops
+ // and sink them into the innermost loop body at that depth
+ llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, epilogue;
+ Fortran::lower::pft::Evaluation *parent = &getEval();
+ Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr;
+ for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) {
+ epilogue.clear();
+ auto &kids = parent->getNestedEvaluations();
+ // Collect all non-loop statements before the next inner loop as prologue,
+ // then mark remaining siblings as epilogue and descend into the inner loop.
+ Fortran::lower::pft::Evaluation *childLoop = nullptr;
+ for (auto it = kids.begin(); it != kids.end(); ++it) {
+ if (it->getIf<Fortran::parser::DoConstruct>()) {
+ childLoop = &*it;
+ for (auto it2 = std::next(it); it2 != kids.end(); ++it2)
+ epilogue.push_back(&*it2);
+ break;
+ }
+ prologue.push_back(&*it);
+ }
+ // Semantics guarantees collapseDepth does not exceed nest depth
+ // so childLoop must be found here.
+ assert(childLoop && "Expected inner DoConstruct for collapse");
+ parent = childLoop;
+ innermostLoopEval = childLoop;
+ }
+
+ // Track sunk evaluations (avoid double-lowering)
+ llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk;
+ for (auto *e : prologue) sunk.insert(e);
+ for (auto *e : epilogue) sunk.insert(e);
+
+ auto sink =
+ [&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) {
+ for (auto *e : lst)
+ genFIR(*e);
+ };
+
+ sink(prologue);
+
+ // Lower innermost loop body, skipping sunk
+ for (Fortran::lower::pft::Evaluation &e : innermostLoopEval->getNestedEvaluations())
+ if (!sunk.contains(&e)) genFIR(e);
+
+ sink(epilogue);
+ } else {
+ // Normal lowering
+ for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
+ genFIR(e);
+ }
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 4a9e49435a907..e24e784895fe8 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2144,8 +2144,23 @@ static void processDoLoopBounds(
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(outerDoConstruct)));
} else {
- auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
- assert(doCons && "expect do construct");
+ // Safely locate the next inner DoConstruct within this eval.
+ const Fortran::parser::DoConstruct *doCons = nullptr;
+ if (crtEval && crtEval->hasNestedEvaluations()) {
+ for (Fortran::lower::pft::Evaluation &child :
+ crtEval->getNestedEvaluations()) {
+ if (auto *cand = child.getIf<Fortran::parser::DoConstruct>()) {
+ doCons = cand;
+ // Prepare to descend for the next iteration
+ crtEval = &child;
+ break;
+ }
+ }
+ }
+ if (!doCons) {
+ // No deeper loop; stop collecting collapsed bounds.
+ break;
+ }
loopControl = &*doCons->GetLoopControl();
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(*doCons)));
@@ -2172,8 +2187,7 @@ static void processDoLoopBounds(
inclusiveBounds.push_back(true);
- if (i < loopsToProcess - 1)
- crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+ // crtEval already updated when descending; no blind increment here.
}
}
}
@@ -2406,10 +2420,6 @@ static mlir::acc::LoopOp createLoopOp(
std::get_if<Fortran::parser::AccClause::Collapse>(
&clause.u)) {
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
- const auto &force = std::get<bool>(arg.t);
- if (force)
- TODO(clauseLocation, "OpenACC collapse force modifier");
-
const auto &intExpr =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
const auto *expr = Fortran::semantics::GetExpr(intExpr);
diff --git a/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
new file mode 100644
index 0000000000000..ca932c1b159ba
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
@@ -0,0 +1,41 @@
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+! Verify collapse(force:2) sinks prologue (between loops) and epilogue (after inner loop)
+! into the acc.loop region body.
+
+subroutine collapse_force_sink(n, m)
+ integer, intent(in) :: n, m
+ real, dimension(n,m) :: a
+ real, dimension(n) :: bb, cc
+ integer :: i, j
+
+ !$acc parallel loop collapse(force:2)
+ do i = 1, n
+ bb(i) = 4.2 ! prologue (between loops)
+ do j = 1, m
+ a(i,j) = a(i,j) + 2.0
+ end do
+ cc(i) = 7.3 ! epilogue (after inner loop)
+ end do
+ !$acc end parallel loop
+end subroutine
+
+! CHECK: func.func @_QPcollapse_force_sink(
+! CHECK: acc.parallel
+! Ensure outer acc.loop is combined(parallel)
+! CHECK: acc.loop combined(parallel)
+! Prologue: constant 4.2 and an assign before inner loop
+! CHECK: arith.constant 4.200000e+00
+! CHECK: hlfir.assign
+! Inner loop and its body include 2.0 add and an assign
+! CHECK: acc.loop
+! CHECK: arith.constant 2.000000e+00
+! CHECK: arith.addf
+! CHECK: hlfir.assign
+! Epilogue: constant 7.3 and an assign after inner loop
+! CHECK: arith.constant 7.300000e+00
+! CHECK: hlfir.assign
+! And the outer acc.loop has collapse = [2]
+! CHECK: } attributes {collapse = [2]
+
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! nice work!
✅ With the latest revision this PR passed the C/C++ code formatter. |
flang/lib/Lower/Bridge.cpp
Outdated
force = std::get<bool>(arg.t); | ||
const auto &intExpr = | ||
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t); | ||
if (const auto *expr = Fortran::semantics::GetExpr(intExpr)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having collapse clause parsing in a single place would be ideal. We currently have getLoopCountForCollapseAndTile
in OpenACC.cpp. Any chance to make something similar - like getLoopCountForCollapse - and then use it both here and in getLoopCountForCollapseAndTile.
Nit: can you change your title prefix to [flang][openacc] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
Currently the force clause
collapse (force:num_level)
is NYI. Added support to sink any prologue and epilogue code to the inner most level as specified.