Skip to content

Commit

Permalink
Update the use of fixtures in distributed tests
Browse files Browse the repository at this point in the history
Our CI was failing with the 2.6 release.  This fixes a bug there, and
also modernizes how we test to be more in line with the dask.distributed
library.

Also, the pre-commit commit for black was giving me trouble.  I've
downgraded this to the commit that we use in distributed, which seems to
work well.
  • Loading branch information
mrocklin committed Oct 16, 2019
1 parent dc70324 commit 1aae769
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 92 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/python/black
rev: 73bd7038fbefdb1c6a61fa1edf16ff61613050a5
rev: cad4138050b86d1c8570b926883e32f7465c2880
hooks:
- id: black
language_version: python3.7
Expand Down
168 changes: 77 additions & 91 deletions dask/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@
from dask.delayed import Delayed
from dask.utils import tmpdir, get_named_args
from distributed import futures_of
from distributed.client import wait, Client
from distributed.utils_test import gen_cluster, inc, cluster, loop # noqa F401
from distributed.client import wait
from distributed.utils_test import ( # noqa F401
gen_cluster,
inc,
cluster,
cluster_fixture,
loop,
client as c,
)


if "should_check_state" in get_named_args(gen_cluster):
Expand All @@ -41,65 +48,59 @@ def test_persist(c, s, a, b):
assert y2.key in a.data or y2.key in b.data


def test_persist_nested(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop):
a = delayed(1) + 5
b = a + 1
c = a + 2
result = persist({"a": a, "b": [1, 2, b]}, (c, 2), 4, [5])
assert isinstance(result[0]["a"], Delayed)
assert isinstance(result[0]["b"][2], Delayed)
assert isinstance(result[1][0], Delayed)
def test_persist_nested(c):
a = delayed(1) + 5
b = a + 1
c = a + 2
result = persist({"a": a, "b": [1, 2, b]}, (c, 2), 4, [5])
assert isinstance(result[0]["a"], Delayed)
assert isinstance(result[0]["b"][2], Delayed)
assert isinstance(result[1][0], Delayed)

sol = ({"a": 6, "b": [1, 2, 7]}, (8, 2), 4, [5])
assert compute(*result) == sol
sol = ({"a": 6, "b": [1, 2, 7]}, (8, 2), 4, [5])
assert compute(*result) == sol

res = persist([a, b], c, 4, [5], traverse=False)
assert res[0][0] is a
assert res[0][1] is b
assert res[1].compute() == 8
assert res[2:] == (4, [5])
res = persist([a, b], c, 4, [5], traverse=False)
assert res[0][0] is a
assert res[0][1] is b
assert res[1].compute() == 8
assert res[2:] == (4, [5])


def test_futures_to_delayed_dataframe(loop):
def test_futures_to_delayed_dataframe(c):
pd = pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
df = pd.DataFrame({"x": [1, 2, 3]})
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as c:
futures = c.scatter([df, df])
ddf = dd.from_delayed(futures)
dd.utils.assert_eq(ddf.compute(), pd.concat([df, df], axis=0))

with pytest.raises(TypeError):
ddf = dd.from_delayed([1, 2])
futures = c.scatter([df, df])
ddf = dd.from_delayed(futures)
dd.utils.assert_eq(ddf.compute(), pd.concat([df, df], axis=0))

with pytest.raises(TypeError):
ddf = dd.from_delayed([1, 2])


def test_futures_to_delayed_bag(loop):
def test_futures_to_delayed_bag(c):
db = pytest.importorskip("dask.bag")
L = [1, 2, 3]
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as c:
futures = c.scatter([L, L])
b = db.from_delayed(futures)
assert list(b) == L + L

futures = c.scatter([L, L])
b = db.from_delayed(futures)
assert list(b) == L + L


def test_futures_to_delayed_array(loop):
def test_futures_to_delayed_array(c):
da = pytest.importorskip("dask.array")
from dask.array.utils import assert_eq

np = pytest.importorskip("numpy")
x = np.arange(5)
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as c:
futures = c.scatter([x, x])
A = da.concatenate(
[da.from_delayed(f, shape=x.shape, dtype=x.dtype) for f in futures],
axis=0,
)
assert_eq(A.compute(), np.concatenate([x, x], axis=0))

futures = c.scatter([x, x])
A = da.concatenate(
[da.from_delayed(f, shape=x.shape, dtype=x.dtype) for f in futures], axis=0
)
assert_eq(A.compute(), np.concatenate([x, x], axis=0))


@gen_cluster(client=True)
Expand All @@ -114,12 +115,10 @@ def test_local_get_with_distributed_active(c, s, a, b):
assert not s.tasks # scheduler hasn't done anything


def test_to_hdf_distributed(loop):
def test_to_hdf_distributed(c):
from ..dataframe.io.tests.test_hdf import test_to_hdf

with cluster() as (s, [a, b]):
with distributed.Client(s["address"], loop=loop):
test_to_hdf()
test_to_hdf()


@pytest.mark.parametrize(
Expand All @@ -136,12 +135,10 @@ def test_to_hdf_distributed(loop):
),
],
)
def test_to_hdf_scheduler_distributed(npartitions, loop):
def test_to_hdf_scheduler_distributed(npartitions, c):
from ..dataframe.io.tests.test_hdf import test_to_hdf_schedulers

with cluster() as (s, [a, b]):
with distributed.Client(s["address"], loop=loop):
test_to_hdf_schedulers(None, npartitions)
test_to_hdf_schedulers(None, npartitions)


@gen_cluster(client=True)
Expand All @@ -156,58 +153,47 @@ def test_serializable_groupby_agg(c, s, a, b):
yield c.compute(result)


def test_futures_in_graph(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as c:
x, y = delayed(1), delayed(2)
xx = delayed(add)(x, x)
yy = delayed(add)(y, y)
xxyy = delayed(add)(xx, yy)
def test_futures_in_graph(c):
x, y = delayed(1), delayed(2)
xx = delayed(add)(x, x)
yy = delayed(add)(y, y)
xxyy = delayed(add)(xx, yy)

xxyy2 = c.persist(xxyy)
xxyy3 = delayed(add)(xxyy2, 10)
xxyy2 = c.persist(xxyy)
xxyy3 = delayed(add)(xxyy2, 10)

assert (
xxyy3.compute(scheduler="dask.distributed") == ((1 + 1) + (2 + 2)) + 10
)
assert xxyy3.compute(scheduler="dask.distributed") == ((1 + 1) + (2 + 2)) + 10


def test_zarr_distributed_roundtrip(loop):
def test_zarr_distributed_roundtrip():
da = pytest.importorskip("dask.array")
pytest.importorskip("zarr")
assert_eq = da.utils.assert_eq
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop):
with tmpdir() as d:
a = da.zeros((3, 3), chunks=(1, 1))
a.to_zarr(d)
a2 = da.from_zarr(d)
assert_eq(a, a2)
assert a2.chunks == a.chunks

with tmpdir() as d:
a = da.zeros((3, 3), chunks=(1, 1))
a.to_zarr(d)
a2 = da.from_zarr(d)
assert_eq(a, a2)
assert a2.chunks == a.chunks


def test_zarr_in_memory_distributed_err(loop):
def test_zarr_in_memory_distributed_err(c):
da = pytest.importorskip("dask.array")
zarr = pytest.importorskip("zarr")
with cluster() as (s, [a, b]):
with Client(
s["address"], loop=loop, client_kwargs={"set_as_default": True}
) as c:
with pytest.raises(RuntimeError):
c = (1, 1)
a = da.ones((3, 3), chunks=c)
z = zarr.zeros_like(a, chunks=c)
a.to_zarr(z)


def test_scheduler_equals_client(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as client:
x = delayed(lambda: 1)()
assert x.compute(scheduler=client) == 1
assert client.run_on_scheduler(
lambda dask_scheduler: dask_scheduler.story(x.key)
)

c = (1, 1)
a = da.ones((3, 3), chunks=c)
z = zarr.zeros_like(a, chunks=c)

with pytest.raises(RuntimeError):
a.to_zarr(z)


def test_scheduler_equals_client(c):
x = delayed(lambda: 1)()
assert x.compute(scheduler=c) == 1
assert c.run_on_scheduler(lambda dask_scheduler: dask_scheduler.story(x.key))


@gen_cluster(client=True)
Expand Down

0 comments on commit 1aae769

Please sign in to comment.