Skip to content

Commit

Permalink
Test storing and keeping results after
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Dec 9, 2017
1 parent 1dba2ff commit 84a0b38
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions dask/array/tests/test_array_core.py
Expand Up @@ -1192,6 +1192,7 @@ def make_target(key):
atd = delayed(make_target)('at')
btd = delayed(make_target)('bt')

# test not keeping result
st = store([a, b], [atd, btd])

at = targs['at']
Expand All @@ -1201,6 +1202,20 @@ def make_target(key):
assert_eq(at, a)
assert_eq(bt, b)

# test keeping result
st = store([a, b], [atd, btd], keep=True)

at = targs['at']
bt = targs['bt']

assert st is not None
assert isinstance(st, tuple)
assert all([isinstance(v, np.ndarray) for v in st])
assert_eq(at, a)
assert_eq(bt, b)
assert_eq(st[0], a)
assert_eq(st[1], b)

pytest.raises(ValueError, lambda: store([a], [at, bt]))
pytest.raises(ValueError, lambda: store(at, at))
pytest.raises(ValueError, lambda: store([at, bt], [at, bt]))
Expand Down Expand Up @@ -1240,6 +1255,29 @@ def test_store_regions():
assert not (bt == 3).all() and not ( bt == 0 ).all()
assert not (at == 2).all() and not ( at == 0 ).all()

# Single region (keep result):
at = np.zeros(shape=(8, 4, 6))
bt = np.zeros(shape=(8, 4, 6))
v = store([a, b], [at, bt], regions=region, compute=False, keep=True)
assert isinstance(v, tuple)
assert all([isinstance(e, da.Array) for e in v])
assert (at == 0).all() and (bt[region] == 0).all()

ar, br = v
assert ar.dtype == a.dtype
assert br.dtype == b.dtype
assert ar.shape == a.shape
assert br.shape == b.shape
assert ar.chunks == a.chunks
assert br.chunks == b.chunks

ar, br = da.compute(ar, br)
assert (at[region] == 2).all() and (bt[region] == 3).all()
assert not (bt == 3).all() and not ( bt == 0 ).all()
assert not (at == 2).all() and not ( at == 0 ).all()
assert (br == 3).all()
assert (ar == 2).all()

# Multiple regions:
at = np.zeros(shape=(8, 4, 6))
bt = np.zeros(shape=(8, 4, 6))
Expand All @@ -1251,6 +1289,29 @@ def test_store_regions():
assert not (bt == 3).all() and not ( bt == 0 ).all()
assert not (at == 2).all() and not ( at == 0 ).all()

# Multiple regions (keep result):
at = np.zeros(shape=(8, 4, 6))
bt = np.zeros(shape=(8, 4, 6))
v = store([a, b], [at, bt], regions=[region, region], compute=False, keep=True)
assert isinstance(v, tuple)
assert all([isinstance(e, da.Array) for e in v])
assert (at == 0).all() and (bt[region] == 0).all()

ar, br = v
assert ar.dtype == a.dtype
assert br.dtype == b.dtype
assert ar.shape == a.shape
assert br.shape == b.shape
assert ar.chunks == a.chunks
assert br.chunks == b.chunks

ar, br = da.compute(ar, br)
assert (at[region] == 2).all() and (bt[region] == 3).all()
assert not (bt == 3).all() and not ( bt == 0 ).all()
assert not (at == 2).all() and not ( at == 0 ).all()
assert (br == 3).all()
assert (ar == 2).all()


def test_store_compute_false():
d = da.ones((4, 4), chunks=(2, 2))
Expand Down

0 comments on commit 84a0b38

Please sign in to comment.