diff --git a/orc-rt/include/orc-rt/SimplePackedSerialization.h b/orc-rt/include/orc-rt/SimplePackedSerialization.h index b01e017d1529b..95e10bc437cfb 100644 --- a/orc-rt/include/orc-rt/SimplePackedSerialization.h +++ b/orc-rt/include/orc-rt/SimplePackedSerialization.h @@ -513,12 +513,6 @@ template <> class SPSSerializationTraits { /// SPS tag type for errors. class SPSError; -/// SPS tag type for expecteds, which are either a T or a string representing -/// an error. -template class SPSExpected; - -namespace detail { - /// Helper type for serializing Errors. /// /// llvm::Errors are move-only, and not inspectable except by consuming them. @@ -529,139 +523,141 @@ namespace detail { /// The SPSSerializableError type is a helper that can be /// constructed from an llvm::Error, but inspected more than once. struct SPSSerializableError { - bool HasError = false; - std::string ErrMsg; -}; - -/// Helper type for serializing Expecteds. -/// -/// See SPSSerializableError for more details. -/// -// FIXME: Use std::variant for storage once we have c++17. -template struct SPSSerializableExpected { - bool HasValue = false; - T Value{}; - std::string ErrMsg; -}; - -inline SPSSerializableError toSPSSerializable(Error Err) { - if (Err) - return {true, toString(std::move(Err))}; - return {false, {}}; -} - -inline Error fromSPSSerializable(SPSSerializableError BSE) { - if (BSE.HasError) - return make_error(BSE.ErrMsg); - return Error::success(); -} - -template -SPSSerializableExpected toSPSSerializable(Expected E) { - if (E) - return {true, std::move(*E), {}}; - else - return {false, {}, toString(E.takeError())}; -} + SPSSerializableError() = default; + SPSSerializableError(Error Err) { + if (Err) + Msg = toString(std::move(Err)); + } -template -Expected fromSPSSerializable(SPSSerializableExpected BSE) { - if (BSE.HasValue) - return std::move(BSE.Value); - else - return make_error(BSE.ErrMsg); -} + Error toError() { + if (Msg) + return make_error(std::move(*Msg)); + return Error::success(); + } -} // namespace detail + std::optional Msg; +}; -/// Serialize to a SPSError from a detail::SPSSerializableError. -template <> -class SPSSerializationTraits { +template <> class SPSSerializationTraits { public: - static size_t size(const detail::SPSSerializableError &BSE) { - size_t Size = SPSArgList::size(BSE.HasError); - if (BSE.HasError) - Size += SPSArgList::size(BSE.ErrMsg); - return Size; + static size_t size(const SPSSerializableError &E) { + if (E.Msg) + return SPSArgList::size(true, *E.Msg); + else + return SPSArgList::size(false); } - static bool serialize(SPSOutputBuffer &OB, - const detail::SPSSerializableError &BSE) { - if (!SPSArgList::serialize(OB, BSE.HasError)) + static bool serialize(SPSOutputBuffer &OB, const SPSSerializableError &E) { + if (E.Msg) + return SPSArgList::serialize(OB, true, *E.Msg); + else + return SPSArgList::serialize(OB, false); + } + + static bool deserialize(SPSInputBuffer &IB, SPSSerializableError &E) { + bool HasError = false; + if (!SPSArgList::deserialize(IB, HasError)) return false; - if (BSE.HasError) - if (!SPSArgList::serialize(OB, BSE.ErrMsg)) + if (HasError) { + std::string Msg; + if (!SPSArgList::deserialize(IB, Msg)) return false; + E.Msg = std::move(Msg); + } else + E.Msg = std::nullopt; return true; } +}; - static bool deserialize(SPSInputBuffer &IB, - detail::SPSSerializableError &BSE) { - if (!SPSArgList::deserialize(IB, BSE.HasError)) - return false; +/// SPS tag type for expecteds, which are either a T or a string representing +/// an error. +template class SPSExpected; - if (!BSE.HasError) - return true; +/// Helper type for serializing Expecteds. +/// +/// See SPSSerializableError for more details. +template struct SPSSerializableExpected { + SPSSerializableExpected() = default; + SPSSerializableExpected(Expected E) { + if (E) + Val = decltype(Val)(std::in_place_index<0>, std::move(*E)); + else + Val = decltype(Val)(std::in_place_index<1>, toString(E.takeError())); + } + SPSSerializableExpected(Error E) { + assert(E && "Cannot create Expected from Error::success()"); + Val = decltype(Val)(std::in_place_index<1>, toString(std::move(E))); + } - return SPSArgList::deserialize(IB, BSE.ErrMsg); + Expected toExpected() { + if (Val.index() == 0) + return Expected(std::move(std::get<0>(Val))); + return Expected(make_error(std::move(std::get<1>(Val)))); } + + std::variant Val{std::in_place_index<0>, T()}; }; -/// Serialize to a SPSExpected from a -/// detail::SPSSerializableExpected. +template +SPSSerializableExpected toSPSSerializableExpected(Expected E) { + return std::move(E); +} + +template +SPSSerializableExpected toSPSSerializableExpected(Error E) { + return std::move(E); +} + template -class SPSSerializationTraits, - detail::SPSSerializableExpected> { +class SPSSerializationTraits, SPSSerializableExpected> { public: - static size_t size(const detail::SPSSerializableExpected &BSE) { - size_t Size = SPSArgList::size(BSE.HasValue); - if (BSE.HasValue) - Size += SPSArgList::size(BSE.Value); + static size_t size(const SPSSerializableExpected &E) { + if (E.Val.index() == 0) + return SPSArgList::size(true, std::get<0>(E.Val)); else - Size += SPSArgList::size(BSE.ErrMsg); - return Size; + return SPSArgList::size(false, std::get<1>(E.Val)); } static bool serialize(SPSOutputBuffer &OB, - const detail::SPSSerializableExpected &BSE) { - if (!SPSArgList::serialize(OB, BSE.HasValue)) - return false; - - if (BSE.HasValue) - return SPSArgList::serialize(OB, BSE.Value); - - return SPSArgList::serialize(OB, BSE.ErrMsg); + const SPSSerializableExpected &E) { + if (E.Val.index() == 0) + return SPSArgList::serialize(OB, true, std::get<0>(E.Val)); + else + return SPSArgList::serialize(OB, false, + std::get<1>(E.Val)); } - static bool deserialize(SPSInputBuffer &IB, - detail::SPSSerializableExpected &BSE) { - if (!SPSArgList::deserialize(IB, BSE.HasValue)) + static bool deserialize(SPSInputBuffer &IB, SPSSerializableExpected &E) { + bool HasValue = false; + if (!SPSArgList::deserialize(IB, HasValue)) return false; - - if (BSE.HasValue) - return SPSArgList::deserialize(IB, BSE.Value); - - return SPSArgList::deserialize(IB, BSE.ErrMsg); + if (HasValue) { + T Val; + if (!SPSArgList::deserialize(IB, Val)) + return false; + E.Val = decltype(E.Val){std::in_place_index<0>, std::move(Val)}; + } else { + std::string Msg; + if (!SPSArgList::deserialize(IB, Msg)) + return false; + E.Val = decltype(E.Val){std::in_place_index<1>, std::move(Msg)}; + } + return true; } }; -/// Serialize to a SPSExpected from a detail::SPSSerializableError. +/// Serialize to a SPSExpected from a SPSSerializableError. template -class SPSSerializationTraits, - detail::SPSSerializableError> { +class SPSSerializationTraits, SPSSerializableError> { public: - static size_t size(const detail::SPSSerializableError &BSE) { - assert(BSE.HasError && "Cannot serialize expected from a success value"); - return SPSArgList::size(false) + - SPSArgList::size(BSE.ErrMsg); + static size_t size(const SPSSerializableError &SE) { + assert(SE.Msg && "Cannot serialize expected from a success value"); + return SPSArgList::size(false, *SE.Msg); } - static bool serialize(SPSOutputBuffer &OB, - const detail::SPSSerializableError &BSE) { - assert(BSE.HasError && "Cannot serialize expected from a success value"); - if (!SPSArgList::serialize(OB, false)) - return false; - return SPSArgList::serialize(OB, BSE.ErrMsg); + static bool serialize(SPSOutputBuffer &OB, const SPSSerializableError &SE) { + assert(SE.Msg && "Cannot serialize expected from a success value"); + return SPSArgList::serialize(OB, false, *SE.Msg); } }; @@ -674,9 +670,7 @@ class SPSSerializationTraits, T> { } static bool serialize(SPSOutputBuffer &OB, const T &Value) { - if (!SPSArgList::serialize(OB, true)) - return false; - return SPSArgList::serialize(Value); + return SPSArgList::serialize(OB, true, Value); } }; diff --git a/orc-rt/unittests/SimplePackedSerializationTest.cpp b/orc-rt/unittests/SimplePackedSerializationTest.cpp index 6c58503f0797a..9ccedef69628f 100644 --- a/orc-rt/unittests/SimplePackedSerializationTest.cpp +++ b/orc-rt/unittests/SimplePackedSerializationTest.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "orc-rt/SimplePackedSerialization.h" + #include "SimplePackedSerializationTestUtils.h" #include "gtest/gtest.h" @@ -182,3 +183,77 @@ TEST(SimplePackedSerializationTest, ArgListSerialization) { EXPECT_EQ(Arg2, ArgOut2); EXPECT_EQ(Arg3, ArgOut3); } + +TEST(SimplePackedSerializationTest, SerializeErrorSuccess) { + auto B = spsSerialize>( + SPSSerializableError(Error::success())); + if (!B) { + ADD_FAILURE() << "Unexpected failure to serialize error-success value"; + return; + } + SPSSerializableError SE; + if (!spsDeserialize>(*B, SE)) { + ADD_FAILURE() << "Unexpected failure to deserialize error-success value"; + return; + } + + auto E = SE.toError(); + EXPECT_FALSE(!!E); // Expect non-error, i.e. Error::success(). +} + +TEST(SimplePackedSerializationTest, SerializeErrorFailure) { + auto B = spsSerialize>( + SPSSerializableError(make_error("test error message"))); + if (!B) { + ADD_FAILURE() << "Unexpected failure to serialize error-failure value"; + return; + } + SPSSerializableError SE; + if (!spsDeserialize>(*B, SE)) { + ADD_FAILURE() << "Unexpected failure to deserialize error-failure value"; + return; + } + + EXPECT_EQ(toString(SE.toError()), std::string("test error message")); +} + +TEST(SimplePackedSerializationTest, SerializeExpectedSuccess) { + auto B = spsSerialize>>( + toSPSSerializableExpected(Expected(42U))); + if (!B) { + ADD_FAILURE() << "Unexpected failure to serialize expected-success value"; + return; + } + SPSSerializableExpected SE; + if (!spsDeserialize>>(*B, SE)) { + ADD_FAILURE() << "Unexpected failure to deserialize expected-success value"; + return; + } + + auto E = SE.toExpected(); + if (E) + EXPECT_EQ(*E, 42U); + else + ADD_FAILURE() << "Unexpected failure value"; +} + +TEST(SimplePackedSerializationTest, SerializeExpectedFailure) { + auto B = spsSerialize>>( + toSPSSerializableExpected( + make_error("test error message"))); + if (!B) { + ADD_FAILURE() << "Unexpected failure to serialize expected-failure value"; + return; + } + SPSSerializableExpected SE; + if (!spsDeserialize>>(*B, SE)) { + ADD_FAILURE() << "Unexpected failure to deserialize expected-failure value"; + return; + } + + auto E = SE.toExpected(); + if (E) + ADD_FAILURE() << "Unexpected failure value"; + else + EXPECT_EQ(toString(E.takeError()), std::string("test error message")); +} diff --git a/orc-rt/unittests/SimplePackedSerializationTestUtils.h b/orc-rt/unittests/SimplePackedSerializationTestUtils.h index 5468045f5fbe7..7bfa37b6d4bda 100644 --- a/orc-rt/unittests/SimplePackedSerializationTestUtils.h +++ b/orc-rt/unittests/SimplePackedSerializationTestUtils.h @@ -10,10 +10,29 @@ #define ORC_RT_UNITTEST_SIMPLEPACKEDSERIALIZATIONTESTUTILS_H #include "orc-rt/SimplePackedSerialization.h" +#include "orc-rt/WrapperFunction.h" #include "gtest/gtest.h" +#include + +template +static inline std::optional +spsSerialize(const ArgTs &...Args) { + auto B = orc_rt::WrapperFunctionBuffer::allocate(SPSTraitsT::size(Args...)); + orc_rt::SPSOutputBuffer OB(B.data(), B.size()); + if (!SPSTraitsT::serialize(OB, Args...)) + return std::nullopt; + return B; +} + +template +static bool spsDeserialize(orc_rt::WrapperFunctionBuffer &B, ArgTs &...Args) { + orc_rt::SPSInputBuffer IB(B.data(), B.size()); + return SPSTraitsT::deserialize(IB, Args...); +} + template -static void blobSerializationRoundTrip(const T &Value) { +static inline void blobSerializationRoundTrip(const T &Value) { using BST = orc_rt::SPSSerializationTraits; size_t Size = BST::size(Value);