123 changes: 52 additions & 71 deletions libc/utils/gpu/loader/amdgpu/amdhsa-loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ template <typename args_t>
hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
hsa_amd_memory_pool_t kernargs_pool,
hsa_amd_memory_pool_t coarsegrained_pool,
hsa_queue_t *queue, rpc_device_t device,
hsa_queue_t *queue, rpc::Server &server,
const LaunchParameters &params,
const char *kernel_name, args_t kernel_args,
bool print_resource_usage) {
Expand All @@ -170,37 +170,10 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
executable, kernel_name, &dev_agent, &symbol))
return err;

// Register RPC callbacks for the malloc and free functions on HSA.
auto tuple = std::make_tuple(dev_agent, coarsegrained_pool);
rpc_register_callback(
device, RPC_MALLOC,
[](rpc_port_t port, void *data) {
auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
auto &[dev_agent, pool] = *static_cast<decltype(tuple) *>(data);
uint64_t size = buffer->data[0];
void *dev_ptr = nullptr;
if (hsa_status_t err =
hsa_amd_memory_pool_allocate(pool, size,
/*flags=*/0, &dev_ptr))
dev_ptr = nullptr;
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr);
buffer->data[0] = reinterpret_cast<uintptr_t>(dev_ptr);
};
rpc_recv_and_send(port, malloc_handler, data);
},
&tuple);
rpc_register_callback(
device, RPC_FREE,
[](rpc_port_t port, void *data) {
auto free_handler = [](rpc_buffer_t *buffer, void *) {
if (hsa_status_t err = hsa_amd_memory_pool_free(
reinterpret_cast<void *>(buffer->data[0])))
handle_error(err);
};
rpc_recv_and_send(port, free_handler, data);
},
nullptr);

uint32_t wavefront_size = 0;
if (hsa_status_t err = hsa_agent_get_info(
dev_agent, HSA_AGENT_INFO_WAVEFRONT_SIZE, &wavefront_size))
handle_error(err);
// Retrieve different properties of the kernel symbol used for launch.
uint64_t kernel;
uint32_t args_size;
Expand Down Expand Up @@ -292,14 +265,38 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
hsa_signal_store_relaxed(queue->doorbell_signal, packet_id);

std::atomic<bool> finished = false;
std::thread server(
[](std::atomic<bool> *finished, rpc_device_t device) {
while (!*finished) {
if (rpc_status_t err = rpc_handle_server(device))
std::thread server_thread(
[](std::atomic<bool> *finished, rpc::Server *server,
uint32_t wavefront_size, hsa_agent_t dev_agent,
hsa_amd_memory_pool_t coarsegrained_pool) {
// Register RPC callbacks for the malloc and free functions on HSA.
auto malloc_handler = [&](size_t size) -> void * {
void *dev_ptr = nullptr;
if (hsa_status_t err =
hsa_amd_memory_pool_allocate(coarsegrained_pool, size,
/*flags=*/0, &dev_ptr))
dev_ptr = nullptr;
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr);
return dev_ptr;
};

auto free_handler = [](void *ptr) -> void {
if (hsa_status_t err =
hsa_amd_memory_pool_free(reinterpret_cast<void *>(ptr)))
handle_error(err);
};

uint32_t index = 0;
while (!*finished) {
if (wavefront_size == 32)
index =
handle_server<32>(*server, index, malloc_handler, free_handler);
else
index =
handle_server<64>(*server, index, malloc_handler, free_handler);
}
},
&finished, device);
&finished, &server, wavefront_size, dev_agent, coarsegrained_pool);

// Wait until the kernel has completed execution on the device. Periodically
// check the RPC client for work to be performed on the server.
Expand All @@ -309,8 +306,8 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
;

finished = true;
if (server.joinable())
server.join();
if (server_thread.joinable())
server_thread.join();

// Destroy the resources acquired to launch the kernel and return.
if (hsa_status_t err = hsa_amd_memory_pool_free(args))
Expand Down Expand Up @@ -457,34 +454,22 @@ int load(int argc, const char **argv, const char **envp, void *image,
handle_error(err);

// Set up the RPC server.
auto tuple = std::make_tuple(dev_agent, finegrained_pool);
auto rpc_alloc = [](uint64_t size, void *data) {
auto &[dev_agent, finegrained_pool] = *static_cast<decltype(tuple) *>(data);
void *dev_ptr = nullptr;
if (hsa_status_t err = hsa_amd_memory_pool_allocate(finegrained_pool, size,
/*flags=*/0, &dev_ptr))
handle_error(err);
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr);
return dev_ptr;
};
rpc_device_t device;
if (rpc_status_t err = rpc_server_init(&device, RPC_MAXIMUM_PORT_COUNT,
wavefront_size, rpc_alloc, &tuple))
void *rpc_buffer;
if (hsa_status_t err = hsa_amd_memory_pool_allocate(
finegrained_pool,
rpc::Server::allocation_size(wavefront_size, rpc::MAX_PORT_COUNT),
/*flags=*/0, &rpc_buffer))
handle_error(err);
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, rpc_buffer);

// Register callbacks for the RPC unit tests.
if (wavefront_size == 32)
register_rpc_callbacks<32>(device);
else if (wavefront_size == 64)
register_rpc_callbacks<64>(device);
else
handle_error("Invalid wavefront size");
rpc::Server server(rpc::MAX_PORT_COUNT, rpc_buffer);
rpc::Client client(rpc::MAX_PORT_COUNT, rpc_buffer);

// Initialize the RPC client on the device by copying the local data to the
// device's internal pointer.
hsa_executable_symbol_t rpc_client_sym;
if (hsa_status_t err = hsa_executable_get_symbol_by_name(
executable, rpc_client_symbol_name, &dev_agent, &rpc_client_sym))
executable, "__llvm_libc_rpc_client", &dev_agent, &rpc_client_sym))
handle_error(err);

void *rpc_client_host;
Expand All @@ -507,19 +492,17 @@ int load(int argc, const char **argv, const char **envp, void *image,

void *rpc_client_buffer;
if (hsa_status_t err =
hsa_amd_memory_lock(const_cast<void *>(rpc_get_client_buffer(device)),
rpc_get_client_size(),
hsa_amd_memory_lock(&client, sizeof(rpc::Client),
/*agents=*/nullptr, 0, &rpc_client_buffer))
handle_error(err);

// Copy the RPC client buffer to the address pointed to by the symbol.
if (hsa_status_t err =
hsa_memcpy(*reinterpret_cast<void **>(rpc_client_host), dev_agent,
rpc_client_buffer, host_agent, rpc_get_client_size()))
rpc_client_buffer, host_agent, sizeof(rpc::Client)))
handle_error(err);

if (hsa_status_t err = hsa_amd_memory_unlock(
const_cast<void *>(rpc_get_client_buffer(device))))
if (hsa_status_t err = hsa_amd_memory_unlock(&client))
handle_error(err);
if (hsa_status_t err = hsa_amd_memory_pool_free(rpc_client_host))
handle_error(err);
Expand Down Expand Up @@ -571,15 +554,15 @@ int load(int argc, const char **argv, const char **envp, void *image,
LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};
begin_args_t init_args = {argc, dev_argv, dev_envp};
if (hsa_status_t err = launch_kernel(dev_agent, executable, kernargs_pool,
coarsegrained_pool, queue, device,
coarsegrained_pool, queue, server,
single_threaded_params, "_begin.kd",
init_args, print_resource_usage))
handle_error(err);

start_args_t args = {argc, dev_argv, dev_envp, dev_ret};
if (hsa_status_t err = launch_kernel(
dev_agent, executable, kernargs_pool, coarsegrained_pool, queue,
device, params, "_start.kd", args, print_resource_usage))
server, params, "_start.kd", args, print_resource_usage))
handle_error(err);

void *host_ret;
Expand All @@ -598,14 +581,12 @@ int load(int argc, const char **argv, const char **envp, void *image,

end_args_t fini_args = {ret};
if (hsa_status_t err = launch_kernel(dev_agent, executable, kernargs_pool,
coarsegrained_pool, queue, device,
coarsegrained_pool, queue, server,
single_threaded_params, "_end.kd",
fini_args, print_resource_usage))
handle_error(err);

if (rpc_status_t err = rpc_server_shutdown(
device, [](void *ptr, void *) { hsa_amd_memory_pool_free(ptr); },
nullptr))
if (hsa_status_t err = hsa_amd_memory_pool_free(rpc_buffer))
handle_error(err);

// Free the memory allocated for the device.
Expand Down
96 changes: 37 additions & 59 deletions libc/utils/gpu/loader/nvptx/nvptx-loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,9 @@ void print_kernel_resources(CUmodule binary, const char *kernel_name) {
}

template <typename args_t>
CUresult launch_kernel(CUmodule binary, CUstream stream,
rpc_device_t rpc_device, const LaunchParameters &params,
const char *kernel_name, args_t kernel_args,
bool print_resource_usage) {
CUresult launch_kernel(CUmodule binary, CUstream stream, rpc::Server &server,
const LaunchParameters &params, const char *kernel_name,
args_t kernel_args, bool print_resource_usage) {
// look up the '_start' kernel in the loaded module.
CUfunction function;
if (CUresult err = cuModuleGetFunction(&function, binary, kernel_name))
Expand All @@ -181,60 +180,44 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
void *args_config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, &kernel_args,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END};
if (print_resource_usage)
print_kernel_resources(binary, kernel_name);

// Initialize a non-blocking CUDA stream to allocate memory if needed. This
// needs to be done on a separate stream or else it will deadlock with the
// executing kernel.
// Initialize a non-blocking CUDA stream to allocate memory if needed.
// This needs to be done on a separate stream or else it will deadlock
// with the executing kernel.
CUstream memory_stream;
if (CUresult err = cuStreamCreate(&memory_stream, CU_STREAM_NON_BLOCKING))
handle_error(err);

// Register RPC callbacks for the malloc and free functions on HSA.
register_rpc_callbacks<32>(rpc_device);

rpc_register_callback(
rpc_device, RPC_MALLOC,
[](rpc_port_t port, void *data) {
auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
CUstream memory_stream = *static_cast<CUstream *>(data);
uint64_t size = buffer->data[0];
std::atomic<bool> finished = false;
std::thread server_thread(
[](std::atomic<bool> *finished, rpc::Server *server,
CUstream memory_stream) {
auto malloc_handler = [&](size_t size) -> void * {
CUdeviceptr dev_ptr;
if (CUresult err = cuMemAllocAsync(&dev_ptr, size, memory_stream))
dev_ptr = 0UL;

// Wait until the memory allocation is complete.
while (cuStreamQuery(memory_stream) == CUDA_ERROR_NOT_READY)
;
buffer->data[0] = static_cast<uintptr_t>(dev_ptr);
return reinterpret_cast<void *>(dev_ptr);
};
rpc_recv_and_send(port, malloc_handler, data);
},
&memory_stream);
rpc_register_callback(
rpc_device, RPC_FREE,
[](rpc_port_t port, void *data) {
auto free_handler = [](rpc_buffer_t *buffer, void *data) {
CUstream memory_stream = *static_cast<CUstream *>(data);
if (CUresult err = cuMemFreeAsync(
static_cast<CUdeviceptr>(buffer->data[0]), memory_stream))

auto free_handler = [&](void *ptr) -> void {
if (CUresult err = cuMemFreeAsync(reinterpret_cast<CUdeviceptr>(ptr),
memory_stream))
handle_error(err);
};
rpc_recv_and_send(port, free_handler, data);
},
&memory_stream);

if (print_resource_usage)
print_kernel_resources(binary, kernel_name);

std::atomic<bool> finished = false;
std::thread server(
[](std::atomic<bool> *finished, rpc_device_t device) {
uint32_t index = 0;
while (!*finished) {
if (rpc_status_t err = rpc_handle_server(device))
handle_error(err);
index =
handle_server<32>(*server, index, malloc_handler, free_handler);
}
},
&finished, rpc_device);
&finished, &server, memory_stream);

// Call the kernel with the given arguments.
if (CUresult err = cuLaunchKernel(
Expand All @@ -247,8 +230,8 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
handle_error(err);

finished = true;
if (server.joinable())
server.join();
if (server_thread.joinable())
server_thread.join();

return CUDA_SUCCESS;
}
Expand Down Expand Up @@ -318,44 +301,40 @@ int load(int argc, const char **argv, const char **envp, void *image,
handle_error(err);

uint32_t warp_size = 32;
auto rpc_alloc = [](uint64_t size, void *) -> void * {
void *dev_ptr;
if (CUresult err = cuMemAllocHost(&dev_ptr, size))
handle_error(err);
return dev_ptr;
};
rpc_device_t rpc_device;
if (rpc_status_t err = rpc_server_init(&rpc_device, RPC_MAXIMUM_PORT_COUNT,
warp_size, rpc_alloc, nullptr))
void *rpc_buffer = nullptr;
if (CUresult err = cuMemAllocHost(
&rpc_buffer,
rpc::Server::allocation_size(warp_size, rpc::MAX_PORT_COUNT)))
handle_error(err);
rpc::Server server(rpc::MAX_PORT_COUNT, rpc_buffer);
rpc::Client client(rpc::MAX_PORT_COUNT, rpc_buffer);

// Initialize the RPC client on the device by copying the local data to the
// device's internal pointer.
CUdeviceptr rpc_client_dev = 0;
uint64_t client_ptr_size = sizeof(void *);
if (CUresult err = cuModuleGetGlobal(&rpc_client_dev, &client_ptr_size,
binary, rpc_client_symbol_name))
binary, "__llvm_libc_rpc_client"))
handle_error(err);

CUdeviceptr rpc_client_host = 0;
if (CUresult err =
cuMemcpyDtoH(&rpc_client_host, rpc_client_dev, sizeof(void *)))
handle_error(err);
if (CUresult err =
cuMemcpyHtoD(rpc_client_host, rpc_get_client_buffer(rpc_device),
rpc_get_client_size()))
cuMemcpyHtoD(rpc_client_host, &client, sizeof(rpc::Client)))
handle_error(err);

LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};
begin_args_t init_args = {argc, dev_argv, dev_envp};
if (CUresult err =
launch_kernel(binary, stream, rpc_device, single_threaded_params,
launch_kernel(binary, stream, server, single_threaded_params,
"_begin", init_args, print_resource_usage))
handle_error(err);

start_args_t args = {argc, dev_argv, dev_envp,
reinterpret_cast<void *>(dev_ret)};
if (CUresult err = launch_kernel(binary, stream, rpc_device, params, "_start",
if (CUresult err = launch_kernel(binary, stream, server, params, "_start",
args, print_resource_usage))
handle_error(err);

Expand All @@ -369,8 +348,8 @@ int load(int argc, const char **argv, const char **envp, void *image,

end_args_t fini_args = {host_ret};
if (CUresult err =
launch_kernel(binary, stream, rpc_device, single_threaded_params,
"_end", fini_args, print_resource_usage))
launch_kernel(binary, stream, server, single_threaded_params, "_end",
fini_args, print_resource_usage))
handle_error(err);

// Free the memory allocated for the device.
Expand All @@ -380,8 +359,7 @@ int load(int argc, const char **argv, const char **envp, void *image,
handle_error(err);
if (CUresult err = cuMemFreeHost(dev_argv))
handle_error(err);
if (rpc_status_t err = rpc_server_shutdown(
rpc_device, [](void *ptr, void *) { cuMemFreeHost(ptr); }, nullptr))
if (CUresult err = cuMemFreeHost(rpc_buffer))
handle_error(err);

// Destroy the context and the loaded binary.
Expand Down
94 changes: 1 addition & 93 deletions libc/utils/gpu/server/llvmlibc_rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,99 +15,7 @@
extern "C" {
#endif

/// The maximum number of ports that can be opened for any server.
const uint64_t RPC_MAXIMUM_PORT_COUNT = 4096;

/// The symbol name associated with the client for use with the LLVM C library
/// implementation.
const char *const rpc_client_symbol_name = "__llvm_libc_rpc_client";

/// status codes.
typedef enum {
RPC_STATUS_SUCCESS = 0x0,
RPC_STATUS_CONTINUE = 0x1,
RPC_STATUS_ERROR = 0x1000,
RPC_STATUS_UNHANDLED_OPCODE = 0x1001,
RPC_STATUS_INVALID_LANE_SIZE = 0x1002,
} rpc_status_t;

/// A struct containing an opaque handle to an RPC port. This is what allows the
/// server to communicate with the client.
typedef struct rpc_port_s {
uint64_t handle;
uint32_t lane_size;
} rpc_port_t;

/// A fixed-size buffer containing the payload sent from the client.
typedef struct rpc_buffer_s {
uint64_t data[8];
} rpc_buffer_t;

/// An opaque handle to an RPC server that can be attached to a device.
typedef struct rpc_device_s {
uintptr_t handle;
} rpc_device_t;

/// A function used to allocate \p bytes for use by the RPC server and client.
/// The memory should support asynchronous and atomic access from both the
/// client and server.
typedef void *(*rpc_alloc_ty)(uint64_t size, void *data);

/// A function used to free the \p ptr previously allocated.
typedef void (*rpc_free_ty)(void *ptr, void *data);

/// A callback function provided with a \p port to communicate with the RPC
/// client. This will be called by the server to handle an opcode.
typedef void (*rpc_opcode_callback_ty)(rpc_port_t port, void *data);

/// A callback function to use the port to receive or send a \p buffer.
typedef void (*rpc_port_callback_ty)(rpc_buffer_t *buffer, void *data);

/// Initialize the server for a given device and return it in \p device.
rpc_status_t rpc_server_init(rpc_device_t *rpc_device, uint64_t num_ports,
uint32_t lane_size, rpc_alloc_ty alloc,
void *data);

/// Shut down the server for a given device.
rpc_status_t rpc_server_shutdown(rpc_device_t rpc_device, rpc_free_ty dealloc,
void *data);

/// Queries the RPC clients at least once and performs server-side work if there
/// are any active requests. Runs until all work on the server is completed.
rpc_status_t rpc_handle_server(rpc_device_t rpc_device);

/// Register a callback to handle an opcode from the RPC client. The associated
/// data must remain accessible as long as the user intends to handle the server
/// with this callback.
rpc_status_t rpc_register_callback(rpc_device_t rpc_device, uint32_t opcode,
rpc_opcode_callback_ty callback, void *data);

/// Obtain a pointer to a local client buffer that can be copied directly to the
/// other process using the address stored at the rpc client symbol name.
const void *rpc_get_client_buffer(rpc_device_t device);

/// Returns the size of the client in bytes to be used for a memory copy.
uint64_t rpc_get_client_size();

/// Use the \p port to send a buffer using the \p callback.
void rpc_send(rpc_port_t port, rpc_port_callback_ty callback, void *data);

/// Use the \p port to send \p bytes using the \p callback. The input is an
/// array of at least the configured lane size.
void rpc_send_n(rpc_port_t port, const void *const *src, uint64_t *size);

/// Use the \p port to recieve a buffer using the \p callback.
void rpc_recv(rpc_port_t port, rpc_port_callback_ty callback, void *data);

/// Use the \p port to recieve \p bytes using the \p callback. The inputs is an
/// array of at least the configured lane size. The \p alloc function allocates
/// memory for the recieved bytes.
void rpc_recv_n(rpc_port_t port, void **dst, uint64_t *size, rpc_alloc_ty alloc,
void *data);

/// Use the \p port to receive and send a buffer using the \p callback.
void rpc_recv_and_send(rpc_port_t port, rpc_port_callback_ty callback,
void *data);
int libc_handle_rpc_port(void *port, uint32_t num_lanes);

#ifdef __cplusplus
}
Expand Down
348 changes: 101 additions & 247 deletions libc/utils/gpu/server/rpc_server.cpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions offload/plugins-nextgen/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ elseif(${LIBOMPTARGET_GPU_LIBC_SUPPORT})
# We may need to get the headers directly from the 'libc' source directory.
target_include_directories(PluginCommon PRIVATE
${CMAKE_SOURCE_DIR}/../libc/utils/gpu/server
${CMAKE_SOURCE_DIR}/../libc/
${CMAKE_SOURCE_DIR}/../libc/include)
endif()
endif()
Expand Down
2 changes: 1 addition & 1 deletion offload/plugins-nextgen/common/include/RPC.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct RPCServerTy {

private:
/// Array from this device's identifier to its attached devices.
llvm::SmallVector<uintptr_t> Handles;
llvm::SmallVector<void *> Buffers;
};

} // namespace llvm::omp::target
Expand Down
126 changes: 56 additions & 70 deletions offload/plugins-nextgen/common/src/RPC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,26 @@

#include "PluginInterface.h"

// TODO: This should be included unconditionally and cleaned up.
#if defined(LIBOMPTARGET_RPC_SUPPORT)
#include "llvm-libc-types/rpc_opcodes_t.h"
#include "include/llvm-libc-types/rpc_opcodes_t.h"
#include "llvmlibc_rpc_server.h"
#include "shared/rpc.h"
#endif

using namespace llvm;
using namespace omp;
using namespace target;

RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
: Handles(Plugin.getNumDevices()) {}
: Buffers(Plugin.getNumDevices()) {}

llvm::Expected<bool>
RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device,
plugin::GenericGlobalHandlerTy &Handler,
plugin::DeviceImageTy &Image) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
return Handler.isSymbolInImage(Device, Image, rpc_client_symbol_name);
return Handler.isSymbolInImage(Device, Image, "__llvm_libc_rpc_client");
#else
return false;
#endif
Expand All @@ -39,59 +41,18 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
plugin::GenericGlobalHandlerTy &Handler,
plugin::DeviceImageTy &Image) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
auto Alloc = [](uint64_t Size, void *Data) {
plugin::GenericDeviceTy &Device =
*reinterpret_cast<plugin::GenericDeviceTy *>(Data);
return Device.allocate(Size, nullptr, TARGET_ALLOC_HOST);
};
uint64_t NumPorts =
std::min(Device.requestedRPCPortCount(), RPC_MAXIMUM_PORT_COUNT);
rpc_device_t RPCDevice;
if (rpc_status_t Err = rpc_server_init(&RPCDevice, NumPorts,
Device.getWarpSize(), Alloc, &Device))
std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT);
void *RPCBuffer = Device.allocate(
rpc::Server::allocation_size(Device.getWarpSize(), NumPorts), nullptr,
TARGET_ALLOC_HOST);
if (!RPCBuffer)
return plugin::Plugin::error(
"Failed to initialize RPC server for device %d: %d",
Device.getDeviceId(), Err);

// Register a custom opcode handler to perform plugin specific allocation.
auto MallocHandler = [](rpc_port_t Port, void *Data) {
rpc_recv_and_send(
Port,
[](rpc_buffer_t *Buffer, void *Data) {
plugin::GenericDeviceTy &Device =
*reinterpret_cast<plugin::GenericDeviceTy *>(Data);
Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate(
Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING));
},
Data);
};
if (rpc_status_t Err =
rpc_register_callback(RPCDevice, RPC_MALLOC, MallocHandler, &Device))
return plugin::Plugin::error(
"Failed to register RPC malloc handler for device %d: %d\n",
Device.getDeviceId(), Err);

// Register a custom opcode handler to perform plugin specific deallocation.
auto FreeHandler = [](rpc_port_t Port, void *Data) {
rpc_recv(
Port,
[](rpc_buffer_t *Buffer, void *Data) {
plugin::GenericDeviceTy &Device =
*reinterpret_cast<plugin::GenericDeviceTy *>(Data);
Device.free(reinterpret_cast<void *>(Buffer->data[0]),
TARGET_ALLOC_DEVICE_NON_BLOCKING);
},
Data);
};
if (rpc_status_t Err =
rpc_register_callback(RPCDevice, RPC_FREE, FreeHandler, &Device))
return plugin::Plugin::error(
"Failed to register RPC free handler for device %d: %d\n",
Device.getDeviceId(), Err);
"Failed to initialize RPC server for device %d", Device.getDeviceId());

// Get the address of the RPC client from the device.
void *ClientPtr;
plugin::GlobalTy ClientGlobal(rpc_client_symbol_name, sizeof(void *));
plugin::GlobalTy ClientGlobal("__llvm_libc_rpc_client", sizeof(void *));
if (auto Err =
Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal))
return Err;
Expand All @@ -100,38 +61,63 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
sizeof(void *), nullptr))
return Err;

const void *ClientBuffer = rpc_get_client_buffer(RPCDevice);
if (auto Err = Device.dataSubmit(ClientPtr, ClientBuffer,
rpc_get_client_size(), nullptr))
rpc::Client client(NumPorts, RPCBuffer);
if (auto Err =
Device.dataSubmit(ClientPtr, &client, sizeof(rpc::Client), nullptr))
return Err;
Handles[Device.getDeviceId()] = RPCDevice.handle;
Buffers[Device.getDeviceId()] = RPCBuffer;

return Error::success();

#endif
return Error::success();
}

Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};
if (rpc_status_t Err = rpc_handle_server(RPCDevice))
return plugin::Plugin::error(
"Error while running RPC server on device %d: %d", Device.getDeviceId(),
Err);
uint64_t NumPorts =
std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT);
rpc::Server Server(NumPorts, Buffers[Device.getDeviceId()]);

auto port = Server.try_open(Device.getWarpSize());
if (!port)
return Error::success();

int Status = rpc::SUCCESS;
switch (port->get_opcode()) {
case RPC_MALLOC: {
port->recv_and_send([&](rpc::Buffer *Buffer, uint32_t) {
Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate(
Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING));
});
break;
}
case RPC_FREE: {
port->recv([&](rpc::Buffer *Buffer, uint32_t) {
Device.free(reinterpret_cast<void *>(Buffer->data[0]),
TARGET_ALLOC_DEVICE_NON_BLOCKING);
});
break;
}
default:
// Let the `libc` library handle any other unhandled opcodes.
Status = libc_handle_rpc_port(&*port, Device.getWarpSize());
break;
}
port->close();

if (Status != rpc::SUCCESS)
return createStringError("RPC server given invalid opcode!");

return Error::success();
#endif
return Error::success();
}

Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};
auto Dealloc = [](void *Ptr, void *Data) {
plugin::GenericDeviceTy &Device =
*reinterpret_cast<plugin::GenericDeviceTy *>(Data);
Device.free(Ptr, TARGET_ALLOC_HOST);
};
if (rpc_status_t Err = rpc_server_shutdown(RPCDevice, Dealloc, &Device))
return plugin::Plugin::error(
"Failed to shut down RPC server for device %d: %d",
Device.getDeviceId(), Err);
Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST);
return Error::success();
#endif
return Error::success();
}