diff --git a/include/rfl/Variant.hpp b/include/rfl/Variant.hpp index 6178f541..b6041823 100644 --- a/include/rfl/Variant.hpp +++ b/include/rfl/Variant.hpp @@ -14,6 +14,7 @@ #include "internal/nth_element_t.hpp" #include "internal/variant/find_max_size.hpp" #include "internal/variant/is_alternative_type.hpp" +#include "internal/variant/result_t.hpp" namespace rfl { @@ -24,8 +25,6 @@ class Variant { static constexpr unsigned long num_bytes_ = max_size_wrapper_.size_; - using LargestType = typename decltype(max_size_wrapper_)::Type; - using DataType = std::array; using IndexType = @@ -35,8 +34,11 @@ class Variant { static constexpr IndexType size_ = sizeof...(AlternativeTypes); + template + using result_t = internal::variant::result_t; + template - struct Index {}; + using Index = std::integral_constant; public: Variant() { @@ -146,14 +148,17 @@ class Variant { } template - auto visit(F& _f) { - using FirstAlternative = internal::nth_element_t<0, AlternativeTypes...>; - using ResultType = std::remove_cvref_t< - std::invoke_result_t, FirstAlternative&>>; + result_t visit(F& _f) { + using ResultType = result_t; if constexpr (std::is_same_v) { bool visited = false; do_visit_no_result(_f, &visited, std::make_integer_sequence()); + } else if constexpr (std::is_reference_v) { + std::remove_reference_t* res = nullptr; + do_visit_with_reference(_f, &res, + std::make_integer_sequence()); + return *res; } else { auto res = std::optional(); do_visit_with_result(_f, &res, @@ -163,14 +168,17 @@ class Variant { } template - auto visit(F& _f) const { - using FirstAlternative = internal::nth_element_t<0, AlternativeTypes...>; - using ResultType = std::remove_cvref_t< - std::invoke_result_t, FirstAlternative&>>; + result_t visit(F& _f) const { + using ResultType = result_t; if constexpr (std::is_same_v) { bool visited = false; do_visit_no_result(_f, &visited, std::make_integer_sequence()); + } else if constexpr (std::is_reference_v) { + std::remove_reference_t* res = nullptr; + do_visit_with_reference(_f, &res, + std::make_integer_sequence()); + return *res; } else { auto res = std::optional(); do_visit_with_result(_f, &res, @@ -180,14 +188,17 @@ class Variant { } template - auto visit(const F& _f) { - using FirstAlternative = internal::nth_element_t<0, AlternativeTypes...>; - using ResultType = std::remove_cvref_t< - std::invoke_result_t, FirstAlternative&>>; + result_t visit(const F& _f) { + using ResultType = std::remove_reference_t>; if constexpr (std::is_same_v) { bool visited = false; do_visit_no_result(_f, &visited, std::make_integer_sequence()); + } else if constexpr (std::is_reference_v) { + std::remove_reference_t* res = nullptr; + do_visit_with_reference(_f, &res, + std::make_integer_sequence()); + return *res; } else { auto res = std::optional(); do_visit_with_result(_f, &res, @@ -197,14 +208,17 @@ class Variant { } template - auto visit(const F& _f) const { - using FirstAlternative = internal::nth_element_t<0, AlternativeTypes...>; - using ResultType = std::remove_cvref_t< - std::invoke_result_t, FirstAlternative&>>; + result_t visit(const F& _f) const { + using ResultType = result_t; if constexpr (std::is_same_v) { bool visited = false; do_visit_no_result(_f, &visited, std::make_integer_sequence()); + } else if constexpr (std::is_reference_v) { + std::remove_reference_t* res = nullptr; + do_visit_with_reference(_f, &res, + std::make_integer_sequence()); + return *res; } else { auto res = std::optional(); do_visit_with_result(_f, &res, @@ -264,32 +278,6 @@ class Variant { (visit_one(_f, _visited, Index<_is>{}), ...); } - template - void do_visit_with_result(F& _f, std::optional* _res, - std::integer_sequence) { - auto visit_one = [this](const F& _f, - std::optional* _res, - Index<_i>) { - if (!*_res && index_ == _i) { - *_res = _f(get_alternative<_i>()); - } - }; - (visit_one(_f, _res, Index<_is>{}), ...); - } - - template - void do_visit_with_result(F& _f, std::optional* _res, - std::integer_sequence) const { - auto visit_one = [this](const F& _f, - std::optional* _res, - Index<_i>) { - if (!*_res && index_ == _i) { - *_res = _f(get_alternative<_i>()); - } - }; - (visit_one(_f, _res, Index<_is>{}), ...); - } - template void do_visit_no_result(const F& _f, bool* _visited, std::integer_sequence) { @@ -316,6 +304,32 @@ class Variant { (visit_one(_f, _visited, Index<_is>{}), ...); } + template + void do_visit_with_result(F& _f, std::optional* _res, + std::integer_sequence) { + auto visit_one = [this](const F& _f, + std::optional* _res, + Index<_i>) { + if (!*_res && index_ == _i) { + *_res = _f(get_alternative<_i>()); + } + }; + (visit_one(_f, _res, Index<_is>{}), ...); + } + + template + void do_visit_with_result(F& _f, std::optional* _res, + std::integer_sequence) const { + auto visit_one = [this](const F& _f, + std::optional* _res, + Index<_i>) { + if (!*_res && index_ == _i) { + *_res = _f(get_alternative<_i>()); + } + }; + (visit_one(_f, _res, Index<_is>{}), ...); + } + template void do_visit_with_result(const F& _f, std::optional* _res, std::integer_sequence) { @@ -342,6 +356,54 @@ class Variant { (visit_one(_f, _res, Index<_is>{}), ...); } + template + void do_visit_with_reference(F& _f, ResultType** _res, + std::integer_sequence) { + const auto visit_one = [this](const F& _f, ResultType** _res, + Index<_i>) { + if (!*_res && index_ == _i) { + *_res = &_f(get_alternative<_i>()); + } + }; + (visit_one(_f, _res, Index<_is>{}), ...); + } + + template + void do_visit_with_reference(F& _f, ResultType** _res, + std::integer_sequence) const { + const auto visit_one = [this](const F& _f, ResultType** _res, + Index<_i>) { + if (!*_res && index_ == _i) { + *_res = &_f(get_alternative<_i>()); + } + }; + (visit_one(_f, _res, Index<_is>{}), ...); + } + + template + void do_visit_with_reference(const F& _f, ResultType** _res, + std::integer_sequence) { + const auto visit_one = [this](const F& _f, ResultType** _res, + Index<_i>) { + if (!*_res && index_ == _i) { + *_res = &_f(get_alternative<_i>()); + } + }; + (visit_one(_f, _res, Index<_is>{}), ...); + } + + template + void do_visit_with_reference(const F& _f, ResultType** _res, + std::integer_sequence) const { + const auto visit_one = [this](const F& _f, ResultType** _res, + Index<_i>) { + if (!*_res && index_ == _i) { + *_res = &_f(get_alternative<_i>()); + } + }; + (visit_one(_f, _res, Index<_is>{}), ...); + } + template auto& get_alternative() noexcept { using CurrentType = internal::nth_element_t<_i, AlternativeTypes...>; @@ -376,7 +438,7 @@ class Variant { IndexType index_; /// The underlying data, can be any of the underlying types. - alignas(LargestType) DataType data_; + alignas(AlternativeTypes...) DataType data_; }; template diff --git a/include/rfl/internal/variant/result_t.hpp b/include/rfl/internal/variant/result_t.hpp new file mode 100644 index 00000000..a906c5af --- /dev/null +++ b/include/rfl/internal/variant/result_t.hpp @@ -0,0 +1,16 @@ +#ifndef RFL_INTERNAL_VARIANT_RESULT_T_HPP_ +#define RFL_INTERNAL_VARIANT_RESULT_T_HPP_ + +#include + +#include "../nth_element_t.hpp" + +namespace rfl::internal::variant { + +template +using result_t = std::remove_cv_t, internal::nth_element_t<0, AlternativeTypes...>&>>; + +} // namespace rfl::internal::variant + +#endif diff --git a/include/rfl/visit.hpp b/include/rfl/visit.hpp index a2f78618..2d98067d 100644 --- a/include/rfl/visit.hpp +++ b/include/rfl/visit.hpp @@ -8,6 +8,7 @@ #include "internal/StringLiteral.hpp" #include "internal/VisitTree.hpp" #include "internal/VisitorWrapper.hpp" +#include "internal/variant/result_t.hpp" namespace rfl { @@ -22,66 +23,84 @@ inline auto visit(const Visitor& _visitor, const Literal<_fields...> _literal, } template -inline auto visit(F& _f, Variant& _v) { +inline internal::variant::result_t visit( + F& _f, Variant& _v) { return _v.visit(_f); } template -inline auto visit(F& _f, Variant&& _v) { +inline internal::variant::result_t visit( + F& _f, Variant&& _v) { return _v.visit(_f); } template -inline auto visit(F& _f, const Variant& _v) { +inline internal::variant::result_t visit( + F& _f, const Variant& _v) { return _v.visit(_f); } template -inline auto visit(const F& _f, Variant& _v) { +inline internal::variant::result_t visit( + const F& _f, Variant& _v) { return _v.visit(_f); } template -inline auto visit(const F& _f, Variant&& _v) { +inline internal::variant::result_t visit( + const F& _f, Variant&& _v) { return _v.visit(_f); } template -inline auto visit(const F& _f, const Variant& _v) { +inline internal::variant::result_t visit( + const F& _f, const Variant& _v) { return _v.visit(_f); } -template -inline auto visit(F& _f, TaggedUnion<_discriminator, Args...>& _tagged_union) { +template +inline internal::variant::result_t visit( + F& _f, TaggedUnion<_discriminator, AlternativeTypes...>& _tagged_union) { return _tagged_union.variant().visit(_f); } -template -inline auto visit(F& _f, TaggedUnion<_discriminator, Args...>&& _tagged_union) { +template +inline internal::variant::result_t visit( + F& _f, TaggedUnion<_discriminator, AlternativeTypes...>&& _tagged_union) { return _tagged_union.variant().visit(_f); } -template -inline auto visit(F& _f, - const TaggedUnion<_discriminator, Args...>& _tagged_union) { +template +inline internal::variant::result_t visit( + F& _f, + const TaggedUnion<_discriminator, AlternativeTypes...>& _tagged_union) { return _tagged_union.variant().visit(_f); } -template -inline auto visit(const F& _f, - TaggedUnion<_discriminator, Args...>& _tagged_union) { +template +inline internal::variant::result_t visit( + const F& _f, + TaggedUnion<_discriminator, AlternativeTypes...>& _tagged_union) { return _tagged_union.variant().visit(_f); } -template -inline auto visit(const F& _f, - TaggedUnion<_discriminator, Args...>&& _tagged_union) { +template +inline internal::variant::result_t visit( + const F& _f, + TaggedUnion<_discriminator, AlternativeTypes...>&& _tagged_union) { return _tagged_union.variant().visit(_f); } -template -inline auto visit(const F& _f, - const TaggedUnion<_discriminator, Args...>& _tagged_union) { +template +inline internal::variant::result_t visit( + const F& _f, + const TaggedUnion<_discriminator, AlternativeTypes...>& _tagged_union) { return _tagged_union.variant().visit(_f); } diff --git a/tests/json/test_rfl_variant_visit_lvalues.cpp b/tests/json/test_rfl_variant_visit_lvalues.cpp new file mode 100644 index 00000000..3eecc815 --- /dev/null +++ b/tests/json/test_rfl_variant_visit_lvalues.cpp @@ -0,0 +1,30 @@ +#include +#include +#include +#include +#include +#include + +#include "write_and_read.hpp" + +namespace test_rfl_variant_visit_lvalues { + +struct Rectangle { + double height; + double width; +}; + +struct Square { + double width; +}; + +using Shapes = rfl::Variant; + +TEST(json, test_rfl_variant) { + const Shapes r = Rectangle{.height = 10, .width = 5}; + const auto get_width = [](const auto& _s) -> const double& { + return _s.width; + }; + EXPECT_EQ(rfl::visit(get_width, r), 5); +} +} // namespace test_rfl_variant_visit_lvalues diff --git a/tests/json/test_rfl_variant_visit_move_only.cpp b/tests/json/test_rfl_variant_visit_move_only.cpp new file mode 100644 index 00000000..5418ddec --- /dev/null +++ b/tests/json/test_rfl_variant_visit_move_only.cpp @@ -0,0 +1,32 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "write_and_read.hpp" + +namespace test_rfl_variant_visit_move_only { + +struct Rectangle { + double height; + std::unique_ptr width; +}; + +struct Square { + std::unique_ptr width; +}; + +using Shapes = rfl::Variant; + +TEST(json, test_rfl_variant_visit_move_only) { + const Shapes r = + Rectangle{.height = 10, .width = std::make_unique(5.0)}; + const auto get_width = [](const auto& _s) -> const std::unique_ptr& { + return _s.width; + }; + EXPECT_EQ(*rfl::visit(get_width, r), 5.0); +} +} // namespace test_rfl_variant_visit_move_only