diff --git a/flang/include/flang/Common/enum-set.h b/flang/include/flang/Common/enum-set.h index e048c66a393d0..ce1129474f8e7 100644 --- a/flang/include/flang/Common/enum-set.h +++ b/flang/include/flang/Common/enum-set.h @@ -217,6 +217,16 @@ template class EnumSet { private: bitsetType bitset_{}; }; + +namespace detail { +template struct IsEnumSetTest { + static constexpr bool value{false}; +}; +template struct IsEnumSetTest> { + static constexpr bool value{true}; +}; +} // namespace detail +template constexpr bool IsEnumSet{detail::IsEnumSetTest::value}; } // namespace Fortran::common template diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h index 32fcd4182bed7..ed6512be14496 100644 --- a/flang/include/flang/Parser/dump-parse-tree.h +++ b/flang/include/flang/Parser/dump-parse-tree.h @@ -14,10 +14,12 @@ #include "parse-tree.h" #include "tools.h" #include "unparse.h" +#include "flang/Common/enum-set.h" #include "flang/Common/idioms.h" #include "flang/Common/indirection.h" #include "flang/Support/Fortran.h" #include "llvm/Frontend/OpenMP/OMP.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/raw_ostream.h" #include #include @@ -35,6 +37,19 @@ class ParseTreeDumper { : out_(out), asFortran_{asFortran} {} static constexpr const char *GetNodeName(const char *) { return "char *"; } + + template + static std::string GetMemberNames(const common::EnumSet &x) { + llvm::ListSeparator sep; + std::string s; + llvm::raw_string_ostream stream(s); + x.IterateOverMembers([&](E e) { stream << sep << T::EnumToString(e); }); + return stream.str(); + } +#define NODE_ENUMSET(T, S) \ + static std::string GetNodeName(const T::S &x) { \ + return #S " = {"s + GetMemberNames(x) + "}"s; \ + } #define NODE_NAME(T, N) \ static constexpr const char *GetNodeName(const T &) { return N; } #define NODE_ENUM(T, E) \ @@ -572,7 +587,8 @@ class ParseTreeDumper { NODE_ENUM(OmpDeviceTypeClause, DeviceTypeDescription) NODE(parser, OmpDirectiveName) NODE(parser, OmpDirectiveSpecification) - NODE_ENUM(OmpDirectiveSpecification, Flags) + NODE_ENUM(OmpDirectiveSpecification, Flag) + NODE_ENUMSET(OmpDirectiveSpecification, Flags) NODE(parser, OmpDoacross) NODE(OmpDoacross, Sink) NODE(OmpDoacross, Source) diff --git a/flang/include/flang/Parser/parse-tree-visitor.h b/flang/include/flang/Parser/parse-tree-visitor.h index af1d34ae804f3..7ebce671c5fd1 100644 --- a/flang/include/flang/Parser/parse-tree-visitor.h +++ b/flang/include/flang/Parser/parse-tree-visitor.h @@ -10,6 +10,7 @@ #define FORTRAN_PARSER_PARSE_TREE_VISITOR_H_ #include "parse-tree.h" +#include "flang/Common/enum-set.h" #include "flang/Common/visit.h" #include #include @@ -41,7 +42,7 @@ struct ParseTreeVisitorLookupScope { // Default case for visitation of non-class data members, strings, and // any other non-decomposable values. template - static std::enable_if_t || + static std::enable_if_t || common::IsEnumSet || std::is_same_v || std::is_same_v> Walk(const A &x, V &visitor) { if (visitor.Pre(x)) { diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h index 003d11721908e..e5d3d3f2a7d5b 100644 --- a/flang/include/flang/Parser/parse-tree.h +++ b/flang/include/flang/Parser/parse-tree.h @@ -22,6 +22,7 @@ #include "format-specification.h" #include "message.h" #include "provenance.h" +#include "flang/Common/enum-set.h" #include "flang/Common/idioms.h" #include "flang/Common/indirection.h" #include "flang/Common/reference.h" @@ -4975,7 +4976,9 @@ struct OmpClauseList { // --- Directives and constructs struct OmpDirectiveSpecification { - ENUM_CLASS(Flags, None, DeprecatedSyntax); + ENUM_CLASS(Flag, DeprecatedSyntax, CrossesLabelDo) + using Flags = common::EnumSet; + TUPLE_CLASS_BOILERPLATE(OmpDirectiveSpecification); const OmpDirectiveName &DirName() const { return std::get(t); diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index b033206d90c41..bd259a9c6e01d 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -1633,7 +1633,8 @@ TYPE_PARSER( maybe(Parser{}), maybe(parenthesized( OmpArgumentListParser{})), - pure(OmpDirectiveSpecification::Flags::DeprecatedSyntax)))) || + pure(OmpDirectiveSpecification::Flags( + {OmpDirectiveSpecification::Flag::DeprecatedSyntax}))))) || // Parse DECLARE_VARIANT individually, because the "[base:]variant" // argument will conflict with DECLARE_REDUCTION's "ident:types...". predicated(Parser{}, @@ -1643,13 +1644,13 @@ TYPE_PARSER( maybe(parenthesized(OmpArgumentListParser< llvm::omp::Directive::OMPD_declare_variant>{})), maybe(Parser{}), - pure(OmpDirectiveSpecification::Flags::None))) || + pure(OmpDirectiveSpecification::Flags()))) || // Parse the standard syntax: directive [(arguments)] [clauses] sourced(construct( // sourced(OmpDirectiveNameParser{}), maybe(parenthesized(OmpArgumentListParser<>{})), maybe(Parser{}), - pure(OmpDirectiveSpecification::Flags::None)))) + pure(OmpDirectiveSpecification::Flags())))) static bool IsStandaloneOrdered(const OmpDirectiveSpecification &dirSpec) { // An ORDERED construct is standalone if it has DOACROSS or DEPEND clause. diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index 3854d33d46d48..8e9c7d04bc522 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -2142,7 +2142,7 @@ class UnparseVisitor { Walk(std::get(x.t)); auto flags{std::get(x.t)}; - if (flags == OmpDirectiveSpecification::Flags::DeprecatedSyntax) { + if (flags.test(OmpDirectiveSpecification::Flag::DeprecatedSyntax)) { if (x.DirId() == llvm::omp::Directive::OMPD_flush) { // FLUSH clause arglist unparseClauses(); @@ -2539,8 +2539,8 @@ class UnparseVisitor { void Unparse(const OpenMPInteropConstruct &x) { BeginOpenMP(); Word("!$OMP INTEROP"); - using Flags = OmpDirectiveSpecification::Flags; - if (std::get(x.v.t) == Flags::DeprecatedSyntax) { + auto flags{std::get(x.v.t)}; + if (flags.test(OmpDirectiveSpecification::Flag::DeprecatedSyntax)) { Walk("(", std::get>(x.v.t), ")"); Walk(" ", std::get>(x.v.t)); } else { @@ -2679,8 +2679,8 @@ class UnparseVisitor { void Unparse(const OpenMPFlushConstruct &x) { BeginOpenMP(); Word("!$OMP FLUSH"); - using Flags = OmpDirectiveSpecification::Flags; - if (std::get(x.v.t) == Flags::DeprecatedSyntax) { + auto flags{std::get(x.v.t)}; + if (flags.test(OmpDirectiveSpecification::Flag::DeprecatedSyntax)) { Walk("(", std::get>(x.v.t), ")"); Walk(" ", std::get>(x.v.t)); } else { diff --git a/flang/lib/Semantics/canonicalize-do.cpp b/flang/lib/Semantics/canonicalize-do.cpp index 409195d5960b4..a0a6f8d870f6e 100644 --- a/flang/lib/Semantics/canonicalize-do.cpp +++ b/flang/lib/Semantics/canonicalize-do.cpp @@ -92,8 +92,11 @@ class CanonicalizationOfDoLoops { [&](common::Indirection &construct) { // If the body of the OpenMP construct ends with a label, // treat the label as ending the construct itself. - CanonicalizeIfMatch( - block, stack, i, omp::GetFinalLabel(construct.value())); + OpenMPConstruct &omp{construct.value()}; + if (CanonicalizeIfMatch( + block, stack, i, omp::GetFinalLabel(omp))) { + MarkOpenMPConstruct(omp); + } }, }, executableConstruct->u); @@ -103,12 +106,12 @@ class CanonicalizationOfDoLoops { private: template - void CanonicalizeIfMatch(Block &originalBlock, std::vector &stack, + bool CanonicalizeIfMatch(Block &originalBlock, std::vector &stack, Block::iterator &i, Statement &statement) { - CanonicalizeIfMatch(originalBlock, stack, i, statement.label); + return CanonicalizeIfMatch(originalBlock, stack, i, statement.label); } - void CanonicalizeIfMatch(Block &originalBlock, std::vector &stack, + bool CanonicalizeIfMatch(Block &originalBlock, std::vector &stack, Block::iterator &i, std::optional