Skip to content

Commit

Permalink
Use custom kernel for extrema (#61)
Browse files Browse the repository at this point in the history
* Add extrema kernel

Implements the `_extrema` function to act as a kernel for the `extrema`
function. Effectively fuses all minimum, maximum, and position
computations into the same computational step. Should make this
significantly faster for Dask to compute and the resulting graph should
be much simpler.

* Use `_extrema` kernel

Rewrite the `extrema` function to make use of the new `_extrema` kernel.
This results in a few changes and requires the `extrema` function to
take on some work that the other functions were doing for us. Though the
performance improvements to `extrema` should be worth it.
  • Loading branch information
jakirkham committed Sep 17, 2018
1 parent ee4349b commit 5ad221c
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 11 deletions.
47 changes: 36 additions & 11 deletions dask_image/ndmeasure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,45 @@ def extrema(input, labels=None, index=None):
input, labels, index
)

min_lbl = minimum(
input, labels, index
)
max_lbl = maximum(
input, labels, index
out_dtype = numpy.dtype([
("min_val", input.dtype),
("min_pos", numpy.int),
("max_val", input.dtype),
("max_pos", numpy.int)
])
default = numpy.zeros((), out_dtype)[()]

extrema_lbl = labeled_comprehension(
input, labels, index,
_utils._extrema, out_dtype, default, pass_positions=True
)
min_pos_lbl = minimum_position(
input, labels, index
)
max_pos_lbl = maximum_position(
input, labels, index

extrema_lbl = {k: extrema_lbl[k] for k in extrema_lbl.dtype.names}

for pos_key in ["min_pos", "max_pos"]:
pos_1d = extrema_lbl[pos_key]
if not pos_1d.ndim:
pos_1d = pos_1d[None]

pos_nd = _utils._unravel_index(pos_1d, input.shape)

if index.ndim == 0:
pos_nd = dask.array.squeeze(pos_nd)
elif index.ndim > 1:
pos_nd = pos_nd.reshape(
(int(numpy.prod(pos_nd.shape[:-1])), pos_nd.shape[-1])
)

extrema_lbl[pos_key] = pos_nd

result = (
extrema_lbl["min_val"],
extrema_lbl["max_val"],
extrema_lbl["min_pos"],
extrema_lbl["max_pos"]
)

return min_lbl, max_lbl, min_pos_lbl, max_pos_lbl
return result


def histogram(input,
Expand Down
24 changes: 24 additions & 0 deletions dask_image/ndmeasure/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,30 @@ def _argmin(a, positions):
return positions[numpy.argmin(a)]


def _extrema(a, positions):
"""
Find minimum and maximum as well as positions for both.
"""

dtype = numpy.dtype([
("min_val", a.dtype),
("min_pos", positions.dtype),
("max_val", a.dtype),
("max_pos", positions.dtype)
])
result = numpy.empty((), dtype=dtype)

int_min_pos = numpy.argmin(a)
result["min_val"] = a[int_min_pos]
result["min_pos"] = positions[int_min_pos]

int_max_pos = numpy.argmax(a)
result["max_val"] = a[int_max_pos]
result["max_pos"] = positions[int_max_pos]

return result[()]


@dask.delayed
def _histogram(input,
min,
Expand Down

0 comments on commit 5ad221c

Please sign in to comment.