Skip to content

Commit

Permalink
[libc] Change RPC interface to not use device ids (#87087)
Browse files Browse the repository at this point in the history
Summary:
The current implementation of RPC tied everything to device IDs and
forced us to do init / shutdown to manage some global state. This turned
out to be a bad idea in situations where we want to track multiple
hetergeneous devices that may report the same device ID in the same
process.

This patch changes the interface to instead create an opaque handle to
the internal device and simply allocates it via `new`. The user will
then take this device and store it to interface with the attached
device. This interface puts the burden of tracking the device identifier
to mapped d evices onto the user, but in return heavily simplifies the
implementation.
  • Loading branch information
jhuber6 committed Mar 29, 2024
1 parent bdb60e6 commit a1a8bb1
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 181 deletions.
8 changes: 4 additions & 4 deletions libc/utils/gpu/loader/Loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ inline void handle_error(rpc_status_t) {
}

template <uint32_t lane_size>
inline void register_rpc_callbacks(uint32_t device_id) {
inline void register_rpc_callbacks(rpc_device_t device) {
static_assert(lane_size == 32 || lane_size == 64, "Invalid Lane size");
// Register the ping test for the `libc` tests.
rpc_register_callback(
device_id, static_cast<rpc_opcode_t>(RPC_TEST_INCREMENT),
device, static_cast<rpc_opcode_t>(RPC_TEST_INCREMENT),
[](rpc_port_t port, void *data) {
rpc_recv_and_send(
port,
Expand All @@ -125,7 +125,7 @@ inline void register_rpc_callbacks(uint32_t device_id) {

// Register the interface test callbacks.
rpc_register_callback(
device_id, static_cast<rpc_opcode_t>(RPC_TEST_INTERFACE),
device, static_cast<rpc_opcode_t>(RPC_TEST_INTERFACE),
[](rpc_port_t port, void *data) {
uint64_t cnt = 0;
bool end_with_recv;
Expand Down Expand Up @@ -207,7 +207,7 @@ inline void register_rpc_callbacks(uint32_t device_id) {

// Register the stream test handler.
rpc_register_callback(
device_id, static_cast<rpc_opcode_t>(RPC_TEST_STREAM),
device, static_cast<rpc_opcode_t>(RPC_TEST_STREAM),
[](rpc_port_t port, void *data) {
uint64_t sizes[lane_size] = {0};
void *dst[lane_size] = {nullptr};
Expand Down
47 changes: 21 additions & 26 deletions libc/utils/gpu/loader/amdgpu/Loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ template <typename args_t>
hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
hsa_amd_memory_pool_t kernargs_pool,
hsa_amd_memory_pool_t coarsegrained_pool,
hsa_queue_t *queue, const LaunchParameters &params,
hsa_queue_t *queue, rpc_device_t device,
const LaunchParameters &params,
const char *kernel_name, args_t kernel_args) {
// Look up the '_start' kernel in the loaded executable.
hsa_executable_symbol_t symbol;
Expand All @@ -162,10 +163,9 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
return err;

// Register RPC callbacks for the malloc and free functions on HSA.
uint32_t device_id = 0;
auto tuple = std::make_tuple(dev_agent, coarsegrained_pool);
rpc_register_callback(
device_id, RPC_MALLOC,
device, RPC_MALLOC,
[](rpc_port_t port, void *data) {
auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
auto &[dev_agent, pool] = *static_cast<decltype(tuple) *>(data);
Expand All @@ -182,7 +182,7 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
},
&tuple);
rpc_register_callback(
device_id, RPC_FREE,
device, RPC_FREE,
[](rpc_port_t port, void *data) {
auto free_handler = [](rpc_buffer_t *buffer, void *) {
if (hsa_status_t err = hsa_amd_memory_pool_free(
Expand Down Expand Up @@ -284,12 +284,12 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
while (hsa_signal_wait_scacquire(
packet->completion_signal, HSA_SIGNAL_CONDITION_EQ, 0,
/*timeout_hint=*/1024, HSA_WAIT_STATE_ACTIVE) != 0)
if (rpc_status_t err = rpc_handle_server(device_id))
if (rpc_status_t err = rpc_handle_server(device))
handle_error(err);

// Handle the server one more time in case the kernel exited with a pending
// send still in flight.
if (rpc_status_t err = rpc_handle_server(device_id))
if (rpc_status_t err = rpc_handle_server(device))
handle_error(err);

// Destroy the resources acquired to launch the kernel and return.
Expand Down Expand Up @@ -342,8 +342,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
handle_error(err);

// Obtain a single agent for the device and host to use the HSA memory model.
uint32_t num_devices = 1;
uint32_t device_id = 0;
hsa_agent_t dev_agent;
hsa_agent_t host_agent;
if (hsa_status_t err = get_agent<HSA_DEVICE_TYPE_GPU>(&dev_agent))
Expand Down Expand Up @@ -433,8 +431,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
handle_error(err);

// Set up the RPC server.
if (rpc_status_t err = rpc_init(num_devices))
handle_error(err);
auto tuple = std::make_tuple(dev_agent, finegrained_pool);
auto rpc_alloc = [](uint64_t size, void *data) {
auto &[dev_agent, finegrained_pool] = *static_cast<decltype(tuple) *>(data);
Expand All @@ -445,15 +441,16 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr);
return dev_ptr;
};
if (rpc_status_t err = rpc_server_init(device_id, RPC_MAXIMUM_PORT_COUNT,
rpc_device_t device;
if (rpc_status_t err = rpc_server_init(&device, RPC_MAXIMUM_PORT_COUNT,
wavefront_size, rpc_alloc, &tuple))
handle_error(err);

// Register callbacks for the RPC unit tests.
if (wavefront_size == 32)
register_rpc_callbacks<32>(device_id);
register_rpc_callbacks<32>(device);
else if (wavefront_size == 64)
register_rpc_callbacks<64>(device_id);
register_rpc_callbacks<64>(device);
else
handle_error("Invalid wavefront size");

Expand Down Expand Up @@ -483,10 +480,10 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
handle_error(err);

void *rpc_client_buffer;
if (hsa_status_t err = hsa_amd_memory_lock(
const_cast<void *>(rpc_get_client_buffer(device_id)),
rpc_get_client_size(),
/*agents=*/nullptr, 0, &rpc_client_buffer))
if (hsa_status_t err =
hsa_amd_memory_lock(const_cast<void *>(rpc_get_client_buffer(device)),
rpc_get_client_size(),
/*agents=*/nullptr, 0, &rpc_client_buffer))
handle_error(err);

// Copy the RPC client buffer to the address pointed to by the symbol.
Expand All @@ -496,7 +493,7 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
handle_error(err);

if (hsa_status_t err = hsa_amd_memory_unlock(
const_cast<void *>(rpc_get_client_buffer(device_id))))
const_cast<void *>(rpc_get_client_buffer(device))))
handle_error(err);
if (hsa_status_t err = hsa_amd_memory_pool_free(rpc_client_host))
handle_error(err);
Expand Down Expand Up @@ -549,13 +546,13 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
begin_args_t init_args = {argc, dev_argv, dev_envp};
if (hsa_status_t err = launch_kernel(
dev_agent, executable, kernargs_pool, coarsegrained_pool, queue,
single_threaded_params, "_begin.kd", init_args))
device, single_threaded_params, "_begin.kd", init_args))
handle_error(err);

start_args_t args = {argc, dev_argv, dev_envp, dev_ret};
if (hsa_status_t err =
launch_kernel(dev_agent, executable, kernargs_pool,
coarsegrained_pool, queue, params, "_start.kd", args))
if (hsa_status_t err = launch_kernel(dev_agent, executable, kernargs_pool,
coarsegrained_pool, queue, device,
params, "_start.kd", args))
handle_error(err);

void *host_ret;
Expand All @@ -575,11 +572,11 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
end_args_t fini_args = {ret};
if (hsa_status_t err = launch_kernel(
dev_agent, executable, kernargs_pool, coarsegrained_pool, queue,
single_threaded_params, "_end.kd", fini_args))
device, single_threaded_params, "_end.kd", fini_args))
handle_error(err);

if (rpc_status_t err = rpc_server_shutdown(
device_id, [](void *ptr, void *) { hsa_amd_memory_pool_free(ptr); },
device, [](void *ptr, void *) { hsa_amd_memory_pool_free(ptr); },
nullptr))
handle_error(err);

Expand All @@ -600,8 +597,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
if (hsa_status_t err = hsa_code_object_destroy(object))
handle_error(err);

if (rpc_status_t err = rpc_shutdown())
handle_error(err);
if (hsa_status_t err = hsa_shut_down())
handle_error(err);

Expand Down
39 changes: 17 additions & 22 deletions libc/utils/gpu/loader/nvptx/Loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ Expected<void *> get_ctor_dtor_array(const void *image, const size_t size,

template <typename args_t>
CUresult launch_kernel(CUmodule binary, CUstream stream,
const LaunchParameters &params, const char *kernel_name,
args_t kernel_args) {
rpc_device_t rpc_device, const LaunchParameters &params,
const char *kernel_name, args_t kernel_args) {
// look up the '_start' kernel in the loaded module.
CUfunction function;
if (CUresult err = cuModuleGetFunction(&function, binary, kernel_name))
Expand All @@ -175,11 +175,10 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
handle_error(err);

// Register RPC callbacks for the malloc and free functions on HSA.
uint32_t device_id = 0;
register_rpc_callbacks<32>(device_id);
register_rpc_callbacks<32>(rpc_device);

rpc_register_callback(
device_id, RPC_MALLOC,
rpc_device, RPC_MALLOC,
[](rpc_port_t port, void *data) {
auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
CUstream memory_stream = *static_cast<CUstream *>(data);
Expand All @@ -197,7 +196,7 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
},
&memory_stream);
rpc_register_callback(
device_id, RPC_FREE,
rpc_device, RPC_FREE,
[](rpc_port_t port, void *data) {
auto free_handler = [](rpc_buffer_t *buffer, void *data) {
CUstream memory_stream = *static_cast<CUstream *>(data);
Expand All @@ -219,12 +218,12 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
// Wait until the kernel has completed execution on the device. Periodically
// check the RPC client for work to be performed on the server.
while (cuStreamQuery(stream) == CUDA_ERROR_NOT_READY)
if (rpc_status_t err = rpc_handle_server(device_id))
if (rpc_status_t err = rpc_handle_server(rpc_device))
handle_error(err);

// Handle the server one more time in case the kernel exited with a pending
// send still in flight.
if (rpc_status_t err = rpc_handle_server(device_id))
if (rpc_status_t err = rpc_handle_server(rpc_device))
handle_error(err);

return CUDA_SUCCESS;
Expand All @@ -235,7 +234,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
if (CUresult err = cuInit(0))
handle_error(err);
// Obtain the first device found on the system.
uint32_t num_devices = 1;
uint32_t device_id = 0;
CUdevice device;
if (CUresult err = cuDeviceGet(&device, device_id))
Expand Down Expand Up @@ -294,17 +292,15 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
if (CUresult err = cuMemsetD32(dev_ret, 0, 1))
handle_error(err);

if (rpc_status_t err = rpc_init(num_devices))
handle_error(err);

uint32_t warp_size = 32;
auto rpc_alloc = [](uint64_t size, void *) -> void * {
void *dev_ptr;
if (CUresult err = cuMemAllocHost(&dev_ptr, size))
handle_error(err);
return dev_ptr;
};
if (rpc_status_t err = rpc_server_init(device_id, RPC_MAXIMUM_PORT_COUNT,
rpc_device_t rpc_device;
if (rpc_status_t err = rpc_server_init(&rpc_device, RPC_MAXIMUM_PORT_COUNT,
warp_size, rpc_alloc, nullptr))
handle_error(err);

Expand All @@ -321,19 +317,20 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
cuMemcpyDtoH(&rpc_client_host, rpc_client_dev, sizeof(void *)))
handle_error(err);
if (CUresult err =
cuMemcpyHtoD(rpc_client_host, rpc_get_client_buffer(device_id),
cuMemcpyHtoD(rpc_client_host, rpc_get_client_buffer(rpc_device),
rpc_get_client_size()))
handle_error(err);

LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};
begin_args_t init_args = {argc, dev_argv, dev_envp};
if (CUresult err = launch_kernel(binary, stream, single_threaded_params,
"_begin", init_args))
if (CUresult err = launch_kernel(binary, stream, rpc_device,
single_threaded_params, "_begin", init_args))
handle_error(err);

start_args_t args = {argc, dev_argv, dev_envp,
reinterpret_cast<void *>(dev_ret)};
if (CUresult err = launch_kernel(binary, stream, params, "_start", args))
if (CUresult err =
launch_kernel(binary, stream, rpc_device, params, "_start", args))
handle_error(err);

// Copy the return value back from the kernel and wait.
Expand All @@ -345,8 +342,8 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
handle_error(err);

end_args_t fini_args = {host_ret};
if (CUresult err = launch_kernel(binary, stream, single_threaded_params,
"_end", fini_args))
if (CUresult err = launch_kernel(binary, stream, rpc_device,
single_threaded_params, "_end", fini_args))
handle_error(err);

// Free the memory allocated for the device.
Expand All @@ -357,15 +354,13 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
if (CUresult err = cuMemFreeHost(dev_argv))
handle_error(err);
if (rpc_status_t err = rpc_server_shutdown(
device_id, [](void *ptr, void *) { cuMemFreeHost(ptr); }, nullptr))
rpc_device, [](void *ptr, void *) { cuMemFreeHost(ptr); }, nullptr))
handle_error(err);

// Destroy the context and the loaded binary.
if (CUresult err = cuModuleUnload(binary))
handle_error(err);
if (CUresult err = cuDevicePrimaryCtxRelease(device))
handle_error(err);
if (rpc_status_t err = rpc_shutdown())
handle_error(err);
return host_ret;
}
29 changes: 13 additions & 16 deletions libc/utils/gpu/server/llvmlibc_rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ typedef enum {
RPC_STATUS_SUCCESS = 0x0,
RPC_STATUS_CONTINUE = 0x1,
RPC_STATUS_ERROR = 0x1000,
RPC_STATUS_OUT_OF_RANGE = 0x1001,
RPC_STATUS_UNHANDLED_OPCODE = 0x1002,
RPC_STATUS_INVALID_LANE_SIZE = 0x1003,
RPC_STATUS_NOT_INITIALIZED = 0x1004,
RPC_STATUS_UNHANDLED_OPCODE = 0x1001,
RPC_STATUS_INVALID_LANE_SIZE = 0x1002,
} rpc_status_t;

/// A struct containing an opaque handle to an RPC port. This is what allows the
Expand All @@ -45,6 +43,11 @@ typedef struct rpc_buffer_s {
uint64_t data[8];
} rpc_buffer_t;

/// An opaque handle to an RPC server that can be attached to a device.
typedef struct rpc_device_s {
uintptr_t handle;
} rpc_device_t;

/// A function used to allocate \p bytes for use by the RPC server and client.
/// The memory should support asynchronous and atomic access from both the
/// client and server.
Expand All @@ -60,34 +63,28 @@ typedef void (*rpc_opcode_callback_ty)(rpc_port_t port, void *data);
/// A callback function to use the port to receive or send a \p buffer.
typedef void (*rpc_port_callback_ty)(rpc_buffer_t *buffer, void *data);

/// Initialize the rpc library for general use on \p num_devices.
rpc_status_t rpc_init(uint32_t num_devices);

/// Shut down the rpc interface.
rpc_status_t rpc_shutdown(void);

/// Initialize the server for a given device.
rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
/// Initialize the server for a given device and return it in \p device.
rpc_status_t rpc_server_init(rpc_device_t *rpc_device, uint64_t num_ports,
uint32_t lane_size, rpc_alloc_ty alloc,
void *data);

/// Shut down the server for a given device.
rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
rpc_status_t rpc_server_shutdown(rpc_device_t rpc_device, rpc_free_ty dealloc,
void *data);

/// Queries the RPC clients at least once and performs server-side work if there
/// are any active requests. Runs until all work on the server is completed.
rpc_status_t rpc_handle_server(uint32_t device_id);
rpc_status_t rpc_handle_server(rpc_device_t rpc_device);

/// Register a callback to handle an opcode from the RPC client. The associated
/// data must remain accessible as long as the user intends to handle the server
/// with this callback.
rpc_status_t rpc_register_callback(uint32_t device_id, uint16_t opcode,
rpc_status_t rpc_register_callback(rpc_device_t rpc_device, uint16_t opcode,
rpc_opcode_callback_ty callback, void *data);

/// Obtain a pointer to a local client buffer that can be copied directly to the
/// other process using the address stored at the rpc client symbol name.
const void *rpc_get_client_buffer(uint32_t device_id);
const void *rpc_get_client_buffer(rpc_device_t device);

/// Returns the size of the client in bytes to be used for a memory copy.
uint64_t rpc_get_client_size();
Expand Down

0 comments on commit a1a8bb1

Please sign in to comment.