From 1ea5d81a05c2e19c834ec11a059f1f6486770006 Mon Sep 17 00:00:00 2001 From: Lang Hames Date: Thu, 4 Sep 2025 21:21:40 +1000 Subject: [PATCH] [orc-rt] Introduce WrapperFunction APIs. Introduces the following key APIs: `orc_rt_WrapperFunction` defines the signature of an ORC asynchronous wrapper function: ``` typedef void (*orc_rt_WrapperFunctionReturn)( orc_rt_SessionRef Session, void *CallCtx, orc_rt_WrapperFunctionBuffer ResultBytes); typedef void (*orc_rt_WrapperFunction)(orc_rt_SessionRef Session, void *CallCtx, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes); ``` A wrapper function takes a reference to the session object, a context pointer for the call being made, and a pointer to an orc_rt_WrapperFunctionReturn function that can be used to send the result bytes. The `orc_rt::WrapperFunction` utility simplifies the writing of wrapper functions whose arguments and return values are serialized/deserialized using an abstract serialization utility. The `orc_rt::SPSWrapperFunction` utility provides a specialized version of `orc_rt::WrapperFunction` that uses SPS serialization. --- orc-rt/include/CMakeLists.txt | 2 + orc-rt/include/orc-rt-c/CoreTypes.h | 28 ++++ orc-rt/include/orc-rt-c/WrapperFunction.h | 20 +++ orc-rt/include/orc-rt/SPSWrapperFunction.h | 89 +++++++++++ orc-rt/include/orc-rt/WrapperFunction.h | 160 ++++++++++++++++++++ orc-rt/unittests/CMakeLists.txt | 1 + orc-rt/unittests/SPSWrapperFunctionTest.cpp | 109 +++++++++++++ 7 files changed, 409 insertions(+) create mode 100644 orc-rt/include/orc-rt-c/CoreTypes.h create mode 100644 orc-rt/include/orc-rt/SPSWrapperFunction.h create mode 100644 orc-rt/unittests/SPSWrapperFunctionTest.cpp diff --git a/orc-rt/include/CMakeLists.txt b/orc-rt/include/CMakeLists.txt index 07a7e52061d6c..67fe060c4b25b 100644 --- a/orc-rt/include/CMakeLists.txt +++ b/orc-rt/include/CMakeLists.txt @@ -1,4 +1,5 @@ set(ORC_RT_HEADERS + orc-rt-c/CoreTyspe.h orc-rt-c/ExternC.h orc-rt-c/WrapperFunction.h orc-rt-c/orc-rt.h @@ -13,6 +14,7 @@ set(ORC_RT_HEADERS orc-rt/RTTI.h orc-rt/WrapperFunction.h orc-rt/SimplePackedSerialization.h + orc-rt/SPSWrapperFunction.h orc-rt/bind.h orc-rt/bit.h orc-rt/move_only_function.h diff --git a/orc-rt/include/orc-rt-c/CoreTypes.h b/orc-rt/include/orc-rt-c/CoreTypes.h new file mode 100644 index 0000000000000..9b3fdbea41498 --- /dev/null +++ b/orc-rt/include/orc-rt-c/CoreTypes.h @@ -0,0 +1,28 @@ +/*===-- CoreTypes.h - Essential types for the ORC Runtime C APIs --*- C -*-===*\ +|* *| +|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| +|* Exceptions. *| +|* See https://llvm.org/LICENSE.txt for license information. *| +|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| +|* *| +|*===----------------------------------------------------------------------===*| +|* *| +|* Defines core types for the ORC runtime. *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifndef ORC_RT_C_CORETYPES_H +#define ORC_RT_C_CORETYPES_H + +#include "orc-rt-c/ExternC.h" + +ORC_RT_C_EXTERN_C_BEGIN + +/** + * A reference to an orc_rt::Session instance. + */ +typedef struct orc_rt_OpaqueSession *orc_rt_SessionRef; + +ORC_RT_C_EXTERN_C_END + +#endif /* ORC_RT_C_CORETYPES_H */ diff --git a/orc-rt/include/orc-rt-c/WrapperFunction.h b/orc-rt/include/orc-rt-c/WrapperFunction.h index b7dbc16978233..34bcdeffef9ee 100644 --- a/orc-rt/include/orc-rt-c/WrapperFunction.h +++ b/orc-rt/include/orc-rt-c/WrapperFunction.h @@ -14,6 +14,7 @@ #ifndef ORC_RT_C_WRAPPERFUNCTION_H #define ORC_RT_C_WRAPPERFUNCTION_H +#include "orc-rt-c/CoreTypes.h" #include "orc-rt-c/ExternC.h" #include @@ -49,6 +50,25 @@ typedef struct { size_t Size; } orc_rt_WrapperFunctionBuffer; +/** + * Asynchronous return function for an orc-rt wrapper function. + */ +typedef void (*orc_rt_WrapperFunctionReturn)( + orc_rt_SessionRef Session, void *CallCtx, + orc_rt_WrapperFunctionBuffer ResultBytes); + +/** + * orc-rt wrapper function prototype. + * + * ArgBytes contains the serialized arguments for the wrapper function. + * Session holds a reference to the session object. + * CallCtx holds a pointer to the context object for this particular call. + * Return holds a pointer to the return function. + */ +typedef void (*orc_rt_WrapperFunction)(orc_rt_SessionRef Session, void *CallCtx, + orc_rt_WrapperFunctionReturn Return, + orc_rt_WrapperFunctionBuffer ArgBytes); + /** * Zero-initialize an orc_rt_WrapperFunctionBuffer. */ diff --git a/orc-rt/include/orc-rt/SPSWrapperFunction.h b/orc-rt/include/orc-rt/SPSWrapperFunction.h new file mode 100644 index 0000000000000..d08176f676289 --- /dev/null +++ b/orc-rt/include/orc-rt/SPSWrapperFunction.h @@ -0,0 +1,89 @@ +//===--- SPSWrapperFunction.h -- SPS-serializing Wrapper utls ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities for calling / handling wrapper functions that use SPS +// serialization. +// +//===----------------------------------------------------------------------===// + +#ifndef ORC_RT_SPSWRAPPERFUNCTION_H +#define ORC_RT_SPSWRAPPERFUNCTION_H + +#include "orc-rt/SimplePackedSerialization.h" +#include "orc-rt/WrapperFunction.h" + +namespace orc_rt { +namespace detail { + +template struct WFSPSSerializer { + template + std::optional operator()(const ArgTs &...Args) { + auto R = + WrapperFunctionBuffer::allocate(SPSArgList::size(Args...)); + SPSOutputBuffer OB(R.data(), R.size()); + if (!SPSArgList::serialize(OB, Args...)) + return std::nullopt; + return std::move(R); + } +}; + +template struct WFSPSDeserializer { + template + bool operator()(WrapperFunctionBuffer &ArgBytes, ArgTs &...Args) { + assert(!ArgBytes.getOutOfBandError() && + "Should not attempt to deserialize out-of-band error"); + SPSInputBuffer IB(ArgBytes.data(), ArgBytes.size()); + return SPSArgList::deserialize(IB, Args...); + } +}; + +} // namespace detail + +template struct WrapperFunctionSPSSerializer; + +template +struct WrapperFunctionSPSSerializer { + static detail::WFSPSSerializer argumentSerializer() noexcept { + return {}; + } + static detail::WFSPSDeserializer + argumentDeserializer() noexcept { + return {}; + } + static detail::WFSPSSerializer resultSerializer() noexcept { + return {}; + } + static detail::WFSPSDeserializer resultDeserializer() noexcept { + return {}; + } +}; + +/// Provides call and handle utilities to simplify writing and invocation of +/// wrapper functions that use SimplePackedSerialization to serialize and +/// deserialize their arguments and return values. +template struct SPSWrapperFunction { + template + static void call(Caller &&C, ResultHandler &&RH, ArgTs &&...Args) { + WrapperFunction::call( + std::forward(C), WrapperFunctionSPSSerializer(), + std::forward(RH), std::forward(Args)...); + } + + template + static void handle(orc_rt_SessionRef Session, void *CallCtx, + orc_rt_WrapperFunctionReturn Return, + WrapperFunctionBuffer ArgBytes, Handler &&H) { + WrapperFunction::handle(Session, CallCtx, Return, std::move(ArgBytes), + WrapperFunctionSPSSerializer(), + std::forward(H)); + } +}; + +} // namespace orc_rt + +#endif // ORC_RT_SPSWRAPPERFUNCTION_H diff --git a/orc-rt/include/orc-rt/WrapperFunction.h b/orc-rt/include/orc-rt/WrapperFunction.h index eb64cf64450e7..24b149cbe15f3 100644 --- a/orc-rt/include/orc-rt/WrapperFunction.h +++ b/orc-rt/include/orc-rt/WrapperFunction.h @@ -14,6 +14,8 @@ #define ORC_RT_WRAPPERFUNCTION_H #include "orc-rt-c/WrapperFunction.h" +#include "orc-rt/Error.h" +#include "orc-rt/bind.h" #include @@ -98,6 +100,164 @@ class WrapperFunctionBuffer { orc_rt_WrapperFunctionBuffer B; }; +namespace detail { + +template +struct WFCallableTraits + : public WFCallableTraits< + decltype(&std::remove_cv_t>::operator())> { +}; + +template struct WFCallableTraits { + typedef void HeadArgType; +}; + +template +struct WFCallableTraits { + typedef ArgT HeadArgType; + typedef std::tuple TailArgTuple; +}; + +template +struct WFCallableTraits + : public WFCallableTraits {}; + +template +struct WFCallableTraits + : public WFCallableTraits {}; + +template class StructuredYieldBase { +public: + StructuredYieldBase(orc_rt_SessionRef Session, void *CallCtx, + orc_rt_WrapperFunctionReturn Return, Serializer &&S) + : Session(Session), CallCtx(CallCtx), Return(Return), + S(std::forward(S)) {} + +protected: + orc_rt_SessionRef Session; + void *CallCtx; + orc_rt_WrapperFunctionReturn Return; + std::decay_t S; +}; + +template +class StructuredYield : public StructuredYieldBase { +public: + using StructuredYieldBase::StructuredYieldBase; + void operator()(RetT &&R) { + if (auto ResultBytes = this->S.resultSerializer()(std::forward(R))) + this->Return(this->Session, this->CallCtx, ResultBytes->release()); + else + this->Return(this->Session, this->CallCtx, + WrapperFunctionBuffer::createOutOfBandError( + "Could not serialize wrapper function result data") + .release()); + } +}; + +template +class StructuredYield + : public StructuredYieldBase { +public: + using StructuredYieldBase::StructuredYieldBase; + void operator()() { + this->Return(this->Session, this->CallCtx, + WrapperFunctionBuffer().release()); + } +}; + +template struct ResultDeserializer; + +template +struct ResultDeserializer, Serializer> { + static Expected deserialize(WrapperFunctionBuffer ResultBytes, + Serializer &S) { + T Val; + if (S.resultDeserializer()(ResultBytes, Val)) + return std::move(Val); + else + return make_error("Could not deserialize result"); + } +}; + +template struct ResultDeserializer { + static Error deserialize(WrapperFunctionBuffer ResultBytes, Serializer &S) { + assert(ResultBytes.empty()); + return Error::success(); + } +}; + +} // namespace detail + +/// Provides call and handle utilities to simplify writing and invocation of +/// wrapper functions in C++. +struct WrapperFunction { + + /// Make a call to a wrapper function. + /// + /// This utility serializes and deserializes arguments and return values + /// (using the given Serializer), and calls the wrapper function via the + /// given Caller object. + template + static void call(Caller &&C, Serializer &&S, ResultHandler &&RH, + ArgTs &&...Args) { + typedef detail::WFCallableTraits ResultHandlerTraits; + static_assert( + std::tuple_size_v == 0, + "Expected one argument to result-handler"); + typedef typename ResultHandlerTraits::HeadArgType ResultType; + + if (auto ArgBytes = S.argumentSerializer()(std::forward(Args)...)) { + C( + [RH = std::move(RH), + S = std::move(S)](orc_rt_SessionRef Session, + WrapperFunctionBuffer ResultBytes) mutable { + if (const char *ErrMsg = ResultBytes.getOutOfBandError()) + RH(make_error(ErrMsg)); + else + RH(detail::ResultDeserializer< + ResultType, Serializer>::deserialize(std::move(ResultBytes), + S)); + }, + std::move(*ArgBytes)); + } else + RH(make_error( + "Could not serialize wrapper function call arguments")); + } + + /// Simplifies implementation of wrapper functions in C++. + /// + /// This utility deserializes and serializes arguments and return values + /// (using the given Serializer), and calls the given handler. + template + static void handle(orc_rt_SessionRef Session, void *CallCtx, + orc_rt_WrapperFunctionReturn Return, + WrapperFunctionBuffer ArgBytes, Serializer &&S, + Handler &&H) { + typedef detail::WFCallableTraits HandlerTraits; + typedef typename HandlerTraits::HeadArgType Yield; + typedef typename HandlerTraits::TailArgTuple ArgTuple; + typedef typename detail::WFCallableTraits::HeadArgType RetType; + + if (ArgBytes.getOutOfBandError()) + return Return(Session, CallCtx, ArgBytes.release()); + + ArgTuple Args; + if (std::apply(bind_front(S.argumentDeserializer(), std::move(ArgBytes)), + Args)) + std::apply(bind_front(std::forward(H), + detail::StructuredYield( + Session, CallCtx, Return, std::move(S))), + std::move(Args)); + else + Return(Session, CallCtx, + WrapperFunctionBuffer::createOutOfBandError( + "Could not deserialize wrapper function arg data") + .release()); + } +}; + } // namespace orc_rt #endif // ORC_RT_WRAPPERFUNCTION_H diff --git a/orc-rt/unittests/CMakeLists.txt b/orc-rt/unittests/CMakeLists.txt index 55e089a539725..7bf53ca9826e2 100644 --- a/orc-rt/unittests/CMakeLists.txt +++ b/orc-rt/unittests/CMakeLists.txt @@ -22,6 +22,7 @@ add_orc_rt_unittest(CoreTests MemoryFlagsTest.cpp RTTITest.cpp SimplePackedSerializationTest.cpp + SPSWrapperFunctionTest.cpp WrapperFunctionBufferTest.cpp bind-test.cpp bit-test.cpp diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp new file mode 100644 index 0000000000000..919ec2cebd69b --- /dev/null +++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp @@ -0,0 +1,109 @@ +//===-- SPSWrapperFunctionTest.cpp ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Test SPSWrapperFunction and associated utilities. +// +//===----------------------------------------------------------------------===// + +#include "orc-rt/SPSWrapperFunction.h" +#include "orc-rt/WrapperFunction.h" +#include "orc-rt/move_only_function.h" + +#include "gtest/gtest.h" + +using namespace orc_rt; + +/// Make calls and call result handlers directly on the current thread. +class DirectCaller { +private: + class DirectResultSender { + public: + virtual ~DirectResultSender() {} + virtual void send(orc_rt_SessionRef Session, + WrapperFunctionBuffer ResultBytes) = 0; + static void send(orc_rt_SessionRef Session, void *CallCtx, + orc_rt_WrapperFunctionBuffer ResultBytes) { + std::unique_ptr( + reinterpret_cast(CallCtx)) + ->send(Session, ResultBytes); + } + }; + + template + class DirectResultSenderImpl : public DirectResultSender { + public: + DirectResultSenderImpl(ImplFn &&Fn) : Fn(std::forward(Fn)) {} + void send(orc_rt_SessionRef Session, + WrapperFunctionBuffer ResultBytes) override { + Fn(Session, std::move(ResultBytes)); + } + + private: + std::decay_t Fn; + }; + + template + static std::unique_ptr + makeDirectResultSender(ImplFn &&Fn) { + return std::make_unique>( + std::forward(Fn)); + } + +public: + DirectCaller(orc_rt_SessionRef Session, orc_rt_WrapperFunction Fn) + : Session(Session), Fn(Fn) {} + + template + void operator()(HandleResultFn &&HandleResult, + WrapperFunctionBuffer ArgBytes) { + auto DR = + makeDirectResultSender(std::forward(HandleResult)); + Fn(Session, reinterpret_cast(DR.release()), + DirectResultSender::send, ArgBytes.release()); + } + +private: + orc_rt_SessionRef Session; + orc_rt_WrapperFunction Fn; +}; + +static void void_noop_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, + orc_rt_WrapperFunctionReturn Return, + orc_rt_WrapperFunctionBuffer ArgBytes) { + SPSWrapperFunction::handle( + Session, CallCtx, Return, ArgBytes, + [](move_only_function Return) { Return(); }); +} + +TEST(SPSWrapperFunctionUtilsTest, TestVoidNoop) { + bool Ran = false; + SPSWrapperFunction::call(DirectCaller(nullptr, void_noop_sps_wrapper), + [&](Error Err) { + cantFail(std::move(Err)); + Ran = true; + }); + EXPECT_TRUE(Ran); +} + +static void add_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx, + orc_rt_WrapperFunctionReturn Return, + orc_rt_WrapperFunctionBuffer ArgBytes) { + SPSWrapperFunction::handle( + Session, CallCtx, Return, ArgBytes, + [](move_only_function Return, int32_t X, int32_t Y) { + Return(X + Y); + }); +} + +TEST(SPSWrapperFunctionUtilsTest, TestAdd) { + int32_t Result = 0; + SPSWrapperFunction::call( + DirectCaller(nullptr, add_sps_wrapper), + [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); + EXPECT_EQ(Result, 42); +}