diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index b46c67fec5..b33d558022 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -28,7 +28,6 @@ from dataclasses import dataclass from functools import partial from operator import add -from threading import Semaphore from time import sleep from typing import Any from unittest import mock @@ -37,7 +36,7 @@ import pytest import yaml from packaging.version import parse as parse_version -from tlz import concat, first, identity, isdistinct, merge, pluck, valmap +from tlz import concat, identity, isdistinct, merge, pluck, valmap import dask import dask.bag as db @@ -3849,7 +3848,7 @@ def test_open_close_many_workers(loop, worker, count, repeat): with cluster(nworkers=0, active_rpc_timeout=2) as (s, _): gc.collect() before = proc.num_fds() - done = Semaphore(0) + done = threading.Semaphore(0) running = weakref.WeakKeyDictionary() workers = set() status = True @@ -5409,63 +5408,43 @@ def test_quiet_quit_when_cluster_leaves(loop_in_thread): assert not text -@gen_cluster([("127.0.0.1", 4)] * 2, client=True) -async def test_call_stack_future(c, s, a, b): - x = c.submit(slowdec, 1, delay=0.5) - future = c.submit(slowinc, 1, delay=0.5) - await asyncio.sleep(0.1) - results = await asyncio.gather( - c.call_stack(future), c.call_stack(keys=[future.key]) - ) - assert all(list(first(result.values())) == [future.key] for result in results) - assert results[0] == results[1] - result = results[0] - ts = a.state.tasks.get(future.key) - if ts is not None and ts.state == "executing": - w = a - else: - w = b +@gen_cluster(client=True) +async def test_call_stack(c, s, a, b): + e1, e2, e3, ew = Event(), Event(), Event(), Event() - assert list(result) == [w.address] - assert list(result[w.address]) == [future.key] - assert "slowinc" in str(result) - assert "slowdec" not in str(result) + def f(es: Event, ew: Event) -> None: + es.set() + ew.wait() + f1 = c.submit(f, e1, ew, key="f1", workers=[a.address]) + f2 = c.submit(f, e2, ew, key="f2", workers=[b.address]) + d3 = c.persist(delayed(f)(e3, ew, dask_key_name="d3"), workers=[b.address]) + await e1.wait() + await e2.wait() + await e3.wait() -@gen_cluster([("127.0.0.1", 4)] * 2, client=True) -async def test_call_stack_all(c, s, a, b): - future = c.submit(slowinc, 1, delay=0.8) - while not a.state.executing_count and not b.state.executing_count: - await asyncio.sleep(0.01) - result = await c.call_stack() - w = a if a.state.executing_count else b - assert list(result) == [w.address] - assert list(result[w.address]) == [future.key] - assert "slowinc" in str(result) + # Test future or keys + r1a = await c.call_stack(f1) + r1b = await c.call_stack([f1]) + r1c = await c.call_stack(keys=[f1.key]) + assert r1a == r1b == r1c + assert r1a.keys() == {a.address} + assert r1a[a.address].keys() == {"f1"} + assert any("event.py" in frame for frame in r1a[a.address]["f1"]) -@gen_cluster([("127.0.0.1", 4)] * 2, client=True) -async def test_call_stack_collections(c, s, a, b): - pytest.importorskip("numpy") - da = pytest.importorskip("dask.array") - - x = c.persist(da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5)) - while not a.state.executing_count and not b.state.executing_count: - await asyncio.sleep(0.001) - result = await c.call_stack(x) - assert result + # test collection + r3 = await c.call_stack(d3) + assert r3.keys() == {b.address} + assert r3[b.address].keys() == {"d3"} + # test all + r4 = await c.call_stack() + assert r4.keys() == {a.address, b.address} + assert r4[a.address].keys() == {"f1"} + assert r4[b.address].keys() == {"f2", "d3"} -@gen_cluster([("127.0.0.1", 4)] * 2, client=True) -async def test_call_stack_collections_all(c, s, a, b): - pytest.importorskip("numpy") - da = pytest.importorskip("dask.array") - - x = c.persist(da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5)) - while not a.state.executing_count and not b.state.executing_count: - await asyncio.sleep(0.001) - result = await c.call_stack() - assert result + await ew.set() @pytest.mark.skipif(sys.version_info.minor == 11, reason="Profiler disabled")