Skip to content

Commit

Permalink
Refactor out _unravel_index function (#60)
Browse files Browse the repository at this point in the history
* Add a function unravel indices in Dask Arrays

Provides a utility function in `ndmeasure` to handle unraveling a Dask
Array of 1-D indices into multiple dimensions. Uses a simple kernel
function to perform the unraveling on 1-D NumPy Array chunks and coerce
them into a format that Dask can handle. This kernel function is then
used in a Dask Array `map_blocks` call to handle conversion of the full
Dask Array of 1-D indices into N-D indices.

* Test the unravel indices utility function

Make sure that our unravel indices utility function behaves roughly like
NumPy's. They are not a 1-to-1 match as we prefer to return a single
Dask Array instead of a `tuple` of Dask Arrays, which is better suited
for our use case.

* Use the unravel indices utility function

Instead of handling the unraveling of a Dask Array of indices through
mathematical operations in Dask (which creates a more complex graph),
Simply use the utility function, which simply uses `map_blocks` on each
chunk to get the N-D indices simply. This makes it easier for Dask to
potentially fuse this task with other tasks. Also keeps the Dask graph
pretty clean.
  • Loading branch information
jakirkham committed Sep 7, 2018
1 parent 2490065 commit d439e77
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 14 deletions.
16 changes: 2 additions & 14 deletions dask_image/ndmeasure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,7 @@ def maximum_position(input, labels=None, index=None):
if not max_1dpos_lbl.ndim:
max_1dpos_lbl = max_1dpos_lbl[None]

max_pos_lbl = []
max_1dpos_lbl_rem = max_1dpos_lbl
for i in _pycompat.irange(input.ndim):
d = int(numpy.prod(input.shape[i + 1:]))
max_pos_lbl.append(max_1dpos_lbl_rem // d)
max_1dpos_lbl_rem %= d
max_pos_lbl = dask.array.stack(max_pos_lbl, axis=1)
max_pos_lbl = _utils._unravel_index(max_1dpos_lbl, input.shape)

if index.shape == tuple():
max_pos_lbl = dask.array.squeeze(max_pos_lbl)
Expand Down Expand Up @@ -537,13 +531,7 @@ def minimum_position(input, labels=None, index=None):
if not min_1dpos_lbl.ndim:
min_1dpos_lbl = min_1dpos_lbl[None]

min_pos_lbl = []
min_1dpos_lbl_rem = min_1dpos_lbl
for i in _pycompat.irange(input.ndim):
d = int(numpy.prod(input.shape[i + 1:]))
min_pos_lbl.append(min_1dpos_lbl_rem // d)
min_1dpos_lbl_rem %= d
min_pos_lbl = dask.array.stack(min_pos_lbl, axis=1)
min_pos_lbl = _utils._unravel_index(min_1dpos_lbl, input.shape)

if index.shape == tuple():
min_pos_lbl = dask.array.squeeze(min_pos_lbl)
Expand Down
27 changes: 27 additions & 0 deletions dask_image/ndmeasure/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,33 @@ def _ravel_shape_indices(dimensions, dtype=int, chunks=None):
return indices


def _unravel_index_kernel(indices, func_kwargs):
return numpy.stack(numpy.unravel_index(indices, **func_kwargs), axis=1)


def _unravel_index(indices, dims, order='C'):
"""
Unravels the indices like NumPy's ``unravel_index``.
Uses NumPy's ``unravel_index`` on Dask Array blocks.
"""

if dims and indices.size:
unraveled_indices = indices.map_blocks(
_unravel_index_kernel,
dtype=numpy.intp,
chunks=indices.chunks + ((len(dims),),),
new_axis=1,
func_kwargs={"dims": dims, "order": order}
)
else:
unraveled_indices = dask.array.empty(
(0, len(dims)), dtype=numpy.intp, chunks=1
)

return unraveled_indices


def _argmax(a, positions):
"""
Find original array position corresponding to the maximum.
Expand Down
22 changes: 22 additions & 0 deletions tests/test_dask_image/test_ndmeasure/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,25 @@ def test___ravel_shape_indices(shape, chunks):
)

dau.assert_eq(d, a)


@pytest.mark.parametrize(
"nindices, shape, order", [
(0, (15,), 'C'),
(1, (15,), 'C'),
(3, (15,), 'C'),
(3, (15,), 'F'),
(2, (15, 16), 'C'),
(2, (15, 16), 'F'),
]
)
def test__unravel_index(nindices, shape, order):
findices = np.random.randint(np.prod(shape, dtype=int), size=nindices)
d_findices = da.from_array(findices, chunks=1)

indices = np.stack(np.unravel_index(findices, shape, order), axis=1)
d_indices = dask_image.ndmeasure._utils._unravel_index(
d_findices, shape, order
)

dau.assert_eq(d_indices, indices)

0 comments on commit d439e77

Please sign in to comment.