Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions orc-rt/include/orc-rt/SPSWrapperFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ template <typename... SPSArgTs> struct WFSPSHelper {
template <typename... Ts>
using DeserializableTuple_t = typename DeserializableTuple<Ts...>::type;

template <typename T> static T fromSerializable(T &&Arg) noexcept {
return Arg;
template <typename T> static T &&fromSerializable(T &&Arg) noexcept {
return std::forward<T>(Arg);
}

static Error fromSerializable(SPSSerializableError Err) noexcept {
Expand Down Expand Up @@ -86,7 +86,10 @@ template <typename... SPSArgTs> struct WFSPSHelper {
decltype(Args)>::deserialize(IB, Args))
return std::nullopt;
return std::apply(
[](auto &&...A) { return ArgTuple(fromSerializable(A)...); },
[](auto &&...A) {
return std::optional<ArgTuple>(std::in_place,
std::move(fromSerializable(A))...);
},
std::move(Args));
}
};
Expand Down
27 changes: 22 additions & 5 deletions orc-rt/include/orc-rt/WrapperFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,23 @@ struct WFHandlerTraitsImpl {
static_assert(std::is_void_v<RetT>,
"Async wrapper function handler must return void");
typedef ReturnT YieldType;
typedef std::tuple<ArgTs...> ArgTupleType;
typedef std::tuple<std::decay_t<ArgTs>...> ArgTupleType;

// Forwards arguments based on the parameter types of the handler.
template <typename FnT> class ForwardArgsAsRequested {
public:
ForwardArgsAsRequested(FnT &&Fn) : Fn(std::move(Fn)) {}
void operator()(ArgTs &...Args) { Fn(std::forward<ArgTs>(Args)...); }

private:
FnT Fn;
};

template <typename FnT>
static ForwardArgsAsRequested<std::decay_t<FnT>>
forwardArgsAsRequested(FnT &&Fn) {
return ForwardArgsAsRequested<std::decay_t<FnT>>(std::forward<FnT>(Fn));
}
};

template <typename C>
Expand Down Expand Up @@ -244,10 +260,11 @@ struct WrapperFunction {

if (auto Args =
S.arguments().template deserialize<ArgTuple>(std::move(ArgBytes)))
std::apply(bind_front(std::forward<Handler>(H),
detail::StructuredYield<RetTupleType, Serializer>(
Session, CallCtx, Return, std::move(S))),
std::move(*Args));
std::apply(HandlerTraits::forwardArgsAsRequested(bind_front(
std::forward<Handler>(H),
detail::StructuredYield<RetTupleType, Serializer>(
Session, CallCtx, Return, std::move(S)))),
*Args);
else
Return(Session, CallCtx,
WrapperFunctionBuffer::createOutOfBandError(
Expand Down
79 changes: 79 additions & 0 deletions orc-rt/unittests/SPSWrapperFunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
//
//===----------------------------------------------------------------------===//

#include "CommonTestUtils.h"

#include "orc-rt/SPSWrapperFunction.h"
#include "orc-rt/WrapperFunction.h"
#include "orc-rt/move_only_function.h"
Expand Down Expand Up @@ -218,3 +220,80 @@ TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedFailureCase) {

EXPECT_EQ(ErrMsg, "N is not a multiple of 2");
}

template <size_t N> struct SPSOpCounter {};

namespace orc_rt {
template <size_t N>
class SPSSerializationTraits<SPSOpCounter<N>, OpCounter<N>> {
public:
static size_t size(const OpCounter<N> &O) { return 0; }
static bool serialize(SPSOutputBuffer &OB, const OpCounter<N> &O) {
return true;
}
static bool deserialize(SPSInputBuffer &OB, OpCounter<N> &O) { return true; }
};
} // namespace orc_rt

static void
handle_with_reference_types_sps_wrapper(orc_rt_SessionRef Session,
void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
SPSWrapperFunction<void(
SPSOpCounter<0>, SPSOpCounter<1>, SPSOpCounter<2>,
SPSOpCounter<3>)>::handle(Session, CallCtx, Return, ArgBytes,
[](move_only_function<void()> Return,
OpCounter<0>, OpCounter<1> &,
const OpCounter<2> &,
OpCounter<3> &&) { Return(); });
}

TEST(SPSWrapperFunctionUtilsTest, TestHandlerWithReferences) {
// Test that we can handle by-value, by-ref, by-const-ref, and by-rvalue-ref
// arguments, and that we generate the expected number of moves.
OpCounter<0>::reset();
OpCounter<1>::reset();
OpCounter<2>::reset();
OpCounter<3>::reset();

bool DidRun = false;
SPSWrapperFunction<void(SPSOpCounter<0>, SPSOpCounter<1>, SPSOpCounter<2>,
SPSOpCounter<3>)>::
call(
DirectCaller(nullptr, handle_with_reference_types_sps_wrapper),
[&](Error R) {
cantFail(std::move(R));
DidRun = true;
},
OpCounter<0>(), OpCounter<1>(), OpCounter<2>(), OpCounter<3>());

EXPECT_TRUE(DidRun);

// We expect two default constructions for each parameter: one for the
// argument to call, and one for the object to deserialize into.
EXPECT_EQ(OpCounter<0>::defaultConstructions(), 2U);
EXPECT_EQ(OpCounter<1>::defaultConstructions(), 2U);
EXPECT_EQ(OpCounter<2>::defaultConstructions(), 2U);
EXPECT_EQ(OpCounter<3>::defaultConstructions(), 2U);

// Pass-by-value: we expect two moves (one for SPS transparent conversion,
// one to copy the value to the parameter), and no copies.
EXPECT_EQ(OpCounter<0>::moves(), 2U);
EXPECT_EQ(OpCounter<0>::copies(), 0U);

// Pass-by-lvalue-reference: we expect one move (for SPS transparent
// conversion), no copies.
EXPECT_EQ(OpCounter<1>::moves(), 1U);
EXPECT_EQ(OpCounter<1>::copies(), 0U);

// Pass-by-const-lvalue-reference: we expect one move (for SPS transparent
// conversion), no copies.
EXPECT_EQ(OpCounter<2>::moves(), 1U);
EXPECT_EQ(OpCounter<2>::copies(), 0U);

// Pass-by-rvalue-reference: we expect one move (for SPS transparent
// conversion), no copies.
EXPECT_EQ(OpCounter<3>::moves(), 1U);
EXPECT_EQ(OpCounter<3>::copies(), 0U);
}
Loading