Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][OpenMP] Main splitting functionality dev-complete #82003

Conversation

kparzysz
Copy link
Contributor

@kparzysz kparzysz commented Feb 16, 2024

This is still just the splitting part, it's not applied yet.

[flang][OpenMP] TableGen support for getting leaf constructs

Implement getLeafConstructs(D), which for a composite directive D will return the list of the constituent leaf directives.

[flang][OpenMP] Set OpenMP attributes in MLIR module in bbc before lowering

Right now attributes like OpenMP version or target attributes for offload are set after lowering in bbc. The flang frontend sets them before lowering, making them available in the lowering process.

This change sets them before lowering in bbc as well.

getOpenMPVersion

The new set of classes representing OpenMP classes mimics the
contents of parser::OmpClause, but differs in a few aspects:
- it can be easily created, copied, etc.
- is based on semantics::SomeExpr instead of parser objects.

The class `OmpObject` is represented by `omp::Object`, which contains
the symbol associated with the object, and semantics::MaybeExpr
representing the designator for the symbol reference.

This patch only introduces the new classes, they are not yet used
anywhere.
Temporarily rename old clause list to `clauses2`, old clause iterator
to `ClauseIterator2`.
Change `findUniqueClause` to iterate over `omp::Clause` objects,
modify all handlers to operate on 'omp::clause::xyz` equivalents.
…essor

Rename `findRepeatableClause` to `findRepeatableClause2`, and make the
new `findRepeatableClause` operate on new `omp::Clause` objects.

Leave `Map` unchanged, because it will require more changes for it to
work.
The related functions are `gatherDataOperandAddrAndBounds` and
`genBoundsOps`. The former is used in OpenACC as well, and it was
updated to pass evaluate::Expr instead of parser objects.

The difference in the test case comes from unfolded conversions
of index expressions, which are explicitly of type integer(kind=8).

Delete now unused `findRepeatableClause2` and `findClause2`.

Add `AsGenericExpr` that takes std::optional. It already returns optional
Expr. Making it accept an optional Expr as input would reduce the number
of necessary checks when handling frequent optional values in evaluator.
Remove `ClauseIterator2` and `clauses2` from ClauseProcessor.
[flang][OpenMP] TableGen support for getting leaf constructs

Implement getLeafConstructs(D), which for a composite directive D
will return the list of the constituent leaf directives.

[flang][OpenMP] Set OpenMP attributes in MLIR module in bbc before lowering

Right now attributes like OpenMP version or target attributes for offload
are set after lowering in bbc. The flang frontend sets them before lowering,
making them available in the lowering process.

This change sets them before lowering in bbc as well.

getOpenMPVersion
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp clang:openmp OpenMP related changes to Clang labels Feb 16, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 16, 2024

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Author: Krzysztof Parzyszek (kparzysz)

Changes

[flang][OpenMP] TableGen support for getting leaf constructs

Implement getLeafConstructs(D), which for a composite directive D will return the list of the constituent leaf directives.

[flang][OpenMP] Set OpenMP attributes in MLIR module in bbc before lowering

Right now attributes like OpenMP version or target attributes for offload are set after lowering in bbc. The flang frontend sets them before lowering, making them available in the lowering process.

This change sets them before lowering in bbc as well.

getOpenMPVersion


Patch is 63.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82003.diff

6 Files Affected:

  • (modified) flang/lib/Lower/OpenMP.cpp (+1034-10)
  • (modified) flang/tools/bbc/bbc.cpp (+1-1)
  • (modified) llvm/include/llvm/Frontend/Directive/DirectiveBase.td (+4)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+54-6)
  • (modified) llvm/include/llvm/TableGen/DirectiveEmitter.h (+4)
  • (modified) llvm/utils/TableGen/DirectiveEmitter.cpp (+77)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index e45ab842b15556..ed6a0063848b18 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
 #include "llvm/Support/CommandLine.h"
@@ -48,6 +49,29 @@ using DeclareTargetCapturePair =
 // Common helper functions
 //===----------------------------------------------------------------------===//
 
+static llvm::ArrayRef<llvm::omp::Directive> getWorksharing() {
+  static llvm::omp::Directive worksharing[] = {
+      llvm::omp::Directive::OMPD_do,     llvm::omp::Directive::OMPD_for,
+      llvm::omp::Directive::OMPD_scope,  llvm::omp::Directive::OMPD_sections,
+      llvm::omp::Directive::OMPD_single, llvm::omp::Directive::OMPD_workshare,
+  };
+  return worksharing;
+}
+
+static llvm::ArrayRef<llvm::omp::Directive> getWorksharingLoop() {
+  static llvm::omp::Directive worksharingLoop[] = {
+      llvm::omp::Directive::OMPD_do,
+      llvm::omp::Directive::OMPD_for,
+  };
+  return worksharingLoop;
+}
+
+static uint32_t getOpenMPVersion(const mlir::ModuleOp &mod) {
+  if (mlir::Attribute verAttr = mod->getAttr("omp.version"))
+    return llvm::cast<mlir::omp::VersionAttr>(verAttr).getVersion();
+  llvm_unreachable("Exoecting OpenMP version attribute in module");
+}
+
 static Fortran::semantics::Symbol *
 getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
   Fortran::semantics::Symbol *sym = nullptr;
@@ -166,6 +190,15 @@ struct SymDsgExtractor {
     return t;
   }
 
+  static semantics::Symbol *symbol_addr(const evaluate::SymbolRef &ref) {
+    // Symbols cannot be created after semantic checks, so all symbol
+    // pointers that are non-null must point to one of those pre-existing
+    // objects. Throughout the code, symbols are often pointed to by
+    // non-const pointers, so there is no harm in casting the constness
+    // away.
+    return const_cast<semantics::Symbol *>(&ref.get());
+  }
+
   template <typename T> //
   static SymDsg visit(T &&) {
     // Use this to see missing overloads:
@@ -175,19 +208,12 @@ struct SymDsgExtractor {
 
   template <typename T> //
   static SymDsg visit(const evaluate::Designator<T> &e) {
-    // Symbols cannot be created after semantic checks, so all symbol
-    // pointers that are non-null must point to one of those pre-existing
-    // objects. Throughout the code, symbols are often pointed to by
-    // non-const pointers, so there is no harm in casting the constness
-    // away.
-    return std::make_tuple(const_cast<semantics::Symbol *>(e.GetLastSymbol()),
+    return std::make_tuple(symbol_addr(*e.GetLastSymbol()),
                            evaluate::AsGenericExpr(AsRvalueRef(e)));
   }
 
   static SymDsg visit(const evaluate::ProcedureDesignator &e) {
-    // See comment above regarding const_cast.
-    return std::make_tuple(const_cast<semantics::Symbol *>(e.GetSymbol()),
-                           std::nullopt);
+    return std::make_tuple(symbol_addr(*e.GetSymbol()), std::nullopt);
   }
 
   template <typename T> //
@@ -313,6 +339,42 @@ std::optional<U> maybeApply(F &&func, const std::optional<T> &inp) {
   return std::move(func(*inp));
 }
 
+std::optional<Object>
+getBaseObject(const Object &object,
+              Fortran::semantics::SemanticsContext &semaCtx) {
+  // If it's just the symbol, then there is no base.
+  if (!object.dsg)
+    return std::nullopt;
+
+  auto maybeRef = evaluate::ExtractDataRef(*object.dsg);
+  if (!maybeRef)
+    return std::nullopt;
+
+  evaluate::DataRef ref = *maybeRef;
+
+  if (std::get_if<evaluate::SymbolRef>(&ref.u)) {
+    return std::nullopt;
+  } else if (auto *comp = std::get_if<evaluate::Component>(&ref.u)) {
+    const evaluate::DataRef &base = comp->base();
+    return Object{SymDsgExtractor::symbol_addr(base.GetLastSymbol()),
+                  evaluate::AsGenericExpr(SymDsgExtractor::AsRvalueRef(base))};
+  } else if (auto *arr = std::get_if<evaluate::ArrayRef>(&ref.u)) {
+    const evaluate::NamedEntity &base = arr->base();
+    evaluate::ExpressionAnalyzer ea{semaCtx};
+    if (auto *comp = base.UnwrapComponent()) {
+      return Object{
+          SymDsgExtractor::symbol_addr(comp->symbol()),
+          ea.Designate(evaluate::DataRef{SymDsgExtractor::AsRvalueRef(*comp)})};
+    } else if (base.UnwrapSymbolRef()) {
+      return std::nullopt;
+    }
+  } else {
+    assert(std::holds_alternative<evaluate::CoarrayRef>(ref.u));
+    llvm_unreachable("Coarray reference not supported at the moment");
+  }
+  return std::nullopt;
+}
+
 namespace clause {
 #ifdef EMPTY_CLASS
 #undef EMPTY_CLASS
@@ -1220,11 +1282,18 @@ struct Clause {
   clause::UnionOfAllClauses u;
 };
 
+template <typename Specific>
+Clause makeClause(llvm::omp::Clause id, Specific &&specific,
+                  parser::CharBlock source = {}) {
+  return Clause{source, id, specific};
+}
+
 Clause makeClause(const Fortran::parser::OmpClause &cls,
                   semantics::SemanticsContext &semaCtx) {
   return std::visit(
       [&](auto &&s) {
-        return Clause{cls.source, getClauseId(cls), clause::make(s, semaCtx)};
+        return makeClause(getClauseId(cls), clause::make(s, semaCtx),
+                          cls.source);
       },
       cls.u);
 }
@@ -1263,6 +1332,957 @@ static void gatherFuncAndVarSyms(
     symbolAndClause.emplace_back(clause, *object.sym);
 }
 
+//===----------------------------------------------------------------------===//
+// Directive decomposition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct DirectiveInfo {
+  llvm::omp::Directive id = llvm::omp::Directive::OMPD_unknown;
+  llvm::SmallVector<const omp::Clause *> clauses;
+};
+
+struct CompositeInfo {
+  CompositeInfo(const mlir::ModuleOp &modOp,
+                Fortran::semantics::SemanticsContext &semaCtx,
+                Fortran::lower::pft::Evaluation &ev,
+                llvm::omp::Directive compDir,
+                const Fortran::parser::OmpClauseList &clauseList);
+  using ClauseSet = std::set<const omp::Clause *>;
+
+  bool split();
+  void addClauseSymbols(const omp::Clause &clause);
+
+  DirectiveInfo *findDirective(llvm::omp::Directive dirId) {
+    for (DirectiveInfo &dir : leafs) {
+      if (dir.id == dirId)
+        return &dir;
+    }
+    return nullptr;
+  }
+  ClauseSet *findClauses(const omp::Object &object) {
+    if (auto found = syms.find(object.sym); found != syms.end())
+      return &found->second;
+    return nullptr;
+  }
+
+  Fortran::semantics::SemanticsContext &semaCtx;
+  const mlir::ModuleOp &mod;
+  Fortran::lower::pft::Evaluation &eval;
+
+  llvm::SmallVector<DirectiveInfo> leafs; // Ordered outer to inner.
+  omp::List<omp::Clause> clauses;
+  llvm::DenseMap<const Fortran::semantics::Symbol *, ClauseSet> syms;
+  llvm::DenseSet<const Fortran::semantics::Symbol *> mapBases;
+  // Storage for newly created clauses. Beware of invalidating addresses.
+  std::list<omp::Clause> extras;
+
+private:
+  void addClauseSymsToMap(const omp::Object &object, const omp::Clause *);
+  void addClauseSymsToMap(const omp::ObjectList &objects, const omp::Clause *);
+  void addClauseSymsToMap(const omp::SomeExpr &item, const omp::Clause *);
+  void addClauseSymsToMap(const omp::clause::Map &item, const omp::Clause *);
+
+  template <typename T>
+  void addClauseSymsToMap(const std::optional<T> &item, const omp::Clause *);
+  template <typename T>
+  void addClauseSymsToMap(const omp::List<T> &item, const omp::Clause *);
+  template <typename... T, size_t... Is>
+  void addClauseSymsToMap(const std::tuple<T...> &item, const omp::Clause *,
+                          std::index_sequence<Is...> = {});
+  template <typename T,
+            std::enable_if_t<std::is_enum_v<llvm::remove_cvref_t<T>>, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+  template <
+      typename T,
+      std::enable_if_t<llvm::remove_cvref_t<T>::EmptyTrait::value, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+  template <
+      typename T,
+      std::enable_if_t<llvm::remove_cvref_t<T>::WrapperTrait::value, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+  template <
+      typename T,
+      std::enable_if_t<llvm::remove_cvref_t<T>::TupleTrait::value, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+  template <
+      typename T,
+      std::enable_if_t<llvm::remove_cvref_t<T>::UnionTrait::value, int> = 0>
+  void addClauseSymsToMap(T &&item, const omp::Clause *);
+
+  // Apply a clause to the only directive that allows it. If there are no
+  // directives that allow it, or if there is more that one, do not apply
+  // anything and return false, otherwise return true.
+  bool applyToUnique(const omp::Clause *node);
+
+  // Apply a clause to the first directive in given range that allows it.
+  // If such a directive does not exist, return false, otherwise return true.
+  template <typename Iterator>
+  bool applyToFirst(const omp::Clause *node, const mlir::ModuleOp &mod,
+                    llvm::iterator_range<Iterator> range);
+
+  // Apply a clause to the innermost directive that allows it. If such a
+  // directive does not exist, return false, otherwise return true.
+  bool applyToInnermost(const omp::Clause *node);
+
+  // Apply a clause to the outermost directive that allows it. If such a
+  // directive does not exist, return false, otherwise return true.
+  bool applyToOutermost(const omp::Clause *node);
+
+  template <typename Predicate>
+  bool applyIf(const omp::Clause *node, Predicate shouldApply);
+
+  bool applyToAll(const omp::Clause *node);
+
+  template <typename Clause>
+  bool applyClause(Clause &&clause, const omp::Clause *node);
+
+  bool applyClause(const omp::clause::Collapse &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Private &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Firstprivate &clause,
+                   const omp::Clause *);
+  bool applyClause(const omp::clause::Lastprivate &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Shared &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Default &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::ThreadLimit &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Order &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Allocate &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Reduction &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::If &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Linear &clause, const omp::Clause *);
+  bool applyClause(const omp::clause::Nowait &clause, const omp::Clause *);
+};
+} // namespace
+
+CompositeInfo::CompositeInfo(const mlir::ModuleOp &modOp,
+                             Fortran::semantics::SemanticsContext &semaCtx,
+                             Fortran::lower::pft::Evaluation &ev,
+                             llvm::omp::Directive compDir,
+                             const Fortran::parser::OmpClauseList &clauseList)
+    : semaCtx(semaCtx), mod(modOp), eval(ev),
+      clauses(omp::makeList(clauseList, semaCtx)) {
+  for (llvm::omp::Directive dir : llvm::omp::getLeafConstructs(compDir))
+    leafs.push_back(DirectiveInfo{dir});
+
+  for (const omp::Clause &clause : clauses)
+    addClauseSymsToMap(clause, &clause);
+}
+
+[[maybe_unused]] static llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os, const DirectiveInfo &dirInfo) {
+  os << llvm::omp::getOpenMPDirectiveName(dirInfo.id);
+  for (auto [index, clause] : llvm::enumerate(dirInfo.clauses)) {
+    os << (index == 0 ? '\t' : ' ');
+    os << llvm::omp::getOpenMPClauseName(clause->id);
+  }
+  return os;
+}
+
+[[maybe_unused]] static llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os, const CompositeInfo &compInfo) {
+  for (const auto &[index, dirInfo] : llvm::enumerate(compInfo.leafs))
+    os << "leaf[" << index << "]: " << dirInfo << '\n';
+
+  os << "syms:\n";
+  for (const auto &[sym, clauses] : compInfo.syms) {
+    os << *sym << " -> {";
+    for (const auto *clause : clauses)
+      os << ' ' << llvm::omp::getOpenMPClauseName(clause->id);
+    os << " }\n";
+  }
+  os << "mapBases: {";
+  for (const auto &sym : compInfo.mapBases)
+    os << ' ' << *sym;
+  os << " }\n";
+  return os;
+}
+
+namespace detail {
+template <typename Container, typename Predicate>
+typename std::remove_reference_t<Container>::iterator
+find_unique(Container &&container, Predicate &&pred) {
+  auto first = std::find_if(container.begin(), container.end(), pred);
+  if (first == container.end())
+    return first;
+  auto second = std::find_if(std::next(first), container.end(), pred);
+  if (second == container.end())
+    return first;
+  return container.end();
+}
+} // namespace detail
+
+static Fortran::semantics::Symbol *
+getIterationVariableSymbol(const Fortran::lower::pft::Evaluation &eval) {
+  return eval.visit(Fortran::common::visitors{
+      [&](const Fortran::parser::DoConstruct &doLoop) {
+        if (const auto &maybeCtrl = doLoop.GetLoopControl()) {
+          using LoopControl = Fortran::parser::LoopControl;
+          if (auto *bounds = std::get_if<LoopControl::Bounds>(&maybeCtrl->u)) {
+            static_assert(
+                std::is_same_v<decltype(bounds->name),
+                               Fortran::parser::Scalar<Fortran::parser::Name>>);
+            return bounds->name.thing.symbol;
+          }
+        }
+        return static_cast<Fortran::semantics::Symbol *>(nullptr);
+      },
+      [](auto &&) {
+        return static_cast<Fortran::semantics::Symbol *>(nullptr);
+      },
+  });
+}
+
+void CompositeInfo::addClauseSymsToMap(const omp::Object &object,
+                                       const omp::Clause *node) {
+  syms[object.sym].insert(node);
+}
+
+void CompositeInfo::addClauseSymsToMap(const omp::ObjectList &objects,
+                                       const omp::Clause *node) {
+  for (auto &object : objects)
+    syms[object.sym].insert(node);
+}
+
+void CompositeInfo::addClauseSymsToMap(const omp::SomeExpr &expr,
+                                       const omp::Clause *node) {
+  // Nothing to do for expressions.
+}
+
+void CompositeInfo::addClauseSymsToMap(const omp::clause::Map &item,
+                                       const omp::Clause *node) {
+  auto &objects = std::get<omp::ObjectList>(item.t);
+  addClauseSymsToMap(objects, node);
+  for (auto &object : objects) {
+    if (auto base = omp::getBaseObject(object, semaCtx))
+      mapBases.insert(base->sym);
+  }
+}
+
+template <typename T>
+void CompositeInfo::addClauseSymsToMap(const std::optional<T> &item,
+                                       const omp::Clause *node) {
+  if (item)
+    addClauseSymsToMap(*item, node);
+}
+
+template <typename T>
+void CompositeInfo::addClauseSymsToMap(const omp::List<T> &item,
+                                       const omp::Clause *node) {
+  for (auto &s : item)
+    addClauseSymsToMap(s, node);
+}
+
+template <typename... T, size_t... Is>
+void CompositeInfo::addClauseSymsToMap(const std::tuple<T...> &item,
+                                       const omp::Clause *node,
+                                       std::index_sequence<Is...>) {
+  (void)node; // Silence strange warning from GCC.
+  (addClauseSymsToMap(std::get<Is>(item), node), ...);
+}
+
+template <typename T,
+          std::enable_if_t<std::is_enum_v<llvm::remove_cvref_t<T>>, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  // Nothing to do for enums.
+}
+
+template <typename T,
+          std::enable_if_t<llvm::remove_cvref_t<T>::EmptyTrait::value, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  // Nothing to do for an empty class.
+}
+
+template <
+    typename T,
+    std::enable_if_t<llvm::remove_cvref_t<T>::WrapperTrait::value, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  addClauseSymsToMap(item.v, node);
+}
+
+template <typename T,
+          std::enable_if_t<llvm::remove_cvref_t<T>::TupleTrait::value, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  constexpr size_t tuple_size =
+      std::tuple_size_v<llvm::remove_cvref_t<decltype(item.t)>>;
+  addClauseSymsToMap(item.t, node, std::make_index_sequence<tuple_size>{});
+}
+
+template <typename T,
+          std::enable_if_t<llvm::remove_cvref_t<T>::UnionTrait::value, int> = 0>
+void CompositeInfo::addClauseSymsToMap(T &&item, const omp::Clause *node) {
+  std::visit([&](auto &&s) { addClauseSymsToMap(s, node); }, item.u);
+}
+
+#if 1
+// Apply a clause to the only directive that allows it. If there are no
+// directives that allow it, or if there is more that one, do not apply
+// anything and return false, otherwise return true.
+bool CompositeInfo::applyToUnique(const omp::Clause *node) {
+  uint32_t version = getOpenMPVersion(mod);
+  auto unique = detail::find_unique(leafs, [=](const auto &dirInfo) {
+    return llvm::omp::isAllowedClauseForDirective(dirInfo.id, node->id,
+                                                  version);
+  });
+
+  if (unique != leafs.end()) {
+    unique->clauses.push_back(node);
+    return true;
+  }
+  return false;
+}
+
+// Apply a clause to the first directive in given range that allows it.
+// If such a directive does not exist, return false, otherwise return true.
+template <typename Iterator>
+bool CompositeInfo::applyToFirst(const omp::Clause *node,
+                                 const mlir::ModuleOp &mod,
+                                 llvm::iterator_range<Iterator> range) {
+  if (range.empty())
+    return false;
+
+  uint32_t version = getOpenMPVersion(mod);
+  for (DirectiveInfo &dir : range) {
+    if (!llvm::omp::isAllowedClauseForDirective(dir.id, node->id, version))
+      continue;
+    dir.clauses.push_back(node);
+    return true;
+  }
+  return false;
+}
+
+// Apply a clause to the innermost directive that allows it. If such a
+// directive does not exist, return false, otherwise return true.
+bool CompositeInfo::applyToInnermost(const omp::Clause *node) {
+  return applyToFirst(node, mod, llvm::reverse(leafs));
+}
+
+// Apply a clause to the outermost directive that allows it. If such a
+// directive does not exist, return false, otherwise return true.
+bool CompositeInfo::applyToOutermost(const omp::Clause *node) {
+  return applyToFirst(node, mod, llvm::iterator_range(leafs));
+}
+
+template <typename Predicate>
+bool CompositeInfo::applyIf(const omp::Clause *node, Predicate shouldApply) {
+  bool applied = false;
+  uint32_t version = getOpenMPVersion(mod);
+  for (DirectiveInfo &dir : leafs) {
+    if (!llvm::omp::isAllowedClauseForDirective(dir.id, node->id, version))
+      continue;
+    if (!shouldApply(dir))
+      continue;
+    dir.clauses.push_back(node);
+    applied = true;
+  }
+
+  return applied;
+}
+
+bool CompositeInfo::applyToAll(const omp::Clause *node) {
+  return applyIf(node, [](auto) { return true; });
+}
+
+template <typename Clause>
+bool CompositeInfo::applyClause(Clause &&clause, const omp::Clause *node) {
+  // The default behavior is to find the unique directive to which the
+  // given clause may be applied. If there are no such directives, or
+  // if there are multiple ones, flag an error.
+  // From "OpenMP Application Programming Interface", Version 5.2:
+  // S Some clauses are permitted only on a single leaf construct of the
+  // S combined or composite construct, in which case the effect is as if
+  // S the clause is applied to that specific construct. (p339, 31-33)
+  if (applyToUnique(node))
+    return true;
+
+  return false;
+}
+
+// COLLAPSE
+bool CompositeInfo::applyClause(const omp::clause::Collapse &clause,
+                                const omp::Clause *node) {
+  // Apply COLLAPSE to the innermost directive. If it's not one that
+  // allows it flag an error.
+  if (!leafs.empty()) {
+    DirectiveInfo &last = leafs.back();
+    uint32_t version = getOpenMPVersion(mod);
+
+    if (llvm::omp::isAllowedClauseForDirectiv...
[truncated]

@kparzysz kparzysz marked this pull request as draft February 16, 2024 15:51
@kparzysz
Copy link
Contributor Author

kparzysz commented Feb 16, 2024

This is a follow-up to the previous draft, based on the clause-representation stack.

Co-authored-by: Valentin Clement (バレンタイン クレメン) <clementval@gmail.com>
@kparzysz kparzysz force-pushed the users/kparzysz/spr/b06-dsp branch 2 times, most recently from b4351a1 to 158901f Compare February 23, 2024 14:44
@kparzysz kparzysz deleted the branch llvm:users/kparzysz/spr/b06-dsp March 15, 2024 21:42
@kparzysz kparzysz closed this Mar 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:openmp OpenMP related changes to Clang flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants