Skip to content

Commit

Permalink
Merge pull request #21267 from jakevdp:quantile-warning
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634556534
  • Loading branch information
jax authors committed May 16, 2024
2 parents 9dd98dc + bfbde5e commit bc19f7f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 36 deletions.
73 changes: 41 additions & 32 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
_broadcast_to, check_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg
from jax._src.util import (
canonicalize_axis as _canonicalize_axis, maybe_named_axis,
NumpyComplexWarning)
Expand Down Expand Up @@ -755,43 +755,45 @@ def cumulative_sum(
return out

# Quantiles

# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@implements(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
check_arraylike("quantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if interpolation is not None:
if not isinstance(interpolation, DeprecatedArg):
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, False)
"Use 'method=' instead.", DeprecationWarning, stacklevel=2)
method = interpolation
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False)

# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@implements(np.nanquantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
check_arraylike("nanquantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
if interpolation is not None:
if not isinstance(interpolation, DeprecatedArg):
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, True)
"Use 'method=' instead.", DeprecationWarning, stacklevel=2)
method = interpolation
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True)

def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
"'midpoint', or 'nearest'")
method: str, keepdims: bool, squash_nans: bool) -> Array:
if method not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'")
a, = promote_dtypes_inexact(a)
keepdim = []
if dtypes.issubdtype(a.dtype, np.complexfloating):
Expand Down Expand Up @@ -890,50 +892,57 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
high_weight = lax.broadcast_in_dim(high_weight, high_value.shape,
broadcast_dimensions=(0,))

if interpolation == "linear":
if method == "linear":
result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight),
lax.mul(high_value.astype(q.dtype), high_weight))
elif interpolation == "lower":
elif method == "lower":
result = low_value
elif interpolation == "higher":
elif method == "higher":
result = high_value
elif interpolation == "nearest":
elif method == "nearest":
pred = lax.le(high_weight, _lax_const(high_weight, 0.5))
result = lax.select(pred, low_value, high_value)
elif interpolation == "midpoint":
elif method == "midpoint":
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
else:
raise ValueError(f"interpolation={interpolation!r} not recognized")
raise ValueError(f"{method=!r} not recognized")
if keepdims and keepdim:
if q_ndim > 0:
keepdim = [np.shape(q)[0], *keepdim]
result = result.reshape(keepdim)
return lax.convert_element_type(result, a.dtype)

# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@implements(np.percentile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
check_arraylike("percentile", a, q)
q, = promote_dtypes_inexact(q)
if not isinstance(interpolation, DeprecatedArg):
warnings.warn("The interpolation= argument to 'percentile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning, stacklevel=2)
method = interpolation
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method, keepdims=keepdims)
method=method, keepdims=keepdims)

# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@implements(np.nanpercentile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
check_arraylike("nanpercentile", a, q)
q = ufuncs.true_divide(q, 100.0)
if not isinstance(interpolation, DeprecatedArg):
warnings.warn("The interpolation= argument to 'nanpercentile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning, stacklevel=2)
method = interpolation
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, method=method,
keepdims=keepdims)
method=method, keepdims=keepdims)

@implements(np.median, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
Expand Down
8 changes: 4 additions & 4 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -610,14 +610,14 @@ def nanmin(a: ArrayLike, axis: _Axis = ..., out: None = ...,
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, tuple[int, ...]]] = ...,
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
keepdims: builtins.bool = ..., interpolation: None = ...) -> Array: ...
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...
def nanprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
out: None = ...,
keepdims: builtins.bool = ..., initial: Optional[ArrayLike] = ...,
where: Optional[ArrayLike] = ...) -> Array: ...
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ...,
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
keepdims: builtins.bool = ..., interpolation: None = ...) -> Array: ...
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...
def nanstd(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...,
ddof: int = ..., keepdims: builtins.bool = ...,
where: Optional[ArrayLike] = ...) -> Array: ...
Expand Down Expand Up @@ -660,7 +660,7 @@ def partition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ...
def percentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, tuple[int, ...]]] = ...,
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
keepdims: builtins.bool = ..., interpolation: None = ...) -> Array: ...
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...
def permute_dims(x: ArrayLike, /, axes: tuple[int, ...]) -> Array: ...
pi: float
def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]],
Expand Down Expand Up @@ -695,7 +695,7 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ...
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = ...,
out: None = ..., overwrite_input: builtins.bool = ..., method: str = ...,
keepdims: builtins.bool = ..., interpolation: None = ...) -> Array: ...
keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ...
r_: _RClass
def rad2deg(x: ArrayLike, /) -> Array: ...
radians = deg2rad
Expand Down

0 comments on commit bc19f7f

Please sign in to comment.