Skip to content

Commit

Permalink
Adjust extrema's internal structured type handling (#77)
Browse files Browse the repository at this point in the history
* Covert `extrema`'s component types to `dtype`s

Make sure that each component type in `extrema`'s structured `dtype` is
converted to a `dtype`.

* Use a 1-D singleton array in `extrema`'s kernel

Instead of using a 0-D array in `extrema`'s kernel function,
use a 1-D singleton array. As we have generally had better luck working
with 1-D singleton structured arrays with Dask, it makes sense to use
these consistently internally as well. Also avoids some hacky tricks to
extract the value from the array in favor of just selecting the value.

* Make `extrema`'s default with 1-D singleton array

Instead of using a 0-D array to construct `extrema`'s default value, use
a 1-D singleton array. As we have generally had better luck working with
1-D singleton structured arrays with Dask, it makes sense to use these
consistently internally as well. Also avoids some hacky tricks to
extract the value from the array in favor of just selecting the value.

* Reorder `extrema`'s internal structured `dtype`

Adjust the order of the `extrema`'s internal structured `dtype` to match
the order of the result returned.

* Make a mapping of `extrema`'s structured `dtype`

Define a mapping using an `OrderedDict` of the names and types within
the structured `dtype`. Use this to build the structured `dtype`. Also
use this to breakout the specific results after running `extrema`'s
kernel with `labeled_comprehension`.

* Bind `dtype` in `extrema` kernel function

Instead of reconstructing the `dtype` in `extrema`'s kernel function,
bind this in advance using `functools.partial` so that it can be passed
through to the `extrema` kernel function unchanged. Thus avoiding the
need to define this in two different places and potentially getting them
messed up.
  • Loading branch information
jakirkham committed Oct 1, 2018
1 parent 21230de commit c9fd2c4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 21 deletions.
25 changes: 13 additions & 12 deletions dask_image/ndmeasure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
__email__ = "kirkhamj@janelia.hhmi.org"


import collections
import functools
import itertools
from warnings import warn
Expand Down Expand Up @@ -98,20 +99,25 @@ def extrema(input, labels=None, index=None):
input, labels, index
)

out_dtype = numpy.dtype([
type_mapping = collections.OrderedDict([
("min_val", input.dtype),
("min_pos", numpy.int),
("max_val", input.dtype),
("max_pos", numpy.int)
("min_pos", numpy.dtype(numpy.int)),
("max_pos", numpy.dtype(numpy.int))
])
default = numpy.zeros((), out_dtype)[()]
out_dtype = numpy.dtype(list(type_mapping.items()))

default_1d = numpy.zeros((1,), dtype=out_dtype)

func = functools.partial(_utils._extrema, dtype=out_dtype)
extrema_lbl = labeled_comprehension(
input, labels, index,
_utils._extrema, out_dtype, default, pass_positions=True
func, out_dtype, default_1d[0], pass_positions=True
)

extrema_lbl = {k: extrema_lbl[k] for k in extrema_lbl.dtype.names}
extrema_lbl = collections.OrderedDict([
(k, extrema_lbl[k]) for k in type_mapping.keys()
])

for pos_key in ["min_pos", "max_pos"]:
pos_1d = extrema_lbl[pos_key]
Expand All @@ -129,12 +135,7 @@ def extrema(input, labels=None, index=None):

extrema_lbl[pos_key] = pos_nd

result = (
extrema_lbl["min_val"],
extrema_lbl["max_val"],
extrema_lbl["min_pos"],
extrema_lbl["max_pos"]
)
result = tuple(extrema_lbl.values())

return result

Expand Down
12 changes: 3 additions & 9 deletions dask_image/ndmeasure/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,12 @@ def _argmin(a, positions):
return positions[numpy.argmin(a)]


def _extrema(a, positions):
def _extrema(a, positions, dtype):
"""
Find minimum and maximum as well as positions for both.
"""

dtype = numpy.dtype([
("min_val", a.dtype),
("min_pos", positions.dtype),
("max_val", a.dtype),
("max_pos", positions.dtype)
])
result = numpy.empty((), dtype=dtype)
result = numpy.empty((1,), dtype=dtype)

int_min_pos = numpy.argmin(a)
result["min_val"] = a[int_min_pos]
Expand All @@ -149,7 +143,7 @@ def _extrema(a, positions):
result["max_val"] = a[int_max_pos]
result["max_pos"] = positions[int_max_pos]

return result[()]
return result[0]


def _histogram(input,
Expand Down

0 comments on commit c9fd2c4

Please sign in to comment.