Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions flang/include/flang/Parser/openmp-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define FORTRAN_PARSER_OPENMP_UTILS_H

#include "flang/Common/indirection.h"
#include "flang/Common/template.h"
#include "flang/Parser/parse-tree.h"
#include "llvm/Frontend/OpenMP/OMP.h"

Expand Down Expand Up @@ -127,7 +128,88 @@ template <typename T> struct IsStatement<Statement<T>> {
std::optional<Label> GetStatementLabel(const ExecutionPartConstruct &x);
std::optional<Label> GetFinalLabel(const OpenMPConstruct &x);

namespace detail {
// Clauses with flangClass = "OmpObjectList".
using MemberObjectListClauses =
std::tuple<OmpClause::Copyin, OmpClause::Copyprivate, OmpClause::Exclusive,
OmpClause::Firstprivate, OmpClause::HasDeviceAddr, OmpClause::Inclusive,
OmpClause::IsDevicePtr, OmpClause::Link, OmpClause::Private,
OmpClause::Shared, OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;

// Clauses with flangClass = "OmpSomeClause", and OmpObjectList a
// member of tuple OmpSomeClause::t.
using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;

// Does U have WrapperTrait (i.e. has a member 'v'), and if so, is T the
// type of v?
template <typename T, typename U, bool IsWrapper> struct WrappedInType {
static constexpr bool value{false};
};

template <typename T, typename U> struct WrappedInType<T, U, true> {
static constexpr bool value{std::is_same_v<T, decltype(U::v)>};
};

// Same as WrappedInType, but with a list of types Us. Satisfied if any
// type U in Us satisfies WrappedInType<T, U>.
template <typename...> struct WrappedInTypes;

template <typename T> struct WrappedInTypes<T> {
static constexpr bool value{false};
};

template <typename T, typename U, typename... Us>
struct WrappedInTypes<T, U, Us...> {
static constexpr bool value{WrappedInType<T, U, WrapperTrait<U>>::value ||
WrappedInTypes<T, Us...>::value};
};

// Same as WrappedInTypes, but takes type list in a form of a tuple or
// a variant.
template <typename...> struct WrappedInTupleOrVariant {
static constexpr bool value{false};
};
template <typename T, typename... Us>
struct WrappedInTupleOrVariant<T, std::tuple<Us...>> {
static constexpr bool value{WrappedInTypes<T, Us...>::value};
};
template <typename T, typename... Us>
struct WrappedInTupleOrVariant<T, std::variant<Us...>> {
static constexpr bool value{WrappedInTypes<T, Us...>::value};
};
template <typename T, typename U>
constexpr bool WrappedInTupleOrVariantV{WrappedInTupleOrVariant<T, U>::value};
} // namespace detail

template <typename T> const OmpObjectList *GetOmpObjectList(const T &clause) {
using namespace detail;
static_assert(std::is_class_v<T>, "Unexpected argument type");

if constexpr (common::HasMember<T, decltype(OmpClause::u)>) {
if constexpr (common::HasMember<T, MemberObjectListClauses>) {
return &clause.v;
} else if constexpr (common::HasMember<T, TupleObjectListClauses>) {
return &std::get<OmpObjectList>(clause.v.t);
} else {
return nullptr;
}
} else if constexpr (WrappedInTupleOrVariantV<T, TupleObjectListClauses>) {
return &std::get<OmpObjectList>(clause.t);
} else if constexpr (WrappedInTupleOrVariantV<T, decltype(OmpClause::u)>) {
return nullptr;
} else {
// The condition should be type-dependent, but it should always be false.
static_assert(sizeof(T) < 0 && "Unexpected argument type");
}
}

const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause);
const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x);

template <typename T>
const T *GetFirstArgument(const OmpDirectiveSpecification &spec) {
Expand Down
45 changes: 11 additions & 34 deletions flang/lib/Parser/openmp-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,43 +117,20 @@ std::optional<Label> GetFinalLabel(const OpenMPConstruct &x) {
}

const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
// Clauses with OmpObjectList as its data member
using MemberObjectListClauses = std::tuple<OmpClause::Copyin,
OmpClause::Copyprivate, OmpClause::Exclusive, OmpClause::Firstprivate,
OmpClause::HasDeviceAddr, OmpClause::Inclusive, OmpClause::IsDevicePtr,
OmpClause::Link, OmpClause::Private, OmpClause::Shared,
OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;

// Clauses with OmpObjectList in the tuple
using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;

// TODO:: Generate the tuples using TableGen.
return common::visit([](auto &&s) { return GetOmpObjectList(s); }, clause.u);
}

const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause) {
return common::visit(
common::visitors{
[&](const OmpClause::Depend &x) -> const OmpObjectList * {
if (auto *taskDep{std::get_if<OmpDependClause::TaskDep>(&x.v.u)}) {
return &std::get<OmpObjectList>(taskDep->t);
} else {
return nullptr;
}
},
[&](const auto &x) -> const OmpObjectList * {
using Ty = std::decay_t<decltype(x)>;
if constexpr (common::HasMember<Ty, MemberObjectListClauses>) {
return &x.v;
} else if constexpr (common::HasMember<Ty,
TupleObjectListClauses>) {
return &std::get<OmpObjectList>(x.v.t);
} else {
return nullptr;
}
},
[](const OmpDoacross &) { return nullptr; },
[](const OmpDependClause::TaskDep &x) { return GetOmpObjectList(x); },
},
clause.u);
clause.v.u);
}

const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x) {
return &std::get<OmpObjectList>(x.t);
}

const BlockConstruct *GetFortranBlockConstruct(
Expand Down
9 changes: 3 additions & 6 deletions flang/lib/Semantics/check-omp-loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,8 @@ void OmpStructureChecker::CheckDistLinear(

// Collect symbols of all the variables from linear clauses
for (auto &clause : clauses.v) {
if (auto *linearClause{std::get_if<parser::OmpClause::Linear>(&clause.u)}) {
auto &objects{std::get<parser::OmpObjectList>(linearClause->v.t)};
GetSymbolsInObjectList(objects, indexVars);
if (std::get_if<parser::OmpClause::Linear>(&clause.u)) {
GetSymbolsInObjectList(*parser::omp::GetOmpObjectList(clause), indexVars);
}
}

Expand Down Expand Up @@ -604,8 +603,6 @@ void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &x) {
auto *maybeModifier{OmpGetUniqueModifier<ReductionModifier>(modifiers)};
if (maybeModifier &&
maybeModifier->v == ReductionModifier::Value::Inscan) {
const auto &objectList{
std::get<parser::OmpObjectList>(reductionClause->v.t)};
auto checkReductionSymbolInScan = [&](const parser::Name *name) {
if (auto &symbol = name->symbol) {
if (!symbol->test(Symbol::Flag::OmpInclusiveScan) &&
Expand All @@ -618,7 +615,7 @@ void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &x) {
}
}
};
for (const auto &ompObj : objectList.v) {
for (const auto &ompObj : parser::omp::GetOmpObjectList(clause)->v) {
common::visit(
common::visitors{
[&](const parser::Designator &designator) {
Expand Down
75 changes: 33 additions & 42 deletions flang/lib/Semantics/check-omp-structure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,11 +624,9 @@ void OmpStructureChecker::CheckMultListItems() {

// Linear clause
for (auto [_, clause] : FindClauses(llvm::omp::Clause::OMPC_linear)) {
auto &linearClause{std::get<parser::OmpClause::Linear>(clause->u)};
std::list<parser::Name> nameList;
SymbolSourceMap symbols;
GetSymbolsInObjectList(
std::get<parser::OmpObjectList>(linearClause.v.t), symbols);
GetSymbolsInObjectList(*GetOmpObjectList(*clause), symbols);
llvm::transform(symbols, std::back_inserter(nameList), [&](auto &&pair) {
return parser::Name{pair.second, const_cast<Symbol *>(pair.first)};
});
Expand Down Expand Up @@ -2101,29 +2099,29 @@ void OmpStructureChecker::Leave(const parser::OpenMPDeclareTargetConstruct &x) {
}
}

bool toClauseFound{false}, deviceTypeClauseFound{false},
enterClauseFound{false};
bool toClauseFound{false};
bool deviceTypeClauseFound{false};
bool enterClauseFound{false};
for (const parser::OmpClause &clause : x.v.Clauses().v) {
common::visit(
common::visitors{
[&](const parser::OmpClause::To &toClause) {
toClauseFound = true;
auto &objList{std::get<parser::OmpObjectList>(toClause.v.t)};
CheckSymbolNames(dirName.source, objList);
CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
CheckThreadprivateOrDeclareTargetVar(objList);
},
[&](const parser::OmpClause::Link &linkClause) {
CheckSymbolNames(dirName.source, linkClause.v);
CheckVarIsNotPartOfAnotherVar(dirName.source, linkClause.v);
CheckThreadprivateOrDeclareTargetVar(linkClause.v);
},
[&](const parser::OmpClause::Enter &enterClause) {
enterClauseFound = true;
auto &objList{std::get<parser::OmpObjectList>(enterClause.v.t)};
CheckSymbolNames(dirName.source, objList);
CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
CheckThreadprivateOrDeclareTargetVar(objList);
[&](const auto &c) {
using TypeC = llvm::remove_cvref_t<decltype(c)>;
if constexpr ( //
std::is_same_v<TypeC, parser::OmpClause::Enter> ||
std::is_same_v<TypeC, parser::OmpClause::Link> ||
std::is_same_v<TypeC, parser::OmpClause::To>) {
auto &objList{*GetOmpObjectList(c)};
CheckSymbolNames(dirName.source, objList);
CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
CheckThreadprivateOrDeclareTargetVar(objList);
}
if constexpr (std::is_same_v<TypeC, parser::OmpClause::Enter>) {
enterClauseFound = true;
}
if constexpr (std::is_same_v<TypeC, parser::OmpClause::To>) {
toClauseFound = true;
}
},
[&](const parser::OmpClause::DeviceType &deviceTypeClause) {
deviceTypeClauseFound = true;
Expand All @@ -2134,7 +2132,6 @@ void OmpStructureChecker::Leave(const parser::OpenMPDeclareTargetConstruct &x) {
deviceConstructFound_ = true;
}
},
[&](const auto &) {},
},
clause.u);

Expand Down Expand Up @@ -2424,12 +2421,8 @@ void OmpStructureChecker::CheckTargetUpdate() {
}
if (toWrapper && fromWrapper) {
SymbolSourceMap toSymbols, fromSymbols;
auto &fromClause{std::get<parser::OmpClause::From>(fromWrapper->u).v};
auto &toClause{std::get<parser::OmpClause::To>(toWrapper->u).v};
GetSymbolsInObjectList(
std::get<parser::OmpObjectList>(fromClause.t), fromSymbols);
GetSymbolsInObjectList(
std::get<parser::OmpObjectList>(toClause.t), toSymbols);
GetSymbolsInObjectList(*GetOmpObjectList(*fromWrapper), fromSymbols);
GetSymbolsInObjectList(*GetOmpObjectList(*toWrapper), toSymbols);

for (auto &[symbol, source] : toSymbols) {
auto fromSymbol{fromSymbols.find(symbol)};
Expand Down Expand Up @@ -3269,7 +3262,7 @@ void OmpStructureChecker::Leave(const parser::OmpClauseList &) {
const auto &irClause{
std::get<parser::OmpClause::InReduction>(dataEnvClause->u)};
checkVarAppearsInDataEnvClause(
std::get<parser::OmpObjectList>(irClause.v.t), "IN_REDUCTION");
*GetOmpObjectList(irClause), "IN_REDUCTION");
}
}
}
Expand Down Expand Up @@ -3436,7 +3429,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Destroy &x) {

void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_reduction);
auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
auto &objects{*GetOmpObjectList(x)};

if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_reduction,
GetContext().clauseSource, context_)) {
Expand Down Expand Up @@ -3476,7 +3469,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {

void OmpStructureChecker::Enter(const parser::OmpClause::InReduction &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_in_reduction);
auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
auto &objects{*GetOmpObjectList(x)};

if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_in_reduction,
GetContext().clauseSource, context_)) {
Expand All @@ -3494,7 +3487,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::InReduction &x) {

void OmpStructureChecker::Enter(const parser::OmpClause::TaskReduction &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_task_reduction);
auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
auto &objects{*GetOmpObjectList(x)};

if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_task_reduction,
GetContext().clauseSource, context_)) {
Expand Down Expand Up @@ -4347,8 +4340,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) {
}};

evaluate::ExpressionAnalyzer ea{context_};
const auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
for (auto &object : objects.v) {
for (auto &object : GetOmpObjectList(x)->v) {
if (const parser::Designator *d{GetDesignatorFromObj(object)}) {
if (auto &&expr{ea.Analyze(*d)}) {
if (hasBasePointer(*expr)) {
Expand Down Expand Up @@ -4501,7 +4493,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Depend &x) {
}
}
if (taskDep) {
auto &objList{std::get<parser::OmpObjectList>(taskDep->t)};
auto &objList{*GetOmpObjectList(*taskDep)};
if (dir == llvm::omp::OMPD_depobj) {
// [5.0:255:13], [5.1:288:6], [5.2:322:26]
// A depend clause on a depobj construct must only specify one locator.
Expand Down Expand Up @@ -4647,7 +4639,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Copyprivate &x) {
void OmpStructureChecker::Enter(const parser::OmpClause::Lastprivate &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_lastprivate);

const auto &objectList{std::get<parser::OmpObjectList>(x.v.t)};
const auto &objectList{*GetOmpObjectList(x)};
CheckVarIsNotPartOfAnotherVar(
GetContext().clauseSource, objectList, "LASTPRIVATE");
CheckCrayPointee(objectList, "LASTPRIVATE");
Expand Down Expand Up @@ -4889,9 +4881,8 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Enter &x) {
x.v, llvm::omp::OMPC_enter, GetContext().clauseSource, context_)) {
return;
}
const parser::OmpObjectList &objList{std::get<parser::OmpObjectList>(x.v.t)};
SymbolSourceMap symbols;
GetSymbolsInObjectList(objList, symbols);
GetSymbolsInObjectList(*GetOmpObjectList(x), symbols);
for (const auto &[symbol, source] : symbols) {
if (!IsExtendedListItem(*symbol)) {
context_.SayWithDecl(*symbol, source,
Expand All @@ -4914,7 +4905,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::From &x) {
CheckIteratorModifier(*iter);
}

const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
const auto &objList{*GetOmpObjectList(x)};
SymbolSourceMap symbols;
GetSymbolsInObjectList(objList, symbols);
CheckVariableListItem(symbols);
Expand Down Expand Up @@ -4954,7 +4945,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::To &x) {
CheckIteratorModifier(*iter);
}

const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
const auto &objList{*GetOmpObjectList(x)};
SymbolSourceMap symbols;
GetSymbolsInObjectList(objList, symbols);
CheckVariableListItem(symbols);
Expand Down
Loading