66 changes: 61 additions & 5 deletions libc/utils/gpu/loader/amdgpu/Loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#include "Loader.h"

#include "src/__support/RPC/rpc.h"

#include <hsa/hsa.h>
#include <hsa/hsa_ext_amd.h>

Expand All @@ -31,8 +33,35 @@ struct kernel_args_t {
int argc;
void *argv;
void *ret;
void *inbox;
void *outbox;
void *buffer;
};

static __llvm_libc::rpc::Server server;

/// Queries the RPC client at least once and performs server-side work if there
/// are any active requests.
void handle_server() {
while (server.run(
[&](__llvm_libc::rpc::Buffer *buffer) {
switch (static_cast<__llvm_libc::rpc::Opcode>(buffer->data[0])) {
case __llvm_libc::rpc::Opcode::PRINT_TO_STDERR: {
fputs(reinterpret_cast<const char *>(&buffer->data[1]), stderr);
break;
}
case __llvm_libc::rpc::Opcode::EXIT: {
exit(buffer->data[1]);
break;
}
default:
return;
};
},
[](__llvm_libc::rpc::Buffer *buffer) {}))
;
}

/// Print the error code and exit if \p code indicates an error.
static void handle_error(hsa_status_t code) {
if (code == HSA_STATUS_SUCCESS || code == HSA_STATUS_INFO_BREAK)
Expand Down Expand Up @@ -278,13 +307,36 @@ int load(int argc, char **argv, void *image, size_t size) {
handle_error(err);
hsa_amd_memory_fill(dev_ret, 0, sizeof(int));

// Allocate finegrained memory for the RPC server and client to share.
void *server_inbox;
void *server_outbox;
void *buffer;
if (hsa_status_t err = hsa_amd_memory_pool_allocate(
finegrained_pool, sizeof(__llvm_libc::cpp::Atomic<int>),
/*flags=*/0, &server_inbox))
handle_error(err);
if (hsa_status_t err = hsa_amd_memory_pool_allocate(
finegrained_pool, sizeof(__llvm_libc::cpp::Atomic<int>),
/*flags=*/0, &server_outbox))
handle_error(err);
if (hsa_status_t err = hsa_amd_memory_pool_allocate(
finegrained_pool, sizeof(__llvm_libc::rpc::Buffer),
/*flags=*/0, &buffer))
handle_error(err);
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, server_inbox);
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, server_outbox);
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, buffer);

// Initialie all the arguments (explicit and implicit) to zero, then set the
// explicit arguments to the values created above.
std::memset(args, 0, args_size);
kernel_args_t *kernel_args = reinterpret_cast<kernel_args_t *>(args);
kernel_args->argc = argc;
kernel_args->argv = dev_argv;
kernel_args->ret = dev_ret;
kernel_args->inbox = server_outbox;
kernel_args->outbox = server_inbox;
kernel_args->buffer = buffer;

// Obtain a packet from the queue.
uint64_t packet_id = hsa_queue_add_write_index_relaxed(queue, 1);
Expand Down Expand Up @@ -316,6 +368,9 @@ int load(int argc, char **argv, void *image, size_t size) {
hsa_signal_create(1, 0, nullptr, &packet->completion_signal))
handle_error(err);

// Initialize the RPC server's buffer for host-device communication.
server.reset(server_inbox, server_outbox, buffer);

// Initialize the packet header and set the doorbell signal to begin execution
// by the HSA runtime.
uint16_t header =
Expand All @@ -326,11 +381,12 @@ int load(int argc, char **argv, void *image, size_t size) {
__ATOMIC_RELEASE);
hsa_signal_store_relaxed(queue->doorbell_signal, packet_id);

// Wait until the kernel has completed execution on the device.
while (hsa_signal_wait_scacquire(packet->completion_signal,
HSA_SIGNAL_CONDITION_EQ, 0, UINT64_MAX,
HSA_WAIT_STATE_ACTIVE) != 0)
;
// Wait until the kernel has completed execution on the device. Periodically
// check the RPC client for work to be performed on the server.
while (hsa_signal_wait_scacquire(
packet->completion_signal, HSA_SIGNAL_CONDITION_EQ, 0,
/*timeout_hint=*/1024, HSA_WAIT_STATE_ACTIVE) != 0)
handle_server();

// Create a memory signal and copy the return value back from the device into
// a new buffer.
Expand Down