Skip to content

Commit

Permalink
support worker_client() in async tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed May 18, 2023
1 parent 7a5b4e2 commit 6c483b1
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
18 changes: 18 additions & 0 deletions distributed/tests/test_worker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ def func(x):
assert len([id for id in s.clients if id.lower().startswith("client")]) == 1


@gen_cluster(client=True)
async def test_submit_from_worker_async(c, s, a, b):
async def func(x):
with worker_client() as c:
x = c.submit(inc, x)
y = c.submit(double, x)
return await x + await y

x, y = c.map(func, [10, 20])
xx, yy = await c.gather([x, y])

assert xx == 10 + 1 + (10 + 1) * 2
assert yy == 20 + 1 + (20 + 1) * 2

assert len(s.transition_log) > 10
assert len([id for id in s.clients if id.lower().startswith("client")]) == 1


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_scatter_from_worker(c, s, a, b):
def func():
Expand Down
40 changes: 22 additions & 18 deletions distributed/worker_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import contextlib
import warnings
from contextlib import contextmanager

import dask

Expand All @@ -11,7 +11,7 @@
from distributed.worker_state_machine import SecedeEvent


@contextmanager
@contextlib.contextmanager
def worker_client(timeout=None, separate_thread=True):
"""Get client for this thread
Expand Down Expand Up @@ -53,22 +53,26 @@ def worker_client(timeout=None, separate_thread=True):

worker = get_worker()
client = get_client(timeout=timeout)
if separate_thread:
duration = time() - thread_state.start_time
secede() # have this thread secede from the thread pool
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"worker-client-secede-{time()}",
),
)

yield client

if separate_thread:
rejoin()
with contextlib.ExitStack() as stack:
if separate_thread:
try:
thread_state.start_time
except AttributeError: # not in a synchronous task, can't secede
pass
else:
duration = time() - thread_state.start_time
secede() # have this thread secede from the thread pool
stack.callback(rejoin)
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"worker-client-secede-{time()}",
),
)

yield client


def local_client(*args, **kwargs):
Expand Down

0 comments on commit 6c483b1

Please sign in to comment.