Skip to content

Commit

Permalink
pythonGH-110829: Ensure Thread.join() joins the OS thread (python#110848
Browse files Browse the repository at this point in the history
)

Joining a thread now ensures the underlying OS thread has exited. This is required for safer fork() in multi-threaded processes.

---------

Co-authored-by: blurb-it[bot] <43283697+blurb-it[bot]@users.noreply.github.com>
  • Loading branch information
pitrou and blurb-it[bot] committed Nov 4, 2023
1 parent a28a396 commit 0e9c364
Show file tree
Hide file tree
Showing 14 changed files with 671 additions and 98 deletions.
1 change: 1 addition & 0 deletions Include/cpython/pthread_stubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ PyAPI_FUNC(int) pthread_create(pthread_t *restrict thread,
void *(*start_routine)(void *),
void *restrict arg);
PyAPI_FUNC(int) pthread_detach(pthread_t thread);
PyAPI_FUNC(int) pthread_join(pthread_t thread, void** value_ptr);
PyAPI_FUNC(pthread_t) pthread_self(void);
PyAPI_FUNC(int) pthread_exit(void *retval) __attribute__ ((__noreturn__));
PyAPI_FUNC(int) pthread_attr_init(pthread_attr_t *attr);
Expand Down
42 changes: 42 additions & 0 deletions Include/internal/pycore_pythread.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,48 @@ PyAPI_FUNC(PyLockStatus) PyThread_acquire_lock_timed_with_retries(
PyThread_type_lock,
PY_TIMEOUT_T microseconds);

typedef unsigned long long PyThread_ident_t;
typedef Py_uintptr_t PyThread_handle_t;

#define PY_FORMAT_THREAD_IDENT_T "llu"
#define Py_PARSE_THREAD_IDENT_T "K"

PyAPI_FUNC(PyThread_ident_t) PyThread_get_thread_ident_ex(void);

/* Thread joining APIs.
*
* These APIs have a strict contract:
* - Either PyThread_join_thread or PyThread_detach_thread must be called
* exactly once with the given handle.
* - Calling neither PyThread_join_thread nor PyThread_detach_thread results
* in a resource leak until the end of the process.
* - Any other usage, such as calling both PyThread_join_thread and
* PyThread_detach_thread, or calling them more than once (including
* simultaneously), results in undefined behavior.
*/
PyAPI_FUNC(int) PyThread_start_joinable_thread(void (*func)(void *),
void *arg,
PyThread_ident_t* ident,
PyThread_handle_t* handle);
/*
* Join a thread started with `PyThread_start_joinable_thread`.
* This function cannot be interrupted. It returns 0 on success,
* a non-zero value on failure.
*/
PyAPI_FUNC(int) PyThread_join_thread(PyThread_handle_t);
/*
* Detach a thread started with `PyThread_start_joinable_thread`, such
* that its resources are relased as soon as it exits.
* This function cannot be interrupted. It returns 0 on success,
* a non-zero value on failure.
*/
PyAPI_FUNC(int) PyThread_detach_thread(PyThread_handle_t);

/*
* Obtain the new thread ident and handle in a forked child process.
*/
PyAPI_FUNC(void) PyThread_update_thread_after_fork(PyThread_ident_t* ident,
PyThread_handle_t* handle);

#ifdef __cplusplus
}
Expand Down
3 changes: 3 additions & 0 deletions Lib/test/_test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2693,6 +2693,9 @@ def test_make_pool(self):
p.join()

def test_terminate(self):
if self.TYPE == 'threads':
self.skipTest("Threads cannot be terminated")

# Simulate slow tasks which take "forever" to complete
p = self.Pool(3)
args = [support.LONG_TIMEOUT for i in range(10_000)]
Expand Down
3 changes: 3 additions & 0 deletions Lib/test/audit-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ def __call__(self):
i = _thread.start_new_thread(test_func(), ())
lock.acquire()

handle = _thread.start_joinable_thread(test_func())
handle.join()


def test_threading_abort():
# Ensures that aborting PyThreadState_New raises the correct exception
Expand Down
2 changes: 2 additions & 0 deletions Lib/test/test_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def test_threading(self):
expected = [
("_thread.start_new_thread", "(<test_func>, (), None)"),
("test.test_func", "()"),
("_thread.start_joinable_thread", "(<test_func>,)"),
("test.test_func", "()"),
]

self.assertEqual(actual, expected)
Expand Down
6 changes: 3 additions & 3 deletions Lib/test/test_concurrent_futures/test_process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def test_python_finalization_error(self):

context = self.get_context()

# gh-109047: Mock the threading.start_new_thread() function to inject
# gh-109047: Mock the threading.start_joinable_thread() function to inject
# RuntimeError: simulate the error raised during Python finalization.
# Block the second creation: create _ExecutorManagerThread, but block
# QueueFeederThread.
orig_start_new_thread = threading._start_new_thread
orig_start_new_thread = threading._start_joinable_thread
nthread = 0
def mock_start_new_thread(func, *args):
nonlocal nthread
Expand All @@ -208,7 +208,7 @@ def mock_start_new_thread(func, *args):
nthread += 1
return orig_start_new_thread(func, *args)

with support.swap_attr(threading, '_start_new_thread',
with support.swap_attr(threading, '_start_joinable_thread',
mock_start_new_thread):
executor = self.executor_type(max_workers=2, mp_context=context)
with executor:
Expand Down
126 changes: 126 additions & 0 deletions Lib/test/test_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,132 @@ def task():
f"Exception ignored in thread started by {task!r}")
self.assertIsNotNone(cm.unraisable.exc_traceback)

def test_join_thread(self):
finished = []

def task():
time.sleep(0.05)
finished.append(thread.get_ident())

with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
handle.join()
self.assertEqual(len(finished), 1)
self.assertEqual(handle.ident, finished[0])

def test_join_thread_already_exited(self):
def task():
pass

with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
time.sleep(0.05)
handle.join()

def test_join_several_times(self):
def task():
pass

with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
handle.join()
with self.assertRaisesRegex(ValueError, "not joinable"):
handle.join()

def test_joinable_not_joined(self):
handle_destroyed = thread.allocate_lock()
handle_destroyed.acquire()

def task():
handle_destroyed.acquire()

with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
del handle
handle_destroyed.release()

def test_join_from_self(self):
errors = []
handles = []
start_joinable_thread_returned = thread.allocate_lock()
start_joinable_thread_returned.acquire()
task_tried_to_join = thread.allocate_lock()
task_tried_to_join.acquire()

def task():
start_joinable_thread_returned.acquire()
try:
handles[0].join()
except Exception as e:
errors.append(e)
finally:
task_tried_to_join.release()

with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
handles.append(handle)
start_joinable_thread_returned.release()
# Can still join after joining failed in other thread
task_tried_to_join.acquire()
handle.join()

assert len(errors) == 1
with self.assertRaisesRegex(RuntimeError, "Cannot join current thread"):
raise errors[0]

def test_detach_from_self(self):
errors = []
handles = []
start_joinable_thread_returned = thread.allocate_lock()
start_joinable_thread_returned.acquire()
thread_detached = thread.allocate_lock()
thread_detached.acquire()

def task():
start_joinable_thread_returned.acquire()
try:
handles[0].detach()
except Exception as e:
errors.append(e)
finally:
thread_detached.release()

with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
handles.append(handle)
start_joinable_thread_returned.release()
thread_detached.acquire()
with self.assertRaisesRegex(ValueError, "not joinable"):
handle.join()

assert len(errors) == 0

def test_detach_then_join(self):
lock = thread.allocate_lock()
lock.acquire()

def task():
lock.acquire()

with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
# detach() returns even though the thread is blocked on lock
handle.detach()
# join() then cannot be called anymore
with self.assertRaisesRegex(ValueError, "not joinable"):
handle.join()
lock.release()

def test_join_then_detach(self):
def task():
pass

with threading_helper.wait_threads_exit():
handle = thread.start_joinable_thread(task)
handle.join()
with self.assertRaisesRegex(ValueError, "not joinable"):
handle.detach()


class Barrier:
def __init__(self, num_threads):
Expand Down
47 changes: 44 additions & 3 deletions Lib/test/test_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,16 +376,16 @@ def test_limbo_cleanup(self):
# Issue 7481: Failure to start thread should cleanup the limbo map.
def fail_new_thread(*args):
raise threading.ThreadError()
_start_new_thread = threading._start_new_thread
threading._start_new_thread = fail_new_thread
_start_joinable_thread = threading._start_joinable_thread
threading._start_joinable_thread = fail_new_thread
try:
t = threading.Thread(target=lambda: None)
self.assertRaises(threading.ThreadError, t.start)
self.assertFalse(
t in threading._limbo,
"Failed to cleanup _limbo map on failure of Thread.start().")
finally:
threading._start_new_thread = _start_new_thread
threading._start_joinable_thread = _start_joinable_thread

def test_finalize_running_thread(self):
# Issue 1402: the PyGILState_Ensure / _Release functions may be called
Expand Down Expand Up @@ -482,6 +482,47 @@ def test_enumerate_after_join(self):
finally:
sys.setswitchinterval(old_interval)

def test_join_from_multiple_threads(self):
# Thread.join() should be thread-safe
errors = []

def worker():
time.sleep(0.005)

def joiner(thread):
try:
thread.join()
except Exception as e:
errors.append(e)

for N in range(2, 20):
threads = [threading.Thread(target=worker)]
for i in range(N):
threads.append(threading.Thread(target=joiner,
args=(threads[0],)))
for t in threads:
t.start()
time.sleep(0.01)
for t in threads:
t.join()
if errors:
raise errors[0]

def test_join_with_timeout(self):
lock = _thread.allocate_lock()
lock.acquire()

def worker():
lock.acquire()

thread = threading.Thread(target=worker)
thread.start()
thread.join(timeout=0.01)
assert thread.is_alive()
lock.release()
thread.join()
assert not thread.is_alive()

def test_no_refcycle_through_target(self):
class RunSelfFunction(object):
def __init__(self, should_raise):
Expand Down

0 comments on commit 0e9c364

Please sign in to comment.