Skip to content

Commit

Permalink
[Libomptarget] Remove RPCHandleTy indirection
Browse files Browse the repository at this point in the history
The 'RPCHandleTy' was intended to capture the intention that a specific
device owns its slot in the RPC server. However, this required creating
a temporary store to hold these pointers. This was causing really weird
spurious failure due to undefined behaviour in the order of library
teardown. For example, the x64 plugin would be torn down, set this to
some invalid memory, and then the CUDA plugin would crash. Rather than
spend the time to fully diagnose this problem I found it pertinent to
simply remove the failure mode.

This patch removes this indirection so now the usage of the RPC server
must always be done with the intended device. This just requires some
extra handling for the AMDGPU indirection where we need to store a
reference to the device.

Reviewed By: JonChesterfield

Differential Revision: https://reviews.llvm.org/D154971
  • Loading branch information
jhuber6 committed Jul 11, 2023
1 parent 14742f2 commit 8a0763f
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 70 deletions.
35 changes: 19 additions & 16 deletions openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,9 @@ struct AMDGPUSignalTy {
}

/// Wait until the signal gets a zero value.
Error wait(const uint64_t ActiveTimeout = 0,
RPCHandleTy *RPCHandle = nullptr) const {
if (ActiveTimeout && !RPCHandle) {
Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr,
GenericDeviceTy *Device = nullptr) const {
if (ActiveTimeout && !RPCServer) {
hsa_signal_value_t Got = 1;
Got = hsa_signal_wait_scacquire(Signal, HSA_SIGNAL_CONDITION_EQ, 0,
ActiveTimeout, HSA_WAIT_STATE_ACTIVE);
Expand All @@ -531,12 +531,12 @@ struct AMDGPUSignalTy {
}

// If there is an RPC device attached to this stream we run it as a server.
uint64_t Timeout = RPCHandle ? 8192 : UINT64_MAX;
auto WaitState = RPCHandle ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED;
uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX;
auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED;
while (hsa_signal_wait_scacquire(Signal, HSA_SIGNAL_CONDITION_EQ, 0,
Timeout, WaitState) != 0) {
if (RPCHandle)
if (auto Err = RPCHandle->runServer())
if (RPCServer && Device)
if (auto Err = RPCServer->runServer(*Device))
return Err;
}
return Plugin::success();
Expand Down Expand Up @@ -888,6 +888,9 @@ struct AMDGPUStreamTy {
/// The manager of signals to reuse signals.
AMDGPUSignalManagerTy &SignalManager;

/// A reference to the associated device.
GenericDeviceTy &Device;

/// Array of stream slots. Use std::deque because it can dynamically grow
/// without invalidating the already inserted elements. For instance, the
/// std::vector may invalidate the elements by reallocating the internal
Expand All @@ -907,7 +910,7 @@ struct AMDGPUStreamTy {
/// A pointer associated with an RPC server running on the given device. If
/// RPC is not being used this will be a null pointer. Otherwise, this
/// indicates that an RPC server is expected to be run on this stream.
RPCHandleTy *RPCHandle;
RPCServerTy *RPCServer;

/// Mutex to protect stream's management.
mutable std::mutex Mutex;
Expand Down Expand Up @@ -1064,8 +1067,8 @@ struct AMDGPUStreamTy {
/// Deinitialize the stream's signals.
Error deinit() { return Plugin::success(); }

/// Attach an RPC handle to this stream.
void setRPCHandle(RPCHandleTy *Handle) { RPCHandle = Handle; }
/// Attach an RPC server to this stream.
void setRPCServer(RPCServerTy *Server) { RPCServer = Server; }

/// Push a asynchronous kernel to the stream. The kernel arguments must be
/// placed in a special allocation for kernel args and must keep alive until
Expand Down Expand Up @@ -1281,8 +1284,8 @@ struct AMDGPUStreamTy {
return Plugin::success();

// Wait until all previous operations on the stream have completed.
if (auto Err =
Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, RPCHandle))
if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds,
RPCServer, &Device))
return Err;

// Reset the stream and perform all pending post actions.
Expand Down Expand Up @@ -2529,9 +2532,9 @@ Error AMDGPUResourceRef<ResourceTy>::create(GenericDeviceTy &Device) {

AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device)
: Agent(Device.getAgent()), Queue(Device.getNextQueue()),
SignalManager(Device.getSignalManager()),
SignalManager(Device.getSignalManager()), Device(Device),
// Initialize the std::deque with some empty positions.
Slots(32), NextSlot(0), SyncCycle(0), RPCHandle(nullptr),
Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr),
StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()) {}

/// Class implementing the AMDGPU-specific functionalities of the global
Expand Down Expand Up @@ -2866,8 +2869,8 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
AMDGPUStreamTy &Stream = AMDGPUDevice.getStream(AsyncInfoWrapper);

// If this kernel requires an RPC server we attach its pointer to the stream.
if (GenericDevice.getRPCHandle())
Stream.setRPCHandle(GenericDevice.getRPCHandle());
if (GenericDevice.getRPCServer())
Stream.setRPCServer(GenericDevice.getRPCServer());

// Push the kernel launch into the stream.
return Stream.pushKernelLaunch(*this, AllArgs, NumThreads, NumBlocks,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ elseif(${LIBOMPTARGET_GPU_LIBC_SUPPORT})
find_library(llvmlibc_rpc_server NAMES llvmlibc_rpc_server
PATHS ${LIBOMPTARGET_LLVM_LIBRARY_DIR} NO_DEFAULT_PATH)
if(llvmlibc_rpc_server)
message(WARNING ${llvmlibc_rpc_server})
target_link_libraries(PluginInterface PRIVATE llvmlibc_rpc_server)
target_compile_definitions(PluginInterface PRIVATE LIBOMPTARGET_RPC_SUPPORT)
endif()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ GenericDeviceTy::GenericDeviceTy(int32_t DeviceId, int32_t NumDevices,
OMPX_InitialNumEvents("LIBOMPTARGET_NUM_INITIAL_EVENTS", 32),
DeviceId(DeviceId), GridValues(OMPGridValues),
PeerAccesses(NumDevices, PeerAccessState::PENDING), PeerAccessesLock(),
PinnedAllocs(*this), RPCHandle(nullptr) {
PinnedAllocs(*this), RPCServer(nullptr) {
#ifdef OMPT_SUPPORT
OmptInitialized.store(false);
// Bind the callbacks to this device's member functions
Expand Down Expand Up @@ -483,8 +483,8 @@ Error GenericDeviceTy::deinit() {
if (RecordReplay.isRecordingOrReplaying())
RecordReplay.deinit();

if (RPCHandle)
if (auto Err = RPCHandle->deinitDevice())
if (RPCServer)
if (auto Err = RPCServer->deinitDevice(*this))
return Err;

#ifdef OMPT_SUPPORT
Expand Down Expand Up @@ -599,10 +599,7 @@ Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin,
if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image))
return Err;

auto DeviceOrErr = Server.getDevice(*this);
if (!DeviceOrErr)
return DeviceOrErr.takeError();
RPCHandle = *DeviceOrErr;
RPCServer = &Server;
DP("Running an RPC server on device %d\n", getDeviceId());
return Plugin::success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
}

/// Get the RPC server running on this device.
RPCHandleTy *getRPCHandle() const { return RPCHandle; }
RPCServerTy *getRPCServer() const { return RPCServer; }

private:
/// Register offload entry for global variable.
Expand Down Expand Up @@ -857,7 +857,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy {

/// A pointer to an RPC server instance attached to this device if present.
/// This is used to run the RPC server during task synchronization.
RPCHandleTy *RPCHandle;
RPCServerTy *RPCServer;

#ifdef OMPT_SUPPORT
/// OMPT callback functions
Expand Down
19 changes: 0 additions & 19 deletions openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ RPCServerTy::RPCServerTy(uint32_t NumDevices) {
// If this fails then something is catastrophically wrong, just exit.
if (rpc_status_t Err = rpc_init(NumDevices))
FATAL_MESSAGE(1, "Error initializing the RPC server: %d\n", Err);
Handles.resize(NumDevices);
#endif
}

Expand Down Expand Up @@ -118,28 +117,10 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
if (auto Err = Device.dataSubmit(ClientPtr, ClientBuffer,
rpc_get_client_size(), nullptr))
return Err;

Handles[DeviceId] = std::make_unique<RPCHandleTy>(*this, Device);
#endif
return Error::success();
}

llvm::Expected<RPCHandleTy *>
RPCServerTy::getDevice(plugin::GenericDeviceTy &Device) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
uint32_t DeviceId = Device.getDeviceId();
if (!Handles[DeviceId] || !rpc_get_buffer(DeviceId) ||
!rpc_get_client_buffer(DeviceId))
return plugin::Plugin::error(
"Attempt to get an RPC device while not initialized");

return Handles[DeviceId].get();
#else
return plugin::Plugin::error(
"Attempt to get an RPC device while not available");
#endif
}

Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
if (rpc_status_t Err = rpc_handle_server(Device.getDeviceId()))
Expand Down
23 changes: 0 additions & 23 deletions openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,6 @@ class DeviceImageTy;
/// these routines will perform no action.
struct RPCServerTy {
public:
/// A wrapper around a single instance of the RPC server for a given device.
/// This is provided to simplify ownership of the underlying device.
struct RPCHandleTy {
RPCHandleTy(RPCServerTy &Server, plugin::GenericDeviceTy &Device)
: Server(Server), Device(Device) {}

llvm::Error runServer() { return Server.runServer(Device); }

llvm::Error deinitDevice() { return Server.deinitDevice(Device); }

private:
RPCServerTy &Server;
plugin::GenericDeviceTy &Device;
};

RPCServerTy(uint32_t NumDevices);

/// Check if this device image is using an RPC server. This checks for the
Expand All @@ -63,9 +48,6 @@ struct RPCServerTy {
plugin::GenericGlobalHandlerTy &Handler,
plugin::DeviceImageTy &Image);

/// Gets a reference to this server for a specific device.
llvm::Expected<RPCHandleTy *> getDevice(plugin::GenericDeviceTy &Device);

/// Runs the RPC server associated with the \p Device until the pending work
/// is cleared.
llvm::Error runServer(plugin::GenericDeviceTy &Device);
Expand All @@ -75,13 +57,8 @@ struct RPCServerTy {
llvm::Error deinitDevice(plugin::GenericDeviceTy &Device);

~RPCServerTy();

private:
llvm::SmallVector<std::unique_ptr<RPCHandleTy>> Handles;
};

using RPCHandleTy = RPCServerTy::RPCHandleTy;

} // namespace llvm::omp::target

#endif
4 changes: 2 additions & 2 deletions openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,12 @@ struct CUDADeviceTy : public GenericDeviceTy {
CUresult Res;
// If we have an RPC server running on this device we will continuously
// query it for work rather than blocking.
if (!getRPCHandle()) {
if (!getRPCServer()) {
Res = cuStreamSynchronize(Stream);
} else {
do {
Res = cuStreamQuery(Stream);
if (auto Err = getRPCHandle()->runServer())
if (auto Err = getRPCServer()->runServer(*this))
return Err;
} while (Res == CUDA_ERROR_NOT_READY);
}
Expand Down

0 comments on commit 8a0763f

Please sign in to comment.