Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions orc-rt/include/orc-rt-c/WrapperFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ typedef struct {
* Asynchronous return function for an orc-rt wrapper function.
*/
typedef void (*orc_rt_WrapperFunctionReturn)(
orc_rt_SessionRef Session, uint64_t CallId,
orc_rt_SessionRef S, uint64_t CallId,
orc_rt_WrapperFunctionBuffer ResultBytes);

/**
Expand All @@ -65,8 +65,7 @@ typedef void (*orc_rt_WrapperFunctionReturn)(
* CallId holds a pointer to the context object for this particular call.
* Return holds a pointer to the return function.
*/
typedef void (*orc_rt_WrapperFunction)(orc_rt_SessionRef Session,
uint64_t CallId,
typedef void (*orc_rt_WrapperFunction)(orc_rt_SessionRef S, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes);

Expand Down
4 changes: 2 additions & 2 deletions orc-rt/include/orc-rt/SPSWrapperFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ template <typename SPSSig> struct SPSWrapperFunction {
}

template <typename Handler>
static void handle(orc_rt_SessionRef Session, uint64_t CallId,
static void handle(orc_rt_SessionRef S, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return,
WrapperFunctionBuffer ArgBytes, Handler &&H) {
WrapperFunction::handle(Session, CallId, Return, std::move(ArgBytes),
WrapperFunction::handle(S, CallId, Return, std::move(ArgBytes),
WrapperFunctionSPSSerializer<SPSSig>(),
std::forward<Handler>(H));
}
Expand Down
16 changes: 8 additions & 8 deletions orc-rt/include/orc-rt/SimpleNativeMemoryMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,21 @@ class SimpleNativeMemoryMap : public ResourceManager {
} // namespace orc_rt

ORC_RT_SPS_INTERFACE void orc_rt_SimpleNativeMemoryMap_reserve_sps_wrapper(
orc_rt_SessionRef Session, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes);
orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes);

ORC_RT_SPS_INTERFACE void
orc_rt_SimpleNativeMemoryMap_releaseMultiple_sps_wrapper(
orc_rt_SessionRef Session, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes);
orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes);

ORC_RT_SPS_INTERFACE void orc_rt_SimpleNativeMemoryMap_initialize_sps_wrapper(
orc_rt_SessionRef Session, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes);
orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes);

ORC_RT_SPS_INTERFACE void
orc_rt_SimpleNativeMemoryMap_deinitializeMultiple_sps_wrapper(
orc_rt_SessionRef Session, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes);
orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes);

#endif // ORC_RT_SIMPLENATIVEMEMORYMAP_H
51 changes: 24 additions & 27 deletions orc-rt/include/orc-rt/WrapperFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,15 @@ using WFHandlerTraits = CallableTraitsHelper<WFHandlerTraitsImpl, C>;

template <typename Serializer> class StructuredYieldBase {
public:
StructuredYieldBase(orc_rt_SessionRef Session, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return, Serializer &&S)
: Session(Session), CallId(CallId), Return(Return),
S(std::forward<Serializer>(S)) {}
StructuredYieldBase(orc_rt_SessionRef S, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return, Serializer &&Z)
: S(S), CallId(CallId), Return(Return), Z(std::forward<Serializer>(Z)) {}

protected:
orc_rt_SessionRef Session;
orc_rt_SessionRef S;
uint64_t CallId;
orc_rt_WrapperFunctionReturn Return;
std::decay_t<Serializer> S;
std::decay_t<Serializer> Z;
};

template <typename RetT, typename Serializer> class StructuredYield;
Expand All @@ -157,10 +156,10 @@ class StructuredYield<std::tuple<RetT>, Serializer>
public:
using StructuredYieldBase<Serializer>::StructuredYieldBase;
void operator()(RetT &&R) {
if (auto ResultBytes = this->S.result().serialize(std::forward<RetT>(R)))
this->Return(this->Session, this->CallId, ResultBytes->release());
if (auto ResultBytes = this->Z.result().serialize(std::forward<RetT>(R)))
this->Return(this->S, this->CallId, ResultBytes->release());
else
this->Return(this->Session, this->CallId,
this->Return(this->S, this->CallId,
WrapperFunctionBuffer::createOutOfBandError(
"Could not serialize wrapper function result data")
.release());
Expand All @@ -173,8 +172,7 @@ class StructuredYield<std::tuple<>, Serializer>
public:
using StructuredYieldBase<Serializer>::StructuredYieldBase;
void operator()() {
this->Return(this->Session, this->CallId,
WrapperFunctionBuffer().release());
this->Return(this->S, this->CallId, WrapperFunctionBuffer().release());
}
};

Expand Down Expand Up @@ -251,12 +249,12 @@ struct WrapperFunction {
///
///
/// static void adder_add_async_sps_wrapper(
/// orc_rt_SessionRef Session, uint64_t CallId,
/// orc_rt_SessionRef S, uint64_t CallId,
/// orc_rt_WrapperFunctionReturn Return,
/// orc_rt_WrapperFunctionBuffer ArgBytes) {
/// using SPSSig = SPSString(SPSExecutorAddr, int32_t, bool);
/// SPSWrapperFunction<SPSSig>::handle(
/// Session, CallId, Return, ArgBytes,
/// S, CallId, Return, ArgBytes,
/// WrapperFunction::handleWithAsyncMethod(&MyClass::myMethod));
/// }
/// @endcode
Expand Down Expand Up @@ -313,12 +311,12 @@ struct WrapperFunction {
///
///
/// static void adder_add_sync_sps_wrapper(
/// orc_rt_SessionRef Session, uint64_t CallId,
/// orc_rt_SessionRef S, uint64_t CallId,
/// orc_rt_WrapperFunctionReturn Return,
/// orc_rt_WrapperFunctionBuffer ArgBytes) {
/// using SPSSig = SPSString(SPSExecutorAddr, int32_t, bool);
/// SPSWrapperFunction<SPSSig>::handle(
/// Session, CallId, Return, ArgBytes,
/// S, CallId, Return, ArgBytes,
/// WrapperFunction::handleWithSyncMethod(&Adder::addSync));
/// }
/// @endcode
Expand All @@ -336,7 +334,7 @@ struct WrapperFunction {
/// given Caller object.
template <typename Caller, typename Serializer, typename ResultHandler,
typename... ArgTs>
static void call(Caller &&C, Serializer &&S, ResultHandler &&RH,
static void call(Caller &&C, Serializer &&Z, ResultHandler &&RH,
ArgTs &&...Args) {
typedef CallableArgInfo<ResultHandler> ResultHandlerTraits;
static_assert(std::is_void_v<typename ResultHandlerTraits::return_type>,
Expand All @@ -346,16 +344,15 @@ struct WrapperFunction {
"Result-handler should have exactly one argument");
typedef typename ResultHandlerTraits::args_tuple_type ResultTupleType;

if (auto ArgBytes = S.arguments().serialize(std::forward<ArgTs>(Args)...)) {
if (auto ArgBytes = Z.arguments().serialize(std::forward<ArgTs>(Args)...)) {
C(
[RH = std::move(RH),
S = std::move(S)](orc_rt_SessionRef Session,
WrapperFunctionBuffer ResultBytes) mutable {
[RH = std::move(RH), Z = std::move(Z)](
orc_rt_SessionRef S, WrapperFunctionBuffer ResultBytes) mutable {
if (const char *ErrMsg = ResultBytes.getOutOfBandError())
RH(make_error<StringError>(ErrMsg));
else
RH(detail::ResultDeserializer<ResultTupleType, Serializer>::
deserialize(std::move(ResultBytes), S));
deserialize(std::move(ResultBytes), Z));
},
std::move(*ArgBytes));
} else
Expand All @@ -368,9 +365,9 @@ struct WrapperFunction {
/// This utility deserializes and serializes arguments and return values
/// (using the given Serializer), and calls the given handler.
template <typename Serializer, typename Handler>
static void handle(orc_rt_SessionRef Session, uint64_t CallId,
static void handle(orc_rt_SessionRef S, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return,
WrapperFunctionBuffer ArgBytes, Serializer &&S,
WrapperFunctionBuffer ArgBytes, Serializer &&Z,
Handler &&H) {
typedef detail::WFHandlerTraits<Handler> HandlerTraits;
typedef typename HandlerTraits::ArgTupleType ArgTuple;
Expand All @@ -380,16 +377,16 @@ struct WrapperFunction {
typedef typename CallableArgInfo<Yield>::args_tuple_type RetTupleType;

if (ArgBytes.getOutOfBandError())
return Return(Session, CallId, ArgBytes.release());
return Return(S, CallId, ArgBytes.release());

if (auto Args = S.arguments().template deserialize<ArgTuple>(ArgBytes))
if (auto Args = Z.arguments().template deserialize<ArgTuple>(ArgBytes))
std::apply(HandlerTraits::forwardArgsAsRequested(bind_front(
std::forward<Handler>(H),
detail::StructuredYield<RetTupleType, Serializer>(
Session, CallId, Return, std::move(S)))),
S, CallId, Return, std::move(Z)))),
*Args);
else
Return(Session, CallId,
Return(S, CallId,
WrapperFunctionBuffer::createOutOfBandError(
"Could not deserialize wrapper function arg data")
.release());
Expand Down
20 changes: 8 additions & 12 deletions orc-rt/lib/executor/SimpleNativeMemoryMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,45 +367,41 @@ Error SimpleNativeMemoryMap::recordDeallocActions(
}

ORC_RT_SPS_INTERFACE void orc_rt_SimpleNativeMemoryMap_reserve_sps_wrapper(
orc_rt_SessionRef Session, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return,
orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
using Sig = SPSExpected<SPSExecutorAddr>(SPSExecutorAddr, SPSSize);
SPSWrapperFunction<Sig>::handle(
Session, CallId, Return, ArgBytes,
S, CallId, Return, ArgBytes,
WrapperFunction::handleWithAsyncMethod(&SimpleNativeMemoryMap::reserve));
}

ORC_RT_SPS_INTERFACE void
orc_rt_SimpleNativeMemoryMap_releaseMultiple_sps_wrapper(
orc_rt_SessionRef Session, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return,
orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
using Sig = SPSError(SPSExecutorAddr, SPSSequence<SPSExecutorAddr>);
SPSWrapperFunction<Sig>::handle(Session, CallId, Return, ArgBytes,
SPSWrapperFunction<Sig>::handle(S, CallId, Return, ArgBytes,
WrapperFunction::handleWithAsyncMethod(
&SimpleNativeMemoryMap::releaseMultiple));
}

ORC_RT_SPS_INTERFACE void orc_rt_SimpleNativeMemoryMap_initialize_sps_wrapper(
orc_rt_SessionRef Session, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return,
orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
using Sig = SPSExpected<SPSExecutorAddr>(
SPSExecutorAddr, SPSSimpleNativeMemoryMapInitializeRequest);
SPSWrapperFunction<Sig>::handle(Session, CallId, Return, ArgBytes,
SPSWrapperFunction<Sig>::handle(S, CallId, Return, ArgBytes,
WrapperFunction::handleWithAsyncMethod(
&SimpleNativeMemoryMap::initialize));
}

ORC_RT_SPS_INTERFACE void
orc_rt_SimpleNativeMemoryMap_deinitializeMultiple_sps_wrapper(
orc_rt_SessionRef Session, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return,
orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
using Sig = SPSError(SPSExecutorAddr, SPSSequence<SPSExecutorAddr>);
SPSWrapperFunction<Sig>::handle(
Session, CallId, Return, ArgBytes,
S, CallId, Return, ArgBytes,
WrapperFunction::handleWithAsyncMethod(
&SimpleNativeMemoryMap::deinitializeMultiple));
}
Expand Down
18 changes: 8 additions & 10 deletions orc-rt/unittests/DirectCaller.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,24 @@ class DirectCaller {
class DirectResultSender {
public:
virtual ~DirectResultSender() {}
virtual void send(orc_rt_SessionRef Session,
virtual void send(orc_rt_SessionRef S,
orc_rt::WrapperFunctionBuffer ResultBytes) = 0;
static void send(orc_rt_SessionRef Session, uint64_t CallId,
static void send(orc_rt_SessionRef S, uint64_t CallId,
orc_rt_WrapperFunctionBuffer ResultBytes) {
std::unique_ptr<DirectResultSender>(
reinterpret_cast<DirectResultSender *>(
static_cast<uintptr_t>(CallId)))
->send(Session, ResultBytes);
->send(S, ResultBytes);
}
};

template <typename ImplFn>
class DirectResultSenderImpl : public DirectResultSender {
public:
DirectResultSenderImpl(ImplFn &&Fn) : Fn(std::forward<ImplFn>(Fn)) {}
void send(orc_rt_SessionRef Session,
void send(orc_rt_SessionRef S,
orc_rt::WrapperFunctionBuffer ResultBytes) override {
Fn(Session, std::move(ResultBytes));
Fn(S, std::move(ResultBytes));
}

private:
Expand All @@ -52,21 +52,19 @@ class DirectCaller {
}

public:
DirectCaller(orc_rt_SessionRef Session, orc_rt_WrapperFunction Fn)
: Session(Session), Fn(Fn) {}
DirectCaller(orc_rt_SessionRef S, orc_rt_WrapperFunction Fn) : S(S), Fn(Fn) {}

template <typename HandleResultFn>
void operator()(HandleResultFn &&HandleResult,
orc_rt::WrapperFunctionBuffer ArgBytes) {
auto DR =
makeDirectResultSender(std::forward<HandleResultFn>(HandleResult));
Fn(Session,
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(DR.release())),
Fn(S, static_cast<uint64_t>(reinterpret_cast<uintptr_t>(DR.release())),
DirectResultSender::send, ArgBytes.release());
}

private:
orc_rt_SessionRef Session;
orc_rt_SessionRef S;
orc_rt_WrapperFunction Fn;
};

Expand Down
Loading
Loading