diff --git a/orc-rt/include/orc-rt/Session.h b/orc-rt/include/orc-rt/Session.h index 529aac6a2fadd..dffaf6ec3cbd9 100644 --- a/orc-rt/include/orc-rt/Session.h +++ b/orc-rt/include/orc-rt/Session.h @@ -16,10 +16,13 @@ #include "orc-rt/Error.h" #include "orc-rt/ResourceManager.h" #include "orc-rt/TaskDispatcher.h" +#include "orc-rt/WrapperFunction.h" #include "orc-rt/move_only_function.h" #include "orc-rt-c/CoreTypes.h" +#include "orc-rt-c/WrapperFunction.h" +#include #include #include #include @@ -27,12 +30,83 @@ namespace orc_rt { +class Session; + +inline orc_rt_SessionRef wrap(Session *S) noexcept { + return reinterpret_cast(S); +} + +inline Session *unwrap(orc_rt_SessionRef S) noexcept { + return reinterpret_cast(S); +} + /// Represents an ORC executor Session. class Session { public: using ErrorReporterFn = move_only_function; using OnShutdownCompleteFn = move_only_function; + using HandlerTag = void *; + using OnCallHandlerCompleteFn = + move_only_function; + + /// Provides access to the controller. + class ControllerAccess { + friend class Session; + + public: + virtual ~ControllerAccess(); + + protected: + using HandlerTag = Session::HandlerTag; + using OnCallHandlerCompleteFn = Session::OnCallHandlerCompleteFn; + + ControllerAccess(Session &S) : S(&S) {} + + /// Called by the Session to disconnect the session with the Controller. + /// + /// disconnect implementations must support concurrent entry on multiple + /// threads, and all calls must block until the disconnect operation is + /// complete. + /// + /// Once disconnect completes, implementations should make no further + /// calls to the Session, and should ignore any calls from the session + /// (implementations are free to ignore any calls from the Session after + /// disconnect is called). + virtual void disconnect() = 0; + + /// Report an error to the session. + void reportError(Error Err) { + assert(S && "Already disconnected"); + S->reportError(std::move(Err)); + } + + /// Call the handler in the controller associated with the given tag. + virtual void callController(OnCallHandlerCompleteFn OnComplete, + HandlerTag T, + WrapperFunctionBuffer ArgBytes) = 0; + + /// Send the result of the given wrapper function call to the controller. + virtual void sendWrapperResult(uint64_t CallId, + WrapperFunctionBuffer ResultBytes) = 0; + + /// Ask the Session to run the given wrapper function. + /// + /// Subclasses must not call this method after disconnect returns. + void handleWrapperCall(uint64_t CallId, orc_rt_WrapperFunction Fn, + WrapperFunctionBuffer ArgBytes) { + assert(S && "Already disconnected"); + S->handleWrapperCall(CallId, Fn, std::move(ArgBytes)); + } + + private: + void doDisconnect() { + disconnect(); + S = nullptr; + } + Session *S; + }; + /// Create a session object. The ReportError function will be called to /// report errors generated while serving JIT'd code, e.g. if a memory /// management request cannot be fulfilled. (Error's within the JIT'd @@ -69,9 +143,22 @@ class Session { void waitForShutdown(); /// Add a ResourceManager to the session. - void addResourceManager(std::unique_ptr RM) { - std::scoped_lock Lock(M); - ResourceMgrs.push_back(std::move(RM)); + void addResourceManager(std::unique_ptr RM); + + /// Set the ControllerAccess object. + void setController(std::shared_ptr CA); + + /// Disconnect the ControllerAccess object. + void detachFromController(); + + void callController(OnCallHandlerCompleteFn OnComplete, HandlerTag T, + WrapperFunctionBuffer ArgBytes) { + if (auto TmpCA = CA) + CA->callController(std::move(OnComplete), T, std::move(ArgBytes)); + else + OnComplete( + WrapperFunctionBuffer::createOutOfBandError("no controller attached") + .release()); } private: @@ -85,7 +172,23 @@ class Session { void shutdownNext(Error Err); void shutdownComplete(); + void handleWrapperCall(uint64_t CallId, orc_rt_WrapperFunction Fn, + WrapperFunctionBuffer ArgBytes) { + dispatch(makeGenericTask([=, ArgBytes = std::move(ArgBytes)]() mutable { + Fn(wrap(this), CallId, wrapperReturn, ArgBytes.release()); + })); + } + + void sendWrapperResult(uint64_t CallId, WrapperFunctionBuffer ResultBytes) { + if (auto TmpCA = CA) + TmpCA->sendWrapperResult(CallId, std::move(ResultBytes)); + } + + static void wrapperReturn(orc_rt_SessionRef S, uint64_t CallId, + orc_rt_WrapperFunctionBuffer ResultBytes); + std::unique_ptr Dispatcher; + std::shared_ptr CA; ErrorReporterFn ReportError; std::mutex M; @@ -93,13 +196,19 @@ class Session { std::unique_ptr SI; }; -inline orc_rt_SessionRef wrap(Session *S) noexcept { - return reinterpret_cast(S); -} +class CallViaSession { +public: + CallViaSession(Session &S, Session::HandlerTag T) : S(S), T(T) {} -inline Session *unwrap(orc_rt_SessionRef S) noexcept { - return reinterpret_cast(S); -} + void operator()(Session::OnCallHandlerCompleteFn &&HandleResult, + WrapperFunctionBuffer ArgBytes) { + S.callController(std::move(HandleResult), T, std::move(ArgBytes)); + } + +private: + Session &S; + Session::HandlerTag T; +}; } // namespace orc_rt diff --git a/orc-rt/lib/executor/Session.cpp b/orc-rt/lib/executor/Session.cpp index 3ee9bad60c5b9..1123a3128d844 100644 --- a/orc-rt/lib/executor/Session.cpp +++ b/orc-rt/lib/executor/Session.cpp @@ -14,9 +14,14 @@ namespace orc_rt { +Session::ControllerAccess::~ControllerAccess() = default; + Session::~Session() { waitForShutdown(); } void Session::shutdown(OnShutdownCompleteFn OnShutdownComplete) { + // Safe to call concurrently / redundantly. + detachFromController(); + { std::scoped_lock Lock(M); if (SI) { @@ -38,6 +43,27 @@ void Session::waitForShutdown() { SI->CompleteCV.wait(Lock, [&]() { return SI->Complete; }); } +void Session::addResourceManager(std::unique_ptr RM) { + std::scoped_lock Lock(M); + assert(!SI && "addResourceManager called after shutdown"); + ResourceMgrs.push_back(std::move(RM)); +} + +void Session::setController(std::shared_ptr CA) { + assert(CA && "Cannot attach null controller"); + std::scoped_lock Lock(M); + assert(!this->CA && "Cannot re-attach controller"); + assert(!SI && "Cannot attach controller after shutdown"); + this->CA = std::move(CA); +} + +void Session::detachFromController() { + if (auto TmpCA = CA) { + TmpCA->doDisconnect(); + CA = nullptr; + } +} + void Session::shutdownNext(Error Err) { if (Err) reportError(std::move(Err)); @@ -72,4 +98,9 @@ void Session::shutdownComplete() { SI->CompleteCV.notify_all(); } +void Session::wrapperReturn(orc_rt_SessionRef S, uint64_t CallId, + orc_rt_WrapperFunctionBuffer ResultBytes) { + unwrap(S)->sendWrapperResult(CallId, ResultBytes); +} + } // namespace orc_rt diff --git a/orc-rt/unittests/SessionTest.cpp b/orc-rt/unittests/SessionTest.cpp index d08326d269a82..2b822192a6de3 100644 --- a/orc-rt/unittests/SessionTest.cpp +++ b/orc-rt/unittests/SessionTest.cpp @@ -11,11 +11,13 @@ //===----------------------------------------------------------------------===// #include "orc-rt/Session.h" +#include "orc-rt/SPSWrapperFunction.h" #include "orc-rt/ThreadPoolTaskDispatcher.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include #include #include #include @@ -75,11 +77,185 @@ class EnqueueingDispatcher : public TaskDispatcher { OnShutdownRun(); } + /// Run up to NumTasks (arbitrarily many if NumTasks == std::nullopt) tasks + /// from the front of the queue, returning the number actually run. + static size_t + runTasksFromFront(std::deque> &Tasks, + std::optional NumTasks = std::nullopt) { + size_t NumRun = 0; + + while (!Tasks.empty() && (!NumTasks || NumRun != *NumTasks)) { + auto T = std::move(Tasks.front()); + Tasks.pop_front(); + T->run(); + ++NumRun; + } + + return NumRun; + } + private: std::deque> &Tasks; OnShutdownRunFn OnShutdownRun; }; +class MockControllerAccess : public Session::ControllerAccess { +public: + MockControllerAccess(Session &SS) : Session::ControllerAccess(SS), SS(SS) {} + + void disconnect() override { + std::unique_lock Lock(M); + Shutdown = true; + ShutdownCV.wait(Lock, [this]() { return Shutdown && Outstanding == 0; }); + } + + void callController(OnCallHandlerCompleteFn OnComplete, HandlerTag T, + WrapperFunctionBuffer ArgBytes) override { + // Simulate a call to the controller by dispatching a task to run the + // requested function. + size_t CId; + { + std::scoped_lock Lock(M); + if (Shutdown) + return; + CId = CallId++; + Pending[CId] = std::move(OnComplete); + ++Outstanding; + } + + SS.dispatch(makeGenericTask([this, CId, OnComplete = std::move(OnComplete), + T, ArgBytes = std::move(ArgBytes)]() mutable { + auto Fn = reinterpret_cast(T); + Fn(reinterpret_cast(this), CId, wfReturn, + ArgBytes.release()); + })); + + bool Notify = false; + { + std::scoped_lock Lock(M); + if (--Outstanding == 0 && Shutdown) + Notify = true; + } + if (Notify) + ShutdownCV.notify_all(); + } + + void sendWrapperResult(uint64_t CallId, + WrapperFunctionBuffer ResultBytes) override { + // Respond to a simulated call by the controller. + OnCallHandlerCompleteFn OnComplete; + { + std::scoped_lock Lock(M); + if (Shutdown) { + assert(Pending.empty() && "Shut down but results still pending?"); + return; + } + auto I = Pending.find(CallId); + assert(I != Pending.end()); + OnComplete = std::move(I->second); + Pending.erase(I); + ++Outstanding; + } + + SS.dispatch( + makeGenericTask([OnComplete = std::move(OnComplete), + ResultBytes = std::move(ResultBytes)]() mutable { + OnComplete(std::move(ResultBytes)); + })); + + bool Notify = false; + { + std::scoped_lock Lock(M); + if (--Outstanding == 0 && Shutdown) + Notify = true; + } + if (Notify) + ShutdownCV.notify_all(); + } + + void callFromController(OnCallHandlerCompleteFn OnComplete, + orc_rt_WrapperFunction Fn, + WrapperFunctionBuffer ArgBytes) { + size_t CId = 0; + bool BailOut = false; + { + std::scoped_lock Lock(M); + if (!Shutdown) { + CId = CallId++; + Pending[CId] = std::move(OnComplete); + ++Outstanding; + } else + BailOut = true; + } + if (BailOut) + return OnComplete(WrapperFunctionBuffer::createOutOfBandError( + "Controller disconnected")); + + handleWrapperCall(CId, Fn, std::move(ArgBytes)); + + bool Notify = false; + { + std::scoped_lock Lock(M); + if (--Outstanding == 0 && Shutdown) + Notify = true; + } + + if (Notify) + ShutdownCV.notify_all(); + } + + /// Simulate start of outstanding operation. + void incOutstanding() { + std::scoped_lock Lock(M); + ++Outstanding; + } + + /// Simulate end of outstanding operation. + void decOutstanding() { + bool Notify = false; + { + std::scoped_lock Lock(M); + if (--Outstanding == 0 && Shutdown) + Notify = true; + } + if (Notify) + ShutdownCV.notify_all(); + } + +private: + static void wfReturn(orc_rt_SessionRef S, uint64_t CallId, + orc_rt_WrapperFunctionBuffer ResultBytes) { + // Abuse "session" to refer to the ControllerAccess object. + // We can just re-use sendFunctionResult for this. + reinterpret_cast(S)->sendWrapperResult(CallId, + ResultBytes); + } + + Session &SS; + + std::mutex M; + bool Shutdown = false; + size_t Outstanding = 0; + size_t CallId = 0; + std::unordered_map Pending; + std::condition_variable ShutdownCV; +}; + +class CallViaMockControllerAccess { +public: + CallViaMockControllerAccess(MockControllerAccess &CA, + orc_rt_WrapperFunction Fn) + : CA(CA), Fn(Fn) {} + void operator()(Session::OnCallHandlerCompleteFn OnComplete, + WrapperFunctionBuffer ArgBytes) { + CA.callFromController(std::move(OnComplete), Fn, std::move(ArgBytes)); + } + +private: + MockControllerAccess &CA; + orc_rt_WrapperFunction Fn; +}; + // Non-overloaded version of cantFail: allows easy construction of // move_only_functionss. static void noErrors(Error Err) { cantFail(std::move(Err)); } @@ -185,3 +361,106 @@ TEST(SessionTest, ExpectedShutdownSequence) { EXPECT_TRUE(SessionShutdownComplete); } + +TEST(ControllerAccessTest, Basics) { + // Test that we can set the ControllerAccess implementation and still shut + // down as expected. + std::deque> Tasks; + Session S(std::make_unique(Tasks), noErrors); + auto CA = std::make_shared(S); + S.setController(CA); + + EnqueueingDispatcher::runTasksFromFront(Tasks); + + S.waitForShutdown(); +} + +static void add_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId, + orc_rt_WrapperFunctionReturn Return, + orc_rt_WrapperFunctionBuffer ArgBytes) { + SPSWrapperFunction::handle( + S, CallId, Return, ArgBytes, + [](move_only_function Return, int32_t X, int32_t Y) { + Return(X + Y); + }); +} + +TEST(ControllerAccessTest, ValidCallToController) { + // Simulate a call to a controller handler. + std::deque> Tasks; + Session S(std::make_unique(Tasks), noErrors); + auto CA = std::make_shared(S); + S.setController(CA); + + int32_t Result = 0; + SPSWrapperFunction::call( + CallViaSession(S, reinterpret_cast(add_sps_wrapper)), + [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); + + EnqueueingDispatcher::runTasksFromFront(Tasks); + + EXPECT_EQ(Result, 42); + + S.waitForShutdown(); +} + +TEST(ControllerAccessTest, CallToControllerBeforeAttach) { + // Expect calls to the controller prior to attaching to fail. + std::deque> Tasks; + Session S(std::make_unique(Tasks), noErrors); + + Error Err = Error::success(); + SPSWrapperFunction::call( + CallViaSession(S, reinterpret_cast(add_sps_wrapper)), + [&](Expected R) { + ErrorAsOutParameter _(Err); + Err = R.takeError(); + }, + 41, 1); + + EXPECT_EQ(toString(std::move(Err)), "no controller attached"); + + S.waitForShutdown(); +} + +TEST(ControllerAccessTest, CallToControllerAfterDetach) { + // Expect calls to the controller prior to attaching to fail. + std::deque> Tasks; + Session S(std::make_unique(Tasks), noErrors); + auto CA = std::make_shared(S); + S.setController(CA); + + S.detachFromController(); + + Error Err = Error::success(); + SPSWrapperFunction::call( + CallViaSession(S, reinterpret_cast(add_sps_wrapper)), + [&](Expected R) { + ErrorAsOutParameter _(Err); + Err = R.takeError(); + }, + 41, 1); + + EXPECT_EQ(toString(std::move(Err)), "no controller attached"); + + S.waitForShutdown(); +} + +TEST(ControllerAccessTest, CallFromController) { + // Simulate a call from the controller. + std::deque> Tasks; + Session S(std::make_unique(Tasks), noErrors); + auto CA = std::make_shared(S); + S.setController(CA); + + int32_t Result = 0; + SPSWrapperFunction::call( + CallViaMockControllerAccess(*CA, add_sps_wrapper), + [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); + + EnqueueingDispatcher::runTasksFromFront(Tasks); + + EXPECT_EQ(Result, 42); + + S.waitForShutdown(); +}