112 changes: 30 additions & 82 deletions libc/utils/gpu/server/rpc_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,127 +248,75 @@ struct Device {
std::unordered_map<uint16_t, void *> callback_data;
};

// A struct containing all the runtime state required to run the RPC server.
struct State {
State(uint32_t num_devices)
: num_devices(num_devices), devices(num_devices), reference_count(0u) {}
uint32_t num_devices;
std::vector<std::unique_ptr<Device>> devices;
std::atomic_uint32_t reference_count;
};

static std::mutex startup_mutex;

static State *state;

rpc_status_t rpc_init(uint32_t num_devices) {
std::scoped_lock<decltype(startup_mutex)> lock(startup_mutex);
if (!state)
state = new State(num_devices);

if (state->reference_count == std::numeric_limits<uint32_t>::max())
return RPC_STATUS_ERROR;

state->reference_count++;

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_shutdown(void) {
if (state && state->reference_count-- == 1)
delete state;

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
rpc_status_t rpc_server_init(rpc_device_t *rpc_device, uint64_t num_ports,
uint32_t lane_size, rpc_alloc_ty alloc,
void *data) {
if (!state)
return RPC_STATUS_NOT_INITIALIZED;
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;
if (!rpc_device)
return RPC_STATUS_ERROR;
if (lane_size != 1 && lane_size != 32 && lane_size != 64)
return RPC_STATUS_INVALID_LANE_SIZE;

if (!state->devices[device_id]) {
uint64_t size = rpc::Server::allocation_size(lane_size, num_ports);
void *buffer = alloc(size, data);
uint64_t size = rpc::Server::allocation_size(lane_size, num_ports);
void *buffer = alloc(size, data);

if (!buffer)
return RPC_STATUS_ERROR;
if (!buffer)
return RPC_STATUS_ERROR;

state->devices[device_id] =
std::make_unique<Device>(lane_size, num_ports, buffer);
if (!state->devices[device_id])
return RPC_STATUS_ERROR;
}
Device *device = new Device(lane_size, num_ports, buffer);
if (!device)
return RPC_STATUS_ERROR;

rpc_device->handle = reinterpret_cast<uintptr_t>(device);
return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
rpc_status_t rpc_server_shutdown(rpc_device_t rpc_device, rpc_free_ty dealloc,
void *data) {
if (!state)
return RPC_STATUS_NOT_INITIALIZED;
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;
if (!state->devices[device_id])
if (!rpc_device.handle)
return RPC_STATUS_ERROR;

dealloc(state->devices[device_id]->buffer, data);
if (state->devices[device_id])
state->devices[device_id].release();
Device *device = reinterpret_cast<Device *>(rpc_device.handle);
dealloc(device->buffer, data);
delete device;

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_handle_server(uint32_t device_id) {
if (!state)
return RPC_STATUS_NOT_INITIALIZED;
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;
if (!state->devices[device_id])
rpc_status_t rpc_handle_server(rpc_device_t rpc_device) {
if (!rpc_device.handle)
return RPC_STATUS_ERROR;

Device *device = reinterpret_cast<Device *>(rpc_device.handle);
uint32_t index = 0;
for (;;) {
Device &device = *state->devices[device_id];
rpc_status_t status = device.handle_server(index);
rpc_status_t status = device->handle_server(index);
if (status != RPC_STATUS_CONTINUE)
return status;
}
}

rpc_status_t rpc_register_callback(uint32_t device_id, uint16_t opcode,
rpc_status_t rpc_register_callback(rpc_device_t rpc_device, uint16_t opcode,
rpc_opcode_callback_ty callback,
void *data) {
if (!state)
return RPC_STATUS_NOT_INITIALIZED;
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;
if (!state->devices[device_id])
if (!rpc_device.handle)
return RPC_STATUS_ERROR;

state->devices[device_id]->callbacks[opcode] = callback;
state->devices[device_id]->callback_data[opcode] = data;
Device *device = reinterpret_cast<Device *>(rpc_device.handle);

device->callbacks[opcode] = callback;
device->callback_data[opcode] = data;
return RPC_STATUS_SUCCESS;
}

const void *rpc_get_client_buffer(uint32_t device_id) {
if (!state || device_id >= state->num_devices || !state->devices[device_id])
const void *rpc_get_client_buffer(rpc_device_t rpc_device) {
if (!rpc_device.handle)
return nullptr;
return &state->devices[device_id]->client;
Device *device = reinterpret_cast<Device *>(rpc_device.handle);
return &device->client;
}

uint64_t rpc_get_client_size() { return sizeof(rpc::Client); }

using ServerPort = std::variant<rpc::Server::Port *>;

ServerPort get_port(rpc_port_t ref) {
return reinterpret_cast<rpc::Server::Port *>(ref.handle);
}

void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
port->send([=](rpc::Buffer *buffer) {
Expand Down
7 changes: 4 additions & 3 deletions openmp/libomptarget/plugins-nextgen/common/include/RPC.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#ifndef OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_COMMON_RPC_H
#define OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_COMMON_RPC_H

#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Error.h"

#include <cstdint>
Expand All @@ -32,8 +33,6 @@ class DeviceImageTy;
/// these routines will perform no action.
struct RPCServerTy {
public:
RPCServerTy(uint32_t NumDevices);

/// Check if this device image is using an RPC server. This checks for the
/// precense of an externally visible symbol in the device image that will
/// be present whenever RPC code is called.
Expand All @@ -56,7 +55,9 @@ struct RPCServerTy {
/// memory associated with the k
llvm::Error deinitDevice(plugin::GenericDeviceTy &Device);

~RPCServerTy();
private:
/// Array from this device's identifier to its attached devices.
llvm::SmallVector<uintptr_t> Handles;
};

} // namespace llvm::omp::target
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1492,7 +1492,7 @@ Error GenericPluginTy::init() {
GlobalHandler = createGlobalHandler();
assert(GlobalHandler && "Invalid global handler");

RPCServer = new RPCServerTy(NumDevices);
RPCServer = new RPCServerTy();
assert(RPCServer && "Invalid RPC server");

return Plugin::success();
Expand Down
44 changes: 17 additions & 27 deletions openmp/libomptarget/plugins-nextgen/common/src/RPC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@ using namespace llvm;
using namespace omp;
using namespace target;

RPCServerTy::RPCServerTy(uint32_t NumDevices) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
// If this fails then something is catastrophically wrong, just exit.
if (rpc_status_t Err = rpc_init(NumDevices))
FATAL_MESSAGE(1, "Error initializing the RPC server: %d\n", Err);
#endif
}

llvm::Expected<bool>
RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device,
plugin::GenericGlobalHandlerTy &Handler,
Expand All @@ -44,18 +36,19 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
plugin::GenericGlobalHandlerTy &Handler,
plugin::DeviceImageTy &Image) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
uint32_t DeviceId = Device.getDeviceId();
auto Alloc = [](uint64_t Size, void *Data) {
plugin::GenericDeviceTy &Device =
*reinterpret_cast<plugin::GenericDeviceTy *>(Data);
return Device.allocate(Size, nullptr, TARGET_ALLOC_HOST);
};
uint64_t NumPorts =
std::min(Device.requestedRPCPortCount(), RPC_MAXIMUM_PORT_COUNT);
if (rpc_status_t Err = rpc_server_init(DeviceId, NumPorts,
rpc_device_t RPCDevice;
if (rpc_status_t Err = rpc_server_init(&RPCDevice, NumPorts,
Device.getWarpSize(), Alloc, &Device))
return plugin::Plugin::error(
"Failed to initialize RPC server for device %d: %d", DeviceId, Err);
"Failed to initialize RPC server for device %d: %d",
Device.getDeviceId(), Err);

// Register a custom opcode handler to perform plugin specific allocation.
auto MallocHandler = [](rpc_port_t Port, void *Data) {
Expand All @@ -70,10 +63,10 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
Data);
};
if (rpc_status_t Err =
rpc_register_callback(DeviceId, RPC_MALLOC, MallocHandler, &Device))
rpc_register_callback(RPCDevice, RPC_MALLOC, MallocHandler, &Device))
return plugin::Plugin::error(
"Failed to register RPC malloc handler for device %d: %d\n", DeviceId,
Err);
"Failed to register RPC malloc handler for device %d: %d\n",
Device.getDeviceId(), Err);

// Register a custom opcode handler to perform plugin specific deallocation.
auto FreeHandler = [](rpc_port_t Port, void *Data) {
Expand All @@ -88,10 +81,10 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
Data);
};
if (rpc_status_t Err =
rpc_register_callback(DeviceId, RPC_FREE, FreeHandler, &Device))
rpc_register_callback(RPCDevice, RPC_FREE, FreeHandler, &Device))
return plugin::Plugin::error(
"Failed to register RPC free handler for device %d: %d\n", DeviceId,
Err);
"Failed to register RPC free handler for device %d: %d\n",
Device.getDeviceId(), Err);

// Get the address of the RPC client from the device.
void *ClientPtr;
Expand All @@ -104,17 +97,20 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
sizeof(void *), nullptr))
return Err;

const void *ClientBuffer = rpc_get_client_buffer(DeviceId);
const void *ClientBuffer = rpc_get_client_buffer(RPCDevice);
if (auto Err = Device.dataSubmit(ClientPtr, ClientBuffer,
rpc_get_client_size(), nullptr))
return Err;
Handles.resize(Device.getDeviceId() + 1);
Handles[Device.getDeviceId()] = RPCDevice.handle;
#endif
return Error::success();
}

Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
if (rpc_status_t Err = rpc_handle_server(Device.getDeviceId()))
rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};
if (rpc_status_t Err = rpc_handle_server(RPCDevice))
return plugin::Plugin::error(
"Error while running RPC server on device %d: %d", Device.getDeviceId(),
Err);
Expand All @@ -124,22 +120,16 @@ Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) {

Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};
auto Dealloc = [](void *Ptr, void *Data) {
plugin::GenericDeviceTy &Device =
*reinterpret_cast<plugin::GenericDeviceTy *>(Data);
Device.free(Ptr, TARGET_ALLOC_HOST);
};
if (rpc_status_t Err =
rpc_server_shutdown(Device.getDeviceId(), Dealloc, &Device))
if (rpc_status_t Err = rpc_server_shutdown(RPCDevice, Dealloc, &Device))
return plugin::Plugin::error(
"Failed to shut down RPC server for device %d: %d",
Device.getDeviceId(), Err);
#endif
return Error::success();
}

RPCServerTy::~RPCServerTy() {
#ifdef LIBOMPTARGET_RPC_SUPPORT
rpc_shutdown();
#endif
}