68 changes: 33 additions & 35 deletions libc/utils/gpu/server/rpc_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,13 @@ struct Server {
Server(std::unique_ptr<rpc::Server<lane_size>> &&server)
: server(std::move(server)) {}

void reset(uint64_t port_count, void *buffer) {
std::visit([&](auto &server) { server->reset(port_count, buffer); },
server);
}

uint64_t allocation_size(uint64_t port_count) {
uint64_t ret = 0;
std::visit([&](auto &server) { ret = server->allocation_size(port_count); },
server);
return ret;
}

void *get_buffer_start() const {
void *ret = nullptr;
std::visit([&](auto &server) { ret = server->get_buffer_start(); }, server);
return ret;
}

rpc_status_t handle_server(
std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
std::unordered_map<rpc_opcode_t, void *> &callback_data) {
Expand Down Expand Up @@ -214,7 +203,9 @@ struct Server {

struct Device {
template <typename T>
Device(std::unique_ptr<T> &&server) : server(std::move(server)) {}
Device(uint32_t num_ports, void *buffer, std::unique_ptr<T> &&server)
: buffer(buffer), server(std::move(server)), client(num_ports, buffer) {}
void *buffer;
Server server;
rpc::Client client;
std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> callbacks;
Expand Down Expand Up @@ -254,6 +245,24 @@ rpc_status_t rpc_shutdown(void) {
return RPC_STATUS_SUCCESS;
}

template <uint32_t lane_size>
rpc_status_t server_init_impl(uint32_t device_id, uint64_t num_ports,
rpc_alloc_ty alloc, void *data) {
uint64_t size = rpc::Server<lane_size>::allocation_size(num_ports);
void *buffer = alloc(size, data);

if (!buffer)
return RPC_STATUS_ERROR;

state->devices[device_id] = std::make_unique<Device>(
num_ports, buffer,
std::make_unique<rpc::Server<lane_size>>(num_ports, buffer));
if (!state->devices[device_id])
return RPC_STATUS_ERROR;

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
uint32_t lane_size, rpc_alloc_ty alloc,
void *data) {
Expand All @@ -265,31 +274,26 @@ rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
if (!state->devices[device_id]) {
switch (lane_size) {
case 1:
state->devices[device_id] =
std::make_unique<Device>(std::make_unique<rpc::Server<1>>());
if (rpc_status_t err =
server_init_impl<1>(device_id, num_ports, alloc, data))
return err;
break;
case 32:
state->devices[device_id] =
std::make_unique<Device>(std::make_unique<rpc::Server<32>>());
case 32: {
if (rpc_status_t err =
server_init_impl<32>(device_id, num_ports, alloc, data))
return err;
break;
}
case 64:
state->devices[device_id] =
std::make_unique<Device>(std::make_unique<rpc::Server<64>>());
if (rpc_status_t err =
server_init_impl<64>(device_id, num_ports, alloc, data))
return err;
break;
default:
return RPC_STATUS_INVALID_LANE_SIZE;
}
}

uint64_t size = state->devices[device_id]->server.allocation_size(num_ports);
void *buffer = alloc(size, data);

if (!buffer)
return RPC_STATUS_ERROR;

state->devices[device_id]->server.reset(num_ports, buffer);
state->devices[device_id]->client.reset(num_ports, buffer);

return RPC_STATUS_SUCCESS;
}

Expand All @@ -302,7 +306,7 @@ rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
if (!state->devices[device_id])
return RPC_STATUS_ERROR;

dealloc(rpc_get_buffer(device_id), data);
dealloc(state->devices[device_id]->buffer, data);
if (state->devices[device_id])
state->devices[device_id].release();

Expand Down Expand Up @@ -341,12 +345,6 @@ rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode,
return RPC_STATUS_SUCCESS;
}

void *rpc_get_buffer(uint32_t device_id) {
if (!state || device_id >= state->num_devices || !state->devices[device_id])
return nullptr;
return state->devices[device_id]->server.get_buffer_start();
}

const void *rpc_get_client_buffer(uint32_t device_id) {
if (!state || device_id >= state->num_devices || !state->devices[device_id])
return nullptr;
Expand Down
5 changes: 1 addition & 4 deletions libc/utils/gpu/server/rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,8 @@ rpc_status_t rpc_handle_server(uint32_t device_id);
rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode,
rpc_opcode_callback_ty callback, void *data);

/// Obtain a pointer to the memory buffer used to run the RPC client and server.
void *rpc_get_buffer(uint32_t device_id);

/// Obtain a pointer to a local client buffer that can be copied directly to the
/// other process.
/// other process using the address stored at the rpc client symbol name.
const void *rpc_get_client_buffer(uint32_t device_id);

/// Returns the size of the client in bytes to be used for a memory copy.
Expand Down