Skip to content

Commit

Permalink
Handle structured dtype in labeled_comprehension (#66)
Browse files Browse the repository at this point in the history
* Type cast record values in `labeled_comprehension`

Make sure that type casting within `labeled_comprehension` and
associated utility functions handles record (a.k.a. structured) values
in addition to typical scalar types.

* Scalar arrays in _labeled_comprehension_delayed

Instead of squeezing NumPy arrays into scalars when returning results
from `_labeled_comprehension_delayed`, actually just return NumPy 0-D
scalar arrays. These behave a little better in some cases (e.g. when a
function returns a scalar array).

* Use 1-D singleton array in `labeled_comprehension`

To better handle structured type values that might be returned from the
custom user function, coerce all results returned by
`labeled_comprehension` into 1-D singleton arrays. As Dask Arrays seem
to do a better job of handling structured types when they have some
shape to them, this addresses returning structured types.

* Adjust `labeled_comprehension` result handling

Preallocate an array to store the result in
`_labeled_comprehension_delayed`. Makes sure that writing into the
`default` value doesn't happen by accident. Also makes sure that only
scalar values are handled from custom user functions. As not enforcing
these constraints may result in the computation breaking later in more
confusing ways, explicit handling of them in
`_labeled_comprehension_delayed` should point out to the user where they
may have gone wrong.

* Test `labeled_comprehension` with structured types

Include a test using `labeled_comprehension` with structured dtypes.
  • Loading branch information
jakirkham committed Sep 17, 2018
1 parent 3dcf9f3 commit ee4349b
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 6 deletions.
4 changes: 2 additions & 2 deletions dask_image/ndmeasure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def labeled_comprehension(input,
input, labels, index
)
out_dtype = numpy.dtype(out_dtype)
default = out_dtype.type(default)
default = numpy.array([default], dtype=out_dtype)
pass_positions = bool(pass_positions)

lbl_mtch = _utils._get_label_matches(labels, index)
Expand Down Expand Up @@ -305,7 +305,7 @@ def labeled_comprehension(input,
for j in index_ranges_i:
result2[j] = dask.array.stack(result[j].tolist(), axis=0)
result = result2
result = result[()]
result = result[()][..., 0]

return result

Expand Down
12 changes: 8 additions & 4 deletions dask_image/ndmeasure/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,17 @@ def _labeled_comprehension_delayed(func,
computation should not occur.
"""

result = numpy.empty((1,), dtype=out_dtype)

if a.size:
if positions is None:
return out_dtype.type(func(a))
result[0] = func(a)
else:
return out_dtype.type(func(a, positions))
result[0] = func(a, positions)
else:
return default
result[0] = default[0]

return result


def _labeled_comprehension_func(func,
Expand All @@ -166,6 +170,6 @@ def _labeled_comprehension_func(func,

return dask.array.from_delayed(
_labeled_comprehension_delayed(func, out_dtype, default, a, positions),
tuple(),
(1,),
out_dtype
)
67 changes: 67 additions & 0 deletions tests/test_dask_image/test_ndmeasure/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,70 @@ def func(val, pos=None):
assert a_cm.dtype == d_cm.dtype
assert a_cm.shape == d_cm.shape
assert np.allclose(np.array(a_cm), np.array(d_cm), equal_nan=True)


@pytest.mark.parametrize(
"shape, chunks, ind", [
((15, 16), (4, 5), None),
((5, 6, 4), (2, 3, 2), None),
((15, 16), (4, 5), 0),
((15, 16), (4, 5), 1),
((15, 16), (4, 5), [1]),
((15, 16), (4, 5), [1, 2]),
((5, 6, 4), (2, 3, 2), [1, 2]),
((15, 16), (4, 5), [1, 100]),
((5, 6, 4), (2, 3, 2), [1, 100]),
]
)
def test_labeled_comprehension_struct(shape, chunks, ind):
a = np.random.random(shape)
d = da.from_array(a, chunks=chunks)

lbls = np.zeros(a.shape, dtype=np.int64)
lbls += (
(a < 0.5).astype(lbls.dtype) +
(a < 0.25).astype(lbls.dtype) +
(a < 0.125).astype(lbls.dtype) +
(a < 0.0625).astype(lbls.dtype)
)
d_lbls = da.from_array(lbls, chunks=d.chunks)

dtype = np.dtype([("val", np.float64), ("pos", np.int)])

default = np.array((np.nan, -1), dtype=dtype)

def func_max(val):
return val[np.argmax(val)]

def func_argmax(val, pos):
return pos[np.argmax(val)]

def func_max_argmax(val, pos):
result = np.empty((), dtype=dtype)

i = np.argmax(val)

result["val"] = val[i]
result["pos"] = pos[i]

return result[()]

a_max = spnd.labeled_comprehension(
a, lbls, ind, func_max, dtype["val"], default["val"], False
)
a_argmax = spnd.labeled_comprehension(
a, lbls, ind, func_argmax, dtype["pos"], default["pos"], True
)

d_max_argmax = dask_image.ndmeasure.labeled_comprehension(
d, d_lbls, ind, func_max_argmax, dtype, default, True
)
d_max = d_max_argmax["val"]
d_argmax = d_max_argmax["pos"]

assert dtype == d_max_argmax.dtype

for e_a_r, e_d_r in zip([a_max, a_argmax], [d_max, d_argmax]):
assert e_a_r.dtype == e_d_r.dtype
assert e_a_r.shape == e_d_r.shape
assert np.allclose(np.array(e_a_r), np.array(e_d_r), equal_nan=True)

0 comments on commit ee4349b

Please sign in to comment.