Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions wish/python/src/wish_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <nanobind/stl/function.h>
#include <nanobind/stl/string.h>

#include <atomic>
#include <condition_variable>
#include <memory>
#include <mutex>

Expand Down Expand Up @@ -56,6 +58,12 @@ struct TlsClientPy {
nb::object on_message_cb;
std::shared_ptr<WishHandlerRef> handler_ref;

// Tracks whether Run() is currently executing. Used by tls_clear to
// wait for the event loop to exit before it clears the callbacks.
std::atomic<bool> running{false};
std::mutex stopped_mu;
std::condition_variable stopped_cv;

TlsClientPy(const std::string& ca, const std::string& cert,
const std::string& key, const std::string& host, int port)
: client(ca, cert, key, host, port) {};
Expand All @@ -67,6 +75,10 @@ struct PlainClientPy {
nb::object on_message_cb;
std::shared_ptr<WishHandlerRef> handler_ref;

std::atomic<bool> running{false};
std::mutex stopped_mu;
std::condition_variable stopped_cv;

PlainClientPy(const std::string& host, int port)
: client(host, port) {}
};
Expand All @@ -84,8 +96,29 @@ static int tls_traverse(PyObject* self, visitproc visit, void* arg) {

static int tls_clear(PyObject* self) {
TlsClientPy* w = nb::inst_ptr<TlsClientPy>(nb::handle(self));
// Clear the C++ callbacks first so the lambda (which captures &*w) is
// dropped before we invalidate on_open_cb / on_message_cb.

// 1. Ask the event loop to stop.
//
// event_base_loopexit is thread-safe.
w->client.Stop();

// 2. Release the GIL and wait for Run() to return.
//
// Without this step there is a data race:
// - Event loop thread reads on_open_ / on_message_ and then blocks
// waiting to acquire the GIL (nb::gil_scoped_acquire).
// - GC thread (holding the GIL) writes those same std::function
// objects via SetOnOpen({}) etc. → UB.
// By releasing the GIL here we let any in-flight callback finish, after
// which event_base_dispatch returns and Run() signals stopped_cv.
{
PyThreadState* ts = PyEval_SaveThread(); // release GIL
std::unique_lock<std::mutex> lk(w->stopped_mu);
w->stopped_cv.wait(lk, [w] { return !w->running.load(std::memory_order_acquire); });
PyEval_RestoreThread(ts); // reacquire GIL
}

// 3. Event loop has exited; mutations are now single-threaded and safe.
w->client.SetOnOpen({});
w->client.SetOnClose({});
w->client.SetOnMessage({});
Expand Down Expand Up @@ -113,6 +146,17 @@ static int plain_traverse(PyObject* self, visitproc visit, void* arg) {

static int plain_clear(PyObject* self) {
PlainClientPy* w = nb::inst_ptr<PlainClientPy>(nb::handle(self));

// Same Stop-then-wait pattern as tls_clear. See comments there.
w->client.Stop();

{
PyThreadState* ts = PyEval_SaveThread();
std::unique_lock<std::mutex> lk(w->stopped_mu);
w->stopped_cv.wait(lk, [w] { return !w->running.load(std::memory_order_acquire); });
PyEval_RestoreThread(ts);
}

w->client.SetOnOpen({});
w->client.SetOnClose({});
w->client.SetOnMessage({});
Expand Down Expand Up @@ -194,7 +238,14 @@ NB_MODULE(wish_ext, m) {
}
});
})
.def("run", [](TlsClientPy& self) { self.client.Run(); }, nb::call_guard<nb::gil_scoped_release>())
.def("run", [](TlsClientPy& self) {
self.running.store(true, std::memory_order_release);
self.client.Run(); // blocks in event_base_dispatch with GIL released
{
std::lock_guard<std::mutex> lk(self.stopped_mu);
self.running.store(false, std::memory_order_release);
}
self.stopped_cv.notify_all(); }, nb::call_guard<nb::gil_scoped_release>())
.def("stop", [](TlsClientPy& self) { self.client.Stop(); });

// ---- PlainClient ------------------------------------------------------
Expand Down Expand Up @@ -244,6 +295,13 @@ NB_MODULE(wish_ext, m) {
}
});
})
.def("run", [](PlainClientPy& self) { self.client.Run(); }, nb::call_guard<nb::gil_scoped_release>())
.def("run", [](PlainClientPy& self) {
self.running.store(true, std::memory_order_release);
self.client.Run();
{
std::lock_guard<std::mutex> lk(self.stopped_mu);
self.running.store(false, std::memory_order_release);
}
self.stopped_cv.notify_all(); }, nb::call_guard<nb::gil_scoped_release>())
.def("stop", [](PlainClientPy& self) { self.client.Stop(); });
}