Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -5260,15 +5260,15 @@ using NestedConstruct =
struct OpenMPLoopConstruct {
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
: t({std::move(a), std::nullopt, std::nullopt}) {}
: t({std::move(a), std::list<NestedConstruct>(), std::nullopt}) {}

const OmpBeginLoopDirective &BeginDir() const {
return std::get<OmpBeginLoopDirective>(t);
}
const std::optional<OmpEndLoopDirective> &EndDir() const {
return std::get<std::optional<OmpEndLoopDirective>>(t);
}
std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
std::tuple<OmpBeginLoopDirective, std::list<NestedConstruct>,
std::optional<OmpEndLoopDirective>>
t;
};
Expand Down
7 changes: 7 additions & 0 deletions flang/include/flang/Semantics/openmp-directive-sets.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,17 @@ static const OmpDirectiveSet loopConstructSet{
Directive::OMPD_teams_distribute_parallel_do_simd,
Directive::OMPD_teams_distribute_simd,
Directive::OMPD_teams_loop,
Directive::OMPD_fuse,
Directive::OMPD_tile,
Directive::OMPD_unroll,
};

static const OmpDirectiveSet loopTransformationSet{
Directive::OMPD_tile,
Directive::OMPD_unroll,
Directive::OMPD_fuse,
};

static const OmpDirectiveSet nonPartialVarSet{
Directive::OMPD_allocate,
Directive::OMPD_allocators,
Expand Down
15 changes: 11 additions & 4 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3471,6 +3471,13 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_tile:
genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_fuse: {
unsigned version = semaCtx.langOptions().OpenMPVersion;
if (!semaCtx.langOptions().OpenMPSimd)
TODO(loc, "Unhandled loop directive (" +
llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
break;
}
case llvm::omp::Directive::OMPD_unroll:
genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
Expand Down Expand Up @@ -3918,12 +3925,12 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,

mlir::Location currentLocation = converter.genLocation(beginSpec.source);

auto &optLoopCons =
std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
if (optLoopCons.has_value()) {
auto &loopConsList =
std::get<std::list<parser::NestedConstruct>>(loopConstruct.t);
for (auto &loopCons : loopConsList) {
if (auto *ompNestedLoopCons{
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&*optLoopCons)}) {
&loopCons)}) {
llvm::omp::Directive nestedDirective =
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
switch (nestedDirective) {
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,13 +631,13 @@ static void processTileSizesFromOpenMPConstruct(
if (!ompCons)
return;
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
const auto &nestedOptional =
std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
assert(nestedOptional.has_value() &&
const auto &loopConsList =
std::get<std::list<parser::NestedConstruct>>(ompLoop->t);
assert(loopConsList.size() == 1 &&
"Expected a DoConstruct or OpenMPLoopConstruct");
const auto *innerConstruct =
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&(nestedOptional.value()));
&(loopConsList.front()));
if (innerConstruct) {
const auto &innerLoopDirective = innerConstruct->value();
const parser::OmpDirectiveSpecification &innerBeginSpec =
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Parser/openmp-parsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2037,6 +2037,7 @@ static constexpr DirectiveSet GetLoopDirectives() {
unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
unsigned(Directive::OMPD_teams_distribute_simd),
unsigned(Directive::OMPD_teams_loop),
unsigned(Directive::OMPD_fuse),
unsigned(Directive::OMPD_tile),
unsigned(Directive::OMPD_unroll),
};
Expand Down
3 changes: 1 addition & 2 deletions flang/lib/Parser/unparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2725,8 +2725,7 @@ class UnparseVisitor {
}
void Unparse(const OpenMPLoopConstruct &x) {
Walk(std::get<OmpBeginLoopDirective>(x.t));
Walk(std::get<std::optional<std::variant<DoConstruct,
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t));
Walk(std::get<std::list<parser::NestedConstruct>>(x.t));
Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
}
void Unparse(const BasedPointer &x) {
Expand Down
130 changes: 78 additions & 52 deletions flang/lib/Semantics/canonicalize-omp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "canonicalize-omp.h"
#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/openmp-directive-sets.h"
#include "flang/Semantics/semantics.h"

// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
Expand Down Expand Up @@ -137,33 +138,45 @@ class CanonicalizationOfOmp {
"A DO loop must follow the %s directive"_err_en_US,
parser::ToUpperCaseLetters(dirName.source.ToString()));
};
auto tileUnrollError = [](const parser::OmpDirectiveName &dirName,
parser::Messages &messages) {
auto transformUnrollError = [](const parser::OmpDirectiveName &dirName,
parser::Messages &messages) {
messages.Say(dirName.source,
"If a loop construct has been fully unrolled, it cannot then be tiled"_err_en_US,
"If a loop construct has been fully unrolled, it cannot then be further transformed"_err_en_US,
parser::ToUpperCaseLetters(dirName.source.ToString()));
};
auto missingEndFuse = [](auto &dir, auto &messages) {
messages.Say(dir.source,
"The %s construct requires the END FUSE directive"_err_en_US,
parser::ToUpperCaseLetters(dir.source.ToString()));
};

bool endFuseNeeded = beginName.v == llvm::omp::Directive::OMPD_fuse;

nextIt = it;
while (++nextIt != block.end()) {
nextIt++;
while (nextIt != block.end()) {
// Ignore compiler directives.
if (GetConstructIf<parser::CompilerDirective>(*nextIt))
if (GetConstructIf<parser::CompilerDirective>(*nextIt)) {
nextIt++;
continue;
}

if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
if (doCons->GetLoopControl()) {
// move DoConstruct
std::get<std::optional<std::variant<parser::DoConstruct,
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
std::move(*doCons);
std::get<std::list<parser::NestedConstruct>>(x.t).push_back(
std::move(*doCons));
nextIt = block.erase(nextIt);
// try to match OmpEndLoopDirective
if (nextIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
nextIt = block.erase(nextIt);
auto &endDirName = endDir->DirName();
if (endDirName.v != llvm::omp::Directive::OMPD_fuse) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
nextIt = block.erase(nextIt);
}
}
}
} else {
Expand All @@ -173,53 +186,48 @@ class CanonicalizationOfOmp {
}
} else if (auto *ompLoopCons{
GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
// We should allow UNROLL and TILE constructs to be inserted between an
// OpenMP Loop Construct and the DO loop itself
// We should allow loop transformation constructs to be inserted between
// an OpenMP Loop Construct and the DO loop itself
auto &nestedBeginDirective = ompLoopCons->BeginDir();
auto &nestedBeginName = nestedBeginDirective.DirName();
if ((nestedBeginName.v == llvm::omp::Directive::OMPD_unroll ||
nestedBeginName.v == llvm::omp::Directive::OMPD_tile) &&
!(nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
beginName.v == llvm::omp::Directive::OMPD_tile)) {
// iterate through the remaining block items to find the end directive
// for the unroll/tile directive.
parser::Block::iterator endIt;
endIt = nextIt;
while (endIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*endIt)}) {
auto &endDirName = endDir->DirName();
if (endDirName.v == beginName.v) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
endIt = block.erase(endIt);
continue;
if (llvm::omp::loopTransformationSet.test(nestedBeginName.v)) {
if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
llvm::omp::loopTransformationSet.test(beginName.v)) {
// if a loop has been unrolled, the user can not then transform that
// loop as it has been unrolled
const parser::OmpClauseList &unrollClauseList{
nestedBeginDirective.Clauses()};
if (unrollClauseList.v.empty()) {
// if the clause list is empty for an unroll construct, we assume
// the loop is being fully unrolled
transformUnrollError(beginName, messages_);
} else {
// parse the clauses for the unroll directive to find the full
// clause
for (auto &clause : unrollClauseList.v) {
if (clause.Id() == llvm::omp::OMPC_full) {
transformUnrollError(beginName, messages_);
}
}
}
++endIt;
}
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
auto &ompLoop = std::get<std::optional<parser::NestedConstruct>>(x.t);
ompLoop =
std::optional<parser::NestedConstruct>{parser::NestedConstruct{
common::Indirection{std::move(*ompLoopCons)}}};
auto &loopConsList =
std::get<std::list<parser::NestedConstruct>>(x.t);
loopConsList.push_back(parser::NestedConstruct{
common::Indirection{std::move(*ompLoopCons)}});
nextIt = block.erase(nextIt);
} else if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
beginName.v == llvm::omp::Directive::OMPD_tile) {
// if a loop has been unrolled, the user can not then tile that loop
// as it has been unrolled
const parser::OmpClauseList &unrollClauseList{
nestedBeginDirective.Clauses()};
if (unrollClauseList.v.empty()) {
// if the clause list is empty for an unroll construct, we assume
// the loop is being fully unrolled
tileUnrollError(beginName, messages_);
} else {
// parse the clauses for the unroll directive to find the full
// clause
for (auto &clause : unrollClauseList.v) {
if (clause.Id() == llvm::omp::OMPC_full) {
tileUnrollError(beginName, messages_);
// check the following block item to find the end directive
// for the loop transform directive.
if (nextIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
auto &endDirName = endDir->DirName();
if (endDirName.v == beginName.v &&
endDirName.v != llvm::omp::Directive::OMPD_fuse) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
nextIt = block.erase(nextIt);
}
}
}
Expand All @@ -231,11 +239,29 @@ class CanonicalizationOfOmp {
} else {
missingDoConstruct(beginName, messages_);
}

if (endFuseNeeded && nextIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
auto &endDirName = endDir->DirName();
if (endDirName.v == llvm::omp::Directive::OMPD_fuse) {
endFuseNeeded = false;
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
nextIt = block.erase(nextIt);
}
}
}
if (endFuseNeeded)
continue;
// If we get here, we either found a loop, or issued an error message.
return;
}
if (nextIt == block.end()) {
missingDoConstruct(beginName, messages_);
if (endFuseNeeded)
missingEndFuse(beginName, messages_);
else
missingDoConstruct(beginName, messages_);
}
}

Expand Down
Loading