Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a6f6c93
Draft
cakedev0 Oct 21, 2025
d30bcbf
Merge remote-tracking branch 'upstream/main' into quantile
cakedev0 Oct 21, 2025
dc236da
revert changes to renovate.json
cakedev0 Oct 21, 2025
f92fc4b
revert changes to renovate.json
cakedev0 Oct 21, 2025
06e370a
untested implem; limited to method="linear"; trying to mimic numpy be…
cakedev0 Oct 21, 2025
dc7a1e5
remove unused imports
cakedev0 Oct 21, 2025
98fe39f
draft version with some tests that are passing
cakedev0 Oct 22, 2025
034c064
linting: fix pyright
cakedev0 Oct 22, 2025
05ffb7b
linting: fix mypy
cakedev0 Oct 22, 2025
89d8410
fixed linting
cakedev0 Oct 22, 2025
19fa6ea
WIP: adding support for weights
cakedev0 Oct 22, 2025
fa789fc
Weighted quantile; nan-policy; everything mostly works
cakedev0 Oct 23, 2025
1d8fef7
linting: pyright & mypy
cakedev0 Oct 23, 2025
26804fe
linting: ruff
cakedev0 Oct 23, 2025
3611708
linting & cleanup
cakedev0 Oct 23, 2025
7160bae
fix tests for numpy 1.x
cakedev0 Oct 23, 2025
0b2cb9b
working on coverage
cakedev0 Oct 23, 2025
3226659
working on coverage
cakedev0 Oct 23, 2025
c395b84
more coverage
cakedev0 Oct 23, 2025
07f7007
fix test
cakedev0 Oct 23, 2025
e319529
second attempt to fix test
cakedev0 Oct 23, 2025
8ab7d62
more validation
cakedev0 Oct 23, 2025
1b48267
some more tests
cakedev0 Oct 23, 2025
c71351f
Fix typo in err msg
cakedev0 Oct 25, 2025
ce55335
avoid sorting a; just sort the weights
cakedev0 Oct 25, 2025
0353767
Merge branch 'quantile' of github.com:cakedev0/array-api-extra into q…
cakedev0 Oct 25, 2025
7c18a82
return max values when all weights are null
cakedev0 Oct 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
one_hot,
pad,
partition,
quantile,
sinc,
)
from ._lib._at import at
Expand Down Expand Up @@ -48,6 +49,7 @@
"one_hot",
"pad",
"partition",
"quantile",
"setdiff1d",
"sinc",
]
275 changes: 273 additions & 2 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from types import ModuleType
from typing import Literal

from ._lib import _funcs
from ._lib import _funcs, _quantile
from ._lib._backends import NUMPY_VERSION
from ._lib._utils._compat import (
array_namespace,
is_cupy_namespace,
Expand Down Expand Up @@ -768,7 +769,7 @@ def argpartition(
Axis along which to partition. The default is ``-1`` (the last axis).
If ``None``, the flattened array is used.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
The standard-compatible namespace for `a`. Default: infer.

Returns
-------
Expand Down Expand Up @@ -895,3 +896,273 @@ def isin(
return xp.isin(a, b, assume_unique=assume_unique, invert=invert)

return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp)


def quantile(
a: Array,
q: float | Array,
/,
axis: int | None = None,
method: str = "linear",
keepdims: bool = False,
nan_policy: str = "propagate",
*,
weights: Array | None = None,
xp: ModuleType | None = None,
) -> Array:
"""
Compute the q-th quantile of the data along the specified axis.

Parameters
----------
a : array_like of real numbers
Input array or object that can be converted to an array.
q : array_like of float
Probability or sequence of probabilities of the quantiles to compute.
Values must be between 0 and 1 inclusive.
axis : {int, tuple of int, None}, optional
Axis or axes along which the quantiles are computed. The default is
to compute the quantile(s) along a flattened version of the array.
method : str, optional
This parameter specifies the method to use for estimating the
quantile. There are many different methods.
The recommended options, numbered as they appear in [1]_, are:

1. 'inverted_cdf'
2. 'averaged_inverted_cdf'
3. 'closest_observation'
4. 'interpolated_inverted_cdf'
5. 'hazen'
6. 'weibull'
7. 'linear' (default)
8. 'median_unbiased'
9. 'normal_unbiased'

The first three methods are discontinuous.
Only 'linear', 'inverted_cdf' and 'averaged_inverted_cdf' are implemented.

keepdims : bool, optional
If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option, the
result will broadcast correctly against the original array `a`.

nan_policy : str, optional
'propagate' (default) or 'omit'.
'omit' is supported only when `weights` are provided.

weights : array_like, optional
An array of weights associated with the values in `a`. Each value in
`a` contributes to the quantile according to its associated weight.
The weights array can either be 1-D (in which case its length must be
the size of `a` along the given axis) or of the same shape as `a`.
If `weights=None`, then all data in `a` are assumed to have a
weight equal to one.
Only `method="inverted_cdf"` or `method="averaged_inverted_cdf"`
support weights. See the notes for more details.

xp : array_namespace, optional
The standard-compatible namespace for `a` and `q`. Default: infer.

Returns
-------
scalar or ndarray
If `q` is a single probability and `axis=None`, then the result
is a scalar. If multiple probability levels are given, first axis
of the result corresponds to the quantiles. The other axes are
the axes that remain after the reduction of `a`. If the input
contains integers or floats smaller than ``float64``, the output
data-type is ``float64``. Otherwise, the output data-type is the
same as that of the input. If `out` is specified, that array is
returned instead.

Notes
-----
Given a sample `a` from an underlying distribution, `quantile` provides a
nonparametric estimate of the inverse cumulative distribution function.

By default, this is done by interpolating between adjacent elements in
``y``, a sorted copy of `a`::

(1-g)*y[j] + g*y[j+1]

where the index ``j`` and coefficient ``g`` are the integral and
fractional components of ``q * (n-1)``, and ``n`` is the number of
elements in the sample.

This is a special case of Equation 1 of H&F [1]_. More generally,

- ``j = (q*n + m - 1) // 1``, and
- ``g = (q*n + m - 1) % 1``,

where ``m`` may be defined according to several different conventions.
The preferred convention may be selected using the ``method`` parameter:

=============================== =============== ===============
``method`` number in H&F ``m``
=============================== =============== ===============
``interpolated_inverted_cdf`` 4 ``0``
``hazen`` 5 ``1/2``
``weibull`` 6 ``q``
``linear`` (default) 7 ``1 - q``
``median_unbiased`` 8 ``q/3 + 1/3``
``normal_unbiased`` 9 ``q/4 + 3/8``
=============================== =============== ===============

Note that indices ``j`` and ``j + 1`` are clipped to the range ``0`` to
``n - 1`` when the results of the formula would be outside the allowed
range of non-negative indices. The ``- 1`` in the formulas for ``j`` and
``g`` accounts for Python's 0-based indexing.

The table above includes only the estimators from H&F that are continuous
functions of probability `q` (estimators 4-9). NumPy also provides the
three discontinuous estimators from H&F (estimators 1-3), where ``j`` is
defined as above, ``m`` is defined as follows, and ``g`` is a function
of the real-valued ``index = q*n + m - 1`` and ``j``.

1. ``inverted_cdf``: ``m = 0`` and ``g = int(index - j > 0)``
2. ``averaged_inverted_cdf``: ``m = 0`` and
``g = (1 + int(index - j > 0)) / 2``
3. ``closest_observation``: ``m = -1/2`` and
``g = 1 - int((index == j) & (j%2 == 1))``

**Weighted quantiles:**
More formally, the quantile at probability level :math:`q` of a cumulative
distribution function :math:`F(y)=P(Y \\leq y)` with probability measure
:math:`P` is defined as any number :math:`x` that fulfills the
*coverage conditions*

.. math:: P(Y < x) \\leq q \\quad\\text{and}\\quad P(Y \\leq x) \\geq q

with random variable :math:`Y\\sim P`.
Sample quantiles, the result of `quantile`, provide nonparametric
estimation of the underlying population counterparts, represented by the
unknown :math:`F`, given a data vector `a` of length ``n``.

Some of the estimators above arise when one considers :math:`F` as the
empirical distribution function of the data, i.e.
:math:`F(y) = \\frac{1}{n} \\sum_i 1_{a_i \\leq y}`.
Then, different methods correspond to different choices of :math:`x` that
fulfill the above coverage conditions. Methods that follow this approach
are ``inverted_cdf`` and ``averaged_inverted_cdf``.

For weighted quantiles, the coverage conditions still hold. The
empirical cumulative distribution is simply replaced by its weighted
version, i.e.
:math:`P(Y \\leq t) = \\frac{1}{\\sum_i w_i} \\sum_i w_i 1_{x_i \\leq t}`.

References
----------
.. [1] R. J. Hyndman and Y. Fan,
"Sample quantiles in statistical packages,"
The American Statistician, 50(4), pp. 361-365, 1996
"""
if xp is None:
xp = array_namespace(a)
if is_pydata_sparse_namespace(xp):
msg = "Sparse backend not supported"
raise ValueError(msg)

methods = {"linear", "inverted_cdf", "averaged_inverted_cdf"}
if method not in methods:
msg = f"`method` must be one of {methods}"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sort methods to get a deterministic output?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's deterministic already. But do you mean declaring methods in the sorted order?

Like this: methods = {"averaged_inverted_cdf", "inverted_cdf", "linear"}

raise ValueError(msg)
nan_policies = {"propagate", "omit"}
if nan_policy not in nan_policies:
msg = f"`nan_policy` must be one of {nan_policies}"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

raise ValueError(msg)

a = xp.asarray(a)
if not xp.isdtype(a.dtype, ("integral", "real floating")):
msg = "`a` must have real dtype."
raise ValueError(msg)
if not xp.isdtype(xp.asarray(q).dtype, "real floating"):
msg = "`q` must have real floating dtype."
raise ValueError(msg)
weights = None if weights is None else xp.asarray(weights)

ndim = a.ndim
if ndim < 1:
msg = "`a` must be at least 1-dimensional."
raise TypeError(msg)
if axis is not None and ((axis >= ndim) or (axis < -ndim)):
msg = "`axis` is not compatible with the dimension of `a`."
raise ValueError(msg)
if weights is None:
if nan_policy != "propagate":
msg = "When `weights` aren't provided, `nan_policy` must be 'propagate'."
raise ValueError(msg)
else:
if method not in {"inverted_cdf", "averaged_inverted_cdf"}:
msg = f"`method` '{method}' not supported with weights."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only method x/ y support weights?

raise ValueError(msg)
if not xp.isdtype(weights.dtype, ("integral", "real floating")):
msg = "`weights` must have real dtype."
raise ValueError(msg)
if ndim > 2:
msg = "When weights are provided, dimension of `a` must be 1 or 2."
raise ValueError(msg)
if a.shape != weights.shape:
if axis is None:
msg = "Axis must be specified when shapes of `a` and ̀ weights` differ."
raise TypeError(msg)
if weights.shape != eager_shape(a, axis):
msg = (
"Shape of weights must be consistent with shape"
" of a along specified axis."
)
raise ValueError(msg)
if axis is None and ndim == 2:
msg = "Axis must be specified when `a` and ̀ weights` are 2d."
raise ValueError(msg)

# Align result dtype with what numpy does:
dtype = xp.result_type(
xp.float64 if xp.isdtype(a.dtype, "integral") else a,
xp.asarray(q),
xp.float64, # at least float64
)
device = get_device(a)
a = xp.asarray(a, dtype=dtype, device=device)
q_arr = xp.asarray(q, dtype=dtype, device=device)
# TODO: cast weights here? Assert weights are on the same device as `a`?

if xp.any((q_arr > 1) | (q_arr < 0) | xp.isnan(q_arr)):
msg = "`q` values must be in the range [0, 1]"
raise ValueError(msg)

# Delegate when possible.
# Note: No delegation for dask: I couldn't make it work.
basic_case = method == "linear" and weights is None

np_2 = NUMPY_VERSION >= (2, 0)
np_handles_weights = np_2 and nan_policy == "propagate" and method == "inverted_cdf"
if weights is None:
if is_numpy_namespace(xp) and (basic_case or np_2):
quantile = xp.quantile if nan_policy == "propagate" else xp.nanquantile
return quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims)
elif is_numpy_namespace(xp) and np_handles_weights:
# TODO: call nanquantile for nan_policy == "omit" once
# https://github.com/numpy/numpy/issues/29709 is fixed
return xp.quantile(
a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights
)

jax_or_cupy = is_jax_namespace(xp) or is_cupy_namespace(xp)
if jax_or_cupy and basic_case and nan_policy == "propagate":
return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims)
if is_torch_namespace(xp) and basic_case:
quantile = xp.quantile if nan_policy == "propagate" else xp.nanquantile
return quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims)

# Otherwise call our implementation (will sort data)
return _quantile.quantile(
# XXX: I'm not sure we want to support dask, it seems uterly slow...
a,
q_arr,
axis=axis,
method=method,
keepdims=keepdims,
nan_policy=nan_policy,
weights=weights,
xp=xp,
)
Loading