From 34c4757a8db6d67ce1fd6f2011ed26ce983c7180 Mon Sep 17 00:00:00 2001 From: Lang Hames Date: Wed, 26 Nov 2025 10:51:52 +1100 Subject: [PATCH] [orc-rt] Add ControllerAccess interface. ControllerAccess provides an abstract interface for bidirectional RPC between the executor (running JIT'd code) and the controller (containing the llvm::orc::ExecutionSession). ControllerAccess implementations are expected to implement IPC / RPC using a concrete communication method (shared memory, pipes, sockets, native system IPC, etc). Calls from executor to controller are made via callController, with "handler tags" (addresses in the executor) specifying the target handler in the controller. A handler must be associated in the controller with the given tag for the call to succeed. This ensures that only registered entry points in the controller can be used, and avoids leaking controller addresses into the executor. Calls in both directions are to "wrapper functions" that take a buffer of bytes as input and return a buffer of bytes as output. In the ORC runtime these must be `orc_rt_WrapperFunction`s (see Session::handleWrapperCall). The interpretation of the byte buffers is up to the wrapper functions: the ORC runtime imposes no restrictions on how the bytes are to be interpreted. ControllerAccess objects may be detached from the Session prior to Session shutdown, in which case no further calls may be made in either direction, and any pending results (from calls made that haven't returned yet) should return errors. If the ControllerAccess class is still attached at Session shutdown time it will be detached as part of the shutdown process. The ControllerAccess::disconnect method must support concurrent entry on multiple threads, and all callers must block until they can guarantee that no further calls will be received or accepted. --- orc-rt/include/orc-rt/Session.h | 127 +++++++++++++- orc-rt/lib/executor/Session.cpp | 31 ++++ orc-rt/unittests/SessionTest.cpp | 279 +++++++++++++++++++++++++++++++ 3 files changed, 428 insertions(+), 9 deletions(-) diff --git a/orc-rt/include/orc-rt/Session.h b/orc-rt/include/orc-rt/Session.h index 529aac6a2fadd..dffaf6ec3cbd9 100644 --- a/orc-rt/include/orc-rt/Session.h +++ b/orc-rt/include/orc-rt/Session.h @@ -16,10 +16,13 @@ #include "orc-rt/Error.h" #include "orc-rt/ResourceManager.h" #include "orc-rt/TaskDispatcher.h" +#include "orc-rt/WrapperFunction.h" #include "orc-rt/move_only_function.h" #include "orc-rt-c/CoreTypes.h" +#include "orc-rt-c/WrapperFunction.h" +#include #include #include #include @@ -27,12 +30,83 @@ namespace orc_rt { +class Session; + +inline orc_rt_SessionRef wrap(Session *S) noexcept { + return reinterpret_cast(S); +} + +inline Session *unwrap(orc_rt_SessionRef S) noexcept { + return reinterpret_cast(S); +} + /// Represents an ORC executor Session. class Session { public: using ErrorReporterFn = move_only_function; using OnShutdownCompleteFn = move_only_function; + using HandlerTag = void *; + using OnCallHandlerCompleteFn = + move_only_function; + + /// Provides access to the controller. + class ControllerAccess { + friend class Session; + + public: + virtual ~ControllerAccess(); + + protected: + using HandlerTag = Session::HandlerTag; + using OnCallHandlerCompleteFn = Session::OnCallHandlerCompleteFn; + + ControllerAccess(Session &S) : S(&S) {} + + /// Called by the Session to disconnect the session with the Controller. + /// + /// disconnect implementations must support concurrent entry on multiple + /// threads, and all calls must block until the disconnect operation is + /// complete. + /// + /// Once disconnect completes, implementations should make no further + /// calls to the Session, and should ignore any calls from the session + /// (implementations are free to ignore any calls from the Session after + /// disconnect is called). + virtual void disconnect() = 0; + + /// Report an error to the session. + void reportError(Error Err) { + assert(S && "Already disconnected"); + S->reportError(std::move(Err)); + } + + /// Call the handler in the controller associated with the given tag. + virtual void callController(OnCallHandlerCompleteFn OnComplete, + HandlerTag T, + WrapperFunctionBuffer ArgBytes) = 0; + + /// Send the result of the given wrapper function call to the controller. + virtual void sendWrapperResult(uint64_t CallId, + WrapperFunctionBuffer ResultBytes) = 0; + + /// Ask the Session to run the given wrapper function. + /// + /// Subclasses must not call this method after disconnect returns. + void handleWrapperCall(uint64_t CallId, orc_rt_WrapperFunction Fn, + WrapperFunctionBuffer ArgBytes) { + assert(S && "Already disconnected"); + S->handleWrapperCall(CallId, Fn, std::move(ArgBytes)); + } + + private: + void doDisconnect() { + disconnect(); + S = nullptr; + } + Session *S; + }; + /// Create a session object. The ReportError function will be called to /// report errors generated while serving JIT'd code, e.g. if a memory /// management request cannot be fulfilled. (Error's within the JIT'd @@ -69,9 +143,22 @@ class Session { void waitForShutdown(); /// Add a ResourceManager to the session. - void addResourceManager(std::unique_ptr RM) { - std::scoped_lock Lock(M); - ResourceMgrs.push_back(std::move(RM)); + void addResourceManager(std::unique_ptr RM); + + /// Set the ControllerAccess object. + void setController(std::shared_ptr CA); + + /// Disconnect the ControllerAccess object. + void detachFromController(); + + void callController(OnCallHandlerCompleteFn OnComplete, HandlerTag T, + WrapperFunctionBuffer ArgBytes) { + if (auto TmpCA = CA) + CA->callController(std::move(OnComplete), T, std::move(ArgBytes)); + else + OnComplete( + WrapperFunctionBuffer::createOutOfBandError("no controller attached") + .release()); } private: @@ -85,7 +172,23 @@ class Session { void shutdownNext(Error Err); void shutdownComplete(); + void handleWrapperCall(uint64_t CallId, orc_rt_WrapperFunction Fn, + WrapperFunctionBuffer ArgBytes) { + dispatch(makeGenericTask([=, ArgBytes = std::move(ArgBytes)]() mutable { + Fn(wrap(this), CallId, wrapperReturn, ArgBytes.release()); + })); + } + + void sendWrapperResult(uint64_t CallId, WrapperFunctionBuffer ResultBytes) { + if (auto TmpCA = CA) + TmpCA->sendWrapperResult(CallId, std::move(ResultBytes)); + } + + static void wrapperReturn(orc_rt_SessionRef S, uint64_t CallId, + orc_rt_WrapperFunctionBuffer ResultBytes); + std::unique_ptr Dispatcher; + std::shared_ptr CA; ErrorReporterFn ReportError; std::mutex M; @@ -93,13 +196,19 @@ class Session { std::unique_ptr SI; }; -inline orc_rt_SessionRef wrap(Session *S) noexcept { - return reinterpret_cast(S); -} +class CallViaSession { +public: + CallViaSession(Session &S, Session::HandlerTag T) : S(S), T(T) {} -inline Session *unwrap(orc_rt_SessionRef S) noexcept { - return reinterpret_cast(S); -} + void operator()(Session::OnCallHandlerCompleteFn &&HandleResult, + WrapperFunctionBuffer ArgBytes) { + S.callController(std::move(HandleResult), T, std::move(ArgBytes)); + } + +private: + Session &S; + Session::HandlerTag T; +}; } // namespace orc_rt diff --git a/orc-rt/lib/executor/Session.cpp b/orc-rt/lib/executor/Session.cpp index 3ee9bad60c5b9..1123a3128d844 100644 --- a/orc-rt/lib/executor/Session.cpp +++ b/orc-rt/lib/executor/Session.cpp @@ -14,9 +14,14 @@ namespace orc_rt { +Session::ControllerAccess::~ControllerAccess() = default; + Session::~Session() { waitForShutdown(); } void Session::shutdown(OnShutdownCompleteFn OnShutdownComplete) { + // Safe to call concurrently / redundantly. + detachFromController(); + { std::scoped_lock Lock(M); if (SI) { @@ -38,6 +43,27 @@ void Session::waitForShutdown() { SI->CompleteCV.wait(Lock, [&]() { return SI->Complete; }); } +void Session::addResourceManager(std::unique_ptr RM) { + std::scoped_lock Lock(M); + assert(!SI && "addResourceManager called after shutdown"); + ResourceMgrs.push_back(std::move(RM)); +} + +void Session::setController(std::shared_ptr CA) { + assert(CA && "Cannot attach null controller"); + std::scoped_lock Lock(M); + assert(!this->CA && "Cannot re-attach controller"); + assert(!SI && "Cannot attach controller after shutdown"); + this->CA = std::move(CA); +} + +void Session::detachFromController() { + if (auto TmpCA = CA) { + TmpCA->doDisconnect(); + CA = nullptr; + } +} + void Session::shutdownNext(Error Err) { if (Err) reportError(std::move(Err)); @@ -72,4 +98,9 @@ void Session::shutdownComplete() { SI->CompleteCV.notify_all(); } +void Session::wrapperReturn(orc_rt_SessionRef S, uint64_t CallId, + orc_rt_WrapperFunctionBuffer ResultBytes) { + unwrap(S)->sendWrapperResult(CallId, ResultBytes); +} + } // namespace orc_rt diff --git a/orc-rt/unittests/SessionTest.cpp b/orc-rt/unittests/SessionTest.cpp index d08326d269a82..2b822192a6de3 100644 --- a/orc-rt/unittests/SessionTest.cpp +++ b/orc-rt/unittests/SessionTest.cpp @@ -11,11 +11,13 @@ //===----------------------------------------------------------------------===// #include "orc-rt/Session.h" +#include "orc-rt/SPSWrapperFunction.h" #include "orc-rt/ThreadPoolTaskDispatcher.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include #include #include #include @@ -75,11 +77,185 @@ class EnqueueingDispatcher : public TaskDispatcher { OnShutdownRun(); } + /// Run up to NumTasks (arbitrarily many if NumTasks == std::nullopt) tasks + /// from the front of the queue, returning the number actually run. + static size_t + runTasksFromFront(std::deque> &Tasks, + std::optional NumTasks = std::nullopt) { + size_t NumRun = 0; + + while (!Tasks.empty() && (!NumTasks || NumRun != *NumTasks)) { + auto T = std::move(Tasks.front()); + Tasks.pop_front(); + T->run(); + ++NumRun; + } + + return NumRun; + } + private: std::deque> &Tasks; OnShutdownRunFn OnShutdownRun; }; +class MockControllerAccess : public Session::ControllerAccess { +public: + MockControllerAccess(Session &SS) : Session::ControllerAccess(SS), SS(SS) {} + + void disconnect() override { + std::unique_lock Lock(M); + Shutdown = true; + ShutdownCV.wait(Lock, [this]() { return Shutdown && Outstanding == 0; }); + } + + void callController(OnCallHandlerCompleteFn OnComplete, HandlerTag T, + WrapperFunctionBuffer ArgBytes) override { + // Simulate a call to the controller by dispatching a task to run the + // requested function. + size_t CId; + { + std::scoped_lock Lock(M); + if (Shutdown) + return; + CId = CallId++; + Pending[CId] = std::move(OnComplete); + ++Outstanding; + } + + SS.dispatch(makeGenericTask([this, CId, OnComplete = std::move(OnComplete), + T, ArgBytes = std::move(ArgBytes)]() mutable { + auto Fn = reinterpret_cast(T); + Fn(reinterpret_cast(this), CId, wfReturn, + ArgBytes.release()); + })); + + bool Notify = false; + { + std::scoped_lock Lock(M); + if (--Outstanding == 0 && Shutdown) + Notify = true; + } + if (Notify) + ShutdownCV.notify_all(); + } + + void sendWrapperResult(uint64_t CallId, + WrapperFunctionBuffer ResultBytes) override { + // Respond to a simulated call by the controller. + OnCallHandlerCompleteFn OnComplete; + { + std::scoped_lock Lock(M); + if (Shutdown) { + assert(Pending.empty() && "Shut down but results still pending?"); + return; + } + auto I = Pending.find(CallId); + assert(I != Pending.end()); + OnComplete = std::move(I->second); + Pending.erase(I); + ++Outstanding; + } + + SS.dispatch( + makeGenericTask([OnComplete = std::move(OnComplete), + ResultBytes = std::move(ResultBytes)]() mutable { + OnComplete(std::move(ResultBytes)); + })); + + bool Notify = false; + { + std::scoped_lock Lock(M); + if (--Outstanding == 0 && Shutdown) + Notify = true; + } + if (Notify) + ShutdownCV.notify_all(); + } + + void callFromController(OnCallHandlerCompleteFn OnComplete, + orc_rt_WrapperFunction Fn, + WrapperFunctionBuffer ArgBytes) { + size_t CId = 0; + bool BailOut = false; + { + std::scoped_lock Lock(M); + if (!Shutdown) { + CId = CallId++; + Pending[CId] = std::move(OnComplete); + ++Outstanding; + } else + BailOut = true; + } + if (BailOut) + return OnComplete(WrapperFunctionBuffer::createOutOfBandError( + "Controller disconnected")); + + handleWrapperCall(CId, Fn, std::move(ArgBytes)); + + bool Notify = false; + { + std::scoped_lock Lock(M); + if (--Outstanding == 0 && Shutdown) + Notify = true; + } + + if (Notify) + ShutdownCV.notify_all(); + } + + /// Simulate start of outstanding operation. + void incOutstanding() { + std::scoped_lock Lock(M); + ++Outstanding; + } + + /// Simulate end of outstanding operation. + void decOutstanding() { + bool Notify = false; + { + std::scoped_lock Lock(M); + if (--Outstanding == 0 && Shutdown) + Notify = true; + } + if (Notify) + ShutdownCV.notify_all(); + } + +private: + static void wfReturn(orc_rt_SessionRef S, uint64_t CallId, + orc_rt_WrapperFunctionBuffer ResultBytes) { + // Abuse "session" to refer to the ControllerAccess object. + // We can just re-use sendFunctionResult for this. + reinterpret_cast(S)->sendWrapperResult(CallId, + ResultBytes); + } + + Session &SS; + + std::mutex M; + bool Shutdown = false; + size_t Outstanding = 0; + size_t CallId = 0; + std::unordered_map Pending; + std::condition_variable ShutdownCV; +}; + +class CallViaMockControllerAccess { +public: + CallViaMockControllerAccess(MockControllerAccess &CA, + orc_rt_WrapperFunction Fn) + : CA(CA), Fn(Fn) {} + void operator()(Session::OnCallHandlerCompleteFn OnComplete, + WrapperFunctionBuffer ArgBytes) { + CA.callFromController(std::move(OnComplete), Fn, std::move(ArgBytes)); + } + +private: + MockControllerAccess &CA; + orc_rt_WrapperFunction Fn; +}; + // Non-overloaded version of cantFail: allows easy construction of // move_only_functionss. static void noErrors(Error Err) { cantFail(std::move(Err)); } @@ -185,3 +361,106 @@ TEST(SessionTest, ExpectedShutdownSequence) { EXPECT_TRUE(SessionShutdownComplete); } + +TEST(ControllerAccessTest, Basics) { + // Test that we can set the ControllerAccess implementation and still shut + // down as expected. + std::deque> Tasks; + Session S(std::make_unique(Tasks), noErrors); + auto CA = std::make_shared(S); + S.setController(CA); + + EnqueueingDispatcher::runTasksFromFront(Tasks); + + S.waitForShutdown(); +} + +static void add_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId, + orc_rt_WrapperFunctionReturn Return, + orc_rt_WrapperFunctionBuffer ArgBytes) { + SPSWrapperFunction::handle( + S, CallId, Return, ArgBytes, + [](move_only_function Return, int32_t X, int32_t Y) { + Return(X + Y); + }); +} + +TEST(ControllerAccessTest, ValidCallToController) { + // Simulate a call to a controller handler. + std::deque> Tasks; + Session S(std::make_unique(Tasks), noErrors); + auto CA = std::make_shared(S); + S.setController(CA); + + int32_t Result = 0; + SPSWrapperFunction::call( + CallViaSession(S, reinterpret_cast(add_sps_wrapper)), + [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); + + EnqueueingDispatcher::runTasksFromFront(Tasks); + + EXPECT_EQ(Result, 42); + + S.waitForShutdown(); +} + +TEST(ControllerAccessTest, CallToControllerBeforeAttach) { + // Expect calls to the controller prior to attaching to fail. + std::deque> Tasks; + Session S(std::make_unique(Tasks), noErrors); + + Error Err = Error::success(); + SPSWrapperFunction::call( + CallViaSession(S, reinterpret_cast(add_sps_wrapper)), + [&](Expected R) { + ErrorAsOutParameter _(Err); + Err = R.takeError(); + }, + 41, 1); + + EXPECT_EQ(toString(std::move(Err)), "no controller attached"); + + S.waitForShutdown(); +} + +TEST(ControllerAccessTest, CallToControllerAfterDetach) { + // Expect calls to the controller prior to attaching to fail. + std::deque> Tasks; + Session S(std::make_unique(Tasks), noErrors); + auto CA = std::make_shared(S); + S.setController(CA); + + S.detachFromController(); + + Error Err = Error::success(); + SPSWrapperFunction::call( + CallViaSession(S, reinterpret_cast(add_sps_wrapper)), + [&](Expected R) { + ErrorAsOutParameter _(Err); + Err = R.takeError(); + }, + 41, 1); + + EXPECT_EQ(toString(std::move(Err)), "no controller attached"); + + S.waitForShutdown(); +} + +TEST(ControllerAccessTest, CallFromController) { + // Simulate a call from the controller. + std::deque> Tasks; + Session S(std::make_unique(Tasks), noErrors); + auto CA = std::make_shared(S); + S.setController(CA); + + int32_t Result = 0; + SPSWrapperFunction::call( + CallViaMockControllerAccess(*CA, add_sps_wrapper), + [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); + + EnqueueingDispatcher::runTasksFromFront(Tasks); + + EXPECT_EQ(Result, 42); + + S.waitForShutdown(); +}