Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -5345,21 +5345,21 @@ struct OmpEndLoopDirective : public OmpEndDirective {
};

// OpenMP directives enclosing do loop
using NestedConstruct =
std::variant<DoConstruct, common::Indirection<OpenMPLoopConstruct>>;
struct OpenMPLoopConstruct {
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
: t({std::move(a), std::nullopt, std::nullopt}) {}
: t({std::move(a), Block{}, std::nullopt}) {}

const OmpBeginLoopDirective &BeginDir() const {
return std::get<OmpBeginLoopDirective>(t);
}
const std::optional<OmpEndLoopDirective> &EndDir() const {
return std::get<std::optional<OmpEndLoopDirective>>(t);
}
std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
std::optional<OmpEndLoopDirective>>
const DoConstruct *GetNestedLoop() const;
const OpenMPLoopConstruct *GetNestedConstruct() const;

std::tuple<OmpBeginLoopDirective, Block, std::optional<OmpEndLoopDirective>>
t;
};

Expand Down
37 changes: 16 additions & 21 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3936,27 +3936,22 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,

mlir::Location currentLocation = converter.genLocation(beginSpec.source);

auto &optLoopCons =
std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
if (optLoopCons.has_value()) {
if (auto *ompNestedLoopCons{
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&*optLoopCons)}) {
llvm::omp::Directive nestedDirective =
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
switch (nestedDirective) {
case llvm::omp::Directive::OMPD_tile:
// Skip OMPD_tile since the tile sizes will be retrieved when
// generating the omp.loop_nest op.
break;
default: {
unsigned version = semaCtx.langOptions().OpenMPVersion;
TODO(currentLocation,
"Applying a loop-associated on the loop generated by the " +
llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
" construct");
}
}
if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
loopConstruct.GetNestedConstruct()) {
llvm::omp::Directive nestedDirective =
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
switch (nestedDirective) {
case llvm::omp::Directive::OMPD_tile:
// Skip OMPD_tile since the tile sizes will be retrieved when
// generating the omp.loop_nest op.
break;
default: {
unsigned version = semaCtx.langOptions().OpenMPVersion;
TODO(currentLocation,
"Applying a loop-associated on the loop generated by the " +
llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
" construct");
}
}
}

Expand Down
12 changes: 2 additions & 10 deletions flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,17 +631,9 @@ static void processTileSizesFromOpenMPConstruct(
if (!ompCons)
return;
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
const auto &nestedOptional =
std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
assert(nestedOptional.has_value() &&
"Expected a DoConstruct or OpenMPLoopConstruct");
const auto *innerConstruct =
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&(nestedOptional.value()));
if (innerConstruct) {
const auto &innerLoopDirective = innerConstruct->value();
if (auto *innerConstruct = ompLoop->GetNestedConstruct()) {
const parser::OmpDirectiveSpecification &innerBeginSpec =
innerLoopDirective.BeginDir();
innerConstruct->BeginDir();
if (innerBeginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
// Get the size values from parse tree and convert to a vector.
for (const auto &clause : innerBeginSpec.Clauses().v) {
Expand Down
16 changes: 16 additions & 0 deletions flang/lib/Parser/parse-tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
//===----------------------------------------------------------------------===//

#include "flang/Parser/parse-tree.h"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] extra empty line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some difference in behavior between the clang-format used in the CI, and the one I'm using. Specifically, mine doesn't care whether the header has the same name as the .cpp file being formatted, while the CI one will put it at the front of the includes. An empty line will stop that include from being reordered, so I can c-f the entire file without worrying about the order of includes.

#include "flang/Common/idioms.h"
#include "flang/Common/indirection.h"
#include "flang/Parser/openmp-utils.h"
#include "flang/Parser/tools.h"
#include "flang/Parser/user-state.h"
#include "llvm/ADT/ArrayRef.h"
Expand Down Expand Up @@ -432,6 +434,20 @@ const OmpClauseList &OmpDirectiveSpecification::Clauses() const {
return empty;
}

const DoConstruct *OpenMPLoopConstruct::GetNestedLoop() const {
if (auto &body{std::get<Block>(t)}; !body.empty()) {
return Unwrap<DoConstruct>(body.front());
}
return nullptr;
}

const OpenMPLoopConstruct *OpenMPLoopConstruct::GetNestedConstruct() const {
if (auto &body{std::get<Block>(t)}; !body.empty()) {
return Unwrap<OpenMPLoopConstruct>(body.front());
}
return nullptr;
}

static bool InitCharBlocksFromStrings(llvm::MutableArrayRef<CharBlock> blocks,
llvm::ArrayRef<std::string> strings) {
for (auto [i, n] : llvm::enumerate(strings)) {
Expand Down
6 changes: 0 additions & 6 deletions flang/lib/Parser/unparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2706,12 +2706,6 @@ class UnparseVisitor {
Put("\n");
EndOpenMP();
}
void Unparse(const OpenMPLoopConstruct &x) {
Walk(std::get<OmpBeginLoopDirective>(x.t));
Walk(std::get<std::optional<std::variant<DoConstruct,
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t));
Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
}
void Unparse(const BasedPointer &x) {
Put('('), Walk(std::get<0>(x.t)), Put(","), Walk(std::get<1>(x.t));
Walk("(", std::get<std::optional<ArraySpec>>(x.t), ")"), Put(')');
Expand Down
11 changes: 4 additions & 7 deletions flang/lib/Semantics/canonicalize-omp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class CanonicalizationOfOmp {
parser::ToUpperCaseLetters(dirName.source.ToString()));
};

auto &body{std::get<parser::Block>(x.t)};

nextIt = it;
while (++nextIt != block.end()) {
// Ignore compiler directives.
Expand All @@ -152,9 +154,7 @@ class CanonicalizationOfOmp {
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
if (doCons->GetLoopControl()) {
// move DoConstruct
std::get<std::optional<std::variant<parser::DoConstruct,
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
std::move(*doCons);
body.push_back(std::move(*nextIt));
nextIt = block.erase(nextIt);
// try to match OmpEndLoopDirective
if (nextIt != block.end()) {
Expand Down Expand Up @@ -198,10 +198,7 @@ class CanonicalizationOfOmp {
++endIt;
}
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
auto &ompLoop = std::get<std::optional<parser::NestedConstruct>>(x.t);
ompLoop =
std::optional<parser::NestedConstruct>{parser::NestedConstruct{
common::Indirection{std::move(*ompLoopCons)}}};
body.push_back(std::move(*nextIt));
nextIt = block.erase(nextIt);
} else if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
beginName.v == llvm::omp::Directive::OMPD_tile) {
Expand Down
96 changes: 37 additions & 59 deletions flang/lib/Semantics/check-omp-loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,9 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
}
SetLoopInfo(x);

auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
if (optLoopCons.has_value()) {
if (const auto &doConstruct{
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
CheckNoBranching(doBlock, beginName.v, beginName.source);
}
if (const auto *doConstruct{x.GetNestedLoop()}) {
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
CheckNoBranching(doBlock, beginName.v, beginName.source);
}
CheckLoopItrVariableIsInt(x);
CheckAssociatedLoopConstraints(x);
Expand All @@ -314,46 +310,34 @@ const parser::Name OmpStructureChecker::GetLoopIndex(
}

void OmpStructureChecker::SetLoopInfo(const parser::OpenMPLoopConstruct &x) {
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
if (optLoopCons.has_value()) {
if (const auto &loopConstruct{
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
const parser::DoConstruct *loop{&*loopConstruct};
if (loop && loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
SetLoopIv(itrVal.symbol);
}
if (const auto *loop{x.GetNestedLoop()}) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
SetLoopIv(itrVal.symbol);
}
}
}

void OmpStructureChecker::CheckLoopItrVariableIsInt(
const parser::OpenMPLoopConstruct &x) {
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
if (optLoopCons.has_value()) {
if (const auto &loopConstruct{
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {

for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
if (itrVal.symbol) {
const auto *type{itrVal.symbol->GetType()};
if (!type->IsNumeric(TypeCategory::Integer)) {
context_.Say(itrVal.source,
"The DO loop iteration"
" variable must be of the type integer."_err_en_US,
itrVal.ToString());
}
}
for (const parser::DoConstruct *loop{x.GetNestedLoop()}; loop;) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
if (itrVal.symbol) {
const auto *type{itrVal.symbol->GetType()};
if (!type->IsNumeric(TypeCategory::Integer)) {
context_.Say(itrVal.source,
"The DO loop iteration"
" variable must be of the type integer."_err_en_US,
itrVal.ToString());
}
// Get the next DoConstruct if block is not empty.
const auto &block{std::get<parser::Block>(loop->t)};
const auto it{block.begin()};
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
: nullptr;
}
}
// Get the next DoConstruct if block is not empty.
const auto &block{std::get<parser::Block>(loop->t)};
const auto it{block.begin()};
loop =
it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it) : nullptr;
}
}

Expand Down Expand Up @@ -417,29 +401,23 @@ void OmpStructureChecker::CheckDistLinear(

// Match the loop index variables with the collected symbols from linear
// clauses.
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
if (optLoopCons.has_value()) {
if (const auto &loopConstruct{
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
if (itrVal.symbol) {
// Remove the symbol from the collected set
indexVars.erase(&itrVal.symbol->GetUltimate());
}
collapseVal--;
if (collapseVal == 0) {
break;
}
}
// Get the next DoConstruct if block is not empty.
const auto &block{std::get<parser::Block>(loop->t)};
const auto it{block.begin()};
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
: nullptr;
for (const parser::DoConstruct *loop{x.GetNestedLoop()}; loop;) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
if (itrVal.symbol) {
// Remove the symbol from the collected set
indexVars.erase(&itrVal.symbol->GetUltimate());
}
collapseVal--;
if (collapseVal == 0) {
break;
}
}
// Get the next DoConstruct if block is not empty.
const auto &block{std::get<parser::Block>(loop->t)};
const auto it{block.begin()};
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
: nullptr;
}

// Show error for the remaining variables
Expand Down
Loading
Loading