-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[flang][OpenMP] Store Block in OpenMPLoopConstruct, add access functions #168078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Instead of storing a variant with specific types, store parser::Block as the body. Add two access functions to make the traversal of the nest simpler. This will allow storing loop-nest sequences in the future.
|
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-semantics Author: Krzysztof Parzyszek (kparzysz) ChangesInstead of storing a variant with specific types, store parser::Block as the body. Add two access functions to make the traversal of the nest simpler. This will allow storing loop-nest sequences in the future. Patch is 48.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168078.diff 19 Files Affected:
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index b1765f927d6c9..60d2ad0b764b9 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -5345,12 +5345,10 @@ struct OmpEndLoopDirective : public OmpEndDirective {
};
// OpenMP directives enclosing do loop
-using NestedConstruct =
- std::variant<DoConstruct, common::Indirection<OpenMPLoopConstruct>>;
struct OpenMPLoopConstruct {
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
- : t({std::move(a), std::nullopt, std::nullopt}) {}
+ : t({std::move(a), Block{}, std::nullopt}) {}
const OmpBeginLoopDirective &BeginDir() const {
return std::get<OmpBeginLoopDirective>(t);
@@ -5358,8 +5356,10 @@ struct OpenMPLoopConstruct {
const std::optional<OmpEndLoopDirective> &EndDir() const {
return std::get<std::optional<OmpEndLoopDirective>>(t);
}
- std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
- std::optional<OmpEndLoopDirective>>
+ const DoConstruct *GetNestedLoop() const;
+ const OpenMPLoopConstruct *GetNestedConstruct() const;
+
+ std::tuple<OmpBeginLoopDirective, Block, std::optional<OmpEndLoopDirective>>
t;
};
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index fe80c46c23d06..2c00131858bdc 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -3936,27 +3936,22 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::Location currentLocation = converter.genLocation(beginSpec.source);
- auto &optLoopCons =
- std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
- if (optLoopCons.has_value()) {
- if (auto *ompNestedLoopCons{
- std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
- &*optLoopCons)}) {
- llvm::omp::Directive nestedDirective =
- parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
- switch (nestedDirective) {
- case llvm::omp::Directive::OMPD_tile:
- // Skip OMPD_tile since the tile sizes will be retrieved when
- // generating the omp.loop_nest op.
- break;
- default: {
- unsigned version = semaCtx.langOptions().OpenMPVersion;
- TODO(currentLocation,
- "Applying a loop-associated on the loop generated by the " +
- llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
- " construct");
- }
- }
+ if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
+ loopConstruct.GetNestedConstruct()) {
+ llvm::omp::Directive nestedDirective =
+ parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
+ switch (nestedDirective) {
+ case llvm::omp::Directive::OMPD_tile:
+ // Skip OMPD_tile since the tile sizes will be retrieved when
+ // generating the omp.loop_nest op.
+ break;
+ default: {
+ unsigned version = semaCtx.langOptions().OpenMPVersion;
+ TODO(currentLocation,
+ "Applying a loop-associated on the loop generated by the " +
+ llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
+ " construct");
+ }
}
}
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 6487f599df72a..faec60d81ce84 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -631,17 +631,9 @@ static void processTileSizesFromOpenMPConstruct(
if (!ompCons)
return;
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
- const auto &nestedOptional =
- std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
- assert(nestedOptional.has_value() &&
- "Expected a DoConstruct or OpenMPLoopConstruct");
- const auto *innerConstruct =
- std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
- &(nestedOptional.value()));
- if (innerConstruct) {
- const auto &innerLoopDirective = innerConstruct->value();
+ if (auto *innerConstruct = ompLoop->GetNestedConstruct()) {
const parser::OmpDirectiveSpecification &innerBeginSpec =
- innerLoopDirective.BeginDir();
+ innerConstruct->BeginDir();
if (innerBeginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
// Get the size values from parse tree and convert to a vector.
for (const auto &clause : innerBeginSpec.Clauses().v) {
diff --git a/flang/lib/Parser/parse-tree.cpp b/flang/lib/Parser/parse-tree.cpp
index ad0016e1404f9..60e51895cdcea 100644
--- a/flang/lib/Parser/parse-tree.cpp
+++ b/flang/lib/Parser/parse-tree.cpp
@@ -7,8 +7,10 @@
//===----------------------------------------------------------------------===//
#include "flang/Parser/parse-tree.h"
+
#include "flang/Common/idioms.h"
#include "flang/Common/indirection.h"
+#include "flang/Parser/openmp-utils.h"
#include "flang/Parser/tools.h"
#include "flang/Parser/user-state.h"
#include "llvm/ADT/ArrayRef.h"
@@ -432,6 +434,20 @@ const OmpClauseList &OmpDirectiveSpecification::Clauses() const {
return empty;
}
+const DoConstruct *OpenMPLoopConstruct::GetNestedLoop() const {
+ if (auto &body{std::get<Block>(t)}; !body.empty()) {
+ return Unwrap<DoConstruct>(body.front());
+ }
+ return nullptr;
+}
+
+const OpenMPLoopConstruct *OpenMPLoopConstruct::GetNestedConstruct() const {
+ if (auto &body{std::get<Block>(t)}; !body.empty()) {
+ return Unwrap<OpenMPLoopConstruct>(body.front());
+ }
+ return nullptr;
+}
+
static bool InitCharBlocksFromStrings(llvm::MutableArrayRef<CharBlock> blocks,
llvm::ArrayRef<std::string> strings) {
for (auto [i, n] : llvm::enumerate(strings)) {
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index e3bc3cdc42ffb..f81200d092b11 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -2706,12 +2706,6 @@ class UnparseVisitor {
Put("\n");
EndOpenMP();
}
- void Unparse(const OpenMPLoopConstruct &x) {
- Walk(std::get<OmpBeginLoopDirective>(x.t));
- Walk(std::get<std::optional<std::variant<DoConstruct,
- common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t));
- Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
- }
void Unparse(const BasedPointer &x) {
Put('('), Walk(std::get<0>(x.t)), Put(","), Walk(std::get<1>(x.t));
Walk("(", std::get<std::optional<ArraySpec>>(x.t), ")"), Put(')');
diff --git a/flang/lib/Semantics/canonicalize-omp.cpp b/flang/lib/Semantics/canonicalize-omp.cpp
index a11c5250b1ab4..0cec1969e0978 100644
--- a/flang/lib/Semantics/canonicalize-omp.cpp
+++ b/flang/lib/Semantics/canonicalize-omp.cpp
@@ -143,6 +143,8 @@ class CanonicalizationOfOmp {
parser::ToUpperCaseLetters(dirName.source.ToString()));
};
+ auto &body{std::get<parser::Block>(x.t)};
+
nextIt = it;
while (++nextIt != block.end()) {
// Ignore compiler directives.
@@ -152,9 +154,7 @@ class CanonicalizationOfOmp {
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
if (doCons->GetLoopControl()) {
// move DoConstruct
- std::get<std::optional<std::variant<parser::DoConstruct,
- common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
- std::move(*doCons);
+ body.push_back(std::move(*nextIt));
nextIt = block.erase(nextIt);
// try to match OmpEndLoopDirective
if (nextIt != block.end()) {
@@ -198,10 +198,7 @@ class CanonicalizationOfOmp {
++endIt;
}
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
- auto &ompLoop = std::get<std::optional<parser::NestedConstruct>>(x.t);
- ompLoop =
- std::optional<parser::NestedConstruct>{parser::NestedConstruct{
- common::Indirection{std::move(*ompLoopCons)}}};
+ body.push_back(std::move(*nextIt));
nextIt = block.erase(nextIt);
} else if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
beginName.v == llvm::omp::Directive::OMPD_tile) {
diff --git a/flang/lib/Semantics/check-omp-loop.cpp b/flang/lib/Semantics/check-omp-loop.cpp
index aaaa2d6e78280..3d3596b500880 100644
--- a/flang/lib/Semantics/check-omp-loop.cpp
+++ b/flang/lib/Semantics/check-omp-loop.cpp
@@ -285,13 +285,9 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
}
SetLoopInfo(x);
- auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
- if (optLoopCons.has_value()) {
- if (const auto &doConstruct{
- std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
- const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
- CheckNoBranching(doBlock, beginName.v, beginName.source);
- }
+ if (const auto *doConstruct{x.GetNestedLoop()}) {
+ const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
+ CheckNoBranching(doBlock, beginName.v, beginName.source);
}
CheckLoopItrVariableIsInt(x);
CheckAssociatedLoopConstraints(x);
@@ -314,46 +310,34 @@ const parser::Name OmpStructureChecker::GetLoopIndex(
}
void OmpStructureChecker::SetLoopInfo(const parser::OpenMPLoopConstruct &x) {
- auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
- if (optLoopCons.has_value()) {
- if (const auto &loopConstruct{
- std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
- const parser::DoConstruct *loop{&*loopConstruct};
- if (loop && loop->IsDoNormal()) {
- const parser::Name &itrVal{GetLoopIndex(loop)};
- SetLoopIv(itrVal.symbol);
- }
+ if (const auto *loop{x.GetNestedLoop()}) {
+ if (loop->IsDoNormal()) {
+ const parser::Name &itrVal{GetLoopIndex(loop)};
+ SetLoopIv(itrVal.symbol);
}
}
}
void OmpStructureChecker::CheckLoopItrVariableIsInt(
const parser::OpenMPLoopConstruct &x) {
- auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
- if (optLoopCons.has_value()) {
- if (const auto &loopConstruct{
- std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
-
- for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
- if (loop->IsDoNormal()) {
- const parser::Name &itrVal{GetLoopIndex(loop)};
- if (itrVal.symbol) {
- const auto *type{itrVal.symbol->GetType()};
- if (!type->IsNumeric(TypeCategory::Integer)) {
- context_.Say(itrVal.source,
- "The DO loop iteration"
- " variable must be of the type integer."_err_en_US,
- itrVal.ToString());
- }
- }
+ for (const parser::DoConstruct *loop{x.GetNestedLoop()}; loop;) {
+ if (loop->IsDoNormal()) {
+ const parser::Name &itrVal{GetLoopIndex(loop)};
+ if (itrVal.symbol) {
+ const auto *type{itrVal.symbol->GetType()};
+ if (!type->IsNumeric(TypeCategory::Integer)) {
+ context_.Say(itrVal.source,
+ "The DO loop iteration"
+ " variable must be of the type integer."_err_en_US,
+ itrVal.ToString());
}
- // Get the next DoConstruct if block is not empty.
- const auto &block{std::get<parser::Block>(loop->t)};
- const auto it{block.begin()};
- loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
- : nullptr;
}
}
+ // Get the next DoConstruct if block is not empty.
+ const auto &block{std::get<parser::Block>(loop->t)};
+ const auto it{block.begin()};
+ loop =
+ it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it) : nullptr;
}
}
@@ -417,29 +401,23 @@ void OmpStructureChecker::CheckDistLinear(
// Match the loop index variables with the collected symbols from linear
// clauses.
- auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
- if (optLoopCons.has_value()) {
- if (const auto &loopConstruct{
- std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
- for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
- if (loop->IsDoNormal()) {
- const parser::Name &itrVal{GetLoopIndex(loop)};
- if (itrVal.symbol) {
- // Remove the symbol from the collected set
- indexVars.erase(&itrVal.symbol->GetUltimate());
- }
- collapseVal--;
- if (collapseVal == 0) {
- break;
- }
- }
- // Get the next DoConstruct if block is not empty.
- const auto &block{std::get<parser::Block>(loop->t)};
- const auto it{block.begin()};
- loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
- : nullptr;
+ for (const parser::DoConstruct *loop{x.GetNestedLoop()}; loop;) {
+ if (loop->IsDoNormal()) {
+ const parser::Name &itrVal{GetLoopIndex(loop)};
+ if (itrVal.symbol) {
+ // Remove the symbol from the collected set
+ indexVars.erase(&itrVal.symbol->GetUltimate());
+ }
+ collapseVal--;
+ if (collapseVal == 0) {
+ break;
}
}
+ // Get the next DoConstruct if block is not empty.
+ const auto &block{std::get<parser::Block>(loop->t)};
+ const auto it{block.begin()};
+ loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
+ : nullptr;
}
// Show error for the remaining variables
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index f1658943ab2e1..cb075ff710e1e 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -2047,13 +2047,9 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPLoopConstruct &x) {
SetContextAssociatedLoopLevel(GetNumAffectedLoopsFromLoopConstruct(x));
if (beginName.v == llvm::omp::Directive::OMPD_do) {
- auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
- if (optLoopCons.has_value()) {
- if (const auto &doConstruct{
- std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
- if (doConstruct->IsDoWhile()) {
- return true;
- }
+ if (const parser::DoConstruct *doConstruct{x.GetNestedLoop()}) {
+ if (doConstruct->IsDoWhile()) {
+ return true;
}
}
}
@@ -2210,18 +2206,8 @@ void OmpAttributeVisitor::CollectNumAffectedLoopsFromInnerLoopContruct(
const parser::OpenMPLoopConstruct &x,
llvm::SmallVector<std::int64_t> &levels,
llvm::SmallVector<const parser::OmpClause *> &clauses) {
-
- const auto &nestedOptional =
- std::get<std::optional<parser::NestedConstruct>>(x.t);
- assert(nestedOptional.has_value() &&
- "Expected a DoConstruct or OpenMPLoopConstruct");
- const auto *innerConstruct =
- std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
- &(nestedOptional.value()));
-
- if (innerConstruct) {
- CollectNumAffectedLoopsFromLoopConstruct(
- innerConstruct->value(), levels, clauses);
+ if (auto *innerConstruct{x.GetNestedConstruct()}) {
+ CollectNumAffectedLoopsFromLoopConstruct(*innerConstruct, levels, clauses);
}
}
@@ -2286,24 +2272,12 @@ void OmpAttributeVisitor::CheckPerfectNestAndRectangularLoop(
// Find the associated region by skipping nested loop-associated constructs
// such as loop transformations
- const parser::NestedConstruct *innermostAssocRegion{nullptr};
const parser::OpenMPLoopConstruct *innermostConstruct{&x};
- while (const auto &innerAssocStmt{
- std::get<std::optional<parser::NestedConstruct>>(
- innermostConstruct->t)}) {
- innermostAssocRegion = &(innerAssocStmt.value());
- if (const auto *innerConstruct{
- std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
- innermostAssocRegion)}) {
- innermostConstruct = &innerConstruct->value();
- } else {
- break;
- }
+ while (auto *nested{innermostConstruct->GetNestedConstruct()}) {
+ innermostConstruct = nested;
}
- if (!innermostAssocRegion)
- return;
- const auto &outer{std::get_if<parser::DoConstruct>(innermostAssocRegion)};
+ const auto *outer{innermostConstruct->GetNestedLoop()};
if (!outer)
return;
@@ -2398,61 +2372,51 @@ void OmpAttributeVisitor::PrivatizeAssociatedLoopIndexAndCheckLoopLevel(
const parser::OmpClause *clause{GetAssociatedClause()};
bool hasCollapseClause{
clause ? (clause->Id() == llvm::omp::OMPC_collapse) : false};
- const parser::OpenMPLoopConstruct *innerMostLoop = &x;
- const parser::NestedConstruct *innerMostNest = nullptr;
- while (auto &optLoopCons{
- std::get<std::optional<parser::NestedConstruct>>(innerMostLoop->t)}) {
- innerMostNest = &(optLoopCons.value());
- if (const auto *innerLoop{
- std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
- innerMostNest)}) {
- innerMostLoop = &(innerLoop->value());
- } else
- break;
- }
- if (innerMostNest) {
- if (const auto &outer{std::get_if<parser::DoConstruct>(innerMostNest)}) {
- for (const parser::DoConstruct *loop{&*outer}; loop && level > 0;
- --level) {
- if (loop->IsDoConcurrent()) {
- // DO CONCURRENT is explicitly allowed for the LOOP construct so long
- // as there isn't a COLLAPSE clause
- if (isLoopConstruct) {
- if (hasCollapseClause) {
- // hasCollapseClause implies clause != nullptr
- context_.Say(clause->source,
- "DO CONCURRENT loops cannot be used with the COLLAPSE clause."_err_en_US);
- }
- } else {
- auto &stmt =
- std::get<parser::Statement<parser::NonLabelDoStmt>>(loop->t);
- context_.Say(stmt.source,
- "DO CONCURRENT loops cannot form part of a loop nest."_err_en_US);
+ const parser::OpenMPLoopConstruct *innerMostNest = &x;
+ while (auto *nested{innerMostNest->GetNestedConstruct()}) {
+ innerMostNest = nested;
+ }
+
+ if (const auto *outer{innerMostNest->GetNestedLoop()}) {
+ for (const parser::DoConstruct *loop{&*outer}; loop && level > 0; --level) {
+ if (loop->IsDoConcurrent()) {
+ // DO CONCURRENT is explicitly allowed for the LOOP construct so long
+ // as there isn't a COLLAPSE clause
+ if (isLoopConstruct) {
+ if (hasCollapseClause) {
+ // hasCollapseClause implies clause != nullptr
+ context_.Say(clause->source,
+ "DO CONCURRENT loops cannot be used with the COLLAPSE clause."_err_en_US);
}
+ } else {
+ auto &stmt =
+ std::get<parser::Statement<parser::NonLabelDoStmt>>(loop->t);
+ context_.Say(stmt.source,
+ "DO CONCURRENT loops cannot form part of a loop nest."_err_en_US);
}
- // go through all the nested do-loops and resolve index variables
- const parser::Name *iv{GetLoopIndex(*loop)};
- if (iv) {
- if (auto *symbol{ResolveOmp(*iv, ivDSA, currScope())}) {
- SetSymbolDSA(*symbol, {Symbol::Flag::OmpPreDetermined, ivDSA});
- iv->symbol = symbol; // adjust the symbol within region
- AddToContextObjectWithDSA(*symbol, ivDSA);
- }
-
- const auto &block{std::get<parser::Block>(loop->t)};
- const auto it{block.begin()};
- loop = it != block.end() ? GetDoConstructIf(*it) : nullptr;
+ }
+ // go through all the nested do-loops and resolve index variables
+ const parser::Name *iv{GetLoopIndex(*loop)};
+ if (iv) {
+ if (auto *symbol{ResolveOmp(*iv, ivDSA, currScope())}) {
+ SetSymbolDSA(*symbol, {Symbol::Flag::OmpPreDetermined, ivDSA});
+ iv->symbol = symbol; // adjust the symbol within region
+ AddToContextObjectWithDSA(*symbol, ivDSA);
}
+
+ const auto &block{std::get<parser::Block>(loop->t)};
+ const auto it{block.begin()};
+ loop = it != block.end() ? GetDoConstructIf(*it) : nullptr;
}
- CheckAssocLoopLevel(level, GetAssociatedClause());
-...
[truncated]
|
|
@llvm/pr-subscribers-flang-openmp Author: Krzysztof Parzyszek (kparzysz) ChangesInstead of storing a variant with specific types, store parser::Block as the body. Add two access functions to make the traversal of the nest simpler. This will allow storing loop-nest sequences in the future. Patch is 48.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168078.diff 19 Files Affected:
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index b1765f927d6c9..60d2ad0b764b9 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -5345,12 +5345,10 @@ struct OmpEndLoopDirective : public OmpEndDirective {
};
// OpenMP directives enclosing do loop
-using NestedConstruct =
- std::variant<DoConstruct, common::Indirection<OpenMPLoopConstruct>>;
struct OpenMPLoopConstruct {
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
- : t({std::move(a), std::nullopt, std::nullopt}) {}
+ : t({std::move(a), Block{}, std::nullopt}) {}
const OmpBeginLoopDirective &BeginDir() const {
return std::get<OmpBeginLoopDirective>(t);
@@ -5358,8 +5356,10 @@ struct OpenMPLoopConstruct {
const std::optional<OmpEndLoopDirective> &EndDir() const {
return std::get<std::optional<OmpEndLoopDirective>>(t);
}
- std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
- std::optional<OmpEndLoopDirective>>
+ const DoConstruct *GetNestedLoop() const;
+ const OpenMPLoopConstruct *GetNestedConstruct() const;
+
+ std::tuple<OmpBeginLoopDirective, Block, std::optional<OmpEndLoopDirective>>
t;
};
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index fe80c46c23d06..2c00131858bdc 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -3936,27 +3936,22 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::Location currentLocation = converter.genLocation(beginSpec.source);
- auto &optLoopCons =
- std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
- if (optLoopCons.has_value()) {
- if (auto *ompNestedLoopCons{
- std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
- &*optLoopCons)}) {
- llvm::omp::Directive nestedDirective =
- parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
- switch (nestedDirective) {
- case llvm::omp::Directive::OMPD_tile:
- // Skip OMPD_tile since the tile sizes will be retrieved when
- // generating the omp.loop_nest op.
- break;
- default: {
- unsigned version = semaCtx.langOptions().OpenMPVersion;
- TODO(currentLocation,
- "Applying a loop-associated on the loop generated by the " +
- llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
- " construct");
- }
- }
+ if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
+ loopConstruct.GetNestedConstruct()) {
+ llvm::omp::Directive nestedDirective =
+ parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
+ switch (nestedDirective) {
+ case llvm::omp::Directive::OMPD_tile:
+ // Skip OMPD_tile since the tile sizes will be retrieved when
+ // generating the omp.loop_nest op.
+ break;
+ default: {
+ unsigned version = semaCtx.langOptions().OpenMPVersion;
+ TODO(currentLocation,
+ "Applying a loop-associated on the loop generated by the " +
+ llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
+ " construct");
+ }
}
}
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 6487f599df72a..faec60d81ce84 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -631,17 +631,9 @@ static void processTileSizesFromOpenMPConstruct(
if (!ompCons)
return;
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
- const auto &nestedOptional =
- std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
- assert(nestedOptional.has_value() &&
- "Expected a DoConstruct or OpenMPLoopConstruct");
- const auto *innerConstruct =
- std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
- &(nestedOptional.value()));
- if (innerConstruct) {
- const auto &innerLoopDirective = innerConstruct->value();
+ if (auto *innerConstruct = ompLoop->GetNestedConstruct()) {
const parser::OmpDirectiveSpecification &innerBeginSpec =
- innerLoopDirective.BeginDir();
+ innerConstruct->BeginDir();
if (innerBeginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
// Get the size values from parse tree and convert to a vector.
for (const auto &clause : innerBeginSpec.Clauses().v) {
diff --git a/flang/lib/Parser/parse-tree.cpp b/flang/lib/Parser/parse-tree.cpp
index ad0016e1404f9..60e51895cdcea 100644
--- a/flang/lib/Parser/parse-tree.cpp
+++ b/flang/lib/Parser/parse-tree.cpp
@@ -7,8 +7,10 @@
//===----------------------------------------------------------------------===//
#include "flang/Parser/parse-tree.h"
+
#include "flang/Common/idioms.h"
#include "flang/Common/indirection.h"
+#include "flang/Parser/openmp-utils.h"
#include "flang/Parser/tools.h"
#include "flang/Parser/user-state.h"
#include "llvm/ADT/ArrayRef.h"
@@ -432,6 +434,20 @@ const OmpClauseList &OmpDirectiveSpecification::Clauses() const {
return empty;
}
+const DoConstruct *OpenMPLoopConstruct::GetNestedLoop() const {
+ if (auto &body{std::get<Block>(t)}; !body.empty()) {
+ return Unwrap<DoConstruct>(body.front());
+ }
+ return nullptr;
+}
+
+const OpenMPLoopConstruct *OpenMPLoopConstruct::GetNestedConstruct() const {
+ if (auto &body{std::get<Block>(t)}; !body.empty()) {
+ return Unwrap<OpenMPLoopConstruct>(body.front());
+ }
+ return nullptr;
+}
+
static bool InitCharBlocksFromStrings(llvm::MutableArrayRef<CharBlock> blocks,
llvm::ArrayRef<std::string> strings) {
for (auto [i, n] : llvm::enumerate(strings)) {
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index e3bc3cdc42ffb..f81200d092b11 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -2706,12 +2706,6 @@ class UnparseVisitor {
Put("\n");
EndOpenMP();
}
- void Unparse(const OpenMPLoopConstruct &x) {
- Walk(std::get<OmpBeginLoopDirective>(x.t));
- Walk(std::get<std::optional<std::variant<DoConstruct,
- common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t));
- Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
- }
void Unparse(const BasedPointer &x) {
Put('('), Walk(std::get<0>(x.t)), Put(","), Walk(std::get<1>(x.t));
Walk("(", std::get<std::optional<ArraySpec>>(x.t), ")"), Put(')');
diff --git a/flang/lib/Semantics/canonicalize-omp.cpp b/flang/lib/Semantics/canonicalize-omp.cpp
index a11c5250b1ab4..0cec1969e0978 100644
--- a/flang/lib/Semantics/canonicalize-omp.cpp
+++ b/flang/lib/Semantics/canonicalize-omp.cpp
@@ -143,6 +143,8 @@ class CanonicalizationOfOmp {
parser::ToUpperCaseLetters(dirName.source.ToString()));
};
+ auto &body{std::get<parser::Block>(x.t)};
+
nextIt = it;
while (++nextIt != block.end()) {
// Ignore compiler directives.
@@ -152,9 +154,7 @@ class CanonicalizationOfOmp {
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
if (doCons->GetLoopControl()) {
// move DoConstruct
- std::get<std::optional<std::variant<parser::DoConstruct,
- common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
- std::move(*doCons);
+ body.push_back(std::move(*nextIt));
nextIt = block.erase(nextIt);
// try to match OmpEndLoopDirective
if (nextIt != block.end()) {
@@ -198,10 +198,7 @@ class CanonicalizationOfOmp {
++endIt;
}
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
- auto &ompLoop = std::get<std::optional<parser::NestedConstruct>>(x.t);
- ompLoop =
- std::optional<parser::NestedConstruct>{parser::NestedConstruct{
- common::Indirection{std::move(*ompLoopCons)}}};
+ body.push_back(std::move(*nextIt));
nextIt = block.erase(nextIt);
} else if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
beginName.v == llvm::omp::Directive::OMPD_tile) {
diff --git a/flang/lib/Semantics/check-omp-loop.cpp b/flang/lib/Semantics/check-omp-loop.cpp
index aaaa2d6e78280..3d3596b500880 100644
--- a/flang/lib/Semantics/check-omp-loop.cpp
+++ b/flang/lib/Semantics/check-omp-loop.cpp
@@ -285,13 +285,9 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
}
SetLoopInfo(x);
- auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
- if (optLoopCons.has_value()) {
- if (const auto &doConstruct{
- std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
- const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
- CheckNoBranching(doBlock, beginName.v, beginName.source);
- }
+ if (const auto *doConstruct{x.GetNestedLoop()}) {
+ const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
+ CheckNoBranching(doBlock, beginName.v, beginName.source);
}
CheckLoopItrVariableIsInt(x);
CheckAssociatedLoopConstraints(x);
@@ -314,46 +310,34 @@ const parser::Name OmpStructureChecker::GetLoopIndex(
}
void OmpStructureChecker::SetLoopInfo(const parser::OpenMPLoopConstruct &x) {
- auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
- if (optLoopCons.has_value()) {
- if (const auto &loopConstruct{
- std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
- const parser::DoConstruct *loop{&*loopConstruct};
- if (loop && loop->IsDoNormal()) {
- const parser::Name &itrVal{GetLoopIndex(loop)};
- SetLoopIv(itrVal.symbol);
- }
+ if (const auto *loop{x.GetNestedLoop()}) {
+ if (loop->IsDoNormal()) {
+ const parser::Name &itrVal{GetLoopIndex(loop)};
+ SetLoopIv(itrVal.symbol);
}
}
}
void OmpStructureChecker::CheckLoopItrVariableIsInt(
const parser::OpenMPLoopConstruct &x) {
- auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
- if (optLoopCons.has_value()) {
- if (const auto &loopConstruct{
- std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
-
- for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
- if (loop->IsDoNormal()) {
- const parser::Name &itrVal{GetLoopIndex(loop)};
- if (itrVal.symbol) {
- const auto *type{itrVal.symbol->GetType()};
- if (!type->IsNumeric(TypeCategory::Integer)) {
- context_.Say(itrVal.source,
- "The DO loop iteration"
- " variable must be of the type integer."_err_en_US,
- itrVal.ToString());
- }
- }
+ for (const parser::DoConstruct *loop{x.GetNestedLoop()}; loop;) {
+ if (loop->IsDoNormal()) {
+ const parser::Name &itrVal{GetLoopIndex(loop)};
+ if (itrVal.symbol) {
+ const auto *type{itrVal.symbol->GetType()};
+ if (!type->IsNumeric(TypeCategory::Integer)) {
+ context_.Say(itrVal.source,
+ "The DO loop iteration"
+ " variable must be of the type integer."_err_en_US,
+ itrVal.ToString());
}
- // Get the next DoConstruct if block is not empty.
- const auto &block{std::get<parser::Block>(loop->t)};
- const auto it{block.begin()};
- loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
- : nullptr;
}
}
+ // Get the next DoConstruct if block is not empty.
+ const auto &block{std::get<parser::Block>(loop->t)};
+ const auto it{block.begin()};
+ loop =
+ it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it) : nullptr;
}
}
@@ -417,29 +401,23 @@ void OmpStructureChecker::CheckDistLinear(
// Match the loop index variables with the collected symbols from linear
// clauses.
- auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
- if (optLoopCons.has_value()) {
- if (const auto &loopConstruct{
- std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
- for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
- if (loop->IsDoNormal()) {
- const parser::Name &itrVal{GetLoopIndex(loop)};
- if (itrVal.symbol) {
- // Remove the symbol from the collected set
- indexVars.erase(&itrVal.symbol->GetUltimate());
- }
- collapseVal--;
- if (collapseVal == 0) {
- break;
- }
- }
- // Get the next DoConstruct if block is not empty.
- const auto &block{std::get<parser::Block>(loop->t)};
- const auto it{block.begin()};
- loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
- : nullptr;
+ for (const parser::DoConstruct *loop{x.GetNestedLoop()}; loop;) {
+ if (loop->IsDoNormal()) {
+ const parser::Name &itrVal{GetLoopIndex(loop)};
+ if (itrVal.symbol) {
+ // Remove the symbol from the collected set
+ indexVars.erase(&itrVal.symbol->GetUltimate());
+ }
+ collapseVal--;
+ if (collapseVal == 0) {
+ break;
}
}
+ // Get the next DoConstruct if block is not empty.
+ const auto &block{std::get<parser::Block>(loop->t)};
+ const auto it{block.begin()};
+ loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
+ : nullptr;
}
// Show error for the remaining variables
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index f1658943ab2e1..cb075ff710e1e 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -2047,13 +2047,9 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPLoopConstruct &x) {
SetContextAssociatedLoopLevel(GetNumAffectedLoopsFromLoopConstruct(x));
if (beginName.v == llvm::omp::Directive::OMPD_do) {
- auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
- if (optLoopCons.has_value()) {
- if (const auto &doConstruct{
- std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
- if (doConstruct->IsDoWhile()) {
- return true;
- }
+ if (const parser::DoConstruct *doConstruct{x.GetNestedLoop()}) {
+ if (doConstruct->IsDoWhile()) {
+ return true;
}
}
}
@@ -2210,18 +2206,8 @@ void OmpAttributeVisitor::CollectNumAffectedLoopsFromInnerLoopContruct(
const parser::OpenMPLoopConstruct &x,
llvm::SmallVector<std::int64_t> &levels,
llvm::SmallVector<const parser::OmpClause *> &clauses) {
-
- const auto &nestedOptional =
- std::get<std::optional<parser::NestedConstruct>>(x.t);
- assert(nestedOptional.has_value() &&
- "Expected a DoConstruct or OpenMPLoopConstruct");
- const auto *innerConstruct =
- std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
- &(nestedOptional.value()));
-
- if (innerConstruct) {
- CollectNumAffectedLoopsFromLoopConstruct(
- innerConstruct->value(), levels, clauses);
+ if (auto *innerConstruct{x.GetNestedConstruct()}) {
+ CollectNumAffectedLoopsFromLoopConstruct(*innerConstruct, levels, clauses);
}
}
@@ -2286,24 +2272,12 @@ void OmpAttributeVisitor::CheckPerfectNestAndRectangularLoop(
// Find the associated region by skipping nested loop-associated constructs
// such as loop transformations
- const parser::NestedConstruct *innermostAssocRegion{nullptr};
const parser::OpenMPLoopConstruct *innermostConstruct{&x};
- while (const auto &innerAssocStmt{
- std::get<std::optional<parser::NestedConstruct>>(
- innermostConstruct->t)}) {
- innermostAssocRegion = &(innerAssocStmt.value());
- if (const auto *innerConstruct{
- std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
- innermostAssocRegion)}) {
- innermostConstruct = &innerConstruct->value();
- } else {
- break;
- }
+ while (auto *nested{innermostConstruct->GetNestedConstruct()}) {
+ innermostConstruct = nested;
}
- if (!innermostAssocRegion)
- return;
- const auto &outer{std::get_if<parser::DoConstruct>(innermostAssocRegion)};
+ const auto *outer{innermostConstruct->GetNestedLoop()};
if (!outer)
return;
@@ -2398,61 +2372,51 @@ void OmpAttributeVisitor::PrivatizeAssociatedLoopIndexAndCheckLoopLevel(
const parser::OmpClause *clause{GetAssociatedClause()};
bool hasCollapseClause{
clause ? (clause->Id() == llvm::omp::OMPC_collapse) : false};
- const parser::OpenMPLoopConstruct *innerMostLoop = &x;
- const parser::NestedConstruct *innerMostNest = nullptr;
- while (auto &optLoopCons{
- std::get<std::optional<parser::NestedConstruct>>(innerMostLoop->t)}) {
- innerMostNest = &(optLoopCons.value());
- if (const auto *innerLoop{
- std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
- innerMostNest)}) {
- innerMostLoop = &(innerLoop->value());
- } else
- break;
- }
- if (innerMostNest) {
- if (const auto &outer{std::get_if<parser::DoConstruct>(innerMostNest)}) {
- for (const parser::DoConstruct *loop{&*outer}; loop && level > 0;
- --level) {
- if (loop->IsDoConcurrent()) {
- // DO CONCURRENT is explicitly allowed for the LOOP construct so long
- // as there isn't a COLLAPSE clause
- if (isLoopConstruct) {
- if (hasCollapseClause) {
- // hasCollapseClause implies clause != nullptr
- context_.Say(clause->source,
- "DO CONCURRENT loops cannot be used with the COLLAPSE clause."_err_en_US);
- }
- } else {
- auto &stmt =
- std::get<parser::Statement<parser::NonLabelDoStmt>>(loop->t);
- context_.Say(stmt.source,
- "DO CONCURRENT loops cannot form part of a loop nest."_err_en_US);
+ const parser::OpenMPLoopConstruct *innerMostNest = &x;
+ while (auto *nested{innerMostNest->GetNestedConstruct()}) {
+ innerMostNest = nested;
+ }
+
+ if (const auto *outer{innerMostNest->GetNestedLoop()}) {
+ for (const parser::DoConstruct *loop{&*outer}; loop && level > 0; --level) {
+ if (loop->IsDoConcurrent()) {
+ // DO CONCURRENT is explicitly allowed for the LOOP construct so long
+ // as there isn't a COLLAPSE clause
+ if (isLoopConstruct) {
+ if (hasCollapseClause) {
+ // hasCollapseClause implies clause != nullptr
+ context_.Say(clause->source,
+ "DO CONCURRENT loops cannot be used with the COLLAPSE clause."_err_en_US);
}
+ } else {
+ auto &stmt =
+ std::get<parser::Statement<parser::NonLabelDoStmt>>(loop->t);
+ context_.Say(stmt.source,
+ "DO CONCURRENT loops cannot form part of a loop nest."_err_en_US);
}
- // go through all the nested do-loops and resolve index variables
- const parser::Name *iv{GetLoopIndex(*loop)};
- if (iv) {
- if (auto *symbol{ResolveOmp(*iv, ivDSA, currScope())}) {
- SetSymbolDSA(*symbol, {Symbol::Flag::OmpPreDetermined, ivDSA});
- iv->symbol = symbol; // adjust the symbol within region
- AddToContextObjectWithDSA(*symbol, ivDSA);
- }
-
- const auto &block{std::get<parser::Block>(loop->t)};
- const auto it{block.begin()};
- loop = it != block.end() ? GetDoConstructIf(*it) : nullptr;
+ }
+ // go through all the nested do-loops and resolve index variables
+ const parser::Name *iv{GetLoopIndex(*loop)};
+ if (iv) {
+ if (auto *symbol{ResolveOmp(*iv, ivDSA, currScope())}) {
+ SetSymbolDSA(*symbol, {Symbol::Flag::OmpPreDetermined, ivDSA});
+ iv->symbol = symbol; // adjust the symbol within region
+ AddToContextObjectWithDSA(*symbol, ivDSA);
}
+
+ const auto &block{std::get<parser::Block>(loop->t)};
+ const auto it{block.begin()};
+ loop = it != block.end() ? GetDoConstructIf(*it) : nullptr;
}
- CheckAssocLoopLevel(level, GetAssociatedClause());
-...
[truncated]
|
Stylie777
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Meinersbur
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "flang/Parser/parse-tree.h" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] extra empty line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is some difference in behavior between the clang-format used in the CI, and the one I'm using. Specifically, mine doesn't care whether the header has the same name as the .cpp file being formatted, while the CI one will put it at the front of the includes. An empty line will stop that include from being reordered, so I can c-f the entire file without worrying about the order of includes.
Instead of storing a variant with specific types, store parser::Block as the body. Add two access functions to make the traversal of the nest simpler.
This will allow storing loop-nest sequences in the future.