Expand Up
@@ -14,6 +14,8 @@
#include < memory>
#include < mutex>
#include < unordered_map>
#include < variant>
#include < vector>
using namespace __llvm_libc ;
Expand All
@@ -22,81 +24,51 @@ static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
static_assert (RPC_MAXIMUM_PORT_COUNT == rpc::DEFAULT_PORT_COUNT,
" Incorrect maximum port count" );
struct Device {
rpc::Server server;
std::unordered_map<rpc_opcode_t , rpc_opcode_callback_ty> callbacks;
std::unordered_map<rpc_opcode_t , void *> callback_data;
};
// A struct containing all the runtime state required to run the RPC server.
struct State {
State (uint32_t num_devices)
: num_devices(num_devices),
devices (std::unique_ptr<Device[]>(new Device[num_devices])),
reference_count(0u ) {}
uint32_t num_devices;
std::unique_ptr<Device[]> devices;
std::atomic_uint32_t reference_count;
};
static std::mutex startup_mutex;
static State *state;
rpc_status_t rpc_init (uint32_t num_devices) {
std::scoped_lock<decltype (startup_mutex)> lock (startup_mutex);
if (!state)
state = new State (num_devices);
if (state->reference_count == std::numeric_limits<uint32_t >::max ())
return RPC_STATUS_ERROR;
state->reference_count ++;
return RPC_STATUS_SUCCESS;
}
rpc_status_t rpc_shutdown (void ) {
if (state->reference_count -- == 1 )
delete state;
return RPC_STATUS_SUCCESS;
}
rpc_status_t rpc_server_init (uint32_t device_id, uint64_t num_ports,
uint32_t lane_size, rpc_alloc_ty alloc,
void *data) {
if (device_id >= state->num_devices )
return RPC_STATUS_OUT_OF_RANGE;
uint64_t buffer_size =
__llvm_libc::rpc::Server::allocation_size (num_ports, lane_size);
void *buffer = alloc (buffer_size, data);
if (!buffer)
return RPC_STATUS_ERROR;
state->devices [device_id].server .reset (num_ports, lane_size, buffer);
// The client needs to support different lane sizes for the SIMT model. Because
// of this we need to select between the possible sizes that the client can use.
struct Server {
template <uint32_t lane_size>
Server (std::unique_ptr<rpc::Server<lane_size>> &&server)
: server(std::move(server)) {}
return RPC_STATUS_SUCCESS;
}
rpc_status_t rpc_server_shutdown (uint32_t device_id, rpc_free_ty dealloc,
void *data) {
if (device_id >= state->num_devices )
return RPC_STATUS_OUT_OF_RANGE;
void reset (uint64_t port_count, void *buffer) {
std::visit ([&](auto &server) { server->reset (port_count, buffer); },
server);
}
dealloc (rpc_get_buffer (device_id), data);
uint64_t allocation_size (uint64_t port_count) {
uint64_t ret = 0 ;
std::visit ([&](auto &server) { ret = server->allocation_size (port_count); },
server);
return ret;
}
return RPC_STATUS_SUCCESS;
}
void *get_buffer_start () const {
void *ret = nullptr ;
std::visit ([&](auto &server) { ret = server->get_buffer_start (); }, server);
return ret;
}
rpc_status_t rpc_handle_server (uint32_t device_id) {
if (device_id >= state->num_devices )
return RPC_STATUS_OUT_OF_RANGE;
rpc_status_t handle_server (
std::unordered_map<rpc_opcode_t , rpc_opcode_callback_ty> &callbacks,
std::unordered_map<rpc_opcode_t , void *> &callback_data) {
rpc_status_t ret = RPC_STATUS_SUCCESS;
std::visit (
[&](auto &server) {
ret = handle_server (*server, callbacks, callback_data);
},
server);
return ret;
}
for (;;) {
auto port = state->devices [device_id].server .try_open ();
private:
template <uint32_t lane_size>
rpc_status_t handle_server (
rpc::Server<lane_size> &server,
std::unordered_map<rpc_opcode_t , rpc_opcode_callback_ty> &callbacks,
std::unordered_map<rpc_opcode_t , void *> &callback_data) {
auto port = server.try_open ();
if (!port)
return RPC_STATUS_SUCCESS;
Expand Down
Expand Up
@@ -175,21 +147,133 @@ rpc_status_t rpc_handle_server(uint32_t device_id) {
break ;
}
default : {
auto handler = state-> devices [device_id]. callbacks . find (
static_cast <rpc_opcode_t >(port->get_opcode ()));
auto handler =
callbacks. find ( static_cast <rpc_opcode_t >(port->get_opcode ()));
// We error out on an unhandled opcode.
if (handler == state-> devices [device_id]. callbacks .end ())
if (handler == callbacks.end ())
return RPC_STATUS_UNHANDLED_OPCODE;
// Invoke the registered callback with a reference to the port.
void *data = state-> devices [device_id]. callback_data . at (
static_cast <rpc_opcode_t >(port->get_opcode ()));
rpc_port_t port_ref{reinterpret_cast <uint64_t >(&*port)};
void *data =
callback_data. at ( static_cast <rpc_opcode_t >(port->get_opcode ()));
rpc_port_t port_ref{reinterpret_cast <uint64_t >(&*port), lane_size };
(handler->second )(port_ref, data);
}
}
port->close ();
return RPC_STATUS_CONTINUE;
}
std::variant<std::unique_ptr<rpc::Server<1 >>,
std::unique_ptr<rpc::Server<32 >>,
std::unique_ptr<rpc::Server<64 >>>
server;
};
struct Device {
template <typename T>
Device (std::unique_ptr<T> &&server) : server(std::move(server)) {}
Server server;
std::unordered_map<rpc_opcode_t , rpc_opcode_callback_ty> callbacks;
std::unordered_map<rpc_opcode_t , void *> callback_data;
};
// A struct containing all the runtime state required to run the RPC server.
struct State {
State (uint32_t num_devices)
: num_devices(num_devices), devices(num_devices), reference_count(0u ) {}
uint32_t num_devices;
std::vector<std::unique_ptr<Device>> devices;
std::atomic_uint32_t reference_count;
};
static std::mutex startup_mutex;
static State *state;
rpc_status_t rpc_init (uint32_t num_devices) {
std::scoped_lock<decltype (startup_mutex)> lock (startup_mutex);
if (!state)
state = new State (num_devices);
if (state->reference_count == std::numeric_limits<uint32_t >::max ())
return RPC_STATUS_ERROR;
state->reference_count ++;
return RPC_STATUS_SUCCESS;
}
rpc_status_t rpc_shutdown (void ) {
if (state->reference_count -- == 1 )
delete state;
return RPC_STATUS_SUCCESS;
}
rpc_status_t rpc_server_init (uint32_t device_id, uint64_t num_ports,
uint32_t lane_size, rpc_alloc_ty alloc,
void *data) {
if (device_id >= state->num_devices )
return RPC_STATUS_OUT_OF_RANGE;
if (!state->devices [device_id]) {
switch (lane_size) {
case 1 :
state->devices [device_id] =
std::make_unique<Device>(std::make_unique<rpc::Server<1 >>());
break ;
case 32 :
state->devices [device_id] =
std::make_unique<Device>(std::make_unique<rpc::Server<32 >>());
break ;
case 64 :
state->devices [device_id] =
std::make_unique<Device>(std::make_unique<rpc::Server<64 >>());
break ;
default :
return RPC_STATUS_INVALID_LANE_SIZE;
}
}
uint64_t size = state->devices [device_id]->server .allocation_size (num_ports);
void *buffer = alloc (size, data);
if (!buffer)
return RPC_STATUS_ERROR;
state->devices [device_id]->server .reset (num_ports, buffer);
return RPC_STATUS_SUCCESS;
}
rpc_status_t rpc_server_shutdown (uint32_t device_id, rpc_free_ty dealloc,
void *data) {
if (device_id >= state->num_devices )
return RPC_STATUS_OUT_OF_RANGE;
if (!state->devices [device_id])
return RPC_STATUS_ERROR;
dealloc (rpc_get_buffer (device_id), data);
if (state->devices [device_id])
state->devices [device_id].release ();
return RPC_STATUS_SUCCESS;
}
rpc_status_t rpc_handle_server (uint32_t device_id) {
if (device_id >= state->num_devices )
return RPC_STATUS_OUT_OF_RANGE;
if (!state->devices [device_id])
return RPC_STATUS_ERROR;
for (;;) {
auto &device = *state->devices [device_id];
rpc_status_t status =
device.server .handle_server (device.callbacks , device.callback_data );
if (status != RPC_STATUS_CONTINUE)
return status;
}
}
Expand All
@@ -198,22 +282,41 @@ rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode,
void *data) {
if (device_id >= state->num_devices )
return RPC_STATUS_OUT_OF_RANGE;
if (!state->devices [device_id])
return RPC_STATUS_ERROR;
state->devices [device_id]. callbacks [opcode] = callback;
state->devices [device_id]. callback_data [opcode] = data;
state->devices [device_id]-> callbacks [opcode] = callback;
state->devices [device_id]-> callback_data [opcode] = data;
return RPC_STATUS_SUCCESS;
}
void *rpc_get_buffer (uint32_t device_id) {
if (device_id >= state->num_devices )
return nullptr ;
return state->devices [device_id].server .get_buffer_start ();
if (!state->devices [device_id])
return nullptr ;
return state->devices [device_id]->server .get_buffer_start ();
}
void rpc_recv_and_send (rpc_port_t ref, rpc_port_callback_ty callback,
void *data) {
rpc::Server::Port *port = reinterpret_cast <rpc::Server::Port *>(ref.handle );
port->recv_and_send ([=](rpc::Buffer *buffer) {
callback (reinterpret_cast <rpc_buffer_t *>(buffer), data);
});
if (ref.lane_size == 1 ) {
rpc::Server<1 >::Port *port =
reinterpret_cast <rpc::Server<1 >::Port *>(ref.handle );
port->recv_and_send ([=](rpc::Buffer *buffer) {
callback (reinterpret_cast <rpc_buffer_t *>(buffer), data);
});
} else if (ref.lane_size == 32 ) {
rpc::Server<32 >::Port *port =
reinterpret_cast <rpc::Server<32 >::Port *>(ref.handle );
port->recv_and_send ([=](rpc::Buffer *buffer) {
callback (reinterpret_cast <rpc_buffer_t *>(buffer), data);
});
} else if (ref.lane_size == 64 ) {
rpc::Server<64 >::Port *port =
reinterpret_cast <rpc::Server<64 >::Port *>(ref.handle );
port->recv_and_send ([=](rpc::Buffer *buffer) {
callback (reinterpret_cast <rpc_buffer_t *>(buffer), data);
});
}
}