-
Couldn't load subscription status.
- Fork 16
ENH: add quantile function with weights support
#494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
cakedev0
wants to merge
27
commits into
data-apis:main
Choose a base branch
from
cakedev0:quantile
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+742
−3
Open
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
a6f6c93
Draft
cakedev0 d30bcbf
Merge remote-tracking branch 'upstream/main' into quantile
cakedev0 dc236da
revert changes to renovate.json
cakedev0 f92fc4b
revert changes to renovate.json
cakedev0 06e370a
untested implem; limited to method="linear"; trying to mimic numpy be…
cakedev0 dc7a1e5
remove unused imports
cakedev0 98fe39f
draft version with some tests that are passing
cakedev0 034c064
linting: fix pyright
cakedev0 05ffb7b
linting: fix mypy
cakedev0 89d8410
fixed linting
cakedev0 19fa6ea
WIP: adding support for weights
cakedev0 fa789fc
Weighted quantile; nan-policy; everything mostly works
cakedev0 1d8fef7
linting: pyright & mypy
cakedev0 26804fe
linting: ruff
cakedev0 3611708
linting & cleanup
cakedev0 7160bae
fix tests for numpy 1.x
cakedev0 0b2cb9b
working on coverage
cakedev0 3226659
working on coverage
cakedev0 c395b84
more coverage
cakedev0 07f7007
fix test
cakedev0 e319529
second attempt to fix test
cakedev0 8ab7d62
more validation
cakedev0 1b48267
some more tests
cakedev0 c71351f
Fix typo in err msg
cakedev0 ce55335
avoid sorting a; just sort the weights
cakedev0 0353767
Merge branch 'quantile' of github.com:cakedev0/array-api-extra into q…
cakedev0 7c18a82
return max values when all weights are null
cakedev0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,8 @@ | |
| from types import ModuleType | ||
| from typing import Literal | ||
|
|
||
| from ._lib import _funcs | ||
| from ._lib import _funcs, _quantile | ||
| from ._lib._backends import NUMPY_VERSION | ||
| from ._lib._utils._compat import ( | ||
| array_namespace, | ||
| is_cupy_namespace, | ||
|
|
@@ -768,7 +769,7 @@ def argpartition( | |
| Axis along which to partition. The default is ``-1`` (the last axis). | ||
| If ``None``, the flattened array is used. | ||
| xp : array_namespace, optional | ||
| The standard-compatible namespace for `x`. Default: infer. | ||
| The standard-compatible namespace for `a`. Default: infer. | ||
|
|
||
| Returns | ||
| ------- | ||
|
|
@@ -895,3 +896,273 @@ def isin( | |
| return xp.isin(a, b, assume_unique=assume_unique, invert=invert) | ||
|
|
||
| return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp) | ||
|
|
||
|
|
||
| def quantile( | ||
| a: Array, | ||
| q: float | Array, | ||
| /, | ||
| axis: int | None = None, | ||
| method: str = "linear", | ||
| keepdims: bool = False, | ||
| nan_policy: str = "propagate", | ||
| *, | ||
| weights: Array | None = None, | ||
| xp: ModuleType | None = None, | ||
| ) -> Array: | ||
| """ | ||
| Compute the q-th quantile of the data along the specified axis. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| a : array_like of real numbers | ||
| Input array or object that can be converted to an array. | ||
| q : array_like of float | ||
| Probability or sequence of probabilities of the quantiles to compute. | ||
| Values must be between 0 and 1 inclusive. | ||
| axis : {int, tuple of int, None}, optional | ||
| Axis or axes along which the quantiles are computed. The default is | ||
| to compute the quantile(s) along a flattened version of the array. | ||
| method : str, optional | ||
| This parameter specifies the method to use for estimating the | ||
| quantile. There are many different methods. | ||
| The recommended options, numbered as they appear in [1]_, are: | ||
|
|
||
| 1. 'inverted_cdf' | ||
| 2. 'averaged_inverted_cdf' | ||
| 3. 'closest_observation' | ||
| 4. 'interpolated_inverted_cdf' | ||
| 5. 'hazen' | ||
| 6. 'weibull' | ||
| 7. 'linear' (default) | ||
| 8. 'median_unbiased' | ||
| 9. 'normal_unbiased' | ||
|
|
||
| The first three methods are discontinuous. | ||
| Only 'linear', 'inverted_cdf' and 'averaged_inverted_cdf' are implemented. | ||
|
|
||
| keepdims : bool, optional | ||
| If this is set to True, the axes which are reduced are left in | ||
| the result as dimensions with size one. With this option, the | ||
| result will broadcast correctly against the original array `a`. | ||
|
|
||
| nan_policy : str, optional | ||
| 'propagate' (default) or 'omit'. | ||
| 'omit' is supported only when `weights` are provided. | ||
|
|
||
| weights : array_like, optional | ||
| An array of weights associated with the values in `a`. Each value in | ||
| `a` contributes to the quantile according to its associated weight. | ||
| The weights array can either be 1-D (in which case its length must be | ||
| the size of `a` along the given axis) or of the same shape as `a`. | ||
| If `weights=None`, then all data in `a` are assumed to have a | ||
| weight equal to one. | ||
| Only `method="inverted_cdf"` or `method="averaged_inverted_cdf"` | ||
| support weights. See the notes for more details. | ||
|
|
||
| xp : array_namespace, optional | ||
| The standard-compatible namespace for `a` and `q`. Default: infer. | ||
|
|
||
| Returns | ||
| ------- | ||
| scalar or ndarray | ||
| If `q` is a single probability and `axis=None`, then the result | ||
| is a scalar. If multiple probability levels are given, first axis | ||
| of the result corresponds to the quantiles. The other axes are | ||
| the axes that remain after the reduction of `a`. If the input | ||
| contains integers or floats smaller than ``float64``, the output | ||
| data-type is ``float64``. Otherwise, the output data-type is the | ||
| same as that of the input. If `out` is specified, that array is | ||
| returned instead. | ||
|
|
||
| Notes | ||
| ----- | ||
| Given a sample `a` from an underlying distribution, `quantile` provides a | ||
| nonparametric estimate of the inverse cumulative distribution function. | ||
|
|
||
| By default, this is done by interpolating between adjacent elements in | ||
| ``y``, a sorted copy of `a`:: | ||
|
|
||
| (1-g)*y[j] + g*y[j+1] | ||
|
|
||
| where the index ``j`` and coefficient ``g`` are the integral and | ||
| fractional components of ``q * (n-1)``, and ``n`` is the number of | ||
| elements in the sample. | ||
|
|
||
| This is a special case of Equation 1 of H&F [1]_. More generally, | ||
|
|
||
| - ``j = (q*n + m - 1) // 1``, and | ||
| - ``g = (q*n + m - 1) % 1``, | ||
|
|
||
| where ``m`` may be defined according to several different conventions. | ||
| The preferred convention may be selected using the ``method`` parameter: | ||
|
|
||
| =============================== =============== =============== | ||
| ``method`` number in H&F ``m`` | ||
| =============================== =============== =============== | ||
| ``interpolated_inverted_cdf`` 4 ``0`` | ||
| ``hazen`` 5 ``1/2`` | ||
| ``weibull`` 6 ``q`` | ||
| ``linear`` (default) 7 ``1 - q`` | ||
| ``median_unbiased`` 8 ``q/3 + 1/3`` | ||
| ``normal_unbiased`` 9 ``q/4 + 3/8`` | ||
| =============================== =============== =============== | ||
|
|
||
| Note that indices ``j`` and ``j + 1`` are clipped to the range ``0`` to | ||
| ``n - 1`` when the results of the formula would be outside the allowed | ||
| range of non-negative indices. The ``- 1`` in the formulas for ``j`` and | ||
| ``g`` accounts for Python's 0-based indexing. | ||
|
|
||
| The table above includes only the estimators from H&F that are continuous | ||
| functions of probability `q` (estimators 4-9). NumPy also provides the | ||
| three discontinuous estimators from H&F (estimators 1-3), where ``j`` is | ||
| defined as above, ``m`` is defined as follows, and ``g`` is a function | ||
| of the real-valued ``index = q*n + m - 1`` and ``j``. | ||
|
|
||
| 1. ``inverted_cdf``: ``m = 0`` and ``g = int(index - j > 0)`` | ||
| 2. ``averaged_inverted_cdf``: ``m = 0`` and | ||
| ``g = (1 + int(index - j > 0)) / 2`` | ||
| 3. ``closest_observation``: ``m = -1/2`` and | ||
| ``g = 1 - int((index == j) & (j%2 == 1))`` | ||
|
|
||
| **Weighted quantiles:** | ||
| More formally, the quantile at probability level :math:`q` of a cumulative | ||
| distribution function :math:`F(y)=P(Y \\leq y)` with probability measure | ||
| :math:`P` is defined as any number :math:`x` that fulfills the | ||
| *coverage conditions* | ||
|
|
||
| .. math:: P(Y < x) \\leq q \\quad\\text{and}\\quad P(Y \\leq x) \\geq q | ||
|
|
||
| with random variable :math:`Y\\sim P`. | ||
| Sample quantiles, the result of `quantile`, provide nonparametric | ||
| estimation of the underlying population counterparts, represented by the | ||
| unknown :math:`F`, given a data vector `a` of length ``n``. | ||
|
|
||
| Some of the estimators above arise when one considers :math:`F` as the | ||
| empirical distribution function of the data, i.e. | ||
| :math:`F(y) = \\frac{1}{n} \\sum_i 1_{a_i \\leq y}`. | ||
| Then, different methods correspond to different choices of :math:`x` that | ||
| fulfill the above coverage conditions. Methods that follow this approach | ||
| are ``inverted_cdf`` and ``averaged_inverted_cdf``. | ||
|
|
||
| For weighted quantiles, the coverage conditions still hold. The | ||
| empirical cumulative distribution is simply replaced by its weighted | ||
| version, i.e. | ||
| :math:`P(Y \\leq t) = \\frac{1}{\\sum_i w_i} \\sum_i w_i 1_{x_i \\leq t}`. | ||
|
|
||
| References | ||
| ---------- | ||
| .. [1] R. J. Hyndman and Y. Fan, | ||
| "Sample quantiles in statistical packages," | ||
| The American Statistician, 50(4), pp. 361-365, 1996 | ||
| """ | ||
| if xp is None: | ||
| xp = array_namespace(a) | ||
| if is_pydata_sparse_namespace(xp): | ||
| msg = "Sparse backend not supported" | ||
| raise ValueError(msg) | ||
|
|
||
| methods = {"linear", "inverted_cdf", "averaged_inverted_cdf"} | ||
| if method not in methods: | ||
| msg = f"`method` must be one of {methods}" | ||
| raise ValueError(msg) | ||
| nan_policies = {"propagate", "omit"} | ||
| if nan_policy not in nan_policies: | ||
| msg = f"`nan_policy` must be one of {nan_policies}" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
| raise ValueError(msg) | ||
|
|
||
| a = xp.asarray(a) | ||
| if not xp.isdtype(a.dtype, ("integral", "real floating")): | ||
| msg = "`a` must have real dtype." | ||
| raise ValueError(msg) | ||
| if not xp.isdtype(xp.asarray(q).dtype, "real floating"): | ||
| msg = "`q` must have real floating dtype." | ||
| raise ValueError(msg) | ||
| weights = None if weights is None else xp.asarray(weights) | ||
|
|
||
| ndim = a.ndim | ||
| if ndim < 1: | ||
| msg = "`a` must be at least 1-dimensional." | ||
| raise TypeError(msg) | ||
| if axis is not None and ((axis >= ndim) or (axis < -ndim)): | ||
| msg = "`axis` is not compatible with the dimension of `a`." | ||
| raise ValueError(msg) | ||
| if weights is None: | ||
| if nan_policy != "propagate": | ||
| msg = "When `weights` aren't provided, `nan_policy` must be 'propagate'." | ||
| raise ValueError(msg) | ||
| else: | ||
| if method not in {"inverted_cdf", "averaged_inverted_cdf"}: | ||
| msg = f"`method` '{method}' not supported with weights." | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| raise ValueError(msg) | ||
| if not xp.isdtype(weights.dtype, ("integral", "real floating")): | ||
| msg = "`weights` must have real dtype." | ||
| raise ValueError(msg) | ||
| if ndim > 2: | ||
| msg = "When weights are provided, dimension of `a` must be 1 or 2." | ||
| raise ValueError(msg) | ||
| if a.shape != weights.shape: | ||
| if axis is None: | ||
| msg = "Axis must be specified when shapes of `a` and ̀ weights` differ." | ||
| raise TypeError(msg) | ||
| if weights.shape != eager_shape(a, axis): | ||
| msg = ( | ||
| "Shape of weights must be consistent with shape" | ||
| " of a along specified axis." | ||
| ) | ||
| raise ValueError(msg) | ||
| if axis is None and ndim == 2: | ||
| msg = "Axis must be specified when `a` and ̀ weights` are 2d." | ||
| raise ValueError(msg) | ||
|
|
||
| # Align result dtype with what numpy does: | ||
| dtype = xp.result_type( | ||
| xp.float64 if xp.isdtype(a.dtype, "integral") else a, | ||
| xp.asarray(q), | ||
| xp.float64, # at least float64 | ||
| ) | ||
| device = get_device(a) | ||
| a = xp.asarray(a, dtype=dtype, device=device) | ||
| q_arr = xp.asarray(q, dtype=dtype, device=device) | ||
| # TODO: cast weights here? Assert weights are on the same device as `a`? | ||
|
|
||
| if xp.any((q_arr > 1) | (q_arr < 0) | xp.isnan(q_arr)): | ||
| msg = "`q` values must be in the range [0, 1]" | ||
| raise ValueError(msg) | ||
|
|
||
| # Delegate when possible. | ||
| # Note: No delegation for dask: I couldn't make it work. | ||
| basic_case = method == "linear" and weights is None | ||
|
|
||
| np_2 = NUMPY_VERSION >= (2, 0) | ||
| np_handles_weights = np_2 and nan_policy == "propagate" and method == "inverted_cdf" | ||
| if weights is None: | ||
| if is_numpy_namespace(xp) and (basic_case or np_2): | ||
| quantile = xp.quantile if nan_policy == "propagate" else xp.nanquantile | ||
| return quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims) | ||
| elif is_numpy_namespace(xp) and np_handles_weights: | ||
| # TODO: call nanquantile for nan_policy == "omit" once | ||
| # https://github.com/numpy/numpy/issues/29709 is fixed | ||
| return xp.quantile( | ||
| a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights | ||
| ) | ||
|
|
||
| jax_or_cupy = is_jax_namespace(xp) or is_cupy_namespace(xp) | ||
| if jax_or_cupy and basic_case and nan_policy == "propagate": | ||
| return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims) | ||
| if is_torch_namespace(xp) and basic_case: | ||
| quantile = xp.quantile if nan_policy == "propagate" else xp.nanquantile | ||
| return quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims) | ||
|
|
||
| # Otherwise call our implementation (will sort data) | ||
| return _quantile.quantile( | ||
| # XXX: I'm not sure we want to support dask, it seems uterly slow... | ||
| a, | ||
| q_arr, | ||
| axis=axis, | ||
| method=method, | ||
| keepdims=keepdims, | ||
| nan_policy=nan_policy, | ||
| weights=weights, | ||
| xp=xp, | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sort methods to get a deterministic output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's deterministic already. But do you mean declaring methods in the sorted order?
Like this:
methods = {"averaged_inverted_cdf", "inverted_cdf", "linear"}