diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index be2117efbabc0..be9531b574c60 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -31,6 +31,7 @@ #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Support/CommandLine.h" @@ -48,27 +49,76 @@ using DeclareTargetCapturePair = // Common helper functions //===----------------------------------------------------------------------===// -static Fortran::semantics::Symbol * -getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { - Fortran::semantics::Symbol *sym = nullptr; - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (auto *arrayEle = - Fortran::parser::Unwrap( - designator)) { - sym = GetFirstName(arrayEle->base).symbol; - } else if (auto *structComp = Fortran::parser::Unwrap< - Fortran::parser::StructureComponent>(designator)) { - sym = structComp->component.symbol; - } else if (const Fortran::parser::Name *name = - Fortran::semantics::getDesignatorNameIfDataRef( - designator)) { - sym = name->symbol; - } - }, - [&](const Fortran::parser::Name &name) { sym = name.symbol; }}, - ompObject.u); +static llvm::ArrayRef getWorksharing() { + static llvm::omp::Directive worksharing[] = { + llvm::omp::Directive::OMPD_do, llvm::omp::Directive::OMPD_for, + llvm::omp::Directive::OMPD_scope, llvm::omp::Directive::OMPD_sections, + llvm::omp::Directive::OMPD_single, llvm::omp::Directive::OMPD_workshare, + }; + return worksharing; +} + +static llvm::ArrayRef getWorksharingLoop() { + static llvm::omp::Directive worksharingLoop[] = { + llvm::omp::Directive::OMPD_do, + llvm::omp::Directive::OMPD_for, + }; + return worksharingLoop; +} + +static uint32_t getOpenMPVersion(const mlir::ModuleOp &mod) { + if (mlir::Attribute verAttr = mod->getAttr("omp.version")) + return llvm::cast(verAttr).getVersion(); + llvm_unreachable("Exoecting OpenMP version attribute in module"); +} + +static std::pair +getOmpObjectSymbolAndBase(const Fortran::parser::Name &name) { + return std::make_pair(name.symbol, nullptr); +} + +static std::pair +getOmpObjectSymbolAndBase(const Fortran::parser::Designator &designator) { + if (auto *arrayEle = + Fortran::parser::Unwrap(designator)) { + auto *sym = GetFirstName(arrayEle->base).symbol; + // Array elements don't have their own symbols, instead the base symbol + // is used. + return std::make_pair(sym, sym); + } + if (auto *structComp = + Fortran::parser::Unwrap( + designator)) { + auto *sym = structComp->component.symbol; + auto *base = GetFirstName(structComp->base).symbol; + return std::make_pair(sym, base); + } + if (const Fortran::parser::Name *name = + Fortran::semantics::getDesignatorNameIfDataRef(designator)) { + return getOmpObjectSymbolAndBase(*name); + } + llvm_unreachable("Cannot obtain symbols for designtor"); +} + +static std::pair +getOmpObjectSymbolAndBase(const Fortran::parser::OmpObject &object) { + std::pair syms; + std::visit(Fortran::common::visitors{ + [&](const Fortran::parser::Designator &designator) { + syms = getOmpObjectSymbolAndBase(designator); + }, + [&](const Fortran::parser::Name &name) { + syms = getOmpObjectSymbolAndBase(name); + }}, + object.u); + return syms; +} + +template +static Fortran::semantics::Symbol *getOmpObjectSymbol(Object &&object) { + auto *sym = + std::get<0>(getOmpObjectSymbolAndBase(std::forward(object))); + assert(sym != nullptr); return sym; } @@ -142,6 +192,1034 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter, converter.genEval(e); } +//===----------------------------------------------------------------------===// +// Directive decomposition +//===----------------------------------------------------------------------===// + +namespace { +struct DirectiveInfo { + llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown; + llvm::SmallVector clauses; +}; + +struct CompositeInfo { + CompositeInfo(const mlir::ModuleOp &modOp, + Fortran::lower::pft::Evaluation &ev, + llvm::omp::Directive compDir, + const std::list &clauses); + using ClauseSet = std::set; + + bool split(); + void addClause(const Fortran::parser::OmpClause *clause); + + DirectiveInfo *findDirective(llvm::omp::Directive dirId) { + for (DirectiveInfo &dir : leafs) { + if (dir.id == dirId) + return &dir; + } + return nullptr; + } + ClauseSet *findClauses(const Fortran::parser::OmpObject &object) { + const Fortran::semantics::Symbol *sym = getOmpObjectSymbol(object); + if (auto found = syms.find(sym); found != syms.end()) + return &found->second; + return nullptr; + } + + llvm::SmallVector leafs; // Ordered outer to inner. + llvm::DenseMap syms; + llvm::DenseSet mapBases; + Fortran::lower::pft::Evaluation &eval; + const mlir::ModuleOp &mod; + + // List of clauses applied to the combined/composite directive. + // Processing of the LINEAR clause can result in FIRSTPRIVATE and/or + // LASTPRIVATE added to this list. + llvm::SmallVector clauses; + // Storage for the OmpClause's created during splitting. + llvm::SmallVector> storage; +}; +} // namespace + +static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const DirectiveInfo &dirInfo); +static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const CompositeInfo &compInfo); + +namespace detail { +template +llvm::omp::Clause getClauseIdForClass(C &&) { + using namespace Fortran; + using A = llvm::remove_cvref_t; // A is referenced in OMP.inc +#define GEN_FLANG_CLAUSE_PARSER_KIND_MAP +#include "llvm/Frontend/OpenMP/OMP.inc" +} + +template +typename std::remove_reference_t::iterator +find_unique(Container &&container, Predicate &&pred) { + auto first = std::find_if(container.begin(), container.end(), pred); + if (first == container.end()) + return first; + auto second = std::find_if(std::next(first), container.end(), pred); + if (second == container.end()) + return first; + return container.end(); +} +} // namespace detail + +static llvm::omp::Clause getClauseId(const Fortran::parser::OmpClause &clause) { + return std::visit([](auto &&s) { return detail::getClauseIdForClass(s); }, + clause.u); +} + +namespace detail { +template +auto clauseDispatch( + Clause &&clause, llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, CompositeInfo &compInfo, + Handler &&handler, + typename llvm::remove_cvref_t::EmptyTrait value = {}) { + return handler(std::forward(clause), clauseId, clauseNode, compInfo); +} + +template +auto clauseDispatch( + Clause &&clause, llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, CompositeInfo &compInfo, + Handler &&handler, + typename llvm::remove_cvref_t::WrapperTrait value = {}) { + return handler(std::forward(clause), clauseId, clause.v, clauseNode, + compInfo); +} + +template +auto visit_clause(const Fortran::parser::OmpClause &clause, Handler &&handler, + CompositeInfo &compInfo) { + return std::visit( + [&](auto &&actual) { + return detail::clauseDispatch(actual, getClauseId(clause), &clause, + compInfo, handler); + }, + clause.u); +} +} // namespace detail + +static Fortran::semantics::Symbol * +getIterationVariableSymbol(const Fortran::lower::pft::Evaluation &eval) { + return eval.visit(Fortran::common::visitors{ + [&](const Fortran::parser::DoConstruct &doLoop) { + if (const auto &maybeCtrl = doLoop.GetLoopControl()) { + using LoopControl = Fortran::parser::LoopControl; + if (auto *bounds = std::get_if(&maybeCtrl->u)) { + static_assert( + std::is_same_vname), + Fortran::parser::Scalar>); + return bounds->name.thing.symbol; + } + } + return static_cast(nullptr); + }, + [](auto &&) { + return static_cast(nullptr); + }, + }); +} + +static void addSymsToMap(const Fortran::parser::OmpObjectList &objects, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + for (const Fortran::parser::OmpObject &object : objects.v) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(object); + compInfo.syms[sym].insert(clauseNode); + } +} + +static void addSymToMap(const Fortran::parser::Name &name, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(name); + compInfo.syms[sym].insert(clauseNode); +} + +static void addSymToMap(const Fortran::parser::Designator &designator, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + Fortran::semantics::Symbol *sym = getOmpObjectSymbol(designator); + compInfo.syms[sym].insert(clauseNode); +} + +template +static void addClauseSymsToMap(Clause &&clause, llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Do nothing for clauses represented by empty classes. +} + +template +static void addClauseSymsToMap(Clause &&clause, llvm::omp::Clause clauseId, + const Fortran::parser::OmpObjectList &objects, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + addSymsToMap(objects, clauseNode, compInfo); +} + +static void +addClauseSymsToMap(const Fortran::parser::OmpClause::Aligned &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpAlignedClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // contents.t<0> -> OmpObjectList + addSymsToMap(std::get<0>(contents.t), clauseNode, compInfo); +} + +static void +addClauseSymsToMap(const Fortran::parser::OmpClause::Allocate &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpAllocateClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // contents.t<1> -> OmpObjectList + addSymsToMap(std::get<1>(contents.t), clauseNode, compInfo); +} + +static void addClauseSymsToMap(const Fortran::parser::OmpClause::Depend &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpDependClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + std::visit(Fortran::common::visitors{ + [&](const Fortran::parser::OmpDependClause::InOut &inout) { + // inout.t<1> -> std::list + for (const auto &designator : std::get<1>(inout.t)) + addSymToMap(designator, clauseNode, compInfo); + }, + [](auto &&) { + // No objects in the other alternatives. + }, + }, + contents.u); +} + +static void +addClauseSymsToMap(const Fortran::parser::OmpClause::InReduction &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpInReductionClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // contents.t<1> -> OmpObjectList + addSymsToMap(std::get<1>(contents.t), clauseNode, compInfo); +} + +static void addClauseSymsToMap(const Fortran::parser::OmpClause::Linear &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpLinearClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // contents.u is a variant where both alternatives have member `names` + // that is std::list. + std::visit( + [&](auto &&s) { + for (const Fortran::parser::Name &name : s.names) + addSymToMap(name, clauseNode, compInfo); + }, + contents.u); +} + +static void addClauseSymsToMap(const Fortran::parser::OmpClause::Map &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpMapClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // contents.t<1> -> OmpObjectList + const Fortran::parser::OmpObjectList &objects = std::get<1>(contents.t); + + addSymsToMap(objects, clauseNode, compInfo); + + // Additionally, add base symbols to the 'mapBases' set. + for (const Fortran::parser::OmpObject &object : objects.v) { + if (auto *base = std::get<1>(getOmpObjectSymbolAndBase(object))) + compInfo.mapBases.insert(base); + } +} + +template +static void +addClauseSymsToMap(Clause &&clause, llvm::omp::Clause clauseId, + const Fortran::parser::OmpReductionClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Reduction, TaskReduction + // contents.t<1> -> OmpObjectList + addSymsToMap(std::get<1>(contents.t), clauseNode, compInfo); +} + +template +static void addClauseSymsToMap(Clause &&clause, llvm::omp::Clause clauseId, + const std::list &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // NonTemporal, Uniform + for (const Fortran::parser::Name &name : contents) + addSymToMap(name, clauseNode, compInfo); +} + +template +static void addClauseSymsToMap(Clause &&clause, llvm::omp::Clause clauseId, + WrappedType &&wrapped, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Make sure that we are not missing anything: list all the wrapped + // types that do not contain (or reference) any objects. + using namespace Fortran::parser; + static_assert( + llvm::is_one_of< + llvm::remove_cvref_t, OmpAtomicDefaultMemOrderClause, + ConstantExpr, ScalarIntConstantExpr, ScalarIntExpr, ScalarLogicalExpr, + std::optional, std::optional, + std::list, OmpDefaultClause, OmpDefaultmapClause, + OmpDeviceClause, OmpDeviceTypeClause, OmpIfClause, OmpOrderClause, + OmpProcBindClause, OmpScheduleClause>::value); +} + +CompositeInfo::CompositeInfo( + const mlir::ModuleOp &modOp, Fortran::lower::pft::Evaluation &ev, + llvm::omp::Directive compDir, + const std::list &clauses) + : eval(ev), mod(modOp) { + for (llvm::omp::Directive dir : llvm::omp::getLeafConstructs(compDir)) + leafs.push_back(DirectiveInfo{dir}); + + for (const Fortran::parser::OmpClause &clause : clauses) + addClause(&clause); +} + +void CompositeInfo::addClause(const Fortran::parser::OmpClause *clause) { + clauses.push_back(clause); + detail::visit_clause( + *clause, [](auto &&...args) { addClauseSymsToMap(args...); }, *this); +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const DirectiveInfo &dirInfo) { + os << llvm::omp::getOpenMPDirectiveName(dirInfo.id); + for (auto [index, clause] : llvm::enumerate(dirInfo.clauses)) { + os << (index == 0 ? '\t' : ' '); + os << llvm::omp::getOpenMPClauseName(getClauseId(*clause)); + } + return os; +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const CompositeInfo &compInfo) { + for (const auto &[index, dirInfo] : llvm::enumerate(compInfo.leafs)) + os << "leaf[" << index << "]: " << dirInfo << '\n'; + + os << "syms:\n"; + for (const auto &[sym, clauses] : compInfo.syms) { + os << *sym << " -> {"; + for (const auto *clause : clauses) + os << ' ' << llvm::omp::getOpenMPClauseName(getClauseId(*clause)); + os << " }\n"; + } + os << "mapBases: {"; + for (const auto &sym : compInfo.mapBases) + os << ' ' << *sym; + os << " }\n"; + return os; +} + +// Apply a clause to the only directive that allows it. If there are no +// directives that allow it, or if there is more that one, do not apply +// anything and return false, otherwise return true. +static bool applyToUnique(llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + uint32_t version = getOpenMPVersion(compInfo.mod); + auto unique = detail::find_unique(compInfo.leafs, [=](const auto &dirInfo) { + return llvm::omp::isAllowedClauseForDirective(dirInfo.id, clauseId, + version); + }); + + if (unique != compInfo.leafs.end()) { + unique->clauses.push_back(clauseNode); + return true; + } + return false; +} + +// Apply a clause to the first directive in given range that allows it. +// If such a directive does not exist, return false, otherwise return true. +template +static bool applyToFirst(llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + const mlir::ModuleOp &mod, + llvm::iterator_range range) { + if (range.empty()) + return false; + + uint32_t version = getOpenMPVersion(mod); + for (DirectiveInfo &dir : range) { + if (!llvm::omp::isAllowedClauseForDirective(dir.id, clauseId, version)) + continue; + dir.clauses.push_back(clauseNode); + return true; + } + return false; +} + +// Apply a clause to the innermost directive that allows it. If such a +// directive does not exist, return false, otherwise return true. +static bool applyToInnermost(llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + return applyToFirst(clauseId, clauseNode, compInfo.mod, + llvm::reverse(compInfo.leafs)); +} + +// Apply a clause to the outermost directive that allows it. If such a +// directive does not exist, return false, otherwise return true. +static bool applyToOutermost(llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + return applyToFirst(clauseId, clauseNode, compInfo.mod, + llvm::iterator_range(compInfo.leafs)); +} + +template +static bool applyIf(llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo, + Predicate shouldApply) { + bool applied = false; + uint32_t version = getOpenMPVersion(compInfo.mod); + for (DirectiveInfo &dir : compInfo.leafs) { + if (!llvm::omp::isAllowedClauseForDirective(dir.id, clauseId, version)) + continue; + if (!shouldApply(dir)) + continue; + dir.clauses.push_back(clauseNode); + applied = true; + } + + return applied; +} + +static bool applyToAll(llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + return applyIf(clauseId, clauseNode, compInfo, [](auto) { return true; }); +} + +template +static bool applyClause(Clause &&clause, llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // The default behavior is to find the unique directive to which the + // given clause may be applied. If there are no such directives, or + // if there are multiple ones, flag an error. + // From "OpenMP Application Programming Interface", Version 5.2: + // "Some clauses are permitted only on a single leaf construct of the + // combined or composite construct, in which case the effect is as if + // the clause is applied to that specific construct." (p339, 31-33) + if (applyToUnique(clauseId, clauseNode, compInfo)) + return true; + + llvm::errs() << "handle empty class:" + << llvm::omp::getOpenMPClauseName(clauseId) << '\n'; + return false; +} + +// Clauses that expected to only be applicable to a single leaf construct. +template +static bool applyClause(Clause &&clause, llvm::omp::Clause clauseId, + WrappedType &&, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + if (applyToUnique(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "handle wrapper class for generic type:" + << llvm::omp::getOpenMPClauseName(clauseId) << '\n'; + return false; +} + +// COLLAPSE +template +static bool applyClause(const Fortran::parser::OmpClause::Collapse &clause, + llvm::omp::Clause clauseId, WrappedType &&contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Apply COLLAPSE to the innermost directive. If it's not one that + // allows it flag an error. + if (!compInfo.leafs.empty()) { + DirectiveInfo &last = compInfo.leafs.back(); + uint32_t version = getOpenMPVersion(compInfo.mod); + + if (llvm::omp::isAllowedClauseForDirective(last.id, clauseId, version)) { + last.clauses.push_back(clauseNode); + return true; + } + } + + llvm::errs() << "Cannot apply COLLAPSE\n"; + return false; +} + +// PRIVATE +static bool applyClause(const Fortran::parser::OmpClause::Private &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpObjectList &objects, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + if (applyToInnermost(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "Cannot apply PRIVATE\n"; + return false; +} + +// FIRSTPRIVATE +static bool applyClause(const Fortran::parser::OmpClause::Firstprivate &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpObjectList &objects, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + bool applied = false; + + // S Section 17.2 + // S The effect of the firstprivate clause is as if it is applied to one + // S or more leaf constructs as follows: + + // S - To the distribute construct if it is among the constituent constructs; + // S - To the teams construct if it is among the constituent constructs and + // S the distribute construct is not; + auto hasDistribute = compInfo.findDirective(llvm::omp::OMPD_distribute); + auto hasTeams = compInfo.findDirective(llvm::omp::OMPD_teams); + if (hasDistribute != nullptr) { + hasDistribute->clauses.push_back(clauseNode); + applied = true; + // S If the teams construct is among the constituent constructs and the + // S effect is not as if the firstprivate clause is applied to it by the + // S above rules, then the effect is as if the shared clause with the + // S same list item is applied to the teams construct. + if (hasTeams != nullptr) { + // TODO: Apply SHARED(objects) + } + } else if (hasTeams != nullptr) { + hasTeams->clauses.push_back(clauseNode); + applied = true; + } + + // S - To a worksharing construct that accepts the clause if one is among + // S the constituent constructs; + auto findWorksharing = [&]() { + auto worksharing = getWorksharing(); + for (DirectiveInfo &dir : compInfo.leafs) { + auto found = llvm::find(worksharing, dir.id); + if (found != std::end(worksharing)) + return &dir; + } + return static_cast(nullptr); + }; + + auto hasWorksharing = findWorksharing(); + if (hasWorksharing != nullptr) { + hasWorksharing->clauses.push_back(clauseNode); + applied = true; + } + + // S - To the taskloop construct if it is among the constituent constructs; + auto hasTaskloop = compInfo.findDirective(llvm::omp::OMPD_taskloop); + if (hasTaskloop != nullptr) { + hasTaskloop->clauses.push_back(clauseNode); + applied = true; + } + + // S - To the parallel construct if it is among the constituent constructs + // S and neither a taskloop construct nor a worksharing construct that + // S accepts the clause is among them; + auto hasParallel = compInfo.findDirective(llvm::omp::OMPD_parallel); + if (hasParallel != nullptr) { + if (hasTaskloop == nullptr && hasWorksharing == nullptr) { + hasParallel->clauses.push_back(clauseNode); + applied = true; + } else { + // S If the parallel construct is among the constituent constructs and + // S the effect is not as if the firstprivate clause is applied to it by + // S the above rules, then the effect is as if the shared clause with + // S the same list item is applied to the parallel construct. + // TODO: apply SHARED(objects) to PARALLEL + } + } + + // S - To the target construct if it is among the constituent constructs + // S and the same list item neither appears in a lastprivate clause nor + // S is the base variable or base pointer of a list item that appears in + // S a map clause. + auto objInLastprivate = [&](const Fortran::parser::OmpObject &object) { + if (CompositeInfo::ClauseSet *clauses = compInfo.findClauses(object)) { + for (const Fortran::parser::OmpClause *clause : *clauses) { + if (getClauseId(*clause) == llvm::omp::Clause::OMPC_lastprivate) + return true; + } + } + return false; + }; + + auto hasTarget = compInfo.findDirective(llvm::omp::OMPD_target); + if (hasTarget != nullptr) { + for (const Fortran::parser::OmpObject &object : objects.v) { + if (objInLastprivate(object)) + continue; + if (compInfo.mapBases.contains(getOmpObjectSymbol(object))) + continue; + // TODO: Add FIRSTPRIVATE(object) to clause list + // TODO: may need a new OmpObjectList here + applied = true; + } + } + + return applied; +} + +// LASTPRIVATE +static bool applyClause(const Fortran::parser::OmpClause::Lastprivate &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpObjectList &objects, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // S The effect of the lastprivate clause is as if it is applied to all leaf + // S constructs that permit the clause. + if (!applyToAll(clauseId, clauseNode, compInfo)) { + llvm::errs() << "Cannot apply LASTPRIVATE\n"; + return false; + } + + // S If the parallel construct is among the constituent constructs and the + // S list item is not also specified in the firstprivate clause, then the + // S effect of the lastprivate clause is as if the shared clause with the + // S same list item is applied to the parallel construct. + auto inFirstprivate = [&](const Fortran::parser::OmpObject &object) { + if (auto *clauses = compInfo.findClauses(object)) { + for (const Fortran::parser::OmpClause *clause : *clauses) { + if (getClauseId(*clause) == llvm::omp::Clause::OMPC_firstprivate) + return true; + } + } + return false; + }; + + if (auto hasParallel = compInfo.findDirective(llvm::omp::OMPD_parallel)) { + for (const Fortran::parser::OmpObject &object : objects.v) { + if (!inFirstprivate(object)) { + // TODO Add SHARED(object) to PARALLEL + } + } + } + + // S If the teams construct is among the constituent constructs and the + // S list item is not also specified in the firstprivate clause, then the + // S effect of the lastprivate clause is as if the shared clause with the + // S same list item is applied to the teams construct. + if (auto hasTeams = compInfo.findDirective(llvm::omp::OMPD_teams)) { + for (const Fortran::parser::OmpObject &object : objects.v) { + if (!inFirstprivate(object)) { + // TODO Add SHARED(object) to TEAMS + } + } + } + + // S If the target construct is among the constituent constructs and the + // S list item is not the base variable or base pointer of a list item that + // S appears in a map clause, the effect of the lastprivate clause is as if + // S the same list item appears in a map clause with a map-type of tofrom. + if (auto hasTarget = compInfo.findDirective(llvm::omp::OMPD_target)) { + for (const Fortran::parser::OmpObject &object : objects.v) { + const Fortran::semantics::Symbol *sym = getOmpObjectSymbol(object); + // See if symbol is a base symbol in MAP. + if (!compInfo.mapBases.contains(sym)) { + // TODO Add MAP(tofrom, object) to TARGET. + } + } + } + + return false; +} + +// SHARED +template +static bool applyClause(const Fortran::parser::OmpClause::Shared &clause, + llvm::omp::Clause clauseId, WrappedType &&contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Apply SHARED to the all leafs that allow it. + if (applyToAll(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "Cannot apply SHARED\n"; + return false; +} + +// DEFAULT +template +static bool applyClause(const Fortran::parser::OmpClause::Default &clause, + llvm::omp::Clause clauseId, WrappedType &&contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Apply DEFAULT to the all leafs that allow it. + if (applyToAll(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "Cannot apply DEFAULT\n"; + return false; +} + +// THREAD_LIMIT +template +static bool applyClause(const Fortran::parser::OmpClause::ThreadLimit &clause, + llvm::omp::Clause clauseId, WrappedType &&contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Apply THREAD_LIMIT to the all leafs that allow it. + if (applyToAll(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "Cannot apply THREAD_LIMIT\n"; + return false; +} + +// ORDER +template +static bool applyClause(const Fortran::parser::OmpClause::Order &clause, + llvm::omp::Clause clauseId, WrappedType &&contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // Apply ORDER to the all leafs that allow it. + if (applyToAll(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "Cannot apply ORDER\n"; + return false; +} + +// ALLOCATE +template +static bool applyClause(const Fortran::parser::OmpClause::Allocate &clause, + llvm::omp::Clause clauseId, WrappedType &&contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // S The effect of the allocate clause is as if it is applied to all leaf + // S constructs that permit the clause and to which a data-sharing attribute + // S clause that may create a private copy of the same list item is applied. + + // XXX This one may need to be applied at the end, once we know which leaf + // constructs have what data-sharing attributes. Or maybe do all data-sharing + // first, then the rest of the clauses? + + // TODO + llvm::errs() << "Cannot apply ALLOCATE\n"; + return false; +} + +// REDUCTION +static bool applyClause(const Fortran::parser::OmpClause::Reduction &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpReductionClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // S The effect of the reduction clause is as if it is applied to all leaf + // S constructs that permit the clause, except for the following constructs: + // S - The parallel construct, when combined with the sections, worksharing- + // S loop, loop, or taskloop construct; and + // S - The teams construct, when combined with the loop construct. + bool applyToParallel = true, applyToTeams = true; + + auto hasParallel = + compInfo.findDirective(llvm::omp::Directive::OMPD_parallel); + if (hasParallel) { + auto exclusions = llvm::concat( + getWorksharingLoop(), llvm::ArrayRef{ + llvm::omp::Directive::OMPD_loop, + llvm::omp::Directive::OMPD_sections, + llvm::omp::Directive::OMPD_taskloop, + }); + auto present = [&](llvm::omp::Directive id) { + return compInfo.findDirective(id) != nullptr; + }; + + if (llvm::any_of(exclusions, present)) + applyToParallel = false; + } + + auto hasTeams = compInfo.findDirective(llvm::omp::Directive::OMPD_teams); + if (hasTeams) { + // The only exclusion is OMPD_loop. + if (compInfo.findDirective(llvm::omp::Directive::OMPD_loop)) + applyToTeams = false; + } + + // S For the parallel and teams constructs above, the effect of the + // S reduction clause instead is as if each list item or, for any list + // S item that is an array item, its corresponding base array or base + // S pointer appears in a shared clause for the construct. + for (auto dir : {hasParallel, hasTeams}) { + if (dir == nullptr) + continue; + // TODO apply SHARED(objects) to *dir. + } + + // TODO: Apply the following. + // S If the task reduction-modifier is specified, the effect is as if + // S it only modifies the behavior of the reduction clause on the innermost + // S leaf construct that accepts the modifier (see Section 5.5.8). If the + // S inscan reduction-modifier is specified, the effect is as if it modifies + // S the behavior of the reduction clause on all constructs of the combined + // S construct to which the clause is applied and that accept the modifier. + + bool applied = + applyIf(clauseId, clauseNode, compInfo, [&](DirectiveInfo &dir) { + if (!applyToParallel && &dir == hasParallel) + return false; + if (!applyToTeams && &dir == hasTeams) + return false; + return true; + }); + + // TODO: Apply the following. + // S If a list item in a reduction clause on a combined target construct + // S does not have the same base variable or base pointer as a list item + // S in a map clause on the construct, then the effect is as if the list + // S item in the reduction clause appears as a list item in a map clause + // S with a map-type of tofrom. + + return applied; +} + +// IF +static bool applyClause(const Fortran::parser::OmpClause::If &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpIfClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + using namespace Fortran::parser; + auto &modifier = + std::get>(contents.t); + + if (modifier) { + llvm::omp::Directive dirId = llvm::omp::Directive::OMPD_unknown; + + switch (*modifier) { + case OmpIfClause::DirectiveNameModifier::Parallel: + dirId = llvm::omp::Directive::OMPD_parallel; + break; + case OmpIfClause::DirectiveNameModifier::Simd: + dirId = llvm::omp::Directive::OMPD_simd; + break; + case OmpIfClause::DirectiveNameModifier::Target: + dirId = llvm::omp::Directive::OMPD_target; + break; + case OmpIfClause::DirectiveNameModifier::Task: + dirId = llvm::omp::Directive::OMPD_task; + break; + case OmpIfClause::DirectiveNameModifier::Taskloop: + dirId = llvm::omp::Directive::OMPD_taskloop; + break; + case OmpIfClause::DirectiveNameModifier::Teams: + dirId = llvm::omp::Directive::OMPD_teams; + break; + + case OmpIfClause::DirectiveNameModifier::TargetData: + case OmpIfClause::DirectiveNameModifier::TargetEnterData: + case OmpIfClause::DirectiveNameModifier::TargetExitData: + case OmpIfClause::DirectiveNameModifier::TargetUpdate: + default: + llvm::errs() << "Invalid modifier in IF clause\n"; + return false; + } + + if (auto *hasDir = compInfo.findDirective(dirId)) { + hasDir->clauses.push_back(clauseNode); + return true; + } + llvm::errs() << "Directive from modifier not found\n"; + return false; + } + + if (applyToAll(clauseId, clauseNode, compInfo)) + return true; + + llvm::errs() << "Cannot apply IF\n"; + return false; +} + +// LINEAR +static bool applyClause(const Fortran::parser::OmpClause::Linear &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpLinearClause &contents, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + // S The effect of the linear clause is as if it is applied to the innermost + // S leaf construct. + if (applyToInnermost(clauseId, clauseNode, compInfo)) { + llvm::errs() << "Cannot apply LINEAR\n"; + return false; + } + + // The rest is about SIMD. + if (!compInfo.findDirective(llvm::omp::OMPD_simd)) + return true; + + const std::list &names = std::visit( + [](auto &&s) { + // Both alternatives have member "names". + return s.names; + }, + contents.u); + Fortran::semantics::Symbol *iterVarSym = + getIterationVariableSymbol(compInfo.eval); + + // S Additionally, if the list item is not the iteration variable of a + // S simd or worksharing-loop SIMD construct, the effect on the outer leaf + // S constructs is as if the list item was specified in firstprivate and + // S lastprivate clauses on the combined or composite construct, [...] + // + // S If a list item of the linear clause is the iteration variable of a + // S simd or worksharing-loop SIMD construct and it is not declared in + // S the construct, the effect on the outer leaf constructs is as if the + // S list item was specified in a lastprivate clause on the combined or + // S composite construct [...] + + // It's not clear how an object can be listed in a clause AND be the + // iteration variable of a construct in which is it declared. If an + // object is declared in the construct, then the declaration is located + // after the clause listing it. + + // Lists of objects that will be used to construct FIRSTPRIVATE and + // LASTPRIVATE clauses. + std::list first, last; + + auto makeObjectFromName = [](Fortran::parser::Name name) { + // Pass "name" by copy. + static_assert(!std::is_lvalue_reference_v); + + auto source = name.source; + Fortran::parser::Designator designator( + Fortran::parser::DataRef(std::move(name))); + designator.source = source; + + return Fortran::parser::OmpObject(std::move(designator)); + }; + + for (const Fortran::parser::Name &name : names) { + last.emplace_back(makeObjectFromName(name)); + if (getOmpObjectSymbol(name) != iterVarSym) + first.emplace_back(makeObjectFromName(name)); + } + + auto addClause = [&](auto &&specific) { + // Take a specific clause, i.e. Fortran::parse::OmpClause::Xyz, + // wrap it into a general OmpClause, and add it to compInfo. + auto general = + std::make_unique(std::move(specific)); + compInfo.storage.emplace_back(std::move(general)); + compInfo.addClause(compInfo.storage.back().get()); + }; + + if (!first.empty()) { + Fortran::parser::OmpObjectList objList(std::move(first)); + addClause(Fortran::parser::OmpClause::Firstprivate(std::move(objList))); + } + if (!last.empty()) { + Fortran::parser::OmpObjectList objList(std::move(first)); + addClause(Fortran::parser::OmpClause::Lastprivate(std::move(objList))); + } + + return true; +} + +// NOWAIT +static bool applyClause(const Fortran::parser::OmpClause::Nowait &clause, + llvm::omp::Clause clauseId, + const Fortran::parser::OmpClause *clauseNode, + CompositeInfo &compInfo) { + if (applyToOutermost(clauseId, clauseNode, compInfo)) + return true; + llvm::errs() << "Cannot apply NOWAIT\n"; + return false; +} + +bool CompositeInfo::split() { + bool success = true; + + // First we need to apply LINEAR, because it can generate additional + // FIRSTPRIVATE and LASTPRIVATE clauses that apply to the compound/ + // composite construct. + // Collect them separately, because they may modify the clause list. + llvm::SmallVector linears; + for (const auto *clause : clauses) { + if (getClauseId(*clause) == llvm::omp::Clause::OMPC_linear) + linears.push_back(clause); + } + for (const auto *clause : linears) { + success = success && + detail::visit_clause( + *clause, [&](auto &&...args) { return applyClause(args...); }, + *this); + } + + // ALLOCATE clauses need to be applied last since they need to see + // which directives have data-privatizing clauses. + auto skip = [&](auto *clause) { + switch (getClauseId(*clause)) { + case llvm::omp::Clause::OMPC_allocate: + case llvm::omp::Clause::OMPC_linear: + return true; + default: + return false; + } + }; + + // Apply (almost) all clauses. + for (const auto *clause : clauses) { + if (skip(clause)) + continue; + success = success && + detail::visit_clause( + *clause, [&](auto &&...args) { return applyClause(args...); }, + *this); + } + + // Apply ALLOCATE. + for (const auto *clause : clauses) { + if (getClauseId(*clause) != llvm::omp::Clause::OMPC_allocate) + continue; + success = success && + detail::visit_clause( + *clause, [&](auto &&...args) { return applyClause(args...); }, + *this); + } + + return success; +} + +static void +splitCompositeConstruct(const mlir::ModuleOp &modOp, + Fortran::lower::pft::Evaluation &eval, + llvm::omp::Directive compDir, + const std::list &clauses) { + llvm::errs() << "composite name:" + << llvm::omp::getOpenMPDirectiveName(compDir) << '\n'; + llvm::errs() << "clause list:"; + for (auto &clause : clauses) + llvm::errs() << ' ' << llvm::omp::getOpenMPClauseName(getClauseId(clause)); + llvm::errs() << '\n'; + + CompositeInfo compInfo(modOp, eval, compDir, clauses); + llvm::errs() << "compInfo.1\n" << compInfo << '\n'; + + bool success = compInfo.split(); + + // Dump + llvm::errs() << "success:" << success << '\n'; + llvm::errs() << "compInfo.2\n" << compInfo << '\n'; +} + //===----------------------------------------------------------------------===// // DataSharingProcessor //===----------------------------------------------------------------------===// @@ -3347,6 +4425,10 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { const auto &beginLoopDirective = std::get(loopConstruct.t); + // Test call + splitCompositeConstruct(converter.getFirOpBuilder().getModule(), eval, + std::get<0>(beginLoopDirective.t).v, + std::get<1>(beginLoopDirective.t).v); const auto &loopOpClauseList = std::get(beginLoopDirective.t); mlir::Location currentLocation = diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp index 98d9258e023e5..b3a8c9fb80f9c 100644 --- a/flang/tools/bbc/bbc.cpp +++ b/flang/tools/bbc/bbc.cpp @@ -342,7 +342,6 @@ static mlir::LogicalResult convertFortranSourceToMLIR( semanticsContext.targetCharacteristics(), parsing.allCooked(), targetTriple, kindMap, loweringOptions, {}, semanticsContext.languageFeatures(), targetMachine); - burnside.lower(parseTree, semanticsContext); mlir::ModuleOp mlirModule = burnside.getModule(); if (enableOpenMP) { if (enableOpenMPGPU && !enableOpenMPDevice) { @@ -358,6 +357,7 @@ static mlir::LogicalResult convertFortranSourceToMLIR( setOffloadModuleInterfaceAttributes(mlirModule, offloadModuleOpts); setOpenMPVersionAttribute(mlirModule, setOpenMPVersion); } + burnside.lower(parseTree, semanticsContext); std::error_code ec; std::string outputName = outputFilename; if (!outputName.size()) diff --git a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td index 31578710365b2..139c794cd4985 100644 --- a/llvm/include/llvm/Frontend/Directive/DirectiveBase.td +++ b/llvm/include/llvm/Frontend/Directive/DirectiveBase.td @@ -152,6 +152,9 @@ class Directive { // List of clauses that are required. list requiredClauses = []; + // List of names of leaf constituent directives. + list leafs = []; + // Set directive used by default when unknown. bit isDefault = false; } diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index 2388abac81ceb..d51f471466669 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -771,6 +771,7 @@ def OMP_TargetParallel : Directive<"target parallel"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "parallel"]; } def OMP_TargetParallelFor : Directive<"target parallel for"> { let allowedClauses = [ @@ -803,6 +804,7 @@ def OMP_TargetParallelFor : Directive<"target parallel for"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "parallel", "for"]; } def OMP_TargetParallelDo : Directive<"target parallel do"> { let allowedClauses = [ @@ -833,6 +835,7 @@ def OMP_TargetParallelDo : Directive<"target parallel do"> { VersionedClause, VersionedClause ]; + let leafs = ["target", "parallel", "do"]; } def OMP_TargetUpdate : Directive<"target update"> { let allowedClauses = [ @@ -866,6 +869,7 @@ def OMP_ParallelFor : Directive<"parallel for"> { VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "for"]; } def OMP_ParallelDo : Directive<"parallel do"> { let allowedClauses = [ @@ -887,6 +891,7 @@ def OMP_ParallelDo : Directive<"parallel do"> { VersionedClause, VersionedClause ]; + let leafs = ["parallel", "do"]; } def OMP_ParallelForSimd : Directive<"parallel for simd"> { let allowedClauses = [ @@ -912,6 +917,7 @@ def OMP_ParallelForSimd : Directive<"parallel for simd"> { VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "for", "simd"]; } def OMP_ParallelDoSimd : Directive<"parallel do simd"> { let allowedClauses = [ @@ -938,6 +944,7 @@ def OMP_ParallelDoSimd : Directive<"parallel do simd"> { VersionedClause, VersionedClause ]; + let leafs = ["parallel", "do", "simd"]; } def OMP_ParallelMaster : Directive<"parallel master"> { let allowedClauses = [ @@ -953,6 +960,7 @@ def OMP_ParallelMaster : Directive<"parallel master"> { VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "master"]; } def OMP_ParallelMasked : Directive<"parallel masked"> { let allowedClauses = [ @@ -969,6 +977,7 @@ def OMP_ParallelMasked : Directive<"parallel masked"> { VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "masked"]; } def OMP_ParallelSections : Directive<"parallel sections"> { let allowedClauses = [ @@ -987,6 +996,7 @@ def OMP_ParallelSections : Directive<"parallel sections"> { VersionedClause, VersionedClause ]; + let leafs = ["parallel", "sections"]; } def OMP_ForSimd : Directive<"for simd"> { let allowedClauses = [ @@ -1007,6 +1017,7 @@ def OMP_ForSimd : Directive<"for simd"> { VersionedClause, VersionedClause ]; + let leafs = ["for", "simd"]; } def OMP_DoSimd : Directive<"do simd"> { let allowedClauses = [ @@ -1027,6 +1038,7 @@ def OMP_DoSimd : Directive<"do simd"> { VersionedClause, VersionedClause ]; + let leafs = ["do", "simd"]; } def OMP_CancellationPoint : Directive<"cancellation point"> {} def OMP_DeclareReduction : Directive<"declare reduction"> {} @@ -1104,6 +1116,7 @@ def OMP_TaskLoopSimd : Directive<"taskloop simd"> { VersionedClause, VersionedClause ]; + let leafs = ["taskloop", "simd"]; } def OMP_Distribute : Directive<"distribute"> { let allowedClauses = [ @@ -1156,6 +1169,7 @@ def OMP_DistributeParallelFor : Directive<"distribute parallel for"> { VersionedClause, VersionedClause, ]; + let leafs = ["distribute", "parallel", "for"]; } def OMP_DistributeParallelDo : Directive<"distribute parallel do"> { let allowedClauses = [ @@ -1179,6 +1193,7 @@ def OMP_DistributeParallelDo : Directive<"distribute parallel do"> { VersionedClause, VersionedClause ]; + let leafs = ["distribute", "parallel", "do"]; } def OMP_DistributeParallelForSimd : Directive<"distribute parallel for simd"> { let allowedClauses = [ @@ -1204,6 +1219,7 @@ def OMP_DistributeParallelForSimd : Directive<"distribute parallel for simd"> { VersionedClause, VersionedClause, ]; + let leafs = ["distribute", "parallel", "for", "simd"]; } def OMP_DistributeParallelDoSimd : Directive<"distribute parallel do simd"> { let allowedClauses = [ @@ -1228,6 +1244,7 @@ def OMP_DistributeParallelDoSimd : Directive<"distribute parallel do simd"> { VersionedClause, VersionedClause ]; + let leafs = ["distribute", "parallel", "do", "simd"]; } def OMP_DistributeSimd : Directive<"distribute simd"> { let allowedClauses = [ @@ -1254,6 +1271,7 @@ def OMP_DistributeSimd : Directive<"distribute simd"> { VersionedClause, VersionedClause ]; + let leafs = ["distribute", "simd"]; } def OMP_TargetParallelForSimd : Directive<"target parallel for simd"> { @@ -1291,6 +1309,7 @@ def OMP_TargetParallelForSimd : Directive<"target parallel for simd"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "parallel", "for", "simd"]; } def OMP_TargetParallelDoSimd : Directive<"target parallel do simd"> { let allowedClauses = [ @@ -1322,6 +1341,7 @@ def OMP_TargetParallelDoSimd : Directive<"target parallel do simd"> { VersionedClause, VersionedClause ]; + let leafs = ["target", "parallel", "do", "simd"]; } def OMP_TargetSimd : Directive<"target simd"> { let allowedClauses = [ @@ -1356,6 +1376,7 @@ def OMP_TargetSimd : Directive<"target simd"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "simd"]; } def OMP_TeamsDistribute : Directive<"teams distribute"> { let allowedClauses = [ @@ -1375,6 +1396,7 @@ def OMP_TeamsDistribute : Directive<"teams distribute"> { let allowedOnceClauses = [ VersionedClause ]; + let leafs = ["teams", "distribute"]; } def OMP_TeamsDistributeSimd : Directive<"teams distribute simd"> { let allowedClauses = [ @@ -1400,6 +1422,7 @@ def OMP_TeamsDistributeSimd : Directive<"teams distribute simd"> { VersionedClause, VersionedClause ]; + let leafs = ["teams", "distribute", "simd"]; } def OMP_TeamsDistributeParallelForSimd : @@ -1428,6 +1451,7 @@ def OMP_TeamsDistributeParallelForSimd : VersionedClause, VersionedClause, ]; + let leafs = ["teams", "distribute", "parallel", "for", "simd"]; } def OMP_TeamsDistributeParallelDoSimd : Directive<"teams distribute parallel do simd"> { @@ -1456,6 +1480,7 @@ def OMP_TeamsDistributeParallelDoSimd : VersionedClause, VersionedClause ]; + let leafs = ["teams", "distribute", "parallel", "do", "simd"]; } def OMP_TeamsDistributeParallelFor : Directive<"teams distribute parallel for"> { @@ -1479,6 +1504,7 @@ def OMP_TeamsDistributeParallelFor : VersionedClause, VersionedClause, ]; + let leafs = ["teams", "distribute", "parallel", "for"]; } def OMP_TeamsDistributeParallelDo : Directive<"teams distribute parallel do"> { @@ -1505,6 +1531,7 @@ let allowedOnceClauses = [ VersionedClause, VersionedClause ]; + let leafs = ["teams", "distribute", "parallel", "do"]; } def OMP_TargetTeams : Directive<"target teams"> { let allowedClauses = [ @@ -1532,6 +1559,7 @@ def OMP_TargetTeams : Directive<"target teams"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "teams"]; } def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> { let allowedClauses = [ @@ -1560,6 +1588,7 @@ def OMP_TargetTeamsDistribute : Directive<"target teams distribute"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "teams", "distribute"]; } def OMP_TargetTeamsDistributeParallelFor : @@ -1594,6 +1623,7 @@ def OMP_TargetTeamsDistributeParallelFor : let allowedOnceClauses = [ VersionedClause, ]; + let leafs = ["target", "teams", "distribute", "parallel", "for"]; } def OMP_TargetTeamsDistributeParallelDo : Directive<"target teams distribute parallel do"> { @@ -1628,6 +1658,7 @@ def OMP_TargetTeamsDistributeParallelDo : VersionedClause, VersionedClause ]; + let leafs = ["target", "teams", "distribute", "parallel", "do"]; } def OMP_TargetTeamsDistributeParallelForSimd : Directive<"target teams distribute parallel for simd"> { @@ -1666,6 +1697,7 @@ def OMP_TargetTeamsDistributeParallelForSimd : let allowedOnceClauses = [ VersionedClause, ]; + let leafs = ["target", "teams", "distribute", "parallel", "for", "simd"]; } def OMP_TargetTeamsDistributeParallelDoSimd : Directive<"target teams distribute parallel do simd"> { @@ -1704,6 +1736,7 @@ def OMP_TargetTeamsDistributeParallelDoSimd : VersionedClause, VersionedClause ]; + let leafs = ["target", "teams", "distribute", "parallel", "do", "simd"]; } def OMP_TargetTeamsDistributeSimd : Directive<"target teams distribute simd"> { @@ -1738,6 +1771,7 @@ def OMP_TargetTeamsDistributeSimd : VersionedClause, VersionedClause ]; + let leafs = ["target", "teams", "distribute", "simd"]; } def OMP_Allocate : Directive<"allocate"> { let allowedOnceClauses = [ @@ -1779,6 +1813,7 @@ def OMP_MasterTaskloop : Directive<"master taskloop"> { VersionedClause, VersionedClause ]; + let leafs = ["master", "taskloop"]; } def OMP_MaskedTaskloop : Directive<"masked taskloop"> { let allowedClauses = [ @@ -1801,6 +1836,7 @@ def OMP_MaskedTaskloop : Directive<"masked taskloop"> { VersionedClause, VersionedClause ]; + let leafs = ["masked", "taskloop"]; } def OMP_ParallelMasterTaskloop : Directive<"parallel master taskloop"> { @@ -1826,6 +1862,7 @@ def OMP_ParallelMasterTaskloop : VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "master", "taskloop"]; } def OMP_ParallelMaskedTaskloop : Directive<"parallel masked taskloop"> { @@ -1852,6 +1889,7 @@ def OMP_ParallelMaskedTaskloop : VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "masked", "taskloop"]; } def OMP_MasterTaskloopSimd : Directive<"master taskloop simd"> { let allowedClauses = [ @@ -1879,6 +1917,7 @@ def OMP_MasterTaskloopSimd : Directive<"master taskloop simd"> { VersionedClause, VersionedClause ]; + let leafs = ["master", "taskloop", "simd"]; } def OMP_MaskedTaskloopSimd : Directive<"masked taskloop simd"> { let allowedClauses = [ @@ -1907,6 +1946,7 @@ def OMP_MaskedTaskloopSimd : Directive<"masked taskloop simd"> { VersionedClause, VersionedClause ]; + let leafs = ["masked", "taskloop", "simd"]; } def OMP_ParallelMasterTaskloopSimd : Directive<"parallel master taskloop simd"> { @@ -1938,6 +1978,7 @@ def OMP_ParallelMasterTaskloopSimd : VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "master", "taskloop", "simd"]; } def OMP_ParallelMaskedTaskloopSimd : Directive<"parallel masked taskloop simd"> { @@ -1970,6 +2011,7 @@ def OMP_ParallelMaskedTaskloopSimd : VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "masked", "taskloop", "simd"]; } def OMP_Depobj : Directive<"depobj"> { let allowedClauses = [ @@ -2016,6 +2058,7 @@ def OMP_ParallelWorkshare : Directive<"parallel workshare"> { VersionedClause, VersionedClause ]; + let leafs = ["parallel", "workshare"]; } def OMP_Workshare : Directive<"workshare"> {} def OMP_EndDo : Directive<"end do"> { @@ -2102,6 +2145,7 @@ def OMP_teams_loop : Directive<"teams loop"> { VersionedClause, VersionedClause, ]; + let leafs = ["teams", "loop"]; } def OMP_target_teams_loop : Directive<"target teams loop"> { let allowedClauses = [ @@ -2131,6 +2175,7 @@ def OMP_target_teams_loop : Directive<"target teams loop"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "teams", "loop"]; } def OMP_parallel_loop : Directive<"parallel loop"> { let allowedClauses = [ @@ -2152,6 +2197,7 @@ def OMP_parallel_loop : Directive<"parallel loop"> { VersionedClause, VersionedClause, ]; + let leafs = ["parallel", "loop"]; } def OMP_target_parallel_loop : Directive<"target parallel loop"> { let allowedClauses = [ @@ -2183,11 +2229,13 @@ def OMP_target_parallel_loop : Directive<"target parallel loop"> { VersionedClause, VersionedClause, ]; + let leafs = ["target", "parallel", "loop"]; } def OMP_Metadirective : Directive<"metadirective"> { let allowedClauses = [VersionedClause]; let allowedOnceClauses = [VersionedClause]; } + def OMP_Unknown : Directive<"unknown"> { let isDefault = true; } diff --git a/llvm/include/llvm/TableGen/DirectiveEmitter.h b/llvm/include/llvm/TableGen/DirectiveEmitter.h index c86018715a48a..88fef74e298bf 100644 --- a/llvm/include/llvm/TableGen/DirectiveEmitter.h +++ b/llvm/include/llvm/TableGen/DirectiveEmitter.h @@ -121,6 +121,10 @@ class Directive : public BaseRecord { std::vector getRequiredClauses() const { return Def->getValueAsListOfDefs("requiredClauses"); } + + std::vector getLeafConstructNames() const { + return Def->getValueAsListOfStrings("leafs"); + } }; // Wrapper class that contains Clause's information defined in DirectiveBase.td diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp index b6aee665f8ee0..232b1a4e6f7b5 100644 --- a/llvm/utils/TableGen/DirectiveEmitter.cpp +++ b/llvm/utils/TableGen/DirectiveEmitter.cpp @@ -186,6 +186,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { if (DirLang.hasEnableBitmaskEnumInNamespace()) OS << "\n#include \"llvm/ADT/BitmaskEnum.h\"\n"; + OS << "#include \"llvm/ADT/SmallVector.h\"\n"; OS << "\n"; OS << "namespace llvm {\n"; @@ -231,6 +232,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { OS << "bool isAllowedClauseForDirective(Directive D, " << "Clause C, unsigned Version);\n"; OS << "\n"; + OS << "const llvm::SmallVector &getLeafConstructs(Directive D);\n"; if (EnumHelperFuncs.length() > 0) { OS << EnumHelperFuncs; OS << "\n"; @@ -435,6 +437,82 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang, OS << "}\n"; // End of function isAllowedClauseForDirective } +// Generate the getLeafConstructs function implementation. +static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang, + raw_ostream &OS) { + llvm::StringMap NameToRec; + for (Record *R : DirLang.getDirectives()) + NameToRec.insert(std::make_pair(BaseRecord(R).getName(), R)); + + auto getQualifiedName = [&](StringRef Formatted) -> std::string { + return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + + "::Directive::" + DirLang.getDirectivePrefix() + Formatted) + .str(); + }; + + // For each list of leafs, generate a static local object, then + // return a reference to that object for a given directive, e.g. + // + // static ListTy leafConstructs_A_B = { A, B }; + // static ListTy leafConstructs_C_D_E = { C, D, E }; + // switch (Dir) { + // case A_B: + // return leafConstructs_A_B; + // case C_D_E: + // return leafConstructs_C_D_E; + + // Map from a record that defines a directive to the name of the + // local object with the list of its leafs. + DenseMap ListNames; + + std::string DirectiveTypeName = + std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive"; + std::string DirectiveListTypeName = + std::string("llvm::SmallVector<") + DirectiveTypeName + ">"; + + // const Container &llvm::::GetLeafConstructs(llvm::::Directive Dir) + OS << "const " << DirectiveListTypeName + << " &llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs(" + << DirectiveTypeName << " Dir) "; + OS << "{\n"; + + // Generate the locals. + for (auto &[_, R] : NameToRec) { + Directive Dir{R}; + + std::vector LeafNames = Dir.getLeafConstructNames(); + if (LeafNames.empty()) + continue; + + std::string ListName = "leafConstructs_" + Dir.getFormattedName(); + OS << " static " << DirectiveListTypeName << ' ' << ListName << " {\n"; + for (StringRef L : LeafNames) { + Directive LeafDir{NameToRec.at(L)}; + OS << " " << getQualifiedName(LeafDir.getFormattedName()) << ",\n"; + } + OS << " };\n"; + ListNames.insert(std::make_pair(R, std::move(ListName))); + } + + OS << " static " << DirectiveListTypeName << " nothing {};\n"; + + OS << '\n'; + OS << " switch (Dir) {\n"; + for (auto &[_, R] : NameToRec) { + auto F = ListNames.find(R); + if (F == ListNames.end()) + continue; + + Directive Dir{R}; + OS << " case " << getQualifiedName(Dir.getFormattedName()) << ":\n"; + OS << " return " << F->second << ";\n"; + } + OS << " default:\n"; + OS << " return nothing;\n"; + OS << " } // switch (Dir)\n"; + OS << "}\n"; +} + // Generate a simple enum set with the give clauses. static void GenerateClauseSet(const std::vector &Clauses, raw_ostream &OS, StringRef ClauseSetPrefix, @@ -876,6 +954,9 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang, // isAllowedClauseForDirective(Directive D, Clause C, unsigned Version) GenerateIsAllowedClause(DirLang, OS); + + // getLeafConstructs(Directive D) + GenerateGetLeafConstructs(DirLang, OS); } // Generate the implemenation section for the enumeration in the directive