Skip to content

Commit

Permalink
BUG: Fix median and quantile NaT handling
Browse files Browse the repository at this point in the history
Note that this doesn't mean that rounding is correct at least for
quantiles, so there is some dubious about it being a good idea to
use this.

But, it does fix the issue, and I the `copyto` solution seems rather
good to me, the only thing that isn't ideal is the `supports_nan`
definition itself.

Closes gh-20376
  • Loading branch information
seberg committed May 17, 2023
1 parent 0200e4a commit 6ac4d6d
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 27 deletions.
41 changes: 22 additions & 19 deletions numpy/lib/function_base.py
Expand Up @@ -3943,8 +3943,10 @@ def _median(a, axis=None, out=None, overwrite_input=False):
kth = [szh - 1, szh]
else:
kth = [(sz - 1) // 2]
# Check if the array contains any nan's
if np.issubdtype(a.dtype, np.inexact):

# We have to check for NaNs (as of writing 'M' doesn't actually work).
supports_nans = np.issubdtype(a.dtype, np.inexact) or a.dtype.kind in 'Mm'
if supports_nans:
kth.append(-1)

if overwrite_input:
Expand Down Expand Up @@ -3975,8 +3977,7 @@ def _median(a, axis=None, out=None, overwrite_input=False):
# Use mean in both odd and even case to coerce data type,
# using out array if needed.
rout = mean(part[indexer], axis=axis, out=out)
# Check if the array contains any nan's
if np.issubdtype(a.dtype, np.inexact) and sz > 0:
if supports_nans and sz > 0:
# If nans are possible, warn and replace by nans like mean would.
rout = np.lib.utils._median_nancheck(part, rout, axis)

Expand Down Expand Up @@ -4784,9 +4785,9 @@ def _quantile(
values_count = arr.shape[axis]
# The dimensions of `q` are prepended to the output shape, so we need the
# axis being sampled from `arr` to be last.
DATA_AXIS = 0
if axis != DATA_AXIS: # But moveaxis is slow, so only call it if axis!=0.
arr = np.moveaxis(arr, axis, destination=DATA_AXIS)

if axis != 0: # But moveaxis is slow, so only call it if necessary.
arr = np.moveaxis(arr, axis, destination=0)
# --- Computation of indexes
# Index where to find the value in the sorted array.
# Virtual because it is a floating point value, not an valid index.
Expand All @@ -4799,12 +4800,16 @@ def _quantile(
f"{_QuantileMethods.keys()}") from None
virtual_indexes = method["get_virtual_index"](values_count, quantiles)
virtual_indexes = np.asanyarray(virtual_indexes)

supports_nans = (
np.issubdtype(arr.dtype, np.inexact) or arr.dtype.kind in 'Mm')

if np.issubdtype(virtual_indexes.dtype, np.integer):
# No interpolation needed, take the points along axis
if np.issubdtype(arr.dtype, np.inexact):
if supports_nans:
# may contain nan, which would sort to the end
arr.partition(concatenate((virtual_indexes.ravel(), [-1])), axis=0)
slices_having_nans = np.isnan(arr[-1])
slices_having_nans = np.isnan(arr[-1, ...])
else:
# cannot contain nan
arr.partition(virtual_indexes.ravel(), axis=0)
Expand All @@ -4820,16 +4825,14 @@ def _quantile(
previous_indexes.ravel(),
next_indexes.ravel(),
))),
axis=DATA_AXIS)
if np.issubdtype(arr.dtype, np.inexact):
slices_having_nans = np.isnan(
take(arr, indices=-1, axis=DATA_AXIS)
)
axis=0)
if supports_nans:
slices_having_nans = np.isnan(arr[-1, ...])
else:
slices_having_nans = None
# --- Get values from indexes
previous = np.take(arr, previous_indexes, axis=DATA_AXIS)
next = np.take(arr, next_indexes, axis=DATA_AXIS)
previous = arr[previous_indexes]
next = arr[next_indexes]
# --- Linear interpolation
gamma = _get_gamma(virtual_indexes, previous_indexes, method)
result_shape = virtual_indexes.shape + (1,) * (arr.ndim - 1)
Expand All @@ -4840,10 +4843,10 @@ def _quantile(
out=out)
if np.any(slices_having_nans):
if result.ndim == 0 and out is None:
# can't write to a scalar
result = arr.dtype.type(np.nan)
# can't write to a scalar, but indexing will be correct
result = arr[-1]
else:
result[..., slices_having_nans] = np.nan
np.copyto(result, arr[-1, ...], where=slices_having_nans)
return result


Expand Down
38 changes: 38 additions & 0 deletions numpy/lib/tests/test_function_base.py
Expand Up @@ -3537,6 +3537,25 @@ def test_nan_q(self):
with pytest.raises(ValueError, match="Percentiles must be in"):
np.percentile([1, 2, 3, 4.0], q)

@pytest.mark.parametrize("dtype", ["m8[D]", "M8[s]"])
@pytest.mark.parametrize("pos", [0, 23, 10])
def test_nat_basic(self, dtype, pos):
# TODO: Note that times have dubious rounding as of fixing NaTs!
# NaT and NaN should behave the same, do basic tests for NaT:
a = np.arange(0, 24, dtype=dtype)
a[pos] = "NaT"
res = np.percentile(a, 30)
assert res.dtype == dtype
assert np.isnat(res)
res = np.percentile(a, [30, 60])
assert res.dtype == dtype
assert np.isnat(res).all()

a = np.arange(0, 24*3, dtype=dtype).reshape(-1, 3)
a[pos, 1] = "NaT"
res = np.percentile(a, 30, axis=0)
assert_array_equal(np.isnat(res), [False, True, False])


quantile_methods = [
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
Expand Down Expand Up @@ -4072,6 +4091,25 @@ def test_keepdims_out(self, axis):
assert result is out
assert_equal(result.shape, shape_out)

@pytest.mark.parametrize("dtype", ["m8[s]"])
@pytest.mark.parametrize("pos", [0, 23, 10])
def test_nat_behavior(self, dtype, pos):
# TODO: Median does not support Datetime, due to `mean`.
# NaT and NaN should behave the same, do basic tests for NaT.
a = np.arange(0, 24, dtype=dtype)
a[pos] = "NaT"
res = np.median(a)
assert res.dtype == dtype
assert np.isnat(res)
res = np.percentile(a, [30, 60])
assert res.dtype == dtype
assert np.isnat(res).all()

a = np.arange(0, 24*3, dtype=dtype).reshape(-1, 3)
a[pos, 1] = "NaT"
res = np.median(a, axis=0)
assert_array_equal(np.isnat(res), [False, True, False])


class TestAdd_newdoc_ufunc:

Expand Down
22 changes: 14 additions & 8 deletions numpy/lib/utils.py
Expand Up @@ -1101,17 +1101,23 @@ def _median_nancheck(data, result, axis):
"""
if data.size == 0:
return result
n = np.isnan(data.take(-1, axis=axis))
# masked NaN values are ok
potential_nans = data.take(-1, axis=axis)
n = np.isnan(potential_nans)
# masked NaN values are ok, although for masked the copyto may fail for
# unmasked ones (this was always broken) when the result is a scalar.
if np.ma.isMaskedArray(n):
n = n.filled(False)
if np.count_nonzero(n.ravel()) > 0:
# Without given output, it is possible that the current result is a
# numpy scalar, which is not writeable. If so, just return nan.
if isinstance(result, np.generic):
return data.dtype.type(np.nan)

result[n] = np.nan
if not n.any():
return result

# Without given output, it is possible that the current result is a
# numpy scalar, which is not writeable. If so, just return nan.
if isinstance(result, np.generic):
return potential_nans

# Otherwise copy NaNs (if there are any)
np.copyto(result, potential_nans, where=n)
return result

def _opt_info():
Expand Down

0 comments on commit 6ac4d6d

Please sign in to comment.