diff --git a/offload/plugins-nextgen/common/include/RPC.h b/offload/plugins-nextgen/common/include/RPC.h index d750ce30e74b0..7b031083647aa 100644 --- a/offload/plugins-nextgen/common/include/RPC.h +++ b/offload/plugins-nextgen/common/include/RPC.h @@ -80,7 +80,7 @@ struct RPCServerTy { std::thread Worker; /// A boolean indicating whether or not the worker thread should continue. - std::atomic Running; + std::atomic Running; /// The number of currently executing kernels across all devices that need /// the server thread to be running. diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index 76ae0a2dd9c45..48c9b671c1a91 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -1058,9 +1058,8 @@ Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin, if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image)) return Err; - if (!Server.Thread->Running.load(std::memory_order_acquire)) - if (auto Err = Server.startThread()) - return Err; + if (auto Err = Server.startThread()) + return Err; RPCServer = &Server; DP("Running an RPC server on device %d\n", getDeviceId()); @@ -1635,12 +1634,11 @@ Error GenericPluginTy::deinit() { if (GlobalHandler) delete GlobalHandler; - if (RPCServer && RPCServer->Thread->Running.load(std::memory_order_acquire)) + if (RPCServer) { if (Error Err = RPCServer->shutDown()) return Err; - - if (RPCServer) delete RPCServer; + } if (RecordReplay) delete RecordReplay; diff --git a/offload/plugins-nextgen/common/src/RPC.cpp b/offload/plugins-nextgen/common/src/RPC.cpp index e6750a540b391..4289f920c0e1e 100644 --- a/offload/plugins-nextgen/common/src/RPC.cpp +++ b/offload/plugins-nextgen/common/src/RPC.cpp @@ -99,18 +99,15 @@ static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) { } void RPCServerTy::ServerThread::startThread() { - assert(!Running.load(std::memory_order_relaxed) && - "Attempting to start thread that is already running"); - Running.store(true, std::memory_order_release); - Worker = std::thread([this]() { run(); }); + if (!Running.fetch_or(true, std::memory_order_acquire)) + Worker = std::thread([this]() { run(); }); } void RPCServerTy::ServerThread::shutDown() { - assert(Running.load(std::memory_order_relaxed) && - "Attempting to shut down a thread that is not running"); + if (!Running.fetch_and(false, std::memory_order_release)) + return; { std::lock_guard Lock(Mutex); - Running.store(false, std::memory_order_release); CV.notify_all(); } if (Worker.joinable())