diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index 468bd90d463..4855a82ca64 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -180,9 +180,10 @@ def f(i): @gen_cluster(client=True) -async def test_client_executor(c, s, a, b): +@pytest.mark.parametrize("separate_thread", [True, False]) +async def test_client_executor(c, s, a, b, separate_thread): def mysum(): - with worker_client() as c: + with worker_client(separate_thread=separate_thread) as c: with c.get_executor() as e: return sum(e.map(double, range(30))) @@ -339,3 +340,40 @@ def long_running(): assert len(res) == 2 assert res[a.address] > 25 assert res[b.address] > 25 + + +@pytest.mark.parametrize("separate_thread", [True, False]) +def test_sync_func_on_main_thread(client, separate_thread): + """A synchronous client running a worker that calls an async task which submits its own task to + the scheduler should not fail""" + # https://github.com/dask/distributed/issues/5513 + + async def inc(n): + return n + 1 + + async def f(): + with worker_client(separate_thread=separate_thread) as c: + m = c.submit(inc, 1) + return m + + res = client.submit(f) + assert res.exception() is None + + +@gen_cluster(client=True) +@pytest.mark.parametrize("separate_thread", [True, False]) +async def test_async_func_on_main_thread(c, s, a, b, separate_thread): + """An asynchronous client running a worker that calls an async task which submits its own task to + the scheduler should not fail""" + # https://github.com/dask/distributed/issues/5513 + + async def inc(n): + return n + 1 + + async def f(): + with worker_client(separate_thread=separate_thread) as client: + m = await client.submit(inc, 1) + return m + + res = await c.submit(f) + assert res == 2 diff --git a/distributed/worker_client.py b/distributed/worker_client.py index 5a775b38191..d76d03b796b 100644 --- a/distributed/worker_client.py +++ b/distributed/worker_client.py @@ -1,3 +1,4 @@ +import threading import warnings from contextlib import contextmanager @@ -50,7 +51,12 @@ def worker_client(timeout=None, separate_thread=True): worker = get_worker() client = get_client(timeout=timeout) - if separate_thread: + + # When passing the client an async function, it runs on the event loop + # in the main thread instead of a background thread. This causes secede() to fail, + is_main_thread = threading.current_thread() is threading.main_thread() + if not is_main_thread and separate_thread: + duration = time() - thread_state.start_time secede() # have this thread secede from the thread pool worker.loop.add_callback( @@ -63,7 +69,7 @@ def worker_client(timeout=None, separate_thread=True): yield client - if separate_thread: + if not is_main_thread and separate_thread: rejoin()