|
|
@@ -0,0 +1,173 @@ |
|
|
//===- RPC.h - Interface for remote procedure calls from the GPU ----------===// |
|
|
// |
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
|
|
// See https://llvm.org/LICENSE.txt for license information. |
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
|
|
// |
|
|
//===----------------------------------------------------------------------===// |
|
|
|
|
|
#include "RPC.h" |
|
|
#include "Debug.h" |
|
|
#include "PluginInterface.h" |
|
|
|
|
|
// This header file may be present in-tree or from an LLVM installation. The |
|
|
// installed version lives alongside the GPU headers so we do not want to |
|
|
// include it directly. |
|
|
#if __has_include(<gpu-none-llvm/rpc_server.h>) |
|
|
#include <gpu-none-llvm/rpc_server.h> |
|
|
#elif defined(LIBOMPTARGET_RPC_SUPPORT) |
|
|
#include <rpc_server.h> |
|
|
#endif |
|
|
|
|
|
using namespace llvm; |
|
|
using namespace omp; |
|
|
using namespace target; |
|
|
|
|
|
RPCServerTy::RPCServerTy(uint32_t NumDevices) { |
|
|
#ifdef LIBOMPTARGET_RPC_SUPPORT |
|
|
// If this fails then something is catastrophically wrong, just exit. |
|
|
if (rpc_status_t Err = rpc_init(NumDevices)) |
|
|
FATAL_MESSAGE(1, "Error initializing the RPC server: %d\n", Err); |
|
|
Handles.resize(NumDevices); |
|
|
#endif |
|
|
} |
|
|
|
|
|
llvm::Expected<bool> |
|
|
RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device, |
|
|
plugin::GenericGlobalHandlerTy &Handler, |
|
|
plugin::DeviceImageTy &Image) { |
|
|
#ifdef LIBOMPTARGET_RPC_SUPPORT |
|
|
void *ClientPtr; |
|
|
plugin::GlobalTy Global(rpc_client_symbol_name, sizeof(void *), &ClientPtr); |
|
|
if (auto Err = Handler.readGlobalFromImage(Device, Image, Global)) { |
|
|
llvm::consumeError(std::move(Err)); |
|
|
return false; |
|
|
} |
|
|
|
|
|
return true; |
|
|
#else |
|
|
return false; |
|
|
#endif |
|
|
} |
|
|
|
|
|
Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, |
|
|
plugin::GenericGlobalHandlerTy &Handler, |
|
|
plugin::DeviceImageTy &Image) { |
|
|
#ifdef LIBOMPTARGET_RPC_SUPPORT |
|
|
uint32_t DeviceId = Device.getDeviceId(); |
|
|
auto Alloc = [](uint64_t Size, void *Data) { |
|
|
plugin::GenericDeviceTy &Device = |
|
|
*reinterpret_cast<plugin::GenericDeviceTy *>(Data); |
|
|
return Device.allocate(Size, nullptr, TARGET_ALLOC_HOST); |
|
|
}; |
|
|
// TODO: Allow the device to declare its requested port count. |
|
|
if (rpc_status_t Err = rpc_server_init(DeviceId, RPC_MAXIMUM_PORT_COUNT, |
|
|
Device.getWarpSize(), Alloc, &Device)) |
|
|
return plugin::Plugin::error( |
|
|
"Failed to initialize RPC server for device %d: %d", DeviceId, Err); |
|
|
|
|
|
// Register a custom opcode handler to perform plugin specific allocation. |
|
|
// FIXME: We need to make sure this uses asynchronous allocations on CUDA. |
|
|
auto MallocHandler = [](rpc_port_t Port, void *Data) { |
|
|
rpc_recv_and_send( |
|
|
Port, |
|
|
[](rpc_buffer_t *Buffer, void *Data) { |
|
|
plugin::GenericDeviceTy &Device = |
|
|
*reinterpret_cast<plugin::GenericDeviceTy *>(Data); |
|
|
Buffer->data[0] = reinterpret_cast<uintptr_t>( |
|
|
Device.allocate(Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE)); |
|
|
}, |
|
|
Data); |
|
|
}; |
|
|
if (rpc_status_t Err = |
|
|
rpc_register_callback(DeviceId, RPC_MALLOC, MallocHandler, &Device)) |
|
|
return plugin::Plugin::error( |
|
|
"Failed to register RPC malloc handler for device %d: %d\n", DeviceId, |
|
|
Err); |
|
|
|
|
|
// Register a custom opcode handler to perform plugin specific deallocation. |
|
|
auto FreeHandler = [](rpc_port_t Port, void *Data) { |
|
|
rpc_recv( |
|
|
Port, |
|
|
[](rpc_buffer_t *Buffer, void *Data) { |
|
|
plugin::GenericDeviceTy &Device = |
|
|
*reinterpret_cast<plugin::GenericDeviceTy *>(Data); |
|
|
Device.free(reinterpret_cast<void *>(Buffer->data[0]), |
|
|
TARGET_ALLOC_DEVICE); |
|
|
}, |
|
|
Data); |
|
|
}; |
|
|
if (rpc_status_t Err = |
|
|
rpc_register_callback(DeviceId, RPC_FREE, FreeHandler, &Device)) |
|
|
return plugin::Plugin::error( |
|
|
"Failed to register RPC free handler for device %d: %d\n", DeviceId, |
|
|
Err); |
|
|
|
|
|
// Get the address of the RPC client from the device. |
|
|
void *ClientPtr; |
|
|
plugin::GlobalTy ClientGlobal(rpc_client_symbol_name, sizeof(void *)); |
|
|
if (auto Err = |
|
|
Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal)) |
|
|
return Err; |
|
|
|
|
|
if (auto Err = Device.dataRetrieve(&ClientPtr, ClientGlobal.getPtr(), |
|
|
sizeof(void *), nullptr)) |
|
|
return Err; |
|
|
|
|
|
const void *ClientBuffer = rpc_get_client_buffer(DeviceId); |
|
|
if (auto Err = Device.dataSubmit(ClientPtr, ClientBuffer, |
|
|
rpc_get_client_size(), nullptr)) |
|
|
return Err; |
|
|
|
|
|
Handles[DeviceId] = std::make_unique<RPCHandleTy>(*this, Device); |
|
|
#endif |
|
|
return Error::success(); |
|
|
} |
|
|
|
|
|
llvm::Expected<RPCHandleTy *> |
|
|
RPCServerTy::getDevice(plugin::GenericDeviceTy &Device) { |
|
|
#ifdef LIBOMPTARGET_RPC_SUPPORT |
|
|
uint32_t DeviceId = Device.getDeviceId(); |
|
|
if (!Handles[DeviceId] || !rpc_get_buffer(DeviceId) || |
|
|
!rpc_get_client_buffer(DeviceId)) |
|
|
return plugin::Plugin::error( |
|
|
"Attempt to get an RPC device while not initialized"); |
|
|
|
|
|
return Handles[DeviceId].get(); |
|
|
#else |
|
|
return plugin::Plugin::error( |
|
|
"Attempt to get an RPC device while not available"); |
|
|
#endif |
|
|
} |
|
|
|
|
|
Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) { |
|
|
#ifdef LIBOMPTARGET_RPC_SUPPORT |
|
|
if (rpc_status_t Err = rpc_handle_server(Device.getDeviceId())) |
|
|
return plugin::Plugin::error( |
|
|
"Error while running RPC server on device %d: %d", Device.getDeviceId(), |
|
|
Err); |
|
|
#endif |
|
|
return Error::success(); |
|
|
} |
|
|
|
|
|
Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) { |
|
|
#ifdef LIBOMPTARGET_RPC_SUPPORT |
|
|
auto Dealloc = [](void *Ptr, void *Data) { |
|
|
plugin::GenericDeviceTy &Device = |
|
|
*reinterpret_cast<plugin::GenericDeviceTy *>(Data); |
|
|
Device.free(Ptr, TARGET_ALLOC_HOST); |
|
|
}; |
|
|
if (rpc_status_t Err = |
|
|
rpc_server_shutdown(Device.getDeviceId(), Dealloc, &Device)) |
|
|
return plugin::Plugin::error( |
|
|
"Failed to shut down RPC server for device %d: %d", |
|
|
Device.getDeviceId(), Err); |
|
|
#endif |
|
|
return Error::success(); |
|
|
} |
|
|
|
|
|
RPCServerTy::~RPCServerTy() { |
|
|
#ifdef LIBOMPTARGET_RPC_SUPPORT |
|
|
rpc_shutdown(); |
|
|
#endif |
|
|
} |