Skip to content

Commit

Permalink
Ensure labeled_comprehension's default is 1D (#69)
Browse files Browse the repository at this point in the history
Create an `empty` 1-D scalar array of the expected return type of the
user function. Then store this `default` value into the 1-D scalar
array. This has the advantage of raising a `ValueError` if the `default`
value cannot be coerced into a scalar of the appropriate type and stored
into this array. It also handles some situations better like having a
1-D scalar array provided as the `default` value by ensuring the result
remains 1-D.
  • Loading branch information
jakirkham committed Sep 18, 2018
1 parent 4a79a70 commit c714522
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions dask_image/ndmeasure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,12 @@ def labeled_comprehension(input,
input, labels, index = _utils._norm_input_labels_index(
input, labels, index
)

out_dtype = numpy.dtype(out_dtype)
default = numpy.array([default], dtype=out_dtype)

default_1d = numpy.empty((1,), dtype=out_dtype)
default_1d[0] = default

pass_positions = bool(pass_positions)

lbl_mtch = _utils._get_label_matches(labels, index)
Expand All @@ -322,7 +326,7 @@ def labeled_comprehension(input,
lbl_mtch_i = lbl_mtch[i]
args_lbl_mtch_i = tuple(e[lbl_mtch_i] for e in args)
result[i] = _utils._labeled_comprehension_func(
func, out_dtype, default, *args_lbl_mtch_i
func, out_dtype, default_1d, *args_lbl_mtch_i
)

for i in _pycompat.irange(result.ndim - 1, -1, -1):
Expand Down

0 comments on commit c714522

Please sign in to comment.