Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions distributed/tests/test_worker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ def f(i):


@gen_cluster(client=True)
async def test_client_executor(c, s, a, b):
@pytest.mark.parametrize("separate_thread", [True, False])
async def test_client_executor(c, s, a, b, separate_thread):
def mysum():
with worker_client() as c:
with worker_client(separate_thread=separate_thread) as c:
with c.get_executor() as e:
return sum(e.map(double, range(30)))

Expand Down Expand Up @@ -339,3 +340,40 @@ def long_running():
assert len(res) == 2
assert res[a.address] > 25
assert res[b.address] > 25


@pytest.mark.parametrize("separate_thread", [True, False])
def test_sync_func_on_main_thread(client, separate_thread):
"""A synchronous client running a worker that calls an async task which submits its own task to
the scheduler should not fail"""
# https://github.com/dask/distributed/issues/5513

async def inc(n):
return n + 1

async def f():
with worker_client(separate_thread=separate_thread) as c:
m = c.submit(inc, 1)
return m

res = client.submit(f)
assert res.exception() is None


@gen_cluster(client=True)
@pytest.mark.parametrize("separate_thread", [True, False])
async def test_async_func_on_main_thread(c, s, a, b, separate_thread):
"""An asynchronous client running a worker that calls an async task which submits its own task to
the scheduler should not fail"""
# https://github.com/dask/distributed/issues/5513

async def inc(n):
return n + 1

async def f():
with worker_client(separate_thread=separate_thread) as client:
m = await client.submit(inc, 1)
return m

res = await c.submit(f)
assert res == 2
10 changes: 8 additions & 2 deletions distributed/worker_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
import warnings
from contextlib import contextmanager

Expand Down Expand Up @@ -50,7 +51,12 @@ def worker_client(timeout=None, separate_thread=True):

worker = get_worker()
client = get_client(timeout=timeout)
if separate_thread:

# When passing the client an async function, it runs on the event loop
# in the main thread instead of a background thread. This causes secede() to fail,
is_main_thread = threading.current_thread() is threading.main_thread()
if not is_main_thread and separate_thread:

duration = time() - thread_state.start_time
secede() # have this thread secede from the thread pool
worker.loop.add_callback(
Expand All @@ -63,7 +69,7 @@ def worker_client(timeout=None, separate_thread=True):

yield client

if separate_thread:
if not is_main_thread and separate_thread:
rejoin()


Expand Down