Skip to content

Commit

Permalink
Test labeled_comprehension with object type (#76)
Browse files Browse the repository at this point in the history
Includes a test using `labeled_comprehension` with functions that return
`object` type results. While this is also being used by `histogram`
currently, it makes sense to have a test of `labeled_comprehension`
for this case as well in the event that `histogram`'s implementation
changes. The ability to return `object` type results is very handy for
using a variety of functions with `labeled_comprehension`. So it is good
to be sure that `labeled_comprehension` retains the behavior regardless
of other changes.
  • Loading branch information
jakirkham committed Sep 30, 2018
1 parent b7ca885 commit 21230de
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions tests/test_dask_image/test_ndmeasure/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,59 @@ def func_max_argmax(val, pos):
assert e_a_r.dtype == e_d_r.dtype
assert e_a_r.shape == e_d_r.shape
assert np.allclose(np.array(e_a_r), np.array(e_d_r), equal_nan=True)


@pytest.mark.parametrize(
"shape, chunks, ind", [
((15, 16), (4, 5), None),
((5, 6, 4), (2, 3, 2), None),
((15, 16), (4, 5), 0),
((15, 16), (4, 5), 1),
((15, 16), (4, 5), [1]),
((15, 16), (4, 5), [1, 2]),
((5, 6, 4), (2, 3, 2), [1, 2]),
((15, 16), (4, 5), [1, 100]),
((5, 6, 4), (2, 3, 2), [1, 100]),
]
)
def test_labeled_comprehension_object(shape, chunks, ind):
a = np.random.random(shape)
d = da.from_array(a, chunks=chunks)

lbls = np.zeros(a.shape, dtype=np.int64)
lbls += (
(a < 0.5).astype(lbls.dtype) +
(a < 0.25).astype(lbls.dtype) +
(a < 0.125).astype(lbls.dtype) +
(a < 0.0625).astype(lbls.dtype)
)
d_lbls = da.from_array(lbls, chunks=d.chunks)

dtype = np.dtype(object)

default = None

def func_min_max(val):
return np.array([np.min(val), np.max(val)])

a_r = spnd.labeled_comprehension(
a, lbls, ind, func_min_max, dtype, default, False
)

d_r = dask_image.ndmeasure.labeled_comprehension(
d, d_lbls, ind, func_min_max, dtype, default, False
)

if ind is None or np.isscalar(ind):
if a_r is None:
assert d_r.compute() is None
else:
np.allclose(a_r, d_r.compute(), equal_nan=True)
else:
assert a_r.dtype == d_r.dtype
assert a_r.shape == d_r.shape
for i in it.product(*[irange(_) for _ in a_r.shape]):
if a_r[i] is None:
assert d_r[i].compute() is None
else:
assert np.allclose(a_r[i], d_r[i].compute(), equal_nan=True)

0 comments on commit 21230de

Please sign in to comment.