Skip to content

Commit

Permalink
Use 1-D structured array fields for position-based kernels in `ndmeas…
Browse files Browse the repository at this point in the history
…ure` (#79)

* Use N-D structured array field in `center_of_mass`

Instead of using multiple structured array fields to pass back the
different coordinate components of the center of mass, use one 1-D field
with the size being the number of components. This simplifies the array
handling a bit and avoids the need to split apart and restack the array.

* Use N-D structured array in `maximum_position`

Rewrite `maximum_position` to use a structured `dtype` with a 1-D field.
This allows the unraveled position to be determined in the
`maximum_position` kernel function and returned. With this change there
is no need to handle unraveling of the position in the Dask graph. Only
the structured field need be selected, which is a constant time
operation. Thus simplifying the Dask graph generated by
`maximum_position` and improving the computation performance as well.

* Use N-D structured array in `minimum_position`

Rewrite `minimum_position` to use a structured `dtype` with a 1-D
field.  This allows the unraveled position to be determined in the
`minimum_position` kernel function and returned. With this change there
is no need to handle unraveling of the position in the Dask graph. Only
the structured field need be selected, which is a constant time
operation. Thus simplifying the Dask graph generated by
`minimum_position` and improving the computation performance as well.

* Use N-D structured array fields in `extrema`

Make use of N-D structured array fields in `extrema`'s kernel function
to pass back N-D coordinate positions. Avoids having an explicit
unraveling step added to the Dask graph after the `extrema`'s kernel
function.

* Drop `_unravel_index` from `ndmeasure._utils`

This function is no longer needed as unraveling of indices now occurs
within kernel functions that are returning position. Hence it is dropped
along with its tests.
  • Loading branch information
jakirkham committed Oct 1, 2018
1 parent f9aea09 commit b470898
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 96 deletions.
65 changes: 31 additions & 34 deletions dask_image/ndmeasure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ def center_of_mass(input, labels=None, index=None):
# This only matters if index is some array.
index = index.T

type_mapping = collections.OrderedDict([
(("%i" % i), input.dtype) for i in _pycompat.irange(input.ndim)
])
out_dtype = numpy.dtype(list(type_mapping.items()))

out_dtype = numpy.dtype([("com", input.dtype, (input.ndim,))])
default_1d = numpy.full((1,), numpy.nan, dtype=out_dtype)

func = functools.partial(
Expand All @@ -64,8 +60,7 @@ def center_of_mass(input, labels=None, index=None):
input, labels, index,
func, out_dtype, default_1d[0], pass_positions=True
)

com_lbl = dask.array.stack([com_lbl[k] for k in type_mapping], axis=-1)
com_lbl = com_lbl["com"]

return com_lbl

Expand Down Expand Up @@ -96,32 +91,28 @@ def extrema(input, labels=None, index=None):
input, labels, index
)

type_mapping = collections.OrderedDict([
out_dtype = numpy.dtype([
("min_val", input.dtype),
("max_val", input.dtype),
("min_pos", numpy.dtype(numpy.int)),
("max_pos", numpy.dtype(numpy.int))
("min_pos", numpy.dtype(numpy.int), input.ndim),
("max_pos", numpy.dtype(numpy.int), input.ndim)
])
out_dtype = numpy.dtype(list(type_mapping.items()))

default_1d = numpy.zeros((1,), dtype=out_dtype)

func = functools.partial(_utils._extrema, dtype=out_dtype)
func = functools.partial(
_utils._extrema, shape=input.shape, dtype=out_dtype
)
extrema_lbl = labeled_comprehension(
input, labels, index,
func, out_dtype, default_1d[0], pass_positions=True
)

extrema_lbl = collections.OrderedDict([
(k, extrema_lbl[k]) for k in type_mapping.keys()
(k, extrema_lbl[k])
for k in ["min_val", "max_val", "min_pos", "max_pos"]
])

for pos_key in ["min_pos", "max_pos"]:
pos_1d = extrema_lbl[pos_key]
if not pos_1d.ndim:
pos_1d = pos_1d[None]

pos_nd = _utils._unravel_index(pos_1d, input.shape)
pos_nd = extrema_lbl[pos_key]

if index.ndim == 0:
pos_nd = dask.array.squeeze(pos_nd)
Expand Down Expand Up @@ -396,14 +387,17 @@ def maximum_position(input, labels=None, index=None):
if index.shape:
index = index.flatten()

max_1dpos_lbl = labeled_comprehension(
input, labels, index, _utils._argmax, int, 0, pass_positions=True
)

if not max_1dpos_lbl.ndim:
max_1dpos_lbl = max_1dpos_lbl[None]
out_dtype = numpy.dtype([("pos", int, (input.ndim,))])
default_1d = numpy.zeros((1,), dtype=out_dtype)

max_pos_lbl = _utils._unravel_index(max_1dpos_lbl, input.shape)
func = functools.partial(
_utils._argmax, shape=input.shape, dtype=out_dtype
)
max_pos_lbl = labeled_comprehension(
input, labels, index,
func, out_dtype, default_1d[0], pass_positions=True
)
max_pos_lbl = max_pos_lbl["pos"]

if index.shape == tuple():
max_pos_lbl = dask.array.squeeze(max_pos_lbl)
Expand Down Expand Up @@ -542,14 +536,17 @@ def minimum_position(input, labels=None, index=None):
if index.shape:
index = index.flatten()

min_1dpos_lbl = labeled_comprehension(
input, labels, index, _utils._argmin, int, 0, pass_positions=True
)

if not min_1dpos_lbl.ndim:
min_1dpos_lbl = min_1dpos_lbl[None]
out_dtype = numpy.dtype([("pos", int, (input.ndim,))])
default_1d = numpy.zeros((1,), dtype=out_dtype)

min_pos_lbl = _utils._unravel_index(min_1dpos_lbl, input.shape)
func = functools.partial(
_utils._argmin, shape=input.shape, dtype=out_dtype
)
min_pos_lbl = labeled_comprehension(
input, labels, index,
func, out_dtype, default_1d[0], pass_positions=True
)
min_pos_lbl = min_pos_lbl["pos"]

if index.shape == tuple():
min_pos_lbl = dask.array.squeeze(min_pos_lbl)
Expand Down
65 changes: 25 additions & 40 deletions dask_image/ndmeasure/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,56 +85,37 @@ def _ravel_shape_indices(dimensions, dtype=int, chunks=None):
return indices


def _unravel_index_kernel(indices, func_kwargs):
nd_indices = numpy.unravel_index(indices, **func_kwargs)
nd_indices = numpy.stack(nd_indices, axis=indices.ndim)
return nd_indices


def _unravel_index(indices, dims, order='C'):
def _argmax(a, positions, shape, dtype):
"""
Unravels the indices like NumPy's ``unravel_index``.
Uses NumPy's ``unravel_index`` on Dask Array blocks.
Find original array position corresponding to the maximum.
"""

if dims and indices.size:
unraveled_indices = indices.map_blocks(
_unravel_index_kernel,
dtype=numpy.intp,
chunks=indices.chunks + ((len(dims),),),
new_axis=indices.ndim,
func_kwargs={"dims": dims, "order": order}
)
else:
unraveled_indices = dask.array.empty(
(0, len(dims)), dtype=numpy.intp, chunks=1
)

return unraveled_indices

result = numpy.empty((1,), dtype=dtype)

def _argmax(a, positions):
"""
Find original array position corresponding to the maximum.
"""
pos_nd = numpy.unravel_index(positions[numpy.argmax(a)], shape)
for i, pos_nd_i in enumerate(pos_nd):
result["pos"][0, i] = pos_nd_i

return positions[numpy.argmax(a)]
return result[0]


def _argmin(a, positions):
def _argmin(a, positions, shape, dtype):
"""
Find original array position corresponding to the minimum.
"""

return positions[numpy.argmin(a)]
result = numpy.empty((1,), dtype=dtype)

pos_nd = numpy.unravel_index(positions[numpy.argmin(a)], shape)
for i, pos_nd_i in enumerate(pos_nd):
result["pos"][0, i] = pos_nd_i

return result[0]


def _center_of_mass(a, positions, shape, dtype):
"""
Find the center of mass for each ROI.
Package the result in a structured array with each field as an index.
"""

result = numpy.empty((1,), dtype=dtype)
Expand All @@ -145,25 +126,29 @@ def _center_of_mass(a, positions, shape, dtype):
a_wt_i = numpy.empty_like(a)
for i, pos_nd_i in enumerate(positions_nd):
a_wt_sum_i = numpy.multiply(a, pos_nd_i, out=a_wt_i).sum()
result[("%i" % i)] = a_wt_sum_i / a_sum
result["com"][0, i] = a_wt_sum_i / a_sum

return result[0]


def _extrema(a, positions, dtype):
def _extrema(a, positions, shape, dtype):
"""
Find minimum and maximum as well as positions for both.
"""

result = numpy.empty((1,), dtype=dtype)

int_min_pos = numpy.argmin(a)
result["min_val"] = a[int_min_pos]
result["min_pos"] = positions[int_min_pos]

int_max_pos = numpy.argmax(a)

result["min_val"] = a[int_min_pos]
result["max_val"] = a[int_max_pos]
result["max_pos"] = positions[int_max_pos]

min_pos_nd = numpy.unravel_index(positions[int_min_pos], shape)
max_pos_nd = numpy.unravel_index(positions[int_max_pos], shape)
for i in range(len(shape)):
result["min_pos"][0, i] = min_pos_nd[i]
result["max_pos"][0, i] = max_pos_nd[i]

return result[0]

Expand Down
22 changes: 0 additions & 22 deletions tests/test_dask_image/test_ndmeasure/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,25 +152,3 @@ def test___ravel_shape_indices(shape, chunks):
)

dau.assert_eq(d, a)


@pytest.mark.parametrize(
"nindices, shape, order", [
(0, (15,), 'C'),
(1, (15,), 'C'),
(3, (15,), 'C'),
(3, (15,), 'F'),
(2, (15, 16), 'C'),
(2, (15, 16), 'F'),
]
)
def test__unravel_index(nindices, shape, order):
findices = np.random.randint(np.prod(shape, dtype=int), size=nindices)
d_findices = da.from_array(findices, chunks=1)

indices = np.stack(np.unravel_index(findices, shape, order), axis=1)
d_indices = dask_image.ndmeasure._utils._unravel_index(
d_findices, shape, order
)

dau.assert_eq(d_indices, indices)

0 comments on commit b470898

Please sign in to comment.