diff --git a/libc/utils/gpu/loader/Loader.h b/libc/utils/gpu/loader/Loader.h index 9338038370197..9c7d328930c23 100644 --- a/libc/utils/gpu/loader/Loader.h +++ b/libc/utils/gpu/loader/Loader.h @@ -108,11 +108,11 @@ inline void handle_error(rpc_status_t) { } template -inline void register_rpc_callbacks(uint32_t device_id) { +inline void register_rpc_callbacks(rpc_device_t device) { static_assert(lane_size == 32 || lane_size == 64, "Invalid Lane size"); // Register the ping test for the `libc` tests. rpc_register_callback( - device_id, static_cast(RPC_TEST_INCREMENT), + device, static_cast(RPC_TEST_INCREMENT), [](rpc_port_t port, void *data) { rpc_recv_and_send( port, @@ -125,7 +125,7 @@ inline void register_rpc_callbacks(uint32_t device_id) { // Register the interface test callbacks. rpc_register_callback( - device_id, static_cast(RPC_TEST_INTERFACE), + device, static_cast(RPC_TEST_INTERFACE), [](rpc_port_t port, void *data) { uint64_t cnt = 0; bool end_with_recv; @@ -207,7 +207,7 @@ inline void register_rpc_callbacks(uint32_t device_id) { // Register the stream test handler. rpc_register_callback( - device_id, static_cast(RPC_TEST_STREAM), + device, static_cast(RPC_TEST_STREAM), [](rpc_port_t port, void *data) { uint64_t sizes[lane_size] = {0}; void *dst[lane_size] = {nullptr}; diff --git a/libc/utils/gpu/loader/amdgpu/Loader.cpp b/libc/utils/gpu/loader/amdgpu/Loader.cpp index e3911eda2bd82..35840c6910bd8 100644 --- a/libc/utils/gpu/loader/amdgpu/Loader.cpp +++ b/libc/utils/gpu/loader/amdgpu/Loader.cpp @@ -153,7 +153,8 @@ template hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable, hsa_amd_memory_pool_t kernargs_pool, hsa_amd_memory_pool_t coarsegrained_pool, - hsa_queue_t *queue, const LaunchParameters ¶ms, + hsa_queue_t *queue, rpc_device_t device, + const LaunchParameters ¶ms, const char *kernel_name, args_t kernel_args) { // Look up the '_start' kernel in the loaded executable. hsa_executable_symbol_t symbol; @@ -162,10 +163,9 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable, return err; // Register RPC callbacks for the malloc and free functions on HSA. - uint32_t device_id = 0; auto tuple = std::make_tuple(dev_agent, coarsegrained_pool); rpc_register_callback( - device_id, RPC_MALLOC, + device, RPC_MALLOC, [](rpc_port_t port, void *data) { auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void { auto &[dev_agent, pool] = *static_cast(data); @@ -182,7 +182,7 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable, }, &tuple); rpc_register_callback( - device_id, RPC_FREE, + device, RPC_FREE, [](rpc_port_t port, void *data) { auto free_handler = [](rpc_buffer_t *buffer, void *) { if (hsa_status_t err = hsa_amd_memory_pool_free( @@ -284,12 +284,12 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable, while (hsa_signal_wait_scacquire( packet->completion_signal, HSA_SIGNAL_CONDITION_EQ, 0, /*timeout_hint=*/1024, HSA_WAIT_STATE_ACTIVE) != 0) - if (rpc_status_t err = rpc_handle_server(device_id)) + if (rpc_status_t err = rpc_handle_server(device)) handle_error(err); // Handle the server one more time in case the kernel exited with a pending // send still in flight. - if (rpc_status_t err = rpc_handle_server(device_id)) + if (rpc_status_t err = rpc_handle_server(device)) handle_error(err); // Destroy the resources acquired to launch the kernel and return. @@ -342,8 +342,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, handle_error(err); // Obtain a single agent for the device and host to use the HSA memory model. - uint32_t num_devices = 1; - uint32_t device_id = 0; hsa_agent_t dev_agent; hsa_agent_t host_agent; if (hsa_status_t err = get_agent(&dev_agent)) @@ -433,8 +431,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, handle_error(err); // Set up the RPC server. - if (rpc_status_t err = rpc_init(num_devices)) - handle_error(err); auto tuple = std::make_tuple(dev_agent, finegrained_pool); auto rpc_alloc = [](uint64_t size, void *data) { auto &[dev_agent, finegrained_pool] = *static_cast(data); @@ -445,15 +441,16 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr); return dev_ptr; }; - if (rpc_status_t err = rpc_server_init(device_id, RPC_MAXIMUM_PORT_COUNT, + rpc_device_t device; + if (rpc_status_t err = rpc_server_init(&device, RPC_MAXIMUM_PORT_COUNT, wavefront_size, rpc_alloc, &tuple)) handle_error(err); // Register callbacks for the RPC unit tests. if (wavefront_size == 32) - register_rpc_callbacks<32>(device_id); + register_rpc_callbacks<32>(device); else if (wavefront_size == 64) - register_rpc_callbacks<64>(device_id); + register_rpc_callbacks<64>(device); else handle_error("Invalid wavefront size"); @@ -483,10 +480,10 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, handle_error(err); void *rpc_client_buffer; - if (hsa_status_t err = hsa_amd_memory_lock( - const_cast(rpc_get_client_buffer(device_id)), - rpc_get_client_size(), - /*agents=*/nullptr, 0, &rpc_client_buffer)) + if (hsa_status_t err = + hsa_amd_memory_lock(const_cast(rpc_get_client_buffer(device)), + rpc_get_client_size(), + /*agents=*/nullptr, 0, &rpc_client_buffer)) handle_error(err); // Copy the RPC client buffer to the address pointed to by the symbol. @@ -496,7 +493,7 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, handle_error(err); if (hsa_status_t err = hsa_amd_memory_unlock( - const_cast(rpc_get_client_buffer(device_id)))) + const_cast(rpc_get_client_buffer(device)))) handle_error(err); if (hsa_status_t err = hsa_amd_memory_pool_free(rpc_client_host)) handle_error(err); @@ -549,13 +546,13 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, begin_args_t init_args = {argc, dev_argv, dev_envp}; if (hsa_status_t err = launch_kernel( dev_agent, executable, kernargs_pool, coarsegrained_pool, queue, - single_threaded_params, "_begin.kd", init_args)) + device, single_threaded_params, "_begin.kd", init_args)) handle_error(err); start_args_t args = {argc, dev_argv, dev_envp, dev_ret}; - if (hsa_status_t err = - launch_kernel(dev_agent, executable, kernargs_pool, - coarsegrained_pool, queue, params, "_start.kd", args)) + if (hsa_status_t err = launch_kernel(dev_agent, executable, kernargs_pool, + coarsegrained_pool, queue, device, + params, "_start.kd", args)) handle_error(err); void *host_ret; @@ -575,11 +572,11 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, end_args_t fini_args = {ret}; if (hsa_status_t err = launch_kernel( dev_agent, executable, kernargs_pool, coarsegrained_pool, queue, - single_threaded_params, "_end.kd", fini_args)) + device, single_threaded_params, "_end.kd", fini_args)) handle_error(err); if (rpc_status_t err = rpc_server_shutdown( - device_id, [](void *ptr, void *) { hsa_amd_memory_pool_free(ptr); }, + device, [](void *ptr, void *) { hsa_amd_memory_pool_free(ptr); }, nullptr)) handle_error(err); @@ -600,8 +597,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, if (hsa_status_t err = hsa_code_object_destroy(object)) handle_error(err); - if (rpc_status_t err = rpc_shutdown()) - handle_error(err); if (hsa_status_t err = hsa_shut_down()) handle_error(err); diff --git a/libc/utils/gpu/loader/nvptx/Loader.cpp b/libc/utils/gpu/loader/nvptx/Loader.cpp index 5388f287063b7..1818932f0a966 100644 --- a/libc/utils/gpu/loader/nvptx/Loader.cpp +++ b/libc/utils/gpu/loader/nvptx/Loader.cpp @@ -154,8 +154,8 @@ Expected get_ctor_dtor_array(const void *image, const size_t size, template CUresult launch_kernel(CUmodule binary, CUstream stream, - const LaunchParameters ¶ms, const char *kernel_name, - args_t kernel_args) { + rpc_device_t rpc_device, const LaunchParameters ¶ms, + const char *kernel_name, args_t kernel_args) { // look up the '_start' kernel in the loaded module. CUfunction function; if (CUresult err = cuModuleGetFunction(&function, binary, kernel_name)) @@ -175,11 +175,10 @@ CUresult launch_kernel(CUmodule binary, CUstream stream, handle_error(err); // Register RPC callbacks for the malloc and free functions on HSA. - uint32_t device_id = 0; - register_rpc_callbacks<32>(device_id); + register_rpc_callbacks<32>(rpc_device); rpc_register_callback( - device_id, RPC_MALLOC, + rpc_device, RPC_MALLOC, [](rpc_port_t port, void *data) { auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void { CUstream memory_stream = *static_cast(data); @@ -197,7 +196,7 @@ CUresult launch_kernel(CUmodule binary, CUstream stream, }, &memory_stream); rpc_register_callback( - device_id, RPC_FREE, + rpc_device, RPC_FREE, [](rpc_port_t port, void *data) { auto free_handler = [](rpc_buffer_t *buffer, void *data) { CUstream memory_stream = *static_cast(data); @@ -219,12 +218,12 @@ CUresult launch_kernel(CUmodule binary, CUstream stream, // Wait until the kernel has completed execution on the device. Periodically // check the RPC client for work to be performed on the server. while (cuStreamQuery(stream) == CUDA_ERROR_NOT_READY) - if (rpc_status_t err = rpc_handle_server(device_id)) + if (rpc_status_t err = rpc_handle_server(rpc_device)) handle_error(err); // Handle the server one more time in case the kernel exited with a pending // send still in flight. - if (rpc_status_t err = rpc_handle_server(device_id)) + if (rpc_status_t err = rpc_handle_server(rpc_device)) handle_error(err); return CUDA_SUCCESS; @@ -235,7 +234,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, if (CUresult err = cuInit(0)) handle_error(err); // Obtain the first device found on the system. - uint32_t num_devices = 1; uint32_t device_id = 0; CUdevice device; if (CUresult err = cuDeviceGet(&device, device_id)) @@ -294,9 +292,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, if (CUresult err = cuMemsetD32(dev_ret, 0, 1)) handle_error(err); - if (rpc_status_t err = rpc_init(num_devices)) - handle_error(err); - uint32_t warp_size = 32; auto rpc_alloc = [](uint64_t size, void *) -> void * { void *dev_ptr; @@ -304,7 +299,8 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, handle_error(err); return dev_ptr; }; - if (rpc_status_t err = rpc_server_init(device_id, RPC_MAXIMUM_PORT_COUNT, + rpc_device_t rpc_device; + if (rpc_status_t err = rpc_server_init(&rpc_device, RPC_MAXIMUM_PORT_COUNT, warp_size, rpc_alloc, nullptr)) handle_error(err); @@ -321,19 +317,20 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, cuMemcpyDtoH(&rpc_client_host, rpc_client_dev, sizeof(void *))) handle_error(err); if (CUresult err = - cuMemcpyHtoD(rpc_client_host, rpc_get_client_buffer(device_id), + cuMemcpyHtoD(rpc_client_host, rpc_get_client_buffer(rpc_device), rpc_get_client_size())) handle_error(err); LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1}; begin_args_t init_args = {argc, dev_argv, dev_envp}; - if (CUresult err = launch_kernel(binary, stream, single_threaded_params, - "_begin", init_args)) + if (CUresult err = launch_kernel(binary, stream, rpc_device, + single_threaded_params, "_begin", init_args)) handle_error(err); start_args_t args = {argc, dev_argv, dev_envp, reinterpret_cast(dev_ret)}; - if (CUresult err = launch_kernel(binary, stream, params, "_start", args)) + if (CUresult err = + launch_kernel(binary, stream, rpc_device, params, "_start", args)) handle_error(err); // Copy the return value back from the kernel and wait. @@ -345,8 +342,8 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, handle_error(err); end_args_t fini_args = {host_ret}; - if (CUresult err = launch_kernel(binary, stream, single_threaded_params, - "_end", fini_args)) + if (CUresult err = launch_kernel(binary, stream, rpc_device, + single_threaded_params, "_end", fini_args)) handle_error(err); // Free the memory allocated for the device. @@ -357,7 +354,7 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, if (CUresult err = cuMemFreeHost(dev_argv)) handle_error(err); if (rpc_status_t err = rpc_server_shutdown( - device_id, [](void *ptr, void *) { cuMemFreeHost(ptr); }, nullptr)) + rpc_device, [](void *ptr, void *) { cuMemFreeHost(ptr); }, nullptr)) handle_error(err); // Destroy the context and the loaded binary. @@ -365,7 +362,5 @@ int load(int argc, char **argv, char **envp, void *image, size_t size, handle_error(err); if (CUresult err = cuDevicePrimaryCtxRelease(device)) handle_error(err); - if (rpc_status_t err = rpc_shutdown()) - handle_error(err); return host_ret; } diff --git a/libc/utils/gpu/server/llvmlibc_rpc_server.h b/libc/utils/gpu/server/llvmlibc_rpc_server.h index b7f2a463b1f5c..b0cf2f916b385 100644 --- a/libc/utils/gpu/server/llvmlibc_rpc_server.h +++ b/libc/utils/gpu/server/llvmlibc_rpc_server.h @@ -27,10 +27,8 @@ typedef enum { RPC_STATUS_SUCCESS = 0x0, RPC_STATUS_CONTINUE = 0x1, RPC_STATUS_ERROR = 0x1000, - RPC_STATUS_OUT_OF_RANGE = 0x1001, - RPC_STATUS_UNHANDLED_OPCODE = 0x1002, - RPC_STATUS_INVALID_LANE_SIZE = 0x1003, - RPC_STATUS_NOT_INITIALIZED = 0x1004, + RPC_STATUS_UNHANDLED_OPCODE = 0x1001, + RPC_STATUS_INVALID_LANE_SIZE = 0x1002, } rpc_status_t; /// A struct containing an opaque handle to an RPC port. This is what allows the @@ -45,6 +43,11 @@ typedef struct rpc_buffer_s { uint64_t data[8]; } rpc_buffer_t; +/// An opaque handle to an RPC server that can be attached to a device. +typedef struct rpc_device_s { + uintptr_t handle; +} rpc_device_t; + /// A function used to allocate \p bytes for use by the RPC server and client. /// The memory should support asynchronous and atomic access from both the /// client and server. @@ -60,34 +63,28 @@ typedef void (*rpc_opcode_callback_ty)(rpc_port_t port, void *data); /// A callback function to use the port to receive or send a \p buffer. typedef void (*rpc_port_callback_ty)(rpc_buffer_t *buffer, void *data); -/// Initialize the rpc library for general use on \p num_devices. -rpc_status_t rpc_init(uint32_t num_devices); - -/// Shut down the rpc interface. -rpc_status_t rpc_shutdown(void); - -/// Initialize the server for a given device. -rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports, +/// Initialize the server for a given device and return it in \p device. +rpc_status_t rpc_server_init(rpc_device_t *rpc_device, uint64_t num_ports, uint32_t lane_size, rpc_alloc_ty alloc, void *data); /// Shut down the server for a given device. -rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc, +rpc_status_t rpc_server_shutdown(rpc_device_t rpc_device, rpc_free_ty dealloc, void *data); /// Queries the RPC clients at least once and performs server-side work if there /// are any active requests. Runs until all work on the server is completed. -rpc_status_t rpc_handle_server(uint32_t device_id); +rpc_status_t rpc_handle_server(rpc_device_t rpc_device); /// Register a callback to handle an opcode from the RPC client. The associated /// data must remain accessible as long as the user intends to handle the server /// with this callback. -rpc_status_t rpc_register_callback(uint32_t device_id, uint16_t opcode, +rpc_status_t rpc_register_callback(rpc_device_t rpc_device, uint16_t opcode, rpc_opcode_callback_ty callback, void *data); /// Obtain a pointer to a local client buffer that can be copied directly to the /// other process using the address stored at the rpc client symbol name. -const void *rpc_get_client_buffer(uint32_t device_id); +const void *rpc_get_client_buffer(rpc_device_t device); /// Returns the size of the client in bytes to be used for a memory copy. uint64_t rpc_get_client_size(); diff --git a/libc/utils/gpu/server/rpc_server.cpp b/libc/utils/gpu/server/rpc_server.cpp index 46ad98fa02cc5..fd306642fdcc4 100644 --- a/libc/utils/gpu/server/rpc_server.cpp +++ b/libc/utils/gpu/server/rpc_server.cpp @@ -248,127 +248,75 @@ struct Device { std::unordered_map callback_data; }; -// A struct containing all the runtime state required to run the RPC server. -struct State { - State(uint32_t num_devices) - : num_devices(num_devices), devices(num_devices), reference_count(0u) {} - uint32_t num_devices; - std::vector> devices; - std::atomic_uint32_t reference_count; -}; - -static std::mutex startup_mutex; - -static State *state; - -rpc_status_t rpc_init(uint32_t num_devices) { - std::scoped_lock lock(startup_mutex); - if (!state) - state = new State(num_devices); - - if (state->reference_count == std::numeric_limits::max()) - return RPC_STATUS_ERROR; - - state->reference_count++; - - return RPC_STATUS_SUCCESS; -} - -rpc_status_t rpc_shutdown(void) { - if (state && state->reference_count-- == 1) - delete state; - - return RPC_STATUS_SUCCESS; -} - -rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports, +rpc_status_t rpc_server_init(rpc_device_t *rpc_device, uint64_t num_ports, uint32_t lane_size, rpc_alloc_ty alloc, void *data) { - if (!state) - return RPC_STATUS_NOT_INITIALIZED; - if (device_id >= state->num_devices) - return RPC_STATUS_OUT_OF_RANGE; + if (!rpc_device) + return RPC_STATUS_ERROR; if (lane_size != 1 && lane_size != 32 && lane_size != 64) return RPC_STATUS_INVALID_LANE_SIZE; - if (!state->devices[device_id]) { - uint64_t size = rpc::Server::allocation_size(lane_size, num_ports); - void *buffer = alloc(size, data); + uint64_t size = rpc::Server::allocation_size(lane_size, num_ports); + void *buffer = alloc(size, data); - if (!buffer) - return RPC_STATUS_ERROR; + if (!buffer) + return RPC_STATUS_ERROR; - state->devices[device_id] = - std::make_unique(lane_size, num_ports, buffer); - if (!state->devices[device_id]) - return RPC_STATUS_ERROR; - } + Device *device = new Device(lane_size, num_ports, buffer); + if (!device) + return RPC_STATUS_ERROR; + rpc_device->handle = reinterpret_cast(device); return RPC_STATUS_SUCCESS; } -rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc, +rpc_status_t rpc_server_shutdown(rpc_device_t rpc_device, rpc_free_ty dealloc, void *data) { - if (!state) - return RPC_STATUS_NOT_INITIALIZED; - if (device_id >= state->num_devices) - return RPC_STATUS_OUT_OF_RANGE; - if (!state->devices[device_id]) + if (!rpc_device.handle) return RPC_STATUS_ERROR; - dealloc(state->devices[device_id]->buffer, data); - if (state->devices[device_id]) - state->devices[device_id].release(); + Device *device = reinterpret_cast(rpc_device.handle); + dealloc(device->buffer, data); + delete device; return RPC_STATUS_SUCCESS; } -rpc_status_t rpc_handle_server(uint32_t device_id) { - if (!state) - return RPC_STATUS_NOT_INITIALIZED; - if (device_id >= state->num_devices) - return RPC_STATUS_OUT_OF_RANGE; - if (!state->devices[device_id]) +rpc_status_t rpc_handle_server(rpc_device_t rpc_device) { + if (!rpc_device.handle) return RPC_STATUS_ERROR; + Device *device = reinterpret_cast(rpc_device.handle); uint32_t index = 0; for (;;) { - Device &device = *state->devices[device_id]; - rpc_status_t status = device.handle_server(index); + rpc_status_t status = device->handle_server(index); if (status != RPC_STATUS_CONTINUE) return status; } } -rpc_status_t rpc_register_callback(uint32_t device_id, uint16_t opcode, +rpc_status_t rpc_register_callback(rpc_device_t rpc_device, uint16_t opcode, rpc_opcode_callback_ty callback, void *data) { - if (!state) - return RPC_STATUS_NOT_INITIALIZED; - if (device_id >= state->num_devices) - return RPC_STATUS_OUT_OF_RANGE; - if (!state->devices[device_id]) + if (!rpc_device.handle) return RPC_STATUS_ERROR; - state->devices[device_id]->callbacks[opcode] = callback; - state->devices[device_id]->callback_data[opcode] = data; + Device *device = reinterpret_cast(rpc_device.handle); + + device->callbacks[opcode] = callback; + device->callback_data[opcode] = data; return RPC_STATUS_SUCCESS; } -const void *rpc_get_client_buffer(uint32_t device_id) { - if (!state || device_id >= state->num_devices || !state->devices[device_id]) +const void *rpc_get_client_buffer(rpc_device_t rpc_device) { + if (!rpc_device.handle) return nullptr; - return &state->devices[device_id]->client; + Device *device = reinterpret_cast(rpc_device.handle); + return &device->client; } uint64_t rpc_get_client_size() { return sizeof(rpc::Client); } -using ServerPort = std::variant; - -ServerPort get_port(rpc_port_t ref) { - return reinterpret_cast(ref.handle); -} - void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) { auto port = reinterpret_cast(ref.handle); port->send([=](rpc::Buffer *buffer) { diff --git a/openmp/libomptarget/plugins-nextgen/common/include/RPC.h b/openmp/libomptarget/plugins-nextgen/common/include/RPC.h index 2e39b3f299c88..b621cc0da4587 100644 --- a/openmp/libomptarget/plugins-nextgen/common/include/RPC.h +++ b/openmp/libomptarget/plugins-nextgen/common/include/RPC.h @@ -16,6 +16,7 @@ #ifndef OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_COMMON_RPC_H #define OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_COMMON_RPC_H +#include "llvm/ADT/DenseMap.h" #include "llvm/Support/Error.h" #include @@ -32,8 +33,6 @@ class DeviceImageTy; /// these routines will perform no action. struct RPCServerTy { public: - RPCServerTy(uint32_t NumDevices); - /// Check if this device image is using an RPC server. This checks for the /// precense of an externally visible symbol in the device image that will /// be present whenever RPC code is called. @@ -56,7 +55,9 @@ struct RPCServerTy { /// memory associated with the k llvm::Error deinitDevice(plugin::GenericDeviceTy &Device); - ~RPCServerTy(); +private: + /// Array from this device's identifier to its attached devices. + llvm::SmallVector Handles; }; } // namespace llvm::omp::target diff --git a/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp index a4e6c93192159..55e2865d6aae4 100644 --- a/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp +++ b/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp @@ -1492,7 +1492,7 @@ Error GenericPluginTy::init() { GlobalHandler = createGlobalHandler(); assert(GlobalHandler && "Invalid global handler"); - RPCServer = new RPCServerTy(NumDevices); + RPCServer = new RPCServerTy(); assert(RPCServer && "Invalid RPC server"); return Plugin::success(); diff --git a/openmp/libomptarget/plugins-nextgen/common/src/RPC.cpp b/openmp/libomptarget/plugins-nextgen/common/src/RPC.cpp index f46b27701b5b9..fab0f6838f4a8 100644 --- a/openmp/libomptarget/plugins-nextgen/common/src/RPC.cpp +++ b/openmp/libomptarget/plugins-nextgen/common/src/RPC.cpp @@ -21,14 +21,6 @@ using namespace llvm; using namespace omp; using namespace target; -RPCServerTy::RPCServerTy(uint32_t NumDevices) { -#ifdef LIBOMPTARGET_RPC_SUPPORT - // If this fails then something is catastrophically wrong, just exit. - if (rpc_status_t Err = rpc_init(NumDevices)) - FATAL_MESSAGE(1, "Error initializing the RPC server: %d\n", Err); -#endif -} - llvm::Expected RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device, plugin::GenericGlobalHandlerTy &Handler, @@ -44,7 +36,6 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, plugin::GenericGlobalHandlerTy &Handler, plugin::DeviceImageTy &Image) { #ifdef LIBOMPTARGET_RPC_SUPPORT - uint32_t DeviceId = Device.getDeviceId(); auto Alloc = [](uint64_t Size, void *Data) { plugin::GenericDeviceTy &Device = *reinterpret_cast(Data); @@ -52,10 +43,12 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, }; uint64_t NumPorts = std::min(Device.requestedRPCPortCount(), RPC_MAXIMUM_PORT_COUNT); - if (rpc_status_t Err = rpc_server_init(DeviceId, NumPorts, + rpc_device_t RPCDevice; + if (rpc_status_t Err = rpc_server_init(&RPCDevice, NumPorts, Device.getWarpSize(), Alloc, &Device)) return plugin::Plugin::error( - "Failed to initialize RPC server for device %d: %d", DeviceId, Err); + "Failed to initialize RPC server for device %d: %d", + Device.getDeviceId(), Err); // Register a custom opcode handler to perform plugin specific allocation. auto MallocHandler = [](rpc_port_t Port, void *Data) { @@ -70,10 +63,10 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, Data); }; if (rpc_status_t Err = - rpc_register_callback(DeviceId, RPC_MALLOC, MallocHandler, &Device)) + rpc_register_callback(RPCDevice, RPC_MALLOC, MallocHandler, &Device)) return plugin::Plugin::error( - "Failed to register RPC malloc handler for device %d: %d\n", DeviceId, - Err); + "Failed to register RPC malloc handler for device %d: %d\n", + Device.getDeviceId(), Err); // Register a custom opcode handler to perform plugin specific deallocation. auto FreeHandler = [](rpc_port_t Port, void *Data) { @@ -88,10 +81,10 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, Data); }; if (rpc_status_t Err = - rpc_register_callback(DeviceId, RPC_FREE, FreeHandler, &Device)) + rpc_register_callback(RPCDevice, RPC_FREE, FreeHandler, &Device)) return plugin::Plugin::error( - "Failed to register RPC free handler for device %d: %d\n", DeviceId, - Err); + "Failed to register RPC free handler for device %d: %d\n", + Device.getDeviceId(), Err); // Get the address of the RPC client from the device. void *ClientPtr; @@ -104,17 +97,20 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, sizeof(void *), nullptr)) return Err; - const void *ClientBuffer = rpc_get_client_buffer(DeviceId); + const void *ClientBuffer = rpc_get_client_buffer(RPCDevice); if (auto Err = Device.dataSubmit(ClientPtr, ClientBuffer, rpc_get_client_size(), nullptr)) return Err; + Handles.resize(Device.getDeviceId() + 1); + Handles[Device.getDeviceId()] = RPCDevice.handle; #endif return Error::success(); } Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) { #ifdef LIBOMPTARGET_RPC_SUPPORT - if (rpc_status_t Err = rpc_handle_server(Device.getDeviceId())) + rpc_device_t RPCDevice{Handles[Device.getDeviceId()]}; + if (rpc_status_t Err = rpc_handle_server(RPCDevice)) return plugin::Plugin::error( "Error while running RPC server on device %d: %d", Device.getDeviceId(), Err); @@ -124,22 +120,16 @@ Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) { Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) { #ifdef LIBOMPTARGET_RPC_SUPPORT + rpc_device_t RPCDevice{Handles[Device.getDeviceId()]}; auto Dealloc = [](void *Ptr, void *Data) { plugin::GenericDeviceTy &Device = *reinterpret_cast(Data); Device.free(Ptr, TARGET_ALLOC_HOST); }; - if (rpc_status_t Err = - rpc_server_shutdown(Device.getDeviceId(), Dealloc, &Device)) + if (rpc_status_t Err = rpc_server_shutdown(RPCDevice, Dealloc, &Device)) return plugin::Plugin::error( "Failed to shut down RPC server for device %d: %d", Device.getDeviceId(), Err); #endif return Error::success(); } - -RPCServerTy::~RPCServerTy() { -#ifdef LIBOMPTARGET_RPC_SUPPORT - rpc_shutdown(); -#endif -}