78 changes: 23 additions & 55 deletions libc/startup/gpu/nvptx/start.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ namespace __llvm_libc {

static cpp::Atomic<uint32_t> lock = 0;

static cpp::Atomic<uint32_t> count = 0;

extern "C" {
// Nvidia's 'nvlink' linker does not provide these symbols. We instead need
// to manually create them and update the globals in the loader implememtation.
Expand All @@ -31,10 +29,6 @@ uintptr_t *__fini_array_end [[gnu::visibility("protected")]];
using InitCallback = void(int, char **, char **);
using FiniCallback = void(void);

static uint64_t get_grid_size() {
return gpu::get_num_threads() * gpu::get_num_blocks();
}

static void call_init_array_callbacks(int argc, char **argv, char **env) {
size_t init_array_size = __init_array_end - __init_array_start;
for (size_t i = 0; i < init_array_size; ++i)
Expand All @@ -47,59 +41,33 @@ static void call_fini_array_callbacks() {
reinterpret_cast<FiniCallback *>(__fini_array_start[i])();
}

// TODO: Put this in a separate kernel and call it with one thread.
void initialize(int argc, char **argv, char **env, void *in, void *out,
void *buffer) {
// We need a single GPU thread to perform the initialization of the global
// constructors and data. We simply mask off all but a single thread and
// execute.
count.fetch_add(1, cpp::MemoryOrder::RELAXED);
if (gpu::get_thread_id() == 0 && gpu::get_block_id() == 0) {
// We need to set up the RPC client first in case any of the constructors
// require it.
rpc::client.reset(&lock, in, out, buffer);

// We want the fini array callbacks to be run after other atexit
// callbacks are run. So, we register them before running the init
// array callbacks as they can potentially register their own atexit
// callbacks.
// FIXME: The function pointer escaping this TU causes warnings.
__llvm_libc::atexit(&call_fini_array_callbacks);
call_init_array_callbacks(argc, argv, env);
}

// We wait until every single thread launched on the GPU has seen the
// initialization code. This will get very, very slow for high thread counts,
// but for testing purposes it is unlikely to matter.
while (count.load(cpp::MemoryOrder::RELAXED) != get_grid_size())
rpc::sleep_briefly();
gpu::sync_threads();
}

// TODO: Put this in a separate kernel and call it with one thread.
void finalize(int retval) {
// We wait until every single thread launched on the GPU has finished
// executing and reached the finalize region.
count.fetch_sub(1, cpp::MemoryOrder::RELAXED);
while (count.load(cpp::MemoryOrder::RELAXED) != 0)
rpc::sleep_briefly();
gpu::sync_threads();
if (gpu::get_thread_id() == 0 && gpu::get_block_id() == 0) {
// Only a single thread should call `exit` here, the rest should gracefully
// return from the kernel. This is so only one thread calls the destructors
// registred with 'atexit' above.
__llvm_libc::exit(retval);
}
}

} // namespace __llvm_libc

extern "C" [[gnu::visibility("protected"), clang::nvptx_kernel]] void
_start(int argc, char **argv, char **envp, int *ret, void *in, void *out,
void *buffer) {
__llvm_libc::initialize(argc, argv, envp, in, out, buffer);
_begin(int argc, char **argv, char **env, void *in, void *out, void *buffer) {
// We need to set up the RPC client first in case any of the constructors
// require it.
__llvm_libc::rpc::client.reset(__llvm_libc::gpu::get_lane_size(),
&__llvm_libc::lock, in, out, buffer);

// We want the fini array callbacks to be run after other atexit
// callbacks are run. So, we register them before running the init
// array callbacks as they can potentially register their own atexit
// callbacks.
__llvm_libc::atexit(&__llvm_libc::call_fini_array_callbacks);
__llvm_libc::call_init_array_callbacks(argc, argv, env);
}

extern "C" [[gnu::visibility("protected"), clang::nvptx_kernel]] void
_start(int argc, char **argv, char **envp, int *ret) {
// Invoke the 'main' function with every active thread that the user launched
// the _start kernel with.
__atomic_fetch_or(ret, main(argc, argv, envp), __ATOMIC_RELAXED);
}

__llvm_libc::finalize(*ret);
extern "C" [[gnu::visibility("protected"), clang::nvptx_kernel]] void
_end(int retval) {
// To finis the execution we invoke all the callbacks registered via 'atexit'
// and then exit with the appropriate return value.
__llvm_libc::exit(retval);
}
8 changes: 6 additions & 2 deletions libc/test/integration/startup/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ add_integration_test(
libc.src.__support.RPC.rpc_client
libc.src.__support.GPU.utils
LOADER_ARGS
--blocks 16
--threads 1
--blocks-x 2
--blocks-y 2
--blocks-z 2
--threads-x 4
--threads-y 4
--threads-z 4
)

add_integration_test(
Expand Down
15 changes: 14 additions & 1 deletion libc/test/integration/startup/gpu/rpc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
using namespace __llvm_libc;

static void test_add_simple() {
uint32_t num_additions = 1000 + 10 * gpu::get_block_id_x();
uint32_t num_additions =
10 + 10 * gpu::get_thread_id() + 10 * gpu::get_block_id();
uint64_t cnt = 0;
for (uint32_t i = 0; i < num_additions; ++i) {
rpc::Client::Port port = rpc::client.open(rpc::TEST_INCREMENT);
Expand All @@ -29,8 +30,20 @@ static void test_add_simple() {
ASSERT_TRUE(cnt == num_additions && "Incorrect sum");
}

// Test to ensure that the RPC mechanism doesn't hang on divergence.
static void test_noop(uint8_t data) {
rpc::Client::Port port = rpc::client.open(rpc::NOOP);
port.send([=](rpc::Buffer *buffer) { buffer->data[0] = data; });
port.close();
}

TEST_MAIN(int argc, char **argv, char **envp) {
test_add_simple();

if (gpu::get_thread_id() % 2)
test_noop(1);
else
test_noop(2);

return 0;
}
28 changes: 28 additions & 0 deletions libc/utils/gpu/loader/Loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,40 @@ struct LaunchParameters {
uint32_t num_blocks_z;
};

/// The arguments to the '_begin' kernel.
struct begin_args_t {
int argc;
void *argv;
void *envp;
void *inbox;
void *outbox;
void *buffer;
};

/// The arguments to the '_start' kernel.
struct start_args_t {
int argc;
void *argv;
void *envp;
void *ret;
};

/// The arguments to the '_end' kernel.
struct end_args_t {
int argc;
};

/// Generic interface to load the \p image and launch execution of the _start
/// kernel on the target device. Copies \p argc and \p argv to the device.
/// Returns the final value of the `main` function on the device.
int load(int argc, char **argv, char **evnp, void *image, size_t size,
const LaunchParameters &params);

/// Return \p V aligned "upwards" according to \p Align.
template <typename V, typename A> inline V align_up(V val, A align) {
return ((val + V(align) - 1) / V(align)) * V(align);
}

/// Copy the system's argument vector to GPU memory allocated using \p alloc.
template <typename Allocator>
void *copy_argument_vector(int argc, char **argv, Allocator alloc) {
Expand Down
23 changes: 13 additions & 10 deletions libc/utils/gpu/loader/Server.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@ void handle_server() {

switch (port->get_opcode()) {
case __llvm_libc::rpc::Opcode::PRINT_TO_STDERR: {
uint64_t str_size;
char *str = nullptr;
port->recv_n([&](uint64_t size) {
str_size = size;
str = new char[size];
return str;
uint64_t str_size[__llvm_libc::rpc::MAX_LANE_SIZE] = {0};
char *strs[__llvm_libc::rpc::MAX_LANE_SIZE] = {nullptr};
port->recv_n([&](uint64_t size, uint32_t id) {
str_size[id] = size;
strs[id] = new char[size];
return strs[id];
});
fwrite(str, str_size, 1, stderr);
delete[] str;
for (uint64_t i = 0; i < __llvm_libc::rpc::MAX_LANE_SIZE; ++i) {
if (strs[i]) {
fwrite(strs[i], str_size[i], 1, stderr);
delete[] strs[i];
}
}
break;
}
case __llvm_libc::rpc::Opcode::EXIT: {
Expand All @@ -54,8 +58,7 @@ void handle_server() {
break;
}
default:
port->recv([](__llvm_libc::rpc::Buffer *) { /* no-op */ });
return;
port->recv([](__llvm_libc::rpc::Buffer *buffer) {});
}
port->close();
}
Expand Down
256 changes: 135 additions & 121 deletions libc/utils/gpu/loader/amdgpu/Loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,6 @@
#include <cstring>
#include <utility>

/// The name of the kernel we will launch. All AMDHSA kernels end with '.kd'.
constexpr const char *KERNEL_START = "_start.kd";

/// The arguments to the '_start' kernel.
struct kernel_args_t {
int argc;
void *argv;
void *envp;
void *ret;
void *inbox;
void *outbox;
void *buffer;
};

/// Print the error code and exit if \p code indicates an error.
static void handle_error(hsa_status_t code) {
if (code == HSA_STATUS_SUCCESS || code == HSA_STATUS_INFO_BREAK)
Expand Down Expand Up @@ -145,6 +131,105 @@ hsa_status_t get_agent_memory_pool(hsa_agent_t agent,
return iterate_agent_memory_pools(agent, cb);
}

template <typename args_t>
hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
hsa_amd_memory_pool_t kernargs_pool,
hsa_queue_t *queue, const LaunchParameters &params,
const char *kernel_name, args_t kernel_args) {
// Look up the '_start' kernel in the loaded executable.
hsa_executable_symbol_t symbol;
if (hsa_status_t err = hsa_executable_get_symbol_by_name(
executable, kernel_name, &dev_agent, &symbol))
return err;

// Retrieve different properties of the kernel symbol used for launch.
uint64_t kernel;
uint32_t args_size;
uint32_t group_size;
uint32_t private_size;

std::pair<hsa_executable_symbol_info_t, void *> symbol_infos[] = {
{HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT, &kernel},
{HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, &args_size},
{HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE, &group_size},
{HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE, &private_size}};

for (auto &[info, value] : symbol_infos)
if (hsa_status_t err = hsa_executable_symbol_get_info(symbol, info, value))
return err;

// Allocate space for the kernel arguments on the host and allow the GPU agent
// to access it.
void *args;
if (hsa_status_t err = hsa_amd_memory_pool_allocate(kernargs_pool, args_size,
/*flags=*/0, &args))
handle_error(err);
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, args);

// Initialie all the arguments (explicit and implicit) to zero, then set the
// explicit arguments to the values created above.
std::memset(args, 0, args_size);
std::memcpy(args, &kernel_args, sizeof(args_t));

// Obtain a packet from the queue.
uint64_t packet_id = hsa_queue_add_write_index_relaxed(queue, 1);
while (packet_id - hsa_queue_load_read_index_scacquire(queue) >= queue->size)
;

const uint32_t mask = queue->size - 1;
hsa_kernel_dispatch_packet_t *packet =
static_cast<hsa_kernel_dispatch_packet_t *>(queue->base_address) +
(packet_id & mask);

// Set up the packet for exeuction on the device. We currently only launch
// with one thread on the device, forcing the rest of the wavefront to be
// masked off.
std::memset(packet, 0, sizeof(hsa_kernel_dispatch_packet_t));
packet->setup = (1 + (params.num_blocks_y * params.num_threads_y != 1) +
(params.num_blocks_z * params.num_threads_z != 1))
<< HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
packet->workgroup_size_x = params.num_threads_x;
packet->workgroup_size_y = params.num_threads_y;
packet->workgroup_size_z = params.num_threads_z;
packet->grid_size_x = params.num_blocks_x * params.num_threads_x;
packet->grid_size_y = params.num_blocks_y * params.num_threads_y;
packet->grid_size_z = params.num_blocks_z * params.num_threads_z;
packet->private_segment_size = private_size;
packet->group_segment_size = group_size;
packet->kernel_object = kernel;
packet->kernarg_address = args;

// Create a signal to indicate when this packet has been completed.
if (hsa_status_t err =
hsa_signal_create(1, 0, nullptr, &packet->completion_signal))
handle_error(err);

// Initialize the packet header and set the doorbell signal to begin execution
// by the HSA runtime.
uint16_t setup = packet->setup;
uint16_t header =
(HSA_PACKET_TYPE_KERNEL_DISPATCH << HSA_PACKET_HEADER_TYPE) |
(HSA_FENCE_SCOPE_SYSTEM << HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE) |
(HSA_FENCE_SCOPE_SYSTEM << HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE);
__atomic_store_n(&packet->header, header | (setup << 16), __ATOMIC_RELEASE);
hsa_signal_store_relaxed(queue->doorbell_signal, packet_id);

// Wait until the kernel has completed execution on the device. Periodically
// check the RPC client for work to be performed on the server.
while (hsa_signal_wait_scacquire(
packet->completion_signal, HSA_SIGNAL_CONDITION_EQ, 0,
/*timeout_hint=*/1024, HSA_WAIT_STATE_ACTIVE) != 0)
handle_server();

// Destroy the resources acquired to launch the kernel and return.
if (hsa_status_t err = hsa_amd_memory_pool_free(args))
handle_error(err);
if (hsa_status_t err = hsa_signal_destroy(packet->completion_signal))
handle_error(err);

return HSA_STATUS_SUCCESS;
}

int load(int argc, char **argv, char **envp, void *image, size_t size,
const LaunchParameters &params) {
// Initialize the HSA runtime used to communicate with the device.
Expand All @@ -169,18 +254,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
if (hsa_status_t err = get_agent<HSA_DEVICE_TYPE_CPU>(&host_agent))
handle_error(err);

// Obtain a queue with the minimum (power of two) size, used to send commands
// to the HSA runtime and launch execution on the device.
uint64_t queue_size;
if (hsa_status_t err = hsa_agent_get_info(
dev_agent, HSA_AGENT_INFO_QUEUE_MIN_SIZE, &queue_size))
handle_error(err);
hsa_queue_t *queue = nullptr;
if (hsa_status_t err =
hsa_queue_create(dev_agent, queue_size, HSA_QUEUE_TYPE_SINGLE,
nullptr, nullptr, UINT32_MAX, UINT32_MAX, &queue))
handle_error(err);

// Load the code object's ISA information and executable data segments.
hsa_code_object_t object;
if (hsa_status_t err = hsa_code_object_deserialize(image, size, "", &object))
Expand Down Expand Up @@ -228,36 +301,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
dev_agent, &coarsegrained_pool))
handle_error(err);

// Look up the '_start' kernel in the loaded executable.
hsa_executable_symbol_t symbol;
if (hsa_status_t err = hsa_executable_get_symbol_by_name(
executable, KERNEL_START, &dev_agent, &symbol))
handle_error(err);

// Retrieve different properties of the kernel symbol used for launch.
uint64_t kernel;
uint32_t args_size;
uint32_t group_size;
uint32_t private_size;

std::pair<hsa_executable_symbol_info_t, void *> symbol_infos[] = {
{HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT, &kernel},
{HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, &args_size},
{HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE, &group_size},
{HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE, &private_size}};

for (auto &[info, value] : symbol_infos)
if (hsa_status_t err = hsa_executable_symbol_get_info(symbol, info, value))
handle_error(err);

// Allocate space for the kernel arguments on the host and allow the GPU agent
// to access it.
void *args;
if (hsa_status_t err = hsa_amd_memory_pool_allocate(kernargs_pool, args_size,
/*flags=*/0, &args))
handle_error(err);
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, args);

// Allocate fine-grained memory on the host to hold the pointer array for the
// copied argv and allow the GPU agent to access it.
auto allocator = [&](uint64_t size) -> void * {
Expand Down Expand Up @@ -287,6 +330,10 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
hsa_amd_memory_fill(dev_ret, 0, sizeof(int));

// Allocate finegrained memory for the RPC server and client to share.
uint32_t wavefront_size = 0;
if (hsa_status_t err = hsa_agent_get_info(
dev_agent, HSA_AGENT_INFO_WAVEFRONT_SIZE, &wavefront_size))
handle_error(err);
void *server_inbox;
void *server_outbox;
void *buffer;
Expand All @@ -299,76 +346,43 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
/*flags=*/0, &server_outbox))
handle_error(err);
if (hsa_status_t err = hsa_amd_memory_pool_allocate(
finegrained_pool, sizeof(__llvm_libc::rpc::Buffer),
finegrained_pool,
align_up(sizeof(__llvm_libc::rpc::Header) +
(wavefront_size * sizeof(__llvm_libc::rpc::Buffer)),
alignof(__llvm_libc::rpc::Packet)),
/*flags=*/0, &buffer))
handle_error(err);
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, server_inbox);
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, server_outbox);
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, buffer);

// Initialie all the arguments (explicit and implicit) to zero, then set the
// explicit arguments to the values created above.
std::memset(args, 0, args_size);
kernel_args_t *kernel_args = reinterpret_cast<kernel_args_t *>(args);
kernel_args->argc = argc;
kernel_args->argv = dev_argv;
kernel_args->envp = dev_envp;
kernel_args->ret = dev_ret;
kernel_args->inbox = server_outbox;
kernel_args->outbox = server_inbox;
kernel_args->buffer = buffer;

// Obtain a packet from the queue.
uint64_t packet_id = hsa_queue_add_write_index_relaxed(queue, 1);
while (packet_id - hsa_queue_load_read_index_scacquire(queue) >= queue_size)
;

const uint32_t mask = queue_size - 1;
hsa_kernel_dispatch_packet_t *packet =
(hsa_kernel_dispatch_packet_t *)queue->base_address + (packet_id & mask);

// Set up the packet for exeuction on the device. We currently only launch
// with one thread on the device, forcing the rest of the wavefront to be
// masked off.
std::memset(packet, 0, sizeof(hsa_kernel_dispatch_packet_t));
packet->setup = (1 + (params.num_blocks_y * params.num_threads_y != 1) +
(params.num_blocks_z * params.num_threads_z != 1))
<< HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
packet->workgroup_size_x = params.num_threads_x;
packet->workgroup_size_y = params.num_threads_y;
packet->workgroup_size_z = params.num_threads_z;
packet->grid_size_x = params.num_blocks_x * params.num_threads_x;
packet->grid_size_y = params.num_blocks_y * params.num_threads_y;
packet->grid_size_z = params.num_blocks_z * params.num_threads_z;
packet->private_segment_size = private_size;
packet->group_segment_size = group_size;
packet->kernel_object = kernel;
packet->kernarg_address = args;
// Initialize the RPC server's buffer for host-device communication.
server.reset(wavefront_size, &lock, server_inbox, server_outbox, buffer);

// Create a signal to indicate when this packet has been completed.
// Obtain a queue with the minimum (power of two) size, used to send commands
// to the HSA runtime and launch execution on the device.
uint64_t queue_size;
if (hsa_status_t err = hsa_agent_get_info(
dev_agent, HSA_AGENT_INFO_QUEUE_MIN_SIZE, &queue_size))
handle_error(err);
hsa_queue_t *queue = nullptr;
if (hsa_status_t err =
hsa_signal_create(1, 0, nullptr, &packet->completion_signal))
hsa_queue_create(dev_agent, queue_size, HSA_QUEUE_TYPE_MULTI, nullptr,
nullptr, UINT32_MAX, UINT32_MAX, &queue))
handle_error(err);

// Initialize the RPC server's buffer for host-device communication.
server.reset(&lock, server_inbox, server_outbox, buffer);

// Initialize the packet header and set the doorbell signal to begin execution
// by the HSA runtime.
uint16_t header =
(HSA_PACKET_TYPE_KERNEL_DISPATCH << HSA_PACKET_HEADER_TYPE) |
(HSA_FENCE_SCOPE_SYSTEM << HSA_PACKET_HEADER_ACQUIRE_FENCE_SCOPE) |
(HSA_FENCE_SCOPE_SYSTEM << HSA_PACKET_HEADER_RELEASE_FENCE_SCOPE);
__atomic_store_n(&packet->header, header | (packet->setup << 16),
__ATOMIC_RELEASE);
hsa_signal_store_relaxed(queue->doorbell_signal, packet_id);
LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};
begin_args_t init_args = {argc, dev_argv, dev_envp,
server_outbox, server_inbox, buffer};
if (hsa_status_t err =
launch_kernel(dev_agent, executable, kernargs_pool, queue,
single_threaded_params, "_begin.kd", init_args))
handle_error(err);

// Wait until the kernel has completed execution on the device. Periodically
// check the RPC client for work to be performed on the server.
while (hsa_signal_wait_scacquire(
packet->completion_signal, HSA_SIGNAL_CONDITION_EQ, 0,
/*timeout_hint=*/1024, HSA_WAIT_STATE_ACTIVE) != 0)
handle_server();
start_args_t args = {argc, dev_argv, dev_envp, dev_ret};
if (hsa_status_t err = launch_kernel(dev_agent, executable, kernargs_pool,
queue, params, "_start.kd", args))
handle_error(err);

// Create a memory signal and copy the return value back from the device into
// a new buffer.
Expand All @@ -395,9 +409,13 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
// Save the return value and perform basic clean-up.
int ret = *static_cast<int *>(host_ret);

// Free the memory allocated for the device.
if (hsa_status_t err = hsa_amd_memory_pool_free(args))
end_args_t fini_args = {ret};
if (hsa_status_t err =
launch_kernel(dev_agent, executable, kernargs_pool, queue,
single_threaded_params, "_end.kd", fini_args))
handle_error(err);

// Free the memory allocated for the device.
if (hsa_status_t err = hsa_amd_memory_pool_free(dev_argv))
handle_error(err);
if (hsa_status_t err = hsa_amd_memory_pool_free(dev_ret))
Expand All @@ -413,10 +431,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,

if (hsa_status_t err = hsa_signal_destroy(memory_signal))
handle_error(err);

if (hsa_status_t err = hsa_signal_destroy(packet->completion_signal))
handle_error(err);

if (hsa_status_t err = hsa_queue_destroy(queue))
handle_error(err);

Expand Down
95 changes: 52 additions & 43 deletions libc/utils/gpu/loader/nvptx/Loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,6 @@
using namespace llvm;
using namespace object;

/// The arguments to the '_start' kernel.
struct kernel_args_t {
int argc;
void *argv;
void *envp;
void *ret;
void *inbox;
void *outbox;
void *buffer;
};

static void handle_error(CUresult err) {
if (err == CUDA_SUCCESS)
return;
Expand Down Expand Up @@ -170,6 +159,36 @@ Expected<void *> get_ctor_dtor_array(const void *image, const size_t size,
return dev_memory;
}

template <typename args_t>
CUresult launch_kernel(CUmodule binary, CUstream stream,
const LaunchParameters &params, const char *kernel_name,
args_t kernel_args) {
// look up the '_start' kernel in the loaded module.
CUfunction function;
if (CUresult err = cuModuleGetFunction(&function, binary, kernel_name))
handle_error(err);

// Set up the arguments to the '_start' kernel on the GPU.
uint64_t args_size = sizeof(args_t);
void *args_config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, &kernel_args,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END};

// Call the kernel with the given arguments.
if (CUresult err = cuLaunchKernel(
function, params.num_blocks_x, params.num_blocks_y,
params.num_blocks_z, params.num_threads_x, params.num_threads_y,
params.num_threads_z, 0, stream, nullptr, args_config))
handle_error(err);

// Wait until the kernel has completed execution on the device. Periodically
// check the RPC client for work to be performed on the server.
while (cuStreamQuery(stream) == CUDA_ERROR_NOT_READY)
handle_server();

return CUDA_SUCCESS;
}

int load(int argc, char **argv, char **envp, void *image, size_t size,
const LaunchParameters &params) {

Expand Down Expand Up @@ -197,11 +216,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
if (CUresult err = cuModuleLoadDataEx(&binary, image, 0, nullptr, nullptr))
handle_error(err);

// look up the '_start' kernel in the loaded module.
CUfunction function;
if (CUresult err = cuModuleGetFunction(&function, binary, "_start"))
handle_error(err);

// Allocate pinned memory on the host to hold the pointer array for the
// copied argv and allow the GPU device to access it.
auto allocator = [&](uint64_t size) -> void * {
Expand Down Expand Up @@ -232,41 +246,31 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
if (CUresult err = cuMemsetD32(dev_ret, 0, 1))
handle_error(err);

uint32_t warp_size = 32;
void *server_inbox = allocator(sizeof(__llvm_libc::cpp::Atomic<int>));
void *server_outbox = allocator(sizeof(__llvm_libc::cpp::Atomic<int>));
void *buffer = allocator(sizeof(__llvm_libc::rpc::Buffer));
void *buffer =
allocator(align_up(sizeof(__llvm_libc::rpc::Header) +
(warp_size * sizeof(__llvm_libc::rpc::Buffer)),
alignof(__llvm_libc::rpc::Packet)));
if (!server_inbox || !server_outbox || !buffer)
handle_error("Failed to allocate memory the RPC client / server.");

// Set up the arguments to the '_start' kernel on the GPU.
uint64_t args_size = sizeof(kernel_args_t);
kernel_args_t args;
std::memset(&args, 0, args_size);
args.argc = argc;
args.argv = dev_argv;
args.envp = dev_envp;
args.ret = reinterpret_cast<void *>(dev_ret);
args.inbox = server_outbox;
args.outbox = server_inbox;
args.buffer = buffer;
void *args_config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, &args,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END};

// Initialize the RPC server's buffer for host-device communication.
server.reset(&lock, server_inbox, server_outbox, buffer);

// Call the kernel with the given arguments.
if (CUresult err = cuLaunchKernel(
function, params.num_blocks_x, params.num_blocks_y,
params.num_blocks_z, params.num_threads_x, params.num_threads_y,
params.num_threads_z, 0, stream, nullptr, args_config))
server.reset(warp_size, &lock, server_inbox, server_outbox, buffer);

LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};
// Call the kernel to
begin_args_t init_args = {argc, dev_argv, dev_envp,
server_outbox, server_inbox, buffer};
if (CUresult err = launch_kernel(binary, stream, single_threaded_params,
"_begin", init_args))
handle_error(err);

// Wait until the kernel has completed execution on the device. Periodically
// check the RPC client for work to be performed on the server.
while (cuStreamQuery(stream) == CUDA_ERROR_NOT_READY)
handle_server();
start_args_t args = {argc, dev_argv, dev_envp,
reinterpret_cast<void *>(dev_ret)};
if (CUresult err = launch_kernel(binary, stream, params, "_start", args))
handle_error(err);

// Copy the return value back from the kernel and wait.
int host_ret = 0;
Expand All @@ -276,6 +280,11 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
if (CUresult err = cuStreamSynchronize(stream))
handle_error(err);

end_args_t fini_args = {host_ret};
if (CUresult err = launch_kernel(binary, stream, single_threaded_params,
"_end", fini_args))
handle_error(err);

// Free the memory allocated for the device.
if (CUresult err = cuMemFreeHost(*memory_or_err))
handle_error(err);
Expand Down