Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions flang/include/flang/Parser/openmp-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ struct DirectiveNameScope {
}

static OmpDirectiveName GetOmpDirectiveName(const OmpBeginLoopDirective &x) {
auto &dir{std::get<OmpLoopDirective>(x.t)};
return MakeName(dir.source, dir.v);
return x.DirName();
}

static OmpDirectiveName GetOmpDirectiveName(const OpenMPSectionConstruct &x) {
Expand Down
19 changes: 11 additions & 8 deletions flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -5158,16 +5158,12 @@ struct OpenMPStandaloneConstruct {
u;
};

struct OmpBeginLoopDirective {
TUPLE_CLASS_BOILERPLATE(OmpBeginLoopDirective);
std::tuple<OmpLoopDirective, OmpClauseList> t;
CharBlock source;
struct OmpBeginLoopDirective : public OmpBeginDirective {
INHERITED_TUPLE_CLASS_BOILERPLATE(OmpBeginLoopDirective, OmpBeginDirective);
};

struct OmpEndLoopDirective {
TUPLE_CLASS_BOILERPLATE(OmpEndLoopDirective);
std::tuple<OmpLoopDirective, OmpClauseList> t;
CharBlock source;
struct OmpEndLoopDirective : public OmpEndDirective {
INHERITED_TUPLE_CLASS_BOILERPLATE(OmpEndLoopDirective, OmpEndDirective);
};

// OpenMP directives enclosing do loop
Expand All @@ -5177,6 +5173,13 @@ struct OpenMPLoopConstruct {
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
: t({std::move(a), std::nullopt, std::nullopt}) {}

const OmpBeginLoopDirective &BeginDir() const {
return std::get<OmpBeginLoopDirective>(t);
}
const std::optional<OmpEndLoopDirective> &EndDir() const {
return std::get<std::optional<OmpEndLoopDirective>>(t);
}
std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
std::optional<OmpEndLoopDirective>>
t;
Expand Down
56 changes: 17 additions & 39 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,26 +408,15 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
const parser::OmpClauseList *beginClauseList = nullptr;
const parser::OmpClauseList *endClauseList = nullptr;
common::visit(
common::visitors{
[&](const parser::OmpBlockConstruct &ompConstruct) {
beginClauseList = &ompConstruct.BeginDir().Clauses();
if (auto &endSpec = ompConstruct.EndDir())
endClauseList = &endSpec->Clauses();
},
[&](const parser::OpenMPLoopConstruct &ompConstruct) {
const auto &beginDirective =
std::get<parser::OmpBeginLoopDirective>(ompConstruct.t);
beginClauseList =
&std::get<parser::OmpClauseList>(beginDirective.t);

if (auto &endDirective =
std::get<std::optional<parser::OmpEndLoopDirective>>(
ompConstruct.t)) {
endClauseList =
&std::get<parser::OmpClauseList>(endDirective->t);
}
},
[&](const auto &) {}},
[&](const auto &construct) {
using Type = llvm::remove_cvref_t<decltype(construct)>;
if constexpr (std::is_same_v<Type, parser::OmpBlockConstruct> ||
std::is_same_v<Type, parser::OpenMPLoopConstruct>) {
beginClauseList = &construct.BeginDir().Clauses();
if (auto &endSpec = construct.EndDir())
endClauseList = &endSpec->Clauses();
}
},
ompEval->u);

assert(beginClauseList && "expected begin directive");
Expand Down Expand Up @@ -3820,19 +3809,12 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval,
const parser::OpenMPLoopConstruct &loopConstruct) {
const auto &beginLoopDirective =
std::get<parser::OmpBeginLoopDirective>(loopConstruct.t);
List<Clause> clauses = makeClauses(
std::get<parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
if (auto &endLoopDirective =
std::get<std::optional<parser::OmpEndLoopDirective>>(
loopConstruct.t)) {
clauses.append(makeClauses(
std::get<parser::OmpClauseList>(endLoopDirective->t), semaCtx));
}
const parser::OmpDirectiveSpecification &beginSpec = loopConstruct.BeginDir();
List<Clause> clauses = makeClauses(beginSpec.Clauses(), semaCtx);
if (auto &endSpec = loopConstruct.EndDir())
clauses.append(makeClauses(endSpec->Clauses(), semaCtx));

mlir::Location currentLocation =
converter.genLocation(beginLoopDirective.source);
mlir::Location currentLocation = converter.genLocation(beginSpec.source);

auto &optLoopCons =
std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
Expand All @@ -3858,13 +3840,10 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
}
}

llvm::omp::Directive directive =
parser::omp::GetOmpDirectiveName(beginLoopDirective).v;
const parser::CharBlock &source =
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source;
const parser::OmpDirectiveName &beginName = beginSpec.DirName();
ConstructQueue queue{
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
eval, source, directive, clauses)};
eval, beginName.source, beginName.v, clauses)};
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
queue.begin());
}
Expand Down Expand Up @@ -4047,8 +4026,7 @@ bool Fortran::lower::isOpenMPTargetConstruct(
dir = block->BeginDir().DirId();
} else if (const auto *loop =
std::get_if<parser::OpenMPLoopConstruct>(&omp.u)) {
const auto &begin = std::get<parser::OmpBeginLoopDirective>(loop->t);
dir = std::get<parser::OmpLoopDirective>(begin.t).v;
dir = loop->BeginDir().DirId();
}
return llvm::omp::allTargetSet.test(dir);
}
Expand Down
13 changes: 4 additions & 9 deletions flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,16 +616,11 @@ static void processTileSizesFromOpenMPConstruct(
&(nestedOptional.value()));
if (innerConstruct) {
const auto &innerLoopDirective = innerConstruct->value();
const auto &innerBegin =
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
const auto &innerDirective =
std::get<parser::OmpLoopDirective>(innerBegin.t).v;

if (innerDirective == llvm::omp::Directive::OMPD_tile) {
const parser::OmpDirectiveSpecification &innerBeginSpec =
innerLoopDirective.BeginDir();
if (innerBeginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
// Get the size values from parse tree and convert to a vector.
const auto &innerClauseList{
std::get<parser::OmpClauseList>(innerBegin.t)};
for (const auto &clause : innerClauseList.v) {
for (const auto &clause : innerBeginSpec.Clauses().v) {
if (const auto tclause{
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
processFun(tclause);
Expand Down
155 changes: 84 additions & 71 deletions flang/lib/Parser/openmp-parsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,32 @@
#include "flang/Parser/openmp-utils.h"
#include "flang/Parser/parse-tree.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Bitset.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include "llvm/Support/MathExtras.h"

#include <algorithm>
#include <cctype>
#include <iterator>
#include <list>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>
Comment on lines +28 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you apply include-what-you-use, or where do these come from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some time I was thinking about using a filter function instead of a set of directives in the parser. I went to add an include of functional (for std::function) only to see that there were no std includes at all. So I searched the file for "std::" and added includes for everything I found. I thought about doing it in a separate PR, but then I thought that it probably wasn't worth the effort.


// OpenMP Directives and Clauses
namespace Fortran::parser {
using namespace Fortran::parser::omp;

using DirectiveSet =
llvm::Bitset<llvm::NextPowerOf2(llvm::omp::Directive_enumSize)>;

// Helper function to print the buffer contents starting at the current point.
[[maybe_unused]] static std::string ahead(const ParseState &state) {
return std::string(
Expand Down Expand Up @@ -1349,95 +1366,46 @@ TYPE_PARSER(sourced(construct<OpenMPUtilityConstruct>(
TYPE_PARSER(sourced(construct<OmpMetadirectiveDirective>(
verbatim("METADIRECTIVE"_tok), Parser<OmpClauseList>{})))

// Omp directives enclosing do loop
TYPE_PARSER(sourced(construct<OmpLoopDirective>(first(
"DISTRIBUTE PARALLEL DO SIMD" >>
pure(llvm::omp::Directive::OMPD_distribute_parallel_do_simd),
"DISTRIBUTE PARALLEL DO" >>
pure(llvm::omp::Directive::OMPD_distribute_parallel_do),
"DISTRIBUTE SIMD" >> pure(llvm::omp::Directive::OMPD_distribute_simd),
"DISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_distribute),
"DO SIMD" >> pure(llvm::omp::Directive::OMPD_do_simd),
"DO" >> pure(llvm::omp::Directive::OMPD_do),
"LOOP" >> pure(llvm::omp::Directive::OMPD_loop),
"MASKED TASKLOOP SIMD" >>
pure(llvm::omp::Directive::OMPD_masked_taskloop_simd),
"MASKED TASKLOOP" >> pure(llvm::omp::Directive::OMPD_masked_taskloop),
"MASTER TASKLOOP SIMD" >>
pure(llvm::omp::Directive::OMPD_master_taskloop_simd),
"MASTER TASKLOOP" >> pure(llvm::omp::Directive::OMPD_master_taskloop),
"PARALLEL DO SIMD" >> pure(llvm::omp::Directive::OMPD_parallel_do_simd),
"PARALLEL DO" >> pure(llvm::omp::Directive::OMPD_parallel_do),
"PARALLEL MASKED TASKLOOP SIMD" >>
pure(llvm::omp::Directive::OMPD_parallel_masked_taskloop_simd),
"PARALLEL MASKED TASKLOOP" >>
pure(llvm::omp::Directive::OMPD_parallel_masked_taskloop),
"PARALLEL MASTER TASKLOOP SIMD" >>
pure(llvm::omp::Directive::OMPD_parallel_master_taskloop_simd),
"PARALLEL MASTER TASKLOOP" >>
pure(llvm::omp::Directive::OMPD_parallel_master_taskloop),
"SIMD" >> pure(llvm::omp::Directive::OMPD_simd),
"TARGET LOOP" >> pure(llvm::omp::Directive::OMPD_target_loop),
"TARGET PARALLEL DO SIMD" >>
pure(llvm::omp::Directive::OMPD_target_parallel_do_simd),
"TARGET PARALLEL DO" >> pure(llvm::omp::Directive::OMPD_target_parallel_do),
"TARGET PARALLEL LOOP" >>
pure(llvm::omp::Directive::OMPD_target_parallel_loop),
"TARGET SIMD" >> pure(llvm::omp::Directive::OMPD_target_simd),
"TARGET TEAMS DISTRIBUTE PARALLEL DO SIMD" >>
pure(llvm::omp::Directive::
OMPD_target_teams_distribute_parallel_do_simd),
"TARGET TEAMS DISTRIBUTE PARALLEL DO" >>
pure(llvm::omp::Directive::OMPD_target_teams_distribute_parallel_do),
"TARGET TEAMS DISTRIBUTE SIMD" >>
pure(llvm::omp::Directive::OMPD_target_teams_distribute_simd),
"TARGET TEAMS DISTRIBUTE" >>
pure(llvm::omp::Directive::OMPD_target_teams_distribute),
"TARGET TEAMS LOOP" >> pure(llvm::omp::Directive::OMPD_target_teams_loop),
"TASKLOOP SIMD" >> pure(llvm::omp::Directive::OMPD_taskloop_simd),
"TASKLOOP" >> pure(llvm::omp::Directive::OMPD_taskloop),
"TEAMS DISTRIBUTE PARALLEL DO SIMD" >>
pure(llvm::omp::Directive::OMPD_teams_distribute_parallel_do_simd),
"TEAMS DISTRIBUTE PARALLEL DO" >>
pure(llvm::omp::Directive::OMPD_teams_distribute_parallel_do),
"TEAMS DISTRIBUTE SIMD" >>
pure(llvm::omp::Directive::OMPD_teams_distribute_simd),
"TEAMS DISTRIBUTE" >> pure(llvm::omp::Directive::OMPD_teams_distribute),
"TEAMS LOOP" >> pure(llvm::omp::Directive::OMPD_teams_loop),
"TILE" >> pure(llvm::omp::Directive::OMPD_tile),
"UNROLL" >> pure(llvm::omp::Directive::OMPD_unroll)))))

TYPE_PARSER(sourced(construct<OmpBeginLoopDirective>(
sourced(Parser<OmpLoopDirective>{}), Parser<OmpClauseList>{})))

static inline constexpr auto IsDirective(llvm::omp::Directive dir) {
return [dir](const OmpDirectiveName &name) -> bool { return dir == name.v; };
}

static inline constexpr auto IsMemberOf(const DirectiveSet &dirs) {
return [&dirs](const OmpDirectiveName &name) -> bool {
return dirs.test(llvm::to_underlying(name.v));
};
}

struct OmpBeginDirectiveParser {
using resultType = OmpDirectiveSpecification;

constexpr OmpBeginDirectiveParser(llvm::omp::Directive dir) : dir_(dir) {}
constexpr OmpBeginDirectiveParser(DirectiveSet dirs) : dirs_(dirs) {}
constexpr OmpBeginDirectiveParser(llvm::omp::Directive dir) {
dirs_.set(llvm::to_underlying(dir));
}

std::optional<resultType> Parse(ParseState &state) const {
auto &&p{predicated(Parser<OmpDirectiveName>{}, IsDirective(dir_)) >=
auto &&p{predicated(Parser<OmpDirectiveName>{}, IsMemberOf(dirs_)) >=
Parser<OmpDirectiveSpecification>{}};
return p.Parse(state);
}

private:
llvm::omp::Directive dir_;
DirectiveSet dirs_;
};

struct OmpEndDirectiveParser {
using resultType = OmpDirectiveSpecification;

constexpr OmpEndDirectiveParser(llvm::omp::Directive dir) : dir_(dir) {}
constexpr OmpEndDirectiveParser(DirectiveSet dirs) : dirs_(dirs) {}
constexpr OmpEndDirectiveParser(llvm::omp::Directive dir) {
dirs_.set(llvm::to_underlying(dir));
}

std::optional<resultType> Parse(ParseState &state) const {
if (startOmpLine.Parse(state)) {
if (auto endToken{verbatim("END"_sptok).Parse(state)}) {
if (auto &&dirSpec{OmpBeginDirectiveParser(dir_).Parse(state)}) {
if (auto &&dirSpec{OmpBeginDirectiveParser(dirs_).Parse(state)}) {
// Extend the "source" on both the OmpDirectiveName and the
// OmpDirectiveNameSpecification.
CharBlock &nameSource{std::get<OmpDirectiveName>(dirSpec->t).source};
Expand All @@ -1451,7 +1419,7 @@ struct OmpEndDirectiveParser {
}

private:
llvm::omp::Directive dir_;
DirectiveSet dirs_;
};

struct OmpStatementConstructParser {
Expand Down Expand Up @@ -1946,11 +1914,56 @@ TYPE_CONTEXT_PARSER("OpenMP construct"_en_US,
construct<OpenMPConstruct>(Parser<OpenMPAssumeConstruct>{}),
construct<OpenMPConstruct>(Parser<OpenMPCriticalConstruct>{}))))

static constexpr DirectiveSet GetLoopDirectives() {
using Directive = llvm::omp::Directive;
constexpr DirectiveSet loopDirectives{
unsigned(Directive::OMPD_distribute),
unsigned(Directive::OMPD_distribute_parallel_do),
unsigned(Directive::OMPD_distribute_parallel_do_simd),
unsigned(Directive::OMPD_distribute_simd),
unsigned(Directive::OMPD_do),
unsigned(Directive::OMPD_do_simd),
unsigned(Directive::OMPD_loop),
unsigned(Directive::OMPD_masked_taskloop),
unsigned(Directive::OMPD_masked_taskloop_simd),
unsigned(Directive::OMPD_master_taskloop),
unsigned(Directive::OMPD_master_taskloop_simd),
unsigned(Directive::OMPD_parallel_do),
unsigned(Directive::OMPD_parallel_do_simd),
unsigned(Directive::OMPD_parallel_masked_taskloop),
unsigned(Directive::OMPD_parallel_masked_taskloop_simd),
unsigned(Directive::OMPD_parallel_master_taskloop),
unsigned(Directive::OMPD_parallel_master_taskloop_simd),
unsigned(Directive::OMPD_simd),
unsigned(Directive::OMPD_target_loop),
unsigned(Directive::OMPD_target_parallel_do),
unsigned(Directive::OMPD_target_parallel_do_simd),
unsigned(Directive::OMPD_target_parallel_loop),
unsigned(Directive::OMPD_target_simd),
unsigned(Directive::OMPD_target_teams_distribute),
unsigned(Directive::OMPD_target_teams_distribute_parallel_do),
unsigned(Directive::OMPD_target_teams_distribute_parallel_do_simd),
unsigned(Directive::OMPD_target_teams_distribute_simd),
unsigned(Directive::OMPD_target_teams_loop),
unsigned(Directive::OMPD_taskloop),
unsigned(Directive::OMPD_taskloop_simd),
unsigned(Directive::OMPD_teams_distribute),
unsigned(Directive::OMPD_teams_distribute_parallel_do),
unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
unsigned(Directive::OMPD_teams_distribute_simd),
unsigned(Directive::OMPD_teams_loop),
unsigned(Directive::OMPD_tile),
unsigned(Directive::OMPD_unroll),
};
return loopDirectives;
}

TYPE_PARSER(sourced(construct<OmpBeginLoopDirective>(
sourced(OmpBeginDirectiveParser(GetLoopDirectives())))))

// END OMP Loop directives
TYPE_PARSER(
startOmpLine >> sourced(construct<OmpEndLoopDirective>(
sourced("END"_tok >> Parser<OmpLoopDirective>{}),
Parser<OmpClauseList>{})))
TYPE_PARSER(sourced(construct<OmpEndLoopDirective>(
sourced(OmpEndDirectiveParser(GetLoopDirectives())))))

TYPE_PARSER(construct<OpenMPLoopConstruct>(
Parser<OmpBeginLoopDirective>{} / endOmpLine))
Expand Down
Loading