From b918ed917b128c6baa4d33d4a79e9f19f71172cd Mon Sep 17 00:00:00 2001 From: Lang Hames Date: Sun, 5 Oct 2025 16:59:35 +1100 Subject: [PATCH] [orc-rt] WrapperFunction::handle: add by-ref args, minimize temporaries. This adds support for WrapperFunction::handle handlers that take their arguments by reference, rather than by value. This commit also reduces the number of temporary objects created to support SPS-transparent conversion in SPSWrapperFunction. --- orc-rt/include/orc-rt/SPSWrapperFunction.h | 9 ++- orc-rt/include/orc-rt/WrapperFunction.h | 27 +++++-- orc-rt/unittests/SPSWrapperFunctionTest.cpp | 79 +++++++++++++++++++++ 3 files changed, 107 insertions(+), 8 deletions(-) diff --git a/orc-rt/include/orc-rt/SPSWrapperFunction.h b/orc-rt/include/orc-rt/SPSWrapperFunction.h index 14a3d8e3d6ad6..3ed3295731780 100644 --- a/orc-rt/include/orc-rt/SPSWrapperFunction.h +++ b/orc-rt/include/orc-rt/SPSWrapperFunction.h @@ -57,8 +57,8 @@ template struct WFSPSHelper { template using DeserializableTuple_t = typename DeserializableTuple::type; - template static T fromSerializable(T &&Arg) noexcept { - return Arg; + template static T &&fromSerializable(T &&Arg) noexcept { + return std::forward(Arg); } static Error fromSerializable(SPSSerializableError Err) noexcept { @@ -86,7 +86,10 @@ template struct WFSPSHelper { decltype(Args)>::deserialize(IB, Args)) return std::nullopt; return std::apply( - [](auto &&...A) { return ArgTuple(fromSerializable(A)...); }, + [](auto &&...A) { + return std::optional(std::in_place, + std::move(fromSerializable(A))...); + }, std::move(Args)); } }; diff --git a/orc-rt/include/orc-rt/WrapperFunction.h b/orc-rt/include/orc-rt/WrapperFunction.h index ca165db7188b4..47e770f0bfbf7 100644 --- a/orc-rt/include/orc-rt/WrapperFunction.h +++ b/orc-rt/include/orc-rt/WrapperFunction.h @@ -111,7 +111,23 @@ struct WFHandlerTraitsImpl { static_assert(std::is_void_v, "Async wrapper function handler must return void"); typedef ReturnT YieldType; - typedef std::tuple ArgTupleType; + typedef std::tuple...> ArgTupleType; + + // Forwards arguments based on the parameter types of the handler. + template class ForwardArgsAsRequested { + public: + ForwardArgsAsRequested(FnT &&Fn) : Fn(std::move(Fn)) {} + void operator()(ArgTs &...Args) { Fn(std::forward(Args)...); } + + private: + FnT Fn; + }; + + template + static ForwardArgsAsRequested> + forwardArgsAsRequested(FnT &&Fn) { + return ForwardArgsAsRequested>(std::forward(Fn)); + } }; template @@ -244,10 +260,11 @@ struct WrapperFunction { if (auto Args = S.arguments().template deserialize(std::move(ArgBytes))) - std::apply(bind_front(std::forward(H), - detail::StructuredYield( - Session, CallCtx, Return, std::move(S))), - std::move(*Args)); + std::apply(HandlerTraits::forwardArgsAsRequested(bind_front( + std::forward(H), + detail::StructuredYield( + Session, CallCtx, Return, std::move(S)))), + *Args); else Return(Session, CallCtx, WrapperFunctionBuffer::createOutOfBandError( diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp index c0c86ff8715ce..32aaa61639dbb 100644 --- a/orc-rt/unittests/SPSWrapperFunctionTest.cpp +++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp @@ -10,6 +10,8 @@ // //===----------------------------------------------------------------------===// +#include "CommonTestUtils.h" + #include "orc-rt/SPSWrapperFunction.h" #include "orc-rt/WrapperFunction.h" #include "orc-rt/move_only_function.h" @@ -218,3 +220,80 @@ TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedFailureCase) { EXPECT_EQ(ErrMsg, "N is not a multiple of 2"); } + +template struct SPSOpCounter {}; + +namespace orc_rt { +template +class SPSSerializationTraits, OpCounter> { +public: + static size_t size(const OpCounter &O) { return 0; } + static bool serialize(SPSOutputBuffer &OB, const OpCounter &O) { + return true; + } + static bool deserialize(SPSInputBuffer &OB, OpCounter &O) { return true; } +}; +} // namespace orc_rt + +static void +handle_with_reference_types_sps_wrapper(orc_rt_SessionRef Session, + void *CallCtx, + orc_rt_WrapperFunctionReturn Return, + orc_rt_WrapperFunctionBuffer ArgBytes) { + SPSWrapperFunction, SPSOpCounter<1>, SPSOpCounter<2>, + SPSOpCounter<3>)>::handle(Session, CallCtx, Return, ArgBytes, + [](move_only_function Return, + OpCounter<0>, OpCounter<1> &, + const OpCounter<2> &, + OpCounter<3> &&) { Return(); }); +} + +TEST(SPSWrapperFunctionUtilsTest, TestHandlerWithReferences) { + // Test that we can handle by-value, by-ref, by-const-ref, and by-rvalue-ref + // arguments, and that we generate the expected number of moves. + OpCounter<0>::reset(); + OpCounter<1>::reset(); + OpCounter<2>::reset(); + OpCounter<3>::reset(); + + bool DidRun = false; + SPSWrapperFunction, SPSOpCounter<1>, SPSOpCounter<2>, + SPSOpCounter<3>)>:: + call( + DirectCaller(nullptr, handle_with_reference_types_sps_wrapper), + [&](Error R) { + cantFail(std::move(R)); + DidRun = true; + }, + OpCounter<0>(), OpCounter<1>(), OpCounter<2>(), OpCounter<3>()); + + EXPECT_TRUE(DidRun); + + // We expect two default constructions for each parameter: one for the + // argument to call, and one for the object to deserialize into. + EXPECT_EQ(OpCounter<0>::defaultConstructions(), 2U); + EXPECT_EQ(OpCounter<1>::defaultConstructions(), 2U); + EXPECT_EQ(OpCounter<2>::defaultConstructions(), 2U); + EXPECT_EQ(OpCounter<3>::defaultConstructions(), 2U); + + // Pass-by-value: we expect two moves (one for SPS transparent conversion, + // one to copy the value to the parameter), and no copies. + EXPECT_EQ(OpCounter<0>::moves(), 2U); + EXPECT_EQ(OpCounter<0>::copies(), 0U); + + // Pass-by-lvalue-reference: we expect one move (for SPS transparent + // conversion), no copies. + EXPECT_EQ(OpCounter<1>::moves(), 1U); + EXPECT_EQ(OpCounter<1>::copies(), 0U); + + // Pass-by-const-lvalue-reference: we expect one move (for SPS transparent + // conversion), no copies. + EXPECT_EQ(OpCounter<2>::moves(), 1U); + EXPECT_EQ(OpCounter<2>::copies(), 0U); + + // Pass-by-rvalue-reference: we expect one move (for SPS transparent + // conversion), no copies. + EXPECT_EQ(OpCounter<3>::moves(), 1U); + EXPECT_EQ(OpCounter<3>::copies(), 0U); +}