Skip to content

Commit

Permalink
Add tests for P2P barrier fusion (#7845)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed May 23, 2023
1 parent e887fde commit 755f768
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 4 deletions.
24 changes: 23 additions & 1 deletion distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,29 @@ async def test_rechunk_2d(c, s, *ws):
async def test_rechunk_4d(c, s, *ws):
"""Try rechunking a random 4d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_4d
"""
old = ((5, 5),) * 4
a = np.random.uniform(0, 1, 10000).reshape((10,) * 4)
x = da.from_array(a, chunks=old)
new = (
(10,),
(10,),
(10,),
(8, 2),
) # This has been altered to return >1 output partition
x2 = rechunk(x, chunks=new, method="p2p")
assert x2.chunks == new
await c.compute(x2)
assert np.all(await c.compute(x2) == a)


@gen_cluster(client=True)
async def test_rechunk_with_single_output_chunk_raises(c, s, *ws):
"""See distributed#7816
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_4d
Expand All @@ -212,7 +235,6 @@ async def test_rechunk_4d(c, s, *ws):
RuntimeError, "rechunk_transfer failed", RuntimeError, "Barrier task"
):
await c.compute(x2)
# assert np.all(await c.compute(x2) == a)


@gen_cluster(client=True)
Expand Down
43 changes: 40 additions & 3 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@
invoke_annotation_chaos,
)
from distributed.utils import Deadline
from distributed.utils_test import cluster, gen_cluster, gen_test, wait_for_state
from distributed.utils_test import (
cluster,
gen_cluster,
gen_test,
raises_with_cause,
wait_for_state,
)
from distributed.worker_state_machine import TaskState as WorkerTaskState

try:
Expand Down Expand Up @@ -104,16 +110,21 @@ async def test_minimal_version(c, s, a, b):
await c.compute(dd.shuffle.shuffle(df, "x", shuffle="p2p"))


@pytest.mark.parametrize("npartitions", [None, 1, 20])
@gen_cluster(client=True)
async def test_basic_integration(c, s, a, b, lose_annotations):
async def test_basic_integration(c, s, a, b, lose_annotations, npartitions):
await invoke_annotation_chaos(lose_annotations, c)
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-10",
dtypes={"x": float, "y": float},
freq="10 s",
)
out = dd.shuffle.shuffle(df, "x", shuffle="p2p")
out = dd.shuffle.shuffle(df, "x", shuffle="p2p", npartitions=npartitions)
if npartitions is None:
assert out.npartitions == df.npartitions
else:
assert out.npartitions == npartitions
x, y = c.compute([df.x.size, out.x.size])
x = await x
y = await y
Expand All @@ -124,6 +135,32 @@ async def test_basic_integration(c, s, a, b, lose_annotations):
await clean_scheduler(s)


@pytest.mark.parametrize("npartitions", [None, 1, 20])
@gen_cluster(client=True)
async def test_shuffle_with_array_conversion(c, s, a, b, lose_annotations, npartitions):
await invoke_annotation_chaos(lose_annotations, c)
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-10",
dtypes={"x": float, "y": float},
freq="10 s",
)
out = dd.shuffle.shuffle(df, "x", shuffle="p2p", npartitions=npartitions).values

if npartitions == 1:
# FIXME: distributed#7816
with raises_with_cause(
RuntimeError, "shuffle_transfer failed", RuntimeError, "Barrier task"
):
await c.compute(out)
else:
await c.compute(out)

await clean_worker(a)
await clean_worker(b)
await clean_scheduler(s)


def test_shuffle_before_categorize(loop_in_thread):
"""Regression test for https://github.com/dask/distributed/issues/7615"""
with cluster() as (s, [a, b]), Client(s["address"], loop=loop_in_thread) as c:
Expand Down

0 comments on commit 755f768

Please sign in to comment.