diff --git a/flang/include/flang/Common/idioms.h b/flang/include/flang/Common/idioms.h index 84a8fd5be4cba..1a086162f1e2a 100644 --- a/flang/include/flang/Common/idioms.h +++ b/flang/include/flang/Common/idioms.h @@ -23,6 +23,7 @@ #error g++ >= 7.2 is required #endif +#include "visit.h" #include "llvm/Support/Compiler.h" #include #include @@ -49,8 +50,8 @@ using namespace std::literals::string_literals; namespace Fortran::common { // Helper templates for combining a list of lambdas into an anonymous -// struct for use with std::visit() on a std::variant<> sum type. -// E.g.: std::visit(visitors{ +// struct for use with common::visit() on a std::variant<> sum type. +// E.g.: common::visit(visitors{ // [&](const firstType &x) { ... }, // [&](const secondType &x) { ... }, // ... diff --git a/flang/include/flang/Common/template.h b/flang/include/flang/Common/template.h index f31b0afa97fb7..2a9958f74db38 100644 --- a/flang/include/flang/Common/template.h +++ b/flang/include/flang/Common/template.h @@ -118,14 +118,14 @@ template const A *GetPtrFromOptional(const std::optional &x) { // Copy a value from one variant type to another. The types allowed in the // source variant must all be allowed in the destination variant type. template TOV CopyVariant(const FROMV &u) { - return std::visit([](const auto &x) -> TOV { return {x}; }, u); + return common::visit([](const auto &x) -> TOV { return {x}; }, u); } // Move a value from one variant type to another. The types allowed in the // source variant must all be allowed in the destination variant type. template common::IfNoLvalue MoveVariant(FROMV &&u) { - return std::visit( + return common::visit( [](auto &&x) -> TOV { return {std::move(x)}; }, std::move(u)); } diff --git a/flang/include/flang/Common/unwrap.h b/flang/include/flang/Common/unwrap.h index b6ea4a1546096..edb343d77b537 100644 --- a/flang/include/flang/Common/unwrap.h +++ b/flang/include/flang/Common/unwrap.h @@ -12,6 +12,7 @@ #include "indirection.h" #include "reference-counted.h" #include "reference.h" +#include "visit.h" #include #include #include @@ -103,7 +104,7 @@ struct UnwrapperHelper { template static A *Unwrap(std::variant &u) { - return std::visit( + return common::visit( [](auto &x) -> A * { using Ty = std::decay_t(x))>; if constexpr (!std::is_const_v> || @@ -117,7 +118,7 @@ struct UnwrapperHelper { template static auto Unwrap(const std::variant &u) -> std::add_const_t * { - return std::visit( + return common::visit( [](const auto &x) -> std::add_const_t * { return Unwrap(x); }, u); } diff --git a/flang/include/flang/Common/visit.h b/flang/include/flang/Common/visit.h new file mode 100644 index 0000000000000..db68fdbbf1099 --- /dev/null +++ b/flang/include/flang/Common/visit.h @@ -0,0 +1,94 @@ +//===-- include/flang/Common/visit.h ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// common::visit() is a drop-in replacement for std::visit() that reduces both +// compiler build time and compiler execution time modestly, and reduces +// compiler build memory requirements significantly (overall & maximum). +// It does not require redefinition of std::variant<>. +// +// The C++ standard mandates that std::visit be O(1), but most variants are +// small and O(logN) is faster in practice to compile and execute, avoiding +// the need to build a dispatch table. +// +// Define FLANG_USE_STD_VISIT to avoid this code and make common::visit() an +// alias for ::std::visit(). +// +// +// With GCC 9.3.0 on a Haswell x86 Ubuntu system, doing out-of-tree builds: +// Before: +// build: +// 6948.53user 212.48system 27:32.92elapsed 433%CPU +// (0avgtext+0avgdata 6429568maxresident)k +// 36181912inputs+8943720outputs (3613684major+97908699minor)pagefaults 0swaps +// execution of tests: +// 205.99user 26.05system 1:08.87elapsed 336%CPU +// (0avgtext+0avgdata 2671452maxresident)k +// 244432inputs+355464outputs (422major+8746468minor)pagefaults 0swaps +// After: +// build: +// 6651.91user 182.57system 25:15.73elapsed 450%CPU +// (0avgtext+0avgdata 6209296maxresident)k +// 17413480inputs+6376360outputs (1567210major+93068230minor)pagefaults 0swaps +// execution of tests: +// 201.42user 25.91system 1:04.68elapsed 351%CPU +// (0avgtext+0avgdata 2661424maxresident)k +// 238840inputs+295912outputs (428major+8489300minor)pagefaults 0swaps + +#ifndef FORTRAN_COMMON_VISIT_H_ +#define FORTRAN_COMMON_VISIT_H_ + +#include +#include + +namespace Fortran::common { +namespace log2visit { + +template +inline RESULT Log2VisitHelper( + VISITOR &&visitor, std::size_t which, VARIANT &&...u) { + if constexpr (LOW == HIGH) { + return visitor(std::get(std::forward(u))...); + } else { + static constexpr std::size_t mid{(HIGH + LOW) / 2}; + if (which <= mid) { + return Log2VisitHelper( + std::forward(visitor), which, std::forward(u)...); + } else { + return Log2VisitHelper<(mid + 1), HIGH, RESULT>( + std::forward(visitor), which, std::forward(u)...); + } + } +} + +template +inline auto visit(VISITOR &&visitor, VARIANT &&...u) + -> decltype(visitor(std::get<0>(std::forward(u))...)) { + using Result = decltype(visitor(std::get<0>(std::forward(u))...)); + if constexpr (sizeof...(u) == 1) { + static constexpr std::size_t high{ + (std::variant_size_v> * ...) - 1}; + return Log2VisitHelper<0, high, Result>(std::forward(visitor), + u.index()..., std::forward(u)...); + } else { + // TODO: figure out how to do multiple variant arguments + return ::std::visit( + std::forward(visitor), std::forward(u)...); + } +} + +} // namespace log2visit + +#ifdef FLANG_USE_STD_VISIT +using ::std::visit; +#else +using Fortran::common::log2visit::visit; +#endif + +} // namespace Fortran::common +#endif // FORTRAN_COMMON_VISIT_H_ diff --git a/flang/include/flang/Evaluate/expression.h b/flang/include/flang/Evaluate/expression.h index f24750e4f0944..90309affbe257 100644 --- a/flang/include/flang/Evaluate/expression.h +++ b/flang/include/flang/Evaluate/expression.h @@ -646,7 +646,7 @@ template <> class Relational { EVALUATE_UNION_CLASS_BOILERPLATE(Relational) static constexpr DynamicType GetType() { return Result::GetType(); } int Rank() const { - return std::visit([](const auto &x) { return x.Rank(); }, u); + return common::visit([](const auto &x) { return x.Rank(); }, u); } llvm::raw_ostream &AsFortran(llvm::raw_ostream &o) const; common::MapTemplate u; diff --git a/flang/include/flang/Evaluate/fold-designator.h b/flang/include/flang/Evaluate/fold-designator.h index 457e86d4fdad9..f246bd12020e0 100644 --- a/flang/include/flang/Evaluate/fold-designator.h +++ b/flang/include/flang/Evaluate/fold-designator.h @@ -67,7 +67,7 @@ class DesignatorFolder { template std::optional FoldDesignator(const Expr &expr) { - return std::visit( + return common::visit( [&](const auto &x) { return FoldDesignator(x, elementNumber_++); }, expr.u); } @@ -98,7 +98,7 @@ class DesignatorFolder { template std::optional FoldDesignator( const Expr &expr, ConstantSubscript which) { - return std::visit( + return common::visit( [&](const auto &x) { return FoldDesignator(x, which); }, expr.u); } @@ -110,14 +110,14 @@ class DesignatorFolder { template std::optional FoldDesignator( const Designator &designator, ConstantSubscript which) { - return std::visit( + return common::visit( [&](const auto &x) { return FoldDesignator(x, which); }, designator.u); } template std::optional FoldDesignator( const Designator> &designator, ConstantSubscript which) { - return std::visit( + return common::visit( common::visitors{ [&](const Substring &ss) { if (const auto *dataRef{ss.GetParentIf()}) { diff --git a/flang/include/flang/Evaluate/initial-image.h b/flang/include/flang/Evaluate/initial-image.h index 596b1f77d1790..fcc18835418a8 100644 --- a/flang/include/flang/Evaluate/initial-image.h +++ b/flang/include/flang/Evaluate/initial-image.h @@ -88,7 +88,7 @@ class InitialImage { template Result Add(ConstantSubscript offset, std::size_t bytes, const Expr &x, FoldingContext &c) { - return std::visit( + return common::visit( [&](const auto &y) { return Add(offset, bytes, y, c); }, x.u); } diff --git a/flang/include/flang/Evaluate/shape.h b/flang/include/flang/Evaluate/shape.h index 246f346dcc327..378c0d6734f40 100644 --- a/flang/include/flang/Evaluate/shape.h +++ b/flang/include/flang/Evaluate/shape.h @@ -167,7 +167,7 @@ class GetShapeHelper template MaybeExtentExpr GetArrayConstructorValueExtent( const ArrayConstructorValue &value) const { - return std::visit( + return common::visit( common::visitors{ [&](const Expr &x) -> MaybeExtentExpr { if (auto xShape{ diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index ae6772a871070..5523b1bf035b2 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -83,7 +83,7 @@ template bool IsAssumedRank(const Designator &designator) { } } template bool IsAssumedRank(const Expr &expr) { - return std::visit([](const auto &x) { return IsAssumedRank(x); }, expr.u); + return common::visit([](const auto &x) { return IsAssumedRank(x); }, expr.u); } template bool IsAssumedRank(const std::optional &x) { return x && IsAssumedRank(*x); @@ -100,7 +100,7 @@ template bool IsCoarray(const Designator &designator) { return false; } template bool IsCoarray(const Expr &expr) { - return std::visit([](const auto &x) { return IsCoarray(x); }, expr.u); + return common::visit([](const auto &x) { return IsCoarray(x); }, expr.u); } template bool IsCoarray(const std::optional &x) { return x && IsCoarray(*x); @@ -177,11 +177,11 @@ auto UnwrapExpr(B &x) -> common::Constify * { return UnwrapExpr(*expr); } } else if constexpr (std::is_same_v>) { - return std::visit([](auto &x) { return UnwrapExpr(x); }, x.u); + return common::visit([](auto &x) { return UnwrapExpr(x); }, x.u); } else if constexpr (!common::HasMember) { if constexpr (std::is_same_v>> || std::is_same_v::category>>>) { - return std::visit([](auto &x) { return UnwrapExpr(x); }, x.u); + return common::visit([](auto &x) { return UnwrapExpr(x); }, x.u); } } return nullptr; @@ -217,15 +217,17 @@ auto UnwrapConvertedExpr(B &x) -> common::Constify * { return UnwrapConvertedExpr(*expr); } } else if constexpr (std::is_same_v>) { - return std::visit([](auto &x) { return UnwrapConvertedExpr(x); }, x.u); + return common::visit( + [](auto &x) { return UnwrapConvertedExpr(x); }, x.u); } else if constexpr (!common::HasMember) { using Result = ResultType; if constexpr (std::is_same_v> || std::is_same_v>>) { - return std::visit([](auto &x) { return UnwrapConvertedExpr(x); }, x.u); + return common::visit( + [](auto &x) { return UnwrapConvertedExpr(x); }, x.u); } else if constexpr (std::is_same_v> || std::is_same_v>) { - return std::visit( + return common::visit( [](auto &x) { return UnwrapConvertedExpr(x); }, x.left().u); } } @@ -262,7 +264,7 @@ common::IfNoLvalue, A> ExtractDataRef( template std::optional ExtractDataRef( const Designator &d, bool intoSubstring = false) { - return std::visit( + return common::visit( [=](const auto &x) -> std::optional { if constexpr (common::HasMember) { return DataRef{x}; @@ -279,7 +281,7 @@ std::optional ExtractDataRef( template std::optional ExtractDataRef( const Expr &expr, bool intoSubstring = false) { - return std::visit( + return common::visit( [=](const auto &x) { return ExtractDataRef(x, intoSubstring); }, expr.u); } template @@ -328,7 +330,7 @@ bool IsArrayElement(const Expr &expr, bool intoSubstring = true, template std::optional ExtractNamedEntity(const A &x) { if (auto dataRef{ExtractDataRef(x, true)}) { - return std::visit( + return common::visit( common::visitors{ [](SymbolRef &&symbol) -> std::optional { return NamedEntity{symbol}; @@ -354,10 +356,10 @@ struct ExtractCoindexedObjectHelper { std::optional operator()(const CoarrayRef &x) const { return x; } template std::optional operator()(const Expr &expr) const { - return std::visit(*this, expr.u); + return common::visit(*this, expr.u); } std::optional operator()(const DataRef &dataRef) const { - return std::visit(*this, dataRef.u); + return common::visit(*this, dataRef.u); } std::optional operator()(const NamedEntity &named) const { if (const Component * component{named.UnwrapComponent()}) { @@ -449,7 +451,7 @@ Expr ConvertToType(Expr> &&x) { ConvertToType(std::move(x)), Expr{Constant{zero}}}}; } else if constexpr (FROMCAT == TypeCategory::Complex) { // Extract and convert the real component of a complex value - return std::visit( + return common::visit( [&](auto &&z) { using ZType = ResultType; using Part = typename ZType::Part; @@ -503,7 +505,7 @@ common::IfNoLvalue>, FROM> ConvertTo( template common::IfNoLvalue>, FROM> ConvertTo( const Expr> &to, FROM &&from) { - return std::visit( + return common::visit( [&](const auto &toKindExpr) { using KindExpr = std::decay_t; return AsCategoryExpr( @@ -515,7 +517,7 @@ common::IfNoLvalue>, FROM> ConvertTo( template common::IfNoLvalue, FROM> ConvertTo( const Expr &to, FROM &&from) { - return std::visit( + return common::visit( [&](const auto &toCatExpr) { return AsGenericExpr(ConvertTo(toCatExpr, std::move(from))); }, @@ -565,7 +567,7 @@ using SameKindExprs = template SameKindExprs AsSameKindExprs( Expr> &&x, Expr> &&y) { - return std::visit( + return common::visit( [&](auto &&kx, auto &&ky) -> SameKindExprs { using XTy = ResultType; using YTy = ResultType; @@ -626,7 +628,7 @@ Expr Combine(Expr &&x, Expr &&y) { template