diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h index e164f63aa189b..6dac64ecbaf08 100644 --- a/flang/include/flang/Parser/openmp-utils.h +++ b/flang/include/flang/Parser/openmp-utils.h @@ -237,6 +237,80 @@ struct OmpAllocateInfo { OmpAllocateInfo SplitOmpAllocate(const OmpAllocateDirective &x); +namespace detail { +template struct ConstIf { + using type = std::conditional_t, T>; +}; + +template +using ConstIfT = typename ConstIf::type; +} // namespace detail + +template struct LoopRange { + using QualBlock = detail::ConstIfT; + using QualReference = decltype(std::declval().front()); + using QualPointer = std::remove_reference_t *; + + LoopRange(QualBlock &x) { Initialize(x); } + LoopRange(QualReference x); + + LoopRange(detail::ConstIfT &x) + : LoopRange(std::get(x.t)) {} + LoopRange(detail::ConstIfT &x) + : LoopRange(std::get(x.t)) {} + + size_t size() const { return items.size(); } + bool empty() const { return items.size() == 0; } + + struct iterator; + + iterator begin(); + iterator end(); + +private: + void Initialize(QualBlock &body); + + std::vector items; +}; + +template LoopRange(T &x) -> LoopRange>; + +template struct LoopRange::iterator { + QualReference operator*() { return **at; } + + bool operator==(const iterator &other) const { return at == other.at; } + bool operator!=(const iterator &other) const { return at != other.at; } + + iterator &operator++() { + ++at; + return *this; + } + iterator &operator--() { + --at; + return *this; + } + iterator &operator++(int); + iterator &operator--(int); + +private: + friend struct LoopRange; + typename decltype(LoopRange::items)::iterator at; +}; + +template inline auto LoopRange::begin() -> iterator { + iterator x; + x.at = items.begin(); + return x; +} + +template inline auto LoopRange::end() -> iterator { + iterator x; + x.at = items.end(); + return x; +} + +using ConstLoopRange = LoopRange; + } // namespace Fortran::parser::omp #endif // FORTRAN_PARSER_OPENMP_UTILS_H diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp index 4c38917e87d29..1593b19d6b372 100644 --- a/flang/lib/Parser/openmp-utils.cpp +++ b/flang/lib/Parser/openmp-utils.cpp @@ -205,4 +205,39 @@ OmpAllocateInfo SplitOmpAllocate(const OmpAllocateDirective &x) { return info; } +template LoopRange::LoopRange(QualReference x) { + if (auto *doLoop{Unwrap(x)}) { + Initialize(std::get(doLoop->t)); + } else if (auto *omp{Unwrap(x)}) { + Initialize(std::get(omp->t)); + } +} + +template void LoopRange::Initialize(QualBlock &body) { + using QualIterator = decltype(std::declval().begin()); + auto makeRange{[](auto &container) { + return llvm::make_range(container.begin(), container.end()); + }}; + + std::vector> nest{makeRange(body)}; + do { + auto at{nest.back().begin()}; + auto end{nest.back().end()}; + nest.pop_back(); + while (at != end) { + if (auto *block{Unwrap(*at)}) { + nest.push_back(llvm::make_range(std::next(at), end)); + nest.push_back(makeRange(std::get(block->t))); + break; + } else if (Unwrap(*at) || Unwrap(*at)) { + items.push_back(&*at); + } + ++at; + } + } while (!nest.empty()); +} + +template struct LoopRange; +template struct LoopRange; + } // namespace Fortran::parser::omp