Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 32 additions & 53 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from dataclasses import dataclass
from functools import partial
from operator import add
from threading import Semaphore
from time import sleep
from typing import Any
from unittest import mock
Expand All @@ -37,7 +36,7 @@
import pytest
import yaml
from packaging.version import parse as parse_version
from tlz import concat, first, identity, isdistinct, merge, pluck, valmap
from tlz import concat, identity, isdistinct, merge, pluck, valmap

import dask
import dask.bag as db
Expand Down Expand Up @@ -3849,7 +3848,7 @@ def test_open_close_many_workers(loop, worker, count, repeat):
with cluster(nworkers=0, active_rpc_timeout=2) as (s, _):
gc.collect()
before = proc.num_fds()
done = Semaphore(0)
done = threading.Semaphore(0)
running = weakref.WeakKeyDictionary()
workers = set()
status = True
Expand Down Expand Up @@ -5409,63 +5408,43 @@ def test_quiet_quit_when_cluster_leaves(loop_in_thread):
assert not text


@gen_cluster([("127.0.0.1", 4)] * 2, client=True)
async def test_call_stack_future(c, s, a, b):
x = c.submit(slowdec, 1, delay=0.5)
future = c.submit(slowinc, 1, delay=0.5)
await asyncio.sleep(0.1)
results = await asyncio.gather(
c.call_stack(future), c.call_stack(keys=[future.key])
)
assert all(list(first(result.values())) == [future.key] for result in results)
assert results[0] == results[1]
result = results[0]
ts = a.state.tasks.get(future.key)
if ts is not None and ts.state == "executing":
w = a
else:
w = b
@gen_cluster(client=True)
async def test_call_stack(c, s, a, b):
e1, e2, e3, ew = Event(), Event(), Event(), Event()

assert list(result) == [w.address]
assert list(result[w.address]) == [future.key]
assert "slowinc" in str(result)
assert "slowdec" not in str(result)
def f(es: Event, ew: Event) -> None:
es.set()
ew.wait()

f1 = c.submit(f, e1, ew, key="f1", workers=[a.address])
f2 = c.submit(f, e2, ew, key="f2", workers=[b.address])
d3 = c.persist(delayed(f)(e3, ew, dask_key_name="d3"), workers=[b.address])
await e1.wait()
await e2.wait()
await e3.wait()

@gen_cluster([("127.0.0.1", 4)] * 2, client=True)
async def test_call_stack_all(c, s, a, b):
future = c.submit(slowinc, 1, delay=0.8)
while not a.state.executing_count and not b.state.executing_count:
await asyncio.sleep(0.01)
result = await c.call_stack()
w = a if a.state.executing_count else b
assert list(result) == [w.address]
assert list(result[w.address]) == [future.key]
assert "slowinc" in str(result)
# Test future or keys
r1a = await c.call_stack(f1)
r1b = await c.call_stack([f1])
r1c = await c.call_stack(keys=[f1.key])

assert r1a == r1b == r1c
assert r1a.keys() == {a.address}
assert r1a[a.address].keys() == {"f1"}
assert any("event.py" in frame for frame in r1a[a.address]["f1"])

@gen_cluster([("127.0.0.1", 4)] * 2, client=True)
async def test_call_stack_collections(c, s, a, b):
pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")

x = c.persist(da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5))
while not a.state.executing_count and not b.state.executing_count:
await asyncio.sleep(0.001)
result = await c.call_stack(x)
assert result
# test collection
r3 = await c.call_stack(d3)
assert r3.keys() == {b.address}
assert r3[b.address].keys() == {"d3"}

# test all
r4 = await c.call_stack()
assert r4.keys() == {a.address, b.address}
assert r4[a.address].keys() == {"f1"}
assert r4[b.address].keys() == {"f2", "d3"}

@gen_cluster([("127.0.0.1", 4)] * 2, client=True)
async def test_call_stack_collections_all(c, s, a, b):
pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")

x = c.persist(da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5))
while not a.state.executing_count and not b.state.executing_count:
await asyncio.sleep(0.001)
result = await c.call_stack()
assert result
await ew.set()


@pytest.mark.skipif(sys.version_info.minor == 11, reason="Profiler disabled")
Expand Down
Loading