94 changes: 56 additions & 38 deletions libc/utils/gpu/loader/nvptx/Loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
//===----------------------------------------------------------------------===//

#include "Loader.h"
#include "Server.h"

#include "cuda.h"

Expand Down Expand Up @@ -43,11 +42,6 @@ static void handle_error(CUresult err) {
exit(1);
}

static void handle_error(const char *msg) {
fprintf(stderr, "%s\n", msg);
exit(EXIT_FAILURE);
}

// Gets the names of all the globals that contain functions to initialize or
// deinitialize. We need to do this manually because the NVPTX toolchain does
// not contain the necessary binary manipulation tools.
Expand Down Expand Up @@ -181,21 +175,37 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
if (CUresult err = cuStreamCreate(&memory_stream, CU_STREAM_NON_BLOCKING))
handle_error(err);

auto allocator = [&](uint64_t size) -> void * {
CUdeviceptr dev_ptr;
if (CUresult err = cuMemAllocAsync(&dev_ptr, size, memory_stream))
handle_error(err);

// Wait until the memory allocation is complete.
while (cuStreamQuery(memory_stream) == CUDA_ERROR_NOT_READY)
;
return reinterpret_cast<void *>(dev_ptr);
};
auto deallocator = [&](void *ptr) -> void {
if (CUresult err =
cuMemFreeAsync(reinterpret_cast<CUdeviceptr>(ptr), memory_stream))
handle_error(err);
};
// Register RPC callbacks for the malloc and free functions on HSA.
uint32_t device_id = 0;
rpc_register_callback(
device_id, RPC_MALLOC,
[](rpc_port_t port, void *data) {
auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
CUstream memory_stream = *static_cast<CUstream *>(data);
uint64_t size = buffer->data[0];
CUdeviceptr dev_ptr;
if (CUresult err = cuMemAllocAsync(&dev_ptr, size, memory_stream))
handle_error(err);

// Wait until the memory allocation is complete.
while (cuStreamQuery(memory_stream) == CUDA_ERROR_NOT_READY)
;
};
rpc_recv_and_send(port, malloc_handler, data);
},
&memory_stream);
rpc_register_callback(
device_id, RPC_FREE,
[](rpc_port_t port, void *data) {
auto free_handler = [](rpc_buffer_t *buffer, void *data) {
CUstream memory_stream = *static_cast<CUstream *>(data);
if (CUresult err = cuMemFreeAsync(
static_cast<CUdeviceptr>(buffer->data[0]), memory_stream))
handle_error(err);
};
rpc_recv_and_send(port, free_handler, data);
},
&memory_stream);

// Call the kernel with the given arguments.
if (CUresult err = cuLaunchKernel(
Expand All @@ -207,23 +217,26 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
// Wait until the kernel has completed execution on the device. Periodically
// check the RPC client for work to be performed on the server.
while (cuStreamQuery(stream) == CUDA_ERROR_NOT_READY)
handle_server(allocator, deallocator);
if (rpc_status_t err = rpc_handle_server(device_id))
handle_error(err);

// Handle the server one more time in case the kernel exited with a pending
// send still in flight.
handle_server(allocator, deallocator);
if (rpc_status_t err = rpc_handle_server(device_id))
handle_error(err);

return CUDA_SUCCESS;
}

int load(int argc, char **argv, char **envp, void *image, size_t size,
const LaunchParameters &params) {

if (CUresult err = cuInit(0))
handle_error(err);
// Obtain the first device found on the system.
uint32_t num_devices = 1;
uint32_t device_id = 0;
CUdevice device;
if (CUresult err = cuDeviceGet(&device, 0))
if (CUresult err = cuDeviceGet(&device, device_id))
handle_error(err);

// Initialize the CUDA context and claim it for this execution.
Expand Down Expand Up @@ -279,22 +292,24 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
if (CUresult err = cuMemsetD32(dev_ret, 0, 1))
handle_error(err);

uint64_t port_size = __llvm_libc::rpc::DEFAULT_PORT_COUNT;
uint32_t warp_size = 32;

uint64_t rpc_shared_buffer_size =
__llvm_libc::rpc::Server::allocation_size(port_size, warp_size);
void *rpc_shared_buffer = allocator(rpc_shared_buffer_size);

if (!rpc_shared_buffer)
handle_error("Failed to allocate memory the RPC client / server.");
if (rpc_status_t err = rpc_init(num_devices))
handle_error(err);

// Initialize the RPC server's buffer for host-device communication.
server.reset(port_size, warp_size, rpc_shared_buffer);
uint32_t warp_size = 32;
auto rpc_alloc = [](uint64_t size, void *) -> void * {
void *dev_ptr;
if (CUresult err = cuMemAllocHost(&dev_ptr, size))
handle_error(err);
return dev_ptr;
};
if (rpc_status_t err = rpc_server_init(device_id, RPC_MAXIMUM_PORT_COUNT,
warp_size, rpc_alloc, nullptr))
handle_error(err);

LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};
// Call the kernel to
begin_args_t init_args = {argc, dev_argv, dev_envp, rpc_shared_buffer};
begin_args_t init_args = {argc, dev_argv, dev_envp,
rpc_get_buffer(device_id)};
if (CUresult err = launch_kernel(binary, stream, single_threaded_params,
"_begin", init_args))
handle_error(err);
Expand Down Expand Up @@ -324,13 +339,16 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
handle_error(err);
if (CUresult err = cuMemFreeHost(dev_argv))
handle_error(err);
if (CUresult err = cuMemFreeHost(rpc_shared_buffer))
if (rpc_status_t err = rpc_server_shutdown(
device_id, [](void *ptr, void *) { cuMemFreeHost(ptr); }, nullptr))
handle_error(err);

// Destroy the context and the loaded binary.
if (CUresult err = cuModuleUnload(binary))
handle_error(err);
if (CUresult err = cuDevicePrimaryCtxRelease(device))
handle_error(err);
if (rpc_status_t err = rpc_shutdown())
handle_error(err);
return host_ret;
}
6 changes: 6 additions & 0 deletions libc/utils/gpu/server/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
add_library(rpc_server STATIC Server.cpp)

# Include the RPC implemenation from libc.
add_dependencies(rpc_server libc.src.__support.RPC.rpc)
target_include_directories(rpc_server PRIVATE ${LIBC_SOURCE_DIR})
target_include_directories(rpc_server PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
219 changes: 219 additions & 0 deletions libc/utils/gpu/server/Server.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
//===-- Shared memory RPC server instantiation ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Server.h"

#include "src/__support/RPC/rpc.h"
#include <atomic>
#include <cstdio>
#include <memory>
#include <mutex>
#include <unordered_map>

using namespace __llvm_libc;

static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
"Buffer size mismatch");

static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::DEFAULT_PORT_COUNT,
"Incorrect maximum port count");
struct Device {
rpc::Server server;
std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> callbacks;
std::unordered_map<rpc_opcode_t, void *> callback_data;
};

// A struct containing all the runtime state required to run the RPC server.
struct State {
State(uint32_t num_devices)
: num_devices(num_devices),
devices(std::unique_ptr<Device[]>(new Device[num_devices])),
reference_count(0u) {}
uint32_t num_devices;
std::unique_ptr<Device[]> devices;
std::atomic_uint32_t reference_count;
};

static std::mutex startup_mutex;

static State *state;

rpc_status_t rpc_init(uint32_t num_devices) {
std::scoped_lock<decltype(startup_mutex)> lock(startup_mutex);
if (!state)
state = new State(num_devices);

if (state->reference_count == std::numeric_limits<uint32_t>::max())
return RPC_STATUS_ERROR;

state->reference_count++;

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_shutdown(void) {
if (state->reference_count-- == 1)
delete state;

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
uint32_t lane_size, rpc_alloc_ty alloc,
void *data) {
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;

uint64_t buffer_size =
__llvm_libc::rpc::Server::allocation_size(num_ports, lane_size);
void *buffer = alloc(buffer_size, data);

if (!buffer)
return RPC_STATUS_ERROR;

state->devices[device_id].server.reset(num_ports, lane_size, buffer);

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
void *data) {
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;

dealloc(rpc_get_buffer(device_id), data);

return RPC_STATUS_SUCCESS;
}

rpc_status_t rpc_handle_server(uint32_t device_id) {
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;

for (;;) {
auto port = state->devices[device_id].server.try_open();
if (!port)
return RPC_STATUS_SUCCESS;

switch (port->get_opcode()) {
case rpc::Opcode::WRITE_TO_STREAM:
case rpc::Opcode::WRITE_TO_STDERR:
case rpc::Opcode::WRITE_TO_STDOUT: {
uint64_t sizes[rpc::MAX_LANE_SIZE] = {0};
void *strs[rpc::MAX_LANE_SIZE] = {nullptr};
FILE *files[rpc::MAX_LANE_SIZE] = {nullptr};
if (port->get_opcode() == rpc::Opcode::WRITE_TO_STREAM)
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
});
port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; });
port->send([&](rpc::Buffer *buffer, uint32_t id) {
FILE *file = port->get_opcode() == rpc::Opcode::WRITE_TO_STDOUT
? stdout
: (port->get_opcode() == rpc::Opcode::WRITE_TO_STDERR
? stderr
: files[id]);
int ret = fwrite(strs[id], sizes[id], 1, file);
reinterpret_cast<int *>(buffer->data)[0] = ret >= 0 ? sizes[id] : ret;
});
for (uint64_t i = 0; i < rpc::MAX_LANE_SIZE; ++i) {
if (strs[i])
delete[] reinterpret_cast<uint8_t *>(strs[i]);
}
break;
}
case rpc::Opcode::EXIT: {
port->recv([](rpc::Buffer *buffer) {
exit(reinterpret_cast<uint32_t *>(buffer->data)[0]);
});
break;
}
// TODO: Move handling of these test cases to the loader implementation.
case rpc::Opcode::TEST_INCREMENT: {
port->recv_and_send([](rpc::Buffer *buffer) {
reinterpret_cast<uint64_t *>(buffer->data)[0] += 1;
});
break;
}
case rpc::Opcode::TEST_INTERFACE: {
uint64_t cnt = 0;
bool end_with_recv;
port->recv([&](rpc::Buffer *buffer) { end_with_recv = buffer->data[0]; });
port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
if (end_with_recv)
port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
else
port->send(
[&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
break;
}
case rpc::Opcode::TEST_STREAM: {
uint64_t sizes[rpc::MAX_LANE_SIZE] = {0};
void *dst[rpc::MAX_LANE_SIZE] = {nullptr};
port->recv_n(dst, sizes, [](uint64_t size) { return new char[size]; });
port->send_n(dst, sizes);
for (uint64_t i = 0; i < rpc::MAX_LANE_SIZE; ++i) {
if (dst[i])
delete[] reinterpret_cast<uint8_t *>(dst[i]);
}
break;
}
case rpc::Opcode::NOOP: {
port->recv([](rpc::Buffer *buffer) {});
break;
}
default: {
auto handler = state->devices[device_id].callbacks.find(
static_cast<rpc_opcode_t>(port->get_opcode()));

// We error out on an unhandled opcode.
if (handler == state->devices[device_id].callbacks.end())
return RPC_STATUS_UNHANDLED_OPCODE;

// Invoke the registered callback with a reference to the port.
void *data = state->devices[device_id].callback_data.at(
static_cast<rpc_opcode_t>(port->get_opcode()));
rpc_port_t port_ref{reinterpret_cast<uint64_t>(&*port)};
(handler->second)(port_ref, data);
}
}
port->close();
}
}

rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode,
rpc_opcode_callback_ty callback,
void *data) {
if (device_id >= state->num_devices)
return RPC_STATUS_OUT_OF_RANGE;

state->devices[device_id].callbacks[opcode] = callback;
state->devices[device_id].callback_data[opcode] = data;
return RPC_STATUS_SUCCESS;
}

void *rpc_get_buffer(uint32_t device_id) {
if (device_id >= state->num_devices)
return nullptr;
return state->devices[device_id].server.get_buffer_start();
}

void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback,
void *data) {
rpc::Server::Port *port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
port->recv_and_send([=](rpc::Buffer *buffer) {
callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
});
}
102 changes: 102 additions & 0 deletions libc/utils/gpu/server/Server.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
//===-- Shared memory RPC server instantiation ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_UTILS_GPU_SERVER_RPC_SERVER_H
#define LLVM_LIBC_UTILS_GPU_SERVER_RPC_SERVER_H

#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

/// The maxium number of ports that can be opened for any server.
const uint64_t RPC_MAXIMUM_PORT_COUNT = 64;

// TODO: Move these to a header exported by the C library.
typedef enum : uint16_t {
RPC_NOOP = 0,
RPC_EXIT = 1,
RPC_WRITE_TO_STDOUT = 2,
RPC_WRITE_TO_STDERR = 3,
RPC_WRITE_TO_STREAM = 4,
RPC_MALLOC = 5,
RPC_FREE = 6,
} rpc_opcode_t;

/// status codes.
typedef enum {
RPC_STATUS_SUCCESS = 0x0,
RPC_STATUS_ERROR = 0x1000,
RPC_STATUS_OUT_OF_RANGE = 0x1001,
RPC_STATUS_UNHANDLED_OPCODE = 0x1002,
} rpc_status_t;

/// A struct containing an opaque handle to an RPC port. This is what allows the
/// server to communicate with the client.
typedef struct rpc_port_s {
uint64_t handle;
} rpc_port_t;

/// A fixed-size buffer containing the payload sent from the client.
typedef struct rpc_buffer_s {
uint64_t data[8];
} rpc_buffer_t;

/// A function used to allocate \p bytes for use by the RPC server and client.
/// The memory should support asynchronous and atomic access from both the
/// client and server.
typedef void *(*rpc_alloc_ty)(uint64_t size, void *data);

/// A function used to free the \p ptr previously allocated.
typedef void (*rpc_free_ty)(void *ptr, void *data);

/// A callback function provided with a \p port to communicate with the RPC
/// client. This will be called by the server to handle an opcode.
typedef void (*rpc_opcode_callback_ty)(rpc_port_t port, void *data);

/// A callback function to use the port to receive or send a \p buffer.
typedef void (*rpc_port_callback_ty)(rpc_buffer_t *buffer, void *data);

/// Initialize the rpc library for general use on \p num_devices.
rpc_status_t rpc_init(uint32_t num_devices);

/// Shut down the rpc interface.
rpc_status_t rpc_shutdown(void);

/// Initialize the server for a given device.
rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
uint32_t lane_size, rpc_alloc_ty alloc,
void *data);

/// Shut down the server for a given device.
rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
void *data);

/// Queries the RPC clients at least once and performs server-side work if there
/// are any active requests. Runs until all work on the server is completed.
rpc_status_t rpc_handle_server(uint32_t device_id);

/// Register a callback to handle an opcode from the RPC client. The associated
/// data must remain accessible as long as the user intends to handle the server
/// with this callback.
rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode,
rpc_opcode_callback_ty callback, void *data);

/// Obtain a pointer to the memory buffer used to run the RPC client and server.
void *rpc_get_buffer(uint32_t device_id);

/// Use the \p port to receive and send a buffer using the \p callback.
void rpc_recv_and_send(rpc_port_t port, rpc_port_callback_ty callback,
void *data);

#ifdef __cplusplus
}
#endif

#endif