Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 118 additions & 9 deletions orc-rt/include/orc-rt/Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,97 @@
#include "orc-rt/Error.h"
#include "orc-rt/ResourceManager.h"
#include "orc-rt/TaskDispatcher.h"
#include "orc-rt/WrapperFunction.h"
#include "orc-rt/move_only_function.h"

#include "orc-rt-c/CoreTypes.h"
#include "orc-rt-c/WrapperFunction.h"

#include <cassert>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <vector>

namespace orc_rt {

class Session;

inline orc_rt_SessionRef wrap(Session *S) noexcept {
return reinterpret_cast<orc_rt_SessionRef>(S);
}

inline Session *unwrap(orc_rt_SessionRef S) noexcept {
return reinterpret_cast<Session *>(S);
}

/// Represents an ORC executor Session.
class Session {
public:
using ErrorReporterFn = move_only_function<void(Error)>;
using OnShutdownCompleteFn = move_only_function<void()>;

using HandlerTag = void *;
using OnCallHandlerCompleteFn =
move_only_function<void(WrapperFunctionBuffer)>;

/// Provides access to the controller.
class ControllerAccess {
friend class Session;

public:
virtual ~ControllerAccess();

protected:
using HandlerTag = Session::HandlerTag;
using OnCallHandlerCompleteFn = Session::OnCallHandlerCompleteFn;

ControllerAccess(Session &S) : S(&S) {}

/// Called by the Session to disconnect the session with the Controller.
///
/// disconnect implementations must support concurrent entry on multiple
/// threads, and all calls must block until the disconnect operation is
/// complete.
///
/// Once disconnect completes, implementations should make no further
/// calls to the Session, and should ignore any calls from the session
/// (implementations are free to ignore any calls from the Session after
/// disconnect is called).
virtual void disconnect() = 0;

/// Report an error to the session.
void reportError(Error Err) {
assert(S && "Already disconnected");
S->reportError(std::move(Err));
}

/// Call the handler in the controller associated with the given tag.
virtual void callController(OnCallHandlerCompleteFn OnComplete,
HandlerTag T,
WrapperFunctionBuffer ArgBytes) = 0;

/// Send the result of the given wrapper function call to the controller.
virtual void sendWrapperResult(uint64_t CallId,
WrapperFunctionBuffer ResultBytes) = 0;

/// Ask the Session to run the given wrapper function.
///
/// Subclasses must not call this method after disconnect returns.
void handleWrapperCall(uint64_t CallId, orc_rt_WrapperFunction Fn,
WrapperFunctionBuffer ArgBytes) {
assert(S && "Already disconnected");
S->handleWrapperCall(CallId, Fn, std::move(ArgBytes));
}

private:
void doDisconnect() {
disconnect();
S = nullptr;
}
Session *S;
};

/// Create a session object. The ReportError function will be called to
/// report errors generated while serving JIT'd code, e.g. if a memory
/// management request cannot be fulfilled. (Error's within the JIT'd
Expand Down Expand Up @@ -69,9 +143,22 @@ class Session {
void waitForShutdown();

/// Add a ResourceManager to the session.
void addResourceManager(std::unique_ptr<ResourceManager> RM) {
std::scoped_lock<std::mutex> Lock(M);
ResourceMgrs.push_back(std::move(RM));
void addResourceManager(std::unique_ptr<ResourceManager> RM);

/// Set the ControllerAccess object.
void setController(std::shared_ptr<ControllerAccess> CA);

/// Disconnect the ControllerAccess object.
void detachFromController();

void callController(OnCallHandlerCompleteFn OnComplete, HandlerTag T,
WrapperFunctionBuffer ArgBytes) {
if (auto TmpCA = CA)
CA->callController(std::move(OnComplete), T, std::move(ArgBytes));
else
OnComplete(
WrapperFunctionBuffer::createOutOfBandError("no controller attached")
.release());
}

private:
Expand All @@ -85,21 +172,43 @@ class Session {
void shutdownNext(Error Err);
void shutdownComplete();

void handleWrapperCall(uint64_t CallId, orc_rt_WrapperFunction Fn,
WrapperFunctionBuffer ArgBytes) {
dispatch(makeGenericTask([=, ArgBytes = std::move(ArgBytes)]() mutable {
Fn(wrap(this), CallId, wrapperReturn, ArgBytes.release());
}));
}

void sendWrapperResult(uint64_t CallId, WrapperFunctionBuffer ResultBytes) {
if (auto TmpCA = CA)
TmpCA->sendWrapperResult(CallId, std::move(ResultBytes));
}

static void wrapperReturn(orc_rt_SessionRef S, uint64_t CallId,
orc_rt_WrapperFunctionBuffer ResultBytes);

std::unique_ptr<TaskDispatcher> Dispatcher;
std::shared_ptr<ControllerAccess> CA;
ErrorReporterFn ReportError;

std::mutex M;
std::vector<std::unique_ptr<ResourceManager>> ResourceMgrs;
std::unique_ptr<ShutdownInfo> SI;
};

inline orc_rt_SessionRef wrap(Session *S) noexcept {
return reinterpret_cast<orc_rt_SessionRef>(S);
}
class CallViaSession {
public:
CallViaSession(Session &S, Session::HandlerTag T) : S(S), T(T) {}

inline Session *unwrap(orc_rt_SessionRef S) noexcept {
return reinterpret_cast<Session *>(S);
}
void operator()(Session::OnCallHandlerCompleteFn &&HandleResult,
WrapperFunctionBuffer ArgBytes) {
S.callController(std::move(HandleResult), T, std::move(ArgBytes));
}

private:
Session &S;
Session::HandlerTag T;
};

} // namespace orc_rt

Expand Down
31 changes: 31 additions & 0 deletions orc-rt/lib/executor/Session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@

namespace orc_rt {

Session::ControllerAccess::~ControllerAccess() = default;

Session::~Session() { waitForShutdown(); }

void Session::shutdown(OnShutdownCompleteFn OnShutdownComplete) {
// Safe to call concurrently / redundantly.
detachFromController();

{
std::scoped_lock<std::mutex> Lock(M);
if (SI) {
Expand All @@ -38,6 +43,27 @@ void Session::waitForShutdown() {
SI->CompleteCV.wait(Lock, [&]() { return SI->Complete; });
}

void Session::addResourceManager(std::unique_ptr<ResourceManager> RM) {
std::scoped_lock<std::mutex> Lock(M);
assert(!SI && "addResourceManager called after shutdown");
ResourceMgrs.push_back(std::move(RM));
}

void Session::setController(std::shared_ptr<ControllerAccess> CA) {
assert(CA && "Cannot attach null controller");
std::scoped_lock<std::mutex> Lock(M);
assert(!this->CA && "Cannot re-attach controller");
assert(!SI && "Cannot attach controller after shutdown");
this->CA = std::move(CA);
}

void Session::detachFromController() {
if (auto TmpCA = CA) {
TmpCA->doDisconnect();
CA = nullptr;
}
}

void Session::shutdownNext(Error Err) {
if (Err)
reportError(std::move(Err));
Expand Down Expand Up @@ -72,4 +98,9 @@ void Session::shutdownComplete() {
SI->CompleteCV.notify_all();
}

void Session::wrapperReturn(orc_rt_SessionRef S, uint64_t CallId,
orc_rt_WrapperFunctionBuffer ResultBytes) {
unwrap(S)->sendWrapperResult(CallId, ResultBytes);
}

} // namespace orc_rt
Loading
Loading