159 changes: 72 additions & 87 deletions libc/src/__support/RPC/rpc.h

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions libc/src/__support/RPC/rpc_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "rpc_client.h"
#include "rpc.h"

namespace __llvm_libc {
Expand Down
3 changes: 1 addition & 2 deletions libc/src/gpu/rpc_reset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ namespace __llvm_libc {
// shared buffer.
LLVM_LIBC_FUNCTION(void, rpc_reset,
(unsigned int num_ports, void *rpc_shared_buffer)) {
__llvm_libc::rpc::client.reset(num_ports, __llvm_libc::gpu::get_lane_size(),
rpc_shared_buffer);
__llvm_libc::rpc::client.reset(num_ports, rpc_shared_buffer);
}

} // namespace __llvm_libc
1 change: 0 additions & 1 deletion libc/startup/gpu/amdgpu/start.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ _begin(int argc, char **argv, char **env, void *rpc_shared_buffer) {
// We need to set up the RPC client first in case any of the constructors
// require it.
__llvm_libc::rpc::client.reset(__llvm_libc::rpc::DEFAULT_PORT_COUNT,
__llvm_libc::gpu::get_lane_size(),
rpc_shared_buffer);

// We want the fini array callbacks to be run after other atexit
Expand Down
1 change: 0 additions & 1 deletion libc/startup/gpu/nvptx/start.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ _begin(int argc, char **argv, char **env, void *rpc_shared_buffer) {
// We need to set up the RPC client first in case any of the constructors
// require it.
__llvm_libc::rpc::client.reset(__llvm_libc::rpc::DEFAULT_PORT_COUNT,
__llvm_libc::gpu::get_lane_size(),
rpc_shared_buffer);

// We want the fini array callbacks to be run after other atexit
Expand Down
267 changes: 185 additions & 82 deletions libc/utils/gpu/server/Server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <memory>
#include <mutex>
#include <unordered_map>
#include <variant>
#include <vector>

using namespace __llvm_libc;

Expand All @@ -22,81 +24,51 @@ static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),

static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::DEFAULT_PORT_COUNT,
"Incorrect maximum port count");
struct Device {
rpc::Server server;
std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> callbacks;
std::unordered_map<rpc_opcode_t, void *> callback_data;
};

// A struct containing all the runtime state required to run the RPC server.
struct State {
State(uint32_t num_devices)
: num_devices(num_devices),
devices(std::unique_ptr<Device[]>(new Device[num_devices])),
reference_count(0u) {}
uint32_t num_devices;
std::unique_ptr<Device[]> devices;
std::atomic_uint32_t reference_count;
};

static std::mutex startup_mutex;

static State *state;

rpc_status_t rpc_init(uint32_t num_devices) {
std::scoped_lock<decltype(startup_mutex)> lock(startup_mutex);
if (!state)
state = new State(num_devices);

if (state->reference_count == std::numeric_limits<uint32_t>::max())
return RPC_STATUS_ERROR;

state->reference_count++;

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_shutdown(void) {
if (state->reference_count-- == 1)
delete state;

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
uint32_t lane_size, rpc_alloc_ty alloc,
void *data) {
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;

uint64_t buffer_size =
__llvm_libc::rpc::Server::allocation_size(num_ports, lane_size);
void *buffer = alloc(buffer_size, data);

if (!buffer)
return RPC_STATUS_ERROR;

state->devices[device_id].server.reset(num_ports, lane_size, buffer);
// The client needs to support different lane sizes for the SIMT model. Because
// of this we need to select between the possible sizes that the client can use.
struct Server {
template <uint32_t lane_size>
Server(std::unique_ptr<rpc::Server<lane_size>> &&server)
: server(std::move(server)) {}

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
void *data) {
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;
void reset(uint64_t port_count, void *buffer) {
std::visit([&](auto &server) { server->reset(port_count, buffer); },
server);
}

dealloc(rpc_get_buffer(device_id), data);
uint64_t allocation_size(uint64_t port_count) {
uint64_t ret = 0;
std::visit([&](auto &server) { ret = server->allocation_size(port_count); },
server);
return ret;
}

return RPC_STATUS_SUCCESS;
}
void *get_buffer_start() const {
void *ret = nullptr;
std::visit([&](auto &server) { ret = server->get_buffer_start(); }, server);
return ret;
}

rpc_status_t rpc_handle_server(uint32_t device_id) {
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;
rpc_status_t handle_server(
std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
std::unordered_map<rpc_opcode_t, void *> &callback_data) {
rpc_status_t ret = RPC_STATUS_SUCCESS;
std::visit(
[&](auto &server) {
ret = handle_server(*server, callbacks, callback_data);
},
server);
return ret;
}

for (;;) {
auto port = state->devices[device_id].server.try_open();
private:
template <uint32_t lane_size>
rpc_status_t handle_server(
rpc::Server<lane_size> &server,
std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
std::unordered_map<rpc_opcode_t, void *> &callback_data) {
auto port = server.try_open();
if (!port)
return RPC_STATUS_SUCCESS;

Expand Down Expand Up @@ -175,21 +147,133 @@ rpc_status_t rpc_handle_server(uint32_t device_id) {
break;
}
default: {
auto handler = state->devices[device_id].callbacks.find(
static_cast<rpc_opcode_t>(port->get_opcode()));
auto handler =
callbacks.find(static_cast<rpc_opcode_t>(port->get_opcode()));

// We error out on an unhandled opcode.
if (handler == state->devices[device_id].callbacks.end())
if (handler == callbacks.end())
return RPC_STATUS_UNHANDLED_OPCODE;

// Invoke the registered callback with a reference to the port.
void *data = state->devices[device_id].callback_data.at(
static_cast<rpc_opcode_t>(port->get_opcode()));
rpc_port_t port_ref{reinterpret_cast<uint64_t>(&*port)};
void *data =
callback_data.at(static_cast<rpc_opcode_t>(port->get_opcode()));
rpc_port_t port_ref{reinterpret_cast<uint64_t>(&*port), lane_size};
(handler->second)(port_ref, data);
}
}
port->close();
return RPC_STATUS_CONTINUE;
}

std::variant<std::unique_ptr<rpc::Server<1>>,
std::unique_ptr<rpc::Server<32>>,
std::unique_ptr<rpc::Server<64>>>
server;
};

struct Device {
template <typename T>
Device(std::unique_ptr<T> &&server) : server(std::move(server)) {}
Server server;
std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> callbacks;
std::unordered_map<rpc_opcode_t, void *> callback_data;
};

// A struct containing all the runtime state required to run the RPC server.
struct State {
State(uint32_t num_devices)
: num_devices(num_devices), devices(num_devices), reference_count(0u) {}
uint32_t num_devices;
std::vector<std::unique_ptr<Device>> devices;
std::atomic_uint32_t reference_count;
};

static std::mutex startup_mutex;

static State *state;

rpc_status_t rpc_init(uint32_t num_devices) {
std::scoped_lock<decltype(startup_mutex)> lock(startup_mutex);
if (!state)
state = new State(num_devices);

if (state->reference_count == std::numeric_limits<uint32_t>::max())
return RPC_STATUS_ERROR;

state->reference_count++;

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_shutdown(void) {
if (state->reference_count-- == 1)
delete state;

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
uint32_t lane_size, rpc_alloc_ty alloc,
void *data) {
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;

if (!state->devices[device_id]) {
switch (lane_size) {
case 1:
state->devices[device_id] =
std::make_unique<Device>(std::make_unique<rpc::Server<1>>());
break;
case 32:
state->devices[device_id] =
std::make_unique<Device>(std::make_unique<rpc::Server<32>>());
break;
case 64:
state->devices[device_id] =
std::make_unique<Device>(std::make_unique<rpc::Server<64>>());
break;
default:
return RPC_STATUS_INVALID_LANE_SIZE;
}
}

uint64_t size = state->devices[device_id]->server.allocation_size(num_ports);
void *buffer = alloc(size, data);

if (!buffer)
return RPC_STATUS_ERROR;

state->devices[device_id]->server.reset(num_ports, buffer);

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
void *data) {
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;
if (!state->devices[device_id])
return RPC_STATUS_ERROR;

dealloc(rpc_get_buffer(device_id), data);
if (state->devices[device_id])
state->devices[device_id].release();

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_handle_server(uint32_t device_id) {
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;
if (!state->devices[device_id])
return RPC_STATUS_ERROR;

for (;;) {
auto &device = *state->devices[device_id];
rpc_status_t status =
device.server.handle_server(device.callbacks, device.callback_data);
if (status != RPC_STATUS_CONTINUE)
return status;
}
}

Expand All @@ -198,22 +282,41 @@ rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode,
void *data) {
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;
if (!state->devices[device_id])
return RPC_STATUS_ERROR;

state->devices[device_id].callbacks[opcode] = callback;
state->devices[device_id].callback_data[opcode] = data;
state->devices[device_id]->callbacks[opcode] = callback;
state->devices[device_id]->callback_data[opcode] = data;
return RPC_STATUS_SUCCESS;
}

void *rpc_get_buffer(uint32_t device_id) {
if (device_id >= state->num_devices)
return nullptr;
return state->devices[device_id].server.get_buffer_start();
if (!state->devices[device_id])
return nullptr;
return state->devices[device_id]->server.get_buffer_start();
}

void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback,
void *data) {
rpc::Server::Port *port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
port->recv_and_send([=](rpc::Buffer *buffer) {
callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
});
if (ref.lane_size == 1) {
rpc::Server<1>::Port *port =
reinterpret_cast<rpc::Server<1>::Port *>(ref.handle);
port->recv_and_send([=](rpc::Buffer *buffer) {
callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
});
} else if (ref.lane_size == 32) {
rpc::Server<32>::Port *port =
reinterpret_cast<rpc::Server<32>::Port *>(ref.handle);
port->recv_and_send([=](rpc::Buffer *buffer) {
callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
});
} else if (ref.lane_size == 64) {
rpc::Server<64>::Port *port =
reinterpret_cast<rpc::Server<64>::Port *>(ref.handle);
port->recv_and_send([=](rpc::Buffer *buffer) {
callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
});
}
}
3 changes: 3 additions & 0 deletions libc/utils/gpu/server/Server.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@ const uint64_t RPC_MAXIMUM_PORT_COUNT = 64;
/// status codes.
typedef enum {
RPC_STATUS_SUCCESS = 0x0,
RPC_STATUS_CONTINUE = 0x1,
RPC_STATUS_ERROR = 0x1000,
RPC_STATUS_OUT_OF_RANGE = 0x1001,
RPC_STATUS_UNHANDLED_OPCODE = 0x1002,
RPC_STATUS_INVALID_LANE_SIZE = 0x1003,
} rpc_status_t;

/// A struct containing an opaque handle to an RPC port. This is what allows the
/// server to communicate with the client.
typedef struct rpc_port_s {
uint64_t handle;
uint32_t lane_size;
} rpc_port_t;

/// A fixed-size buffer containing the payload sent from the client.
Expand Down