From 51ade21baf381a9b84ae80bfcc530ad084c95310 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 1 Oct 2025 16:53:04 +0200 Subject: [PATCH 01/13] initial work with a small test; fails for sparse backend as it has not argsort; --- src/array_api_extra/__init__.py | 11 ++- src/array_api_extra/_delegation.py | 107 +++++++++++++++++++++++++++++ src/array_api_extra/_lib/_funcs.py | 20 ++++++ tests/test_funcs.py | 17 +++++ 4 files changed, 154 insertions(+), 1 deletion(-) diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 0b4bd5e2..1d3c0fac 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,13 @@ """Extra array functions built on top of the array API standard.""" -from ._delegation import isclose, nan_to_num, one_hot, pad +from ._delegation import ( + argpartition, + isclose, + nan_to_num, + one_hot, + pad, + partition, +) from ._lib._at import at from ._lib._funcs import ( apply_where, @@ -23,6 +30,7 @@ __all__ = [ "__version__", "apply_where", + "argpartition", "at", "atleast_nd", "broadcast_shapes", @@ -37,6 +45,7 @@ "nunique", "one_hot", "pad", + "partition", "setdiff1d", "sinc", ] diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 2c061e36..01747395 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -326,3 +326,110 @@ def pad( return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) + + +def partition( + a: Array, + kth: int, + *, + xp: ModuleType | None = None, +) -> Array: + """ + Return a partitioned copy of an array. + + Parameters + ---------- + a : 1-dimensional array + Input array. + kth : int + Element index to partition by. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + partitioned_array + Array of the same type and shape as a. + """ + # Validate inputs. + if xp is None: + xp = array_namespace(a) + if a.ndim != 1: + msg = "only 1-dimensional arrays are currently supported" + raise NotImplementedError(msg) + + # Delegate where possible. + if is_numpy_namespace(xp) or is_cupy_namespace(xp): + return xp.partition(a, kth) + if is_jax_namespace(xp): + from jax import numpy + + return numpy.partition(a, kth) + + # Use top-k when possible: + if is_torch_namespace(xp): + from torch import topk + + a_left, indices_left = topk(a, kth, largest=False, sorted=False) + mask_right = xp.ones(a.shape, dtype=bool) + mask_right[indices_left] = False + return xp.concat((a_left, a[mask_right])) + # Note: dask topk/argtopk sort the return values, so it's + # not much more efficient than sorting everything when + # kth is not small compared to x.size + + return _funcs.partition(a, kth, xp=xp) + + +def argpartition( + a: Array, + kth: int, + *, + xp: ModuleType | None = None, +) -> Array: + """ + Perform an indirect partition along the given axis. + + Parameters + ---------- + a : 1-dimensional array + Input array. + kth : int + Element index to partition by. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + index_array + Array of indices that partition `a` along the specified axis. + """ + # Validate inputs. + if xp is None: + xp = array_namespace(a) + if a.ndim != 1: + msg = "only 1-dimensional arrays are currently supported" + raise NotImplementedError(msg) + + # Delegate where possible. + if is_numpy_namespace(xp) or is_cupy_namespace(xp): + return xp.argpartition(a, kth) + if is_jax_namespace(xp): + from jax import numpy + + return numpy.argpartition(a, kth) + + # Use top-k when possible: + if is_torch_namespace(xp): + from torch import topk + + _, indices = topk(a, kth, largest=False, sorted=False) + mask = xp.ones(a.shape, dtype=bool) + mask[indices] = False + indices_above = xp.arange(a.shape[0])[mask] + return xp.concat((indices, indices_above)) + # Note: dask topk/argtopk sort the return values, so it's + # not much more efficient than sorting everything when + # kth is not small compared to x.size + + return _funcs.argpartition(a, kth, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index f61affe5..a399a6df 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -1029,3 +1029,23 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)), ) return xp.sin(y) / y + + +def partition( # numpydoc ignore=PR01,RT01 + x: Array, + kth: int, # noqa: ARG001 + *, + xp: ModuleType, +) -> Array: + """See docstring in `array_api_extra._delegation.py`.""" + return xp.sort(x, stable=False) + + +def argpartition( # numpydoc ignore=PR01,RT01 + x: Array, + kth: int, # noqa: ARG001 + *, + xp: ModuleType, +) -> Array: + """See docstring in `array_api_extra._delegation.py`.""" + return xp.argsort(x, stable=False) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 90813ecb..eba0bcc9 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -12,6 +12,7 @@ from array_api_extra import ( apply_where, + argpartition, at, atleast_nd, broadcast_shapes, @@ -25,6 +26,7 @@ nunique, one_hot, pad, + partition, setdiff1d, sinc, ) @@ -1298,3 +1300,18 @@ def test_device(self, xp: ModuleType, device: Device): def test_xp(self, xp: ModuleType): xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) + + +class TestPartition: + def test_basic(self, xp: ModuleType): + # Using 0-dimensional array + rng = np.random.default_rng(2847) + + for _ in range(100): + n = rng.integers(1, 1000) + x = xp.asarray(rng.random(size=n)) + k = int(rng.integers(1, n - 1)) + y = partition(x, k) + assert xp.max(y[:k]) <= xp.min(y[k:]) + y = x[argpartition(x, k)] + assert xp.max(y[:k]) <= xp.min(y[k:]) From 81b8ac3a5ed76573054022d70a0162c2c1de288d Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 1 Oct 2025 21:48:34 +0200 Subject: [PATCH 02/13] Support for multi-dimensional arrays --- src/array_api_extra/_delegation.py | 112 +++++++++++++++++++++-------- src/array_api_extra/_lib/_funcs.py | 8 ++- tests/test_funcs.py | 84 ++++++++++++++++++---- 3 files changed, 158 insertions(+), 46 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 01747395..d34631b6 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -331,6 +331,8 @@ def pad( def partition( a: Array, kth: int, + /, + axis: int | None = -1, *, xp: ModuleType | None = None, ) -> Array: @@ -343,6 +345,9 @@ def partition( Input array. kth : int Element index to partition by. + axis : int, optional + Axis along which to partition. The default is -1 (the last axis). + If None, the flattened array is used. xp : array_namespace, optional The standard-compatible namespace for `x`. Default: infer. @@ -354,36 +359,61 @@ def partition( # Validate inputs. if xp is None: xp = array_namespace(a) - if a.ndim != 1: - msg = "only 1-dimensional arrays are currently supported" - raise NotImplementedError(msg) + if a.ndim < 1: + msg = "`a` must be at least 1-dimensional" + raise TypeError(msg) + if axis is None: + return partition(xp.reshape(a, -1), kth, axis=0, xp=xp) + size = a.shape[axis] + if size is None: + msg = "Array dimensions must be known" + raise ValueError(msg) + if not (0 <= kth < size): + msg = f"kth(={kth}) out of bounds [0 {size})" + raise ValueError(msg) # Delegate where possible. - if is_numpy_namespace(xp) or is_cupy_namespace(xp): - return xp.partition(a, kth) - if is_jax_namespace(xp): - from jax import numpy - - return numpy.partition(a, kth) + if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp): + return xp.partition(a, kth, axis=axis) # Use top-k when possible: if is_torch_namespace(xp): - from torch import topk + if not (axis == -1 or axis == a.ndim - 1): + a = xp.transpose(a, axis, -1) - a_left, indices_left = topk(a, kth, largest=False, sorted=False) + # Get smallest `kth` elements along axis + kth += 1 # HACK: we use a non-specified behavior of torch.topk: + # in `a_left`, the element in the last position is the max + a_left, indices = xp.topk(a, kth, dim=-1, largest=False, sorted=False) + + # Build a mask to remove the selected elements mask_right = xp.ones(a.shape, dtype=bool) - mask_right[indices_left] = False - return xp.concat((a_left, a[mask_right])) + mask_right.scatter_(dim=-1, index=indices, value=False) + + # Remaining elements along axis + a_right = a[mask_right] # 1-d array + + # Reshape. This is valid only because we work on the last axis + a_right = xp.reshape(a_right, shape=(*a.shape[:-1], -1)) + + # Concatenate the two parts along axis + partitioned_array = xp.cat((a_left, a_right), dim=-1) + if not (axis == -1 or axis == a.ndim - 1): + partitioned_array = xp.transpose(partitioned_array, axis, -1) + return partitioned_array + # Note: dask topk/argtopk sort the return values, so it's # not much more efficient than sorting everything when # kth is not small compared to x.size - return _funcs.partition(a, kth, xp=xp) + return _funcs.partition(a, kth, axis=axis, xp=xp) def argpartition( a: Array, kth: int, + /, + axis: int | None = -1, *, xp: ModuleType | None = None, ) -> Array: @@ -392,10 +422,13 @@ def argpartition( Parameters ---------- - a : 1-dimensional array + a : Array Input array. kth : int Element index to partition by. + axis : int, optional + Axis along which to partition. The default is -1 (the last axis). + If None, the flattened array is used. xp : array_namespace, optional The standard-compatible namespace for `x`. Default: infer. @@ -407,29 +440,46 @@ def argpartition( # Validate inputs. if xp is None: xp = array_namespace(a) - if a.ndim != 1: - msg = "only 1-dimensional arrays are currently supported" - raise NotImplementedError(msg) + if a.ndim < 1: + msg = "`a` must be at least 1-dimensional" + raise TypeError(msg) + if axis is None: + return partition(xp.reshape(a, -1), kth, axis=0, xp=xp) + size = a.shape[axis] + if size is None: + msg = "Array dimensions must be known" + raise ValueError(msg) + if not (0 <= kth < size): + msg = f"kth(={kth}) out of bounds [0 {size})" + raise ValueError(msg) # Delegate where possible. - if is_numpy_namespace(xp) or is_cupy_namespace(xp): - return xp.argpartition(a, kth) - if is_jax_namespace(xp): - from jax import numpy - - return numpy.argpartition(a, kth) + if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp): + return xp.argpartition(a, kth, axis=axis) # Use top-k when possible: if is_torch_namespace(xp): - from torch import topk + # see `partition` above for commented details of those steps: + if not (axis == -1 or axis == a.ndim - 1): + a = xp.transpose(a, axis, -1) + + kth += 1 # HACK + _, indices_left = xp.topk(a, kth, dim=-1, largest=False, sorted=False) + + mask_right = xp.ones(a.shape, dtype=bool) + mask_right.scatter_(dim=-1, index=indices_left, value=False) + + indices_right = xp.nonzero(mask_right)[-1] + indices_right = xp.reshape(indices_right, shape=(*a.shape[:-1], -1)) + + # Concatenate the two parts along axis + index_array = xp.cat((indices_left, indices_right), dim=-1) + if not (axis == -1 or axis == a.ndim - 1): + index_array = xp.transpose(index_array, axis, -1) + return index_array - _, indices = topk(a, kth, largest=False, sorted=False) - mask = xp.ones(a.shape, dtype=bool) - mask[indices] = False - indices_above = xp.arange(a.shape[0])[mask] - return xp.concat((indices, indices_above)) # Note: dask topk/argtopk sort the return values, so it's # not much more efficient than sorting everything when # kth is not small compared to x.size - return _funcs.argpartition(a, kth, xp=xp) + return _funcs.argpartition(a, kth, axis=axis, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index a399a6df..c0e5a7a8 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -1034,18 +1034,22 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: def partition( # numpydoc ignore=PR01,RT01 x: Array, kth: int, # noqa: ARG001 + /, + axis: int = -1, *, xp: ModuleType, ) -> Array: """See docstring in `array_api_extra._delegation.py`.""" - return xp.sort(x, stable=False) + return xp.sort(x, axis=axis, stable=False) def argpartition( # numpydoc ignore=PR01,RT01 x: Array, kth: int, # noqa: ARG001 + /, + axis: int = -1, *, xp: ModuleType, ) -> Array: """See docstring in `array_api_extra._delegation.py`.""" - return xp.argsort(x, stable=False) + return xp.argsort(x, axis=axis, stable=False) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index eba0bcc9..dd1db4a2 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -9,6 +9,7 @@ import pytest from hypothesis import given from hypothesis import strategies as st +from typing_extensions import override from array_api_extra import ( apply_where, @@ -32,7 +33,12 @@ ) from array_api_extra._lib._backends import NUMPY_VERSION, Backend from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal -from array_api_extra._lib._utils._compat import device as get_device +from array_api_extra._lib._utils._compat import ( + device as get_device, +) +from array_api_extra._lib._utils._compat import ( + is_pydata_sparse_namespace, +) from array_api_extra._lib._utils._helpers import eager_shape, ndindex from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function @@ -1303,15 +1309,67 @@ def test_xp(self, xp: ModuleType): class TestPartition: - def test_basic(self, xp: ModuleType): - # Using 0-dimensional array - rng = np.random.default_rng(2847) - - for _ in range(100): - n = rng.integers(1, 1000) - x = xp.asarray(rng.random(size=n)) - k = int(rng.integers(1, n - 1)) - y = partition(x, k) - assert xp.max(y[:k]) <= xp.min(y[k:]) - y = x[argpartition(x, k)] - assert xp.max(y[:k]) <= xp.min(y[k:]) + @classmethod + def _assert_valid_partition(cls, x: Array, k: int, xp: ModuleType, axis: int = -1): + if x.ndim != 1 and axis == 0: + assert isinstance(x.shape[1], int) + for i in range(x.shape[1]): + cls._assert_valid_partition(x[:, i, ...], k, xp, axis=0) + elif x.ndim != 1: + axis = axis - 1 if axis != -1 else -1 + assert isinstance(x.shape[0], int) + for i in range(x.shape[0]): + cls._assert_valid_partition(x[i, ...], k, xp, axis=axis) + else: + if k > 0: + assert xp.max(x[:k]) <= x[k] + assert x[k] <= xp.min(x[k:]) + + @classmethod + def _partition( + cls, + x: Array, + k: int, + xp: ModuleType, # noqa: ARG003 + axis: int | None = -1, + ): + return partition(x, k, axis=axis) + + def test_1d(self, xp: ModuleType): + rng = np.random.default_rng() + for n in [2, 3, 4, 5, 7, 10, 20, 50, 100, 1_000]: + k = int(rng.integers(n)) + x = xp.asarray(rng.integers(n, size=n)) + self._assert_valid_partition(self._partition(x, k, xp), k, xp) + x = xp.asarray(rng.random(n)) + self._assert_valid_partition(self._partition(x, k, xp), k, xp) + + @pytest.mark.parametrize("ndim", [2, 3, 4, 5]) + def test_nd(self, xp: ModuleType, ndim: int): + rng = np.random.default_rng() + + for n in [2, 3, 5, 10, 20, 100]: + base_shape = [int(v) for v in rng.integers(1, 4, size=ndim)] + k = int(rng.integers(n)) + + for i in range(ndim): + shape = base_shape[:] + shape[i] = n + x = xp.asarray(rng.integers(n, size=tuple(shape))) + y = self._partition(x, k, xp, axis=i) + self._assert_valid_partition(y, k, xp, axis=i) + + +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort") +class TestArgpartition(TestPartition): + @classmethod + @override + def _partition(cls, x: Array, k: int, xp: ModuleType, axis: int | None = -1): + if is_pydata_sparse_namespace(xp): + pytest.xfail(reason="Sparse backend has no argsort") + indices = argpartition(x, k, axis=axis) + if x.ndim == 1: + return x[indices] + if not hasattr(xp, "take_along_axis"): + pytest.skip("TODO: find an alternative to take_along_axis") + return xp.take_along_axis(x, indices, axis=axis) From 45121c522627b44ba5068eaf1d7260a7f8936175 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 1 Oct 2025 21:57:08 +0200 Subject: [PATCH 03/13] Test input validation --- tests/test_funcs.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index dd1db4a2..db612acc 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1359,6 +1359,12 @@ def test_nd(self, xp: ModuleType, ndim: int): y = self._partition(x, k, xp, axis=i) self._assert_valid_partition(y, k, xp, axis=i) + def test_input_validation(self, xp: ModuleType): + with pytest.raises(TypeError): + _ = self._partition(xp.asarray(1), 1, xp) + with pytest.raises(ValueError, match="out of bounds"): + _ = self._partition(xp.asarray([1, 2]), 3, xp) + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort") class TestArgpartition(TestPartition): From 74c509ff8abb06500a26acfb087e782372c798c4 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Oct 2025 12:00:49 +0200 Subject: [PATCH 04/13] adress PR comments --- src/array_api_extra/_delegation.py | 12 +++--------- src/array_api_extra/_lib/_utils/_helpers.py | 13 +++++++++++-- tests/test_helpers.py | 3 +++ 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index d34631b6..9d32cf23 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -15,7 +15,7 @@ is_torch_namespace, ) from ._lib._utils._compat import device as get_device -from ._lib._utils._helpers import asarrays +from ._lib._utils._helpers import asarrays, eager_shape from ._lib._utils._typing import Array, DType __all__ = ["isclose", "nan_to_num", "one_hot", "pad"] @@ -364,10 +364,7 @@ def partition( raise TypeError(msg) if axis is None: return partition(xp.reshape(a, -1), kth, axis=0, xp=xp) - size = a.shape[axis] - if size is None: - msg = "Array dimensions must be known" - raise ValueError(msg) + (size,) = eager_shape(a, axis) if not (0 <= kth < size): msg = f"kth(={kth}) out of bounds [0 {size})" raise ValueError(msg) @@ -445,10 +442,7 @@ def argpartition( raise TypeError(msg) if axis is None: return partition(xp.reshape(a, -1), kth, axis=0, xp=xp) - size = a.shape[axis] - if size is None: - msg = "Array dimensions must be known" - raise ValueError(msg) + (size,) = eager_shape(a, axis) if not (0 <= kth < size): msg = f"kth(={kth}) out of bounds [0 {size})" raise ValueError(msg) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 6dd94a38..fbe986a1 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -250,7 +250,7 @@ def ndindex(*x: int) -> Generator[tuple[int, ...]]: yield *i, j -def eager_shape(x: Array, /) -> tuple[int, ...]: +def eager_shape(x: Array, /, axis: int | None = None) -> tuple[int, ...]: """ Return shape of an array. Raise if shape is not fully defined. @@ -258,6 +258,8 @@ def eager_shape(x: Array, /) -> tuple[int, ...]: ---------- x : Array Input array. + axis : int, optional + If provided, only returns the tuple (shape[axis],). Returns ------- @@ -265,7 +267,14 @@ def eager_shape(x: Array, /) -> tuple[int, ...]: Shape of the array. """ shape = x.shape - # Dask arrays uses non-standard NaN instead of None + if axis is not None: + s = shape[axis] + # Dask arrays uses non-standard NaN instead of None + if s is None or math.isnan(s): + msg = f"Unsupported lazy shape for axis {axis}" + raise TypeError(msg) + return (s,) + if any(s is None or math.isnan(s) for s in shape): msg = "Unsupported lazy shape" raise TypeError(msg) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 77ba8cd8..74ad3a19 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -182,11 +182,14 @@ def test_eager_shape(xp: ModuleType, library: Backend): # Lazy arrays, like Dask, have an eager shape until you slice them with # a lazy boolean mask assert eager_shape(a) == a.shape == (3,) + assert eager_shape(a, axis=0) == a.shape == (3,) b = a[a > 2] if library is Backend.DASK: with pytest.raises(TypeError, match="Unsupported lazy shape"): _ = eager_shape(b) + with pytest.raises(TypeError, match="Unsupported lazy shape"): + _ = eager_shape(b, axis=0) # FIXME can't test use case for None in the shape until we add support for # other lazy backends else: From 6efc73a12e0daab714133a840c1278029ad91310 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Oct 2025 12:32:13 +0200 Subject: [PATCH 05/13] improved tests & coverage --- src/array_api_extra/_delegation.py | 4 +- tests/test_funcs.py | 87 ++++++++++++++++++------------ 2 files changed, 55 insertions(+), 36 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 9d32cf23..37aeec63 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -363,7 +363,7 @@ def partition( msg = "`a` must be at least 1-dimensional" raise TypeError(msg) if axis is None: - return partition(xp.reshape(a, -1), kth, axis=0, xp=xp) + return partition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp) (size,) = eager_shape(a, axis) if not (0 <= kth < size): msg = f"kth(={kth}) out of bounds [0 {size})" @@ -441,7 +441,7 @@ def argpartition( msg = "`a` must be at least 1-dimensional" raise TypeError(msg) if axis is None: - return partition(xp.reshape(a, -1), kth, axis=0, xp=xp) + return argpartition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp) (size,) = eager_shape(a, axis) if not (0 <= kth < size): msg = f"kth(={kth}) out of bounds [0 {size})" diff --git a/tests/test_funcs.py b/tests/test_funcs.py index db612acc..c229e594 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1310,41 +1310,52 @@ def test_xp(self, xp: ModuleType): class TestPartition: @classmethod - def _assert_valid_partition(cls, x: Array, k: int, xp: ModuleType, axis: int = -1): - if x.ndim != 1 and axis == 0: - assert isinstance(x.shape[1], int) - for i in range(x.shape[1]): - cls._assert_valid_partition(x[:, i, ...], k, xp, axis=0) - elif x.ndim != 1: + def _assert_valid_partition( + cls, + x_np: np.ndarray | None, + k: int, + y: Array, + xp: ModuleType, + axis: int | None = -1, + ): + """ + x : input array + k : int + y : output array returned by the partition function to test + """ + if x_np is not None: + assert y.shape == np.partition(x_np, k, axis=axis).shape + if y.ndim != 1 and axis == 0: + assert isinstance(y.shape[1], int) + for i in range(y.shape[1]): + cls._assert_valid_partition(None, k, y[:, i, ...], xp, axis=0) + elif y.ndim != 1: + assert axis is not None axis = axis - 1 if axis != -1 else -1 - assert isinstance(x.shape[0], int) - for i in range(x.shape[0]): - cls._assert_valid_partition(x[i, ...], k, xp, axis=axis) + assert isinstance(y.shape[0], int) + for i in range(y.shape[0]): + cls._assert_valid_partition(None, k, y[i, ...], xp, axis=axis) else: if k > 0: - assert xp.max(x[:k]) <= x[k] - assert x[k] <= xp.min(x[k:]) + assert xp.max(y[:k]) <= y[k] + assert y[k] <= xp.min(y[k:]) @classmethod - def _partition( - cls, - x: Array, - k: int, - xp: ModuleType, # noqa: ARG003 - axis: int | None = -1, - ): - return partition(x, k, axis=axis) + def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1): + return partition(xp.asarray(x), k, axis=axis) def test_1d(self, xp: ModuleType): rng = np.random.default_rng() for n in [2, 3, 4, 5, 7, 10, 20, 50, 100, 1_000]: k = int(rng.integers(n)) - x = xp.asarray(rng.integers(n, size=n)) - self._assert_valid_partition(self._partition(x, k, xp), k, xp) - x = xp.asarray(rng.random(n)) - self._assert_valid_partition(self._partition(x, k, xp), k, xp) - - @pytest.mark.parametrize("ndim", [2, 3, 4, 5]) + x1 = rng.integers(n, size=n) + y = self._partition(x1, k, xp) + self._assert_valid_partition(x1, k, y, xp) + x2 = rng.random(n) + y = self._partition(x2, k, xp) + self._assert_valid_partition(x2, k, y, xp) + + @pytest.mark.parametrize("ndim", [2, 3, 4]) def test_nd(self, xp: ModuleType, ndim: int): rng = np.random.default_rng() @@ -1355,27 +1366,35 @@ def test_nd(self, xp: ModuleType, ndim: int): for i in range(ndim): shape = base_shape[:] shape[i] = n - x = xp.asarray(rng.integers(n, size=tuple(shape))) + x = rng.integers(n, size=tuple(shape)) y = self._partition(x, k, xp, axis=i) - self._assert_valid_partition(y, k, xp, axis=i) + self._assert_valid_partition(x, k, y, xp, axis=i) + + z = rng.random(tuple(base_shape)) + k = int(rng.integers(z.size)) + y = self._partition(z, k, xp, axis=None) + self._assert_valid_partition(z, k, y, xp, axis=None) def test_input_validation(self, xp: ModuleType): with pytest.raises(TypeError): - _ = self._partition(xp.asarray(1), 1, xp) + _ = self._partition(np.asarray(1), 1, xp) with pytest.raises(ValueError, match="out of bounds"): - _ = self._partition(xp.asarray([1, 2]), 3, xp) + _ = self._partition(np.asarray([1, 2]), 3, xp) @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort") class TestArgpartition(TestPartition): @classmethod @override - def _partition(cls, x: Array, k: int, xp: ModuleType, axis: int | None = -1): + def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1): if is_pydata_sparse_namespace(xp): pytest.xfail(reason="Sparse backend has no argsort") - indices = argpartition(x, k, axis=axis) - if x.ndim == 1: - return x[indices] + arr = xp.asarray(x) + indices = argpartition(arr, k, axis=axis) + if axis is None: + arr = xp.reshape(arr, shape=(-1,)) + if arr.ndim == 1: + return arr[indices] if not hasattr(xp, "take_along_axis"): pytest.skip("TODO: find an alternative to take_along_axis") - return xp.take_along_axis(x, indices, axis=axis) + return xp.take_along_axis(arr, indices, axis=axis) From 6e69083077f238277419a9365f00bbcc73f82d87 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Oct 2025 13:23:14 +0200 Subject: [PATCH 06/13] fix xfail thingy --- src/array_api_extra/_delegation.py | 3 +++ tests/test_funcs.py | 35 ++++++++++++++++++++++-------- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 37aeec63..14ddb094 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -437,6 +437,9 @@ def argpartition( # Validate inputs. if xp is None: xp = array_namespace(a) + if is_pydata_sparse_namespace(xp): + msg = "Not implemented for sparse backend" + raise NotImplementedError(msg) if a.ndim < 1: msg = "`a` must be at least 1-dimensional" raise TypeError(msg) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index c229e594..ca8b19fc 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -36,9 +36,6 @@ from array_api_extra._lib._utils._compat import ( device as get_device, ) -from array_api_extra._lib._utils._compat import ( - is_pydata_sparse_namespace, -) from array_api_extra._lib._utils._helpers import eager_shape, ndindex from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function @@ -1344,7 +1341,7 @@ def _assert_valid_partition( def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1): return partition(xp.asarray(x), k, axis=axis) - def test_1d(self, xp: ModuleType): + def _test_1d(self, xp: ModuleType): rng = np.random.default_rng() for n in [2, 3, 4, 5, 7, 10, 20, 50, 100, 1_000]: k = int(rng.integers(n)) @@ -1355,8 +1352,7 @@ def test_1d(self, xp: ModuleType): y = self._partition(x2, k, xp) self._assert_valid_partition(x2, k, y, xp) - @pytest.mark.parametrize("ndim", [2, 3, 4]) - def test_nd(self, xp: ModuleType, ndim: int): + def _test_nd(self, xp: ModuleType, ndim: int): rng = np.random.default_rng() for n in [2, 3, 5, 10, 20, 100]: @@ -1375,20 +1371,28 @@ def test_nd(self, xp: ModuleType, ndim: int): y = self._partition(z, k, xp, axis=None) self._assert_valid_partition(z, k, y, xp, axis=None) - def test_input_validation(self, xp: ModuleType): + def _test_input_validation(self, xp: ModuleType): with pytest.raises(TypeError): _ = self._partition(np.asarray(1), 1, xp) with pytest.raises(ValueError, match="out of bounds"): _ = self._partition(np.asarray([1, 2]), 3, xp) + def test_1d(self, xp: ModuleType): + self._test_1d(xp) + + @pytest.mark.parametrize("ndim", [2, 3, 4]) + def test_nd(self, xp: ModuleType, ndim: int): + self._test_nd(xp, ndim) + + def test_input_validation(self, xp: ModuleType): + self._test_input_validation(xp) + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no argsort") class TestArgpartition(TestPartition): @classmethod @override def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1): - if is_pydata_sparse_namespace(xp): - pytest.xfail(reason="Sparse backend has no argsort") arr = xp.asarray(x) indices = argpartition(arr, k, axis=axis) if axis is None: @@ -1398,3 +1402,16 @@ def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1 if not hasattr(xp, "take_along_axis"): pytest.skip("TODO: find an alternative to take_along_axis") return xp.take_along_axis(arr, indices, axis=axis) + + @override + def test_1d(self, xp: ModuleType): + self._test_1d(xp) + + @pytest.mark.parametrize("ndim", [2, 3, 4]) + @override + def test_nd(self, xp: ModuleType, ndim: int): + self._test_nd(xp, ndim) + + @override + def test_input_validation(self, xp: ModuleType): + self._test_input_validation(xp) From 259f93d5881afeffe943f35947d6fef99f4e24e1 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Oct 2025 14:03:35 +0200 Subject: [PATCH 07/13] add in err msg: no argsort --- src/array_api_extra/_delegation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 14ddb094..93f9f4b3 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -438,7 +438,7 @@ def argpartition( if xp is None: xp = array_namespace(a) if is_pydata_sparse_namespace(xp): - msg = "Not implemented for sparse backend" + msg = "Not implemented for sparse backend: no argsort" raise NotImplementedError(msg) if a.ndim < 1: msg = "`a` must be at least 1-dimensional" From c88e93e5f114e3955ebacebcce6e7efccf330171 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Oct 2025 17:31:19 +0200 Subject: [PATCH 08/13] Docstring improvements --- src/array_api_extra/_delegation.py | 32 +++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 93f9f4b3..470c6021 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -339,22 +339,33 @@ def partition( """ Return a partitioned copy of an array. - Parameters + Creates a copy of the array and partially sorts it in such a way that the value + of the element in k-th position is in the position it would be in a sorted array. + In the output array, all elements smaller than the k-th element are located to + the left of this element and all equal or greater are located to its right. + The ordering of the elements in the two partitions on the either side of + the k-th element in the output array is undefined. + ---------- - a : 1-dimensional array + a : Array Input array. kth : int Element index to partition by. axis : int, optional - Axis along which to partition. The default is -1 (the last axis). - If None, the flattened array is used. + Axis along which to partition. The default is ``-1`` (the last axis). + If ``None``, the flattened array is used. xp : array_namespace, optional The standard-compatible namespace for `x`. Default: infer. Returns ------- partitioned_array - Array of the same type and shape as a. + Array of the same type and shape as `a`. + + Notes: + If `xp` implements `partition` or an equivalent method (e.g. topk for torch), + complexity will likely be O(n). + If not, this function simply calls `xp.sort` and complexity is O(n log n). """ # Validate inputs. if xp is None: @@ -416,6 +427,8 @@ def argpartition( ) -> Array: """ Perform an indirect partition along the given axis. + It returns an array of indices of the same shape as `a` that + index data along the given axis in partitioned order. Parameters ---------- @@ -424,8 +437,8 @@ def argpartition( kth : int Element index to partition by. axis : int, optional - Axis along which to partition. The default is -1 (the last axis). - If None, the flattened array is used. + Axis along which to partition. The default is ``-1`` (the last axis). + If ``None``, the flattened array is used. xp : array_namespace, optional The standard-compatible namespace for `x`. Default: infer. @@ -433,6 +446,11 @@ def argpartition( ------- index_array Array of indices that partition `a` along the specified axis. + + Notes: + If `xp` implements `argpartition` or an equivalent method (e.g. topk for torch), + complexity will likely be O(n). + If not, this function simply calls `xp.argsort` and complexity is O(n log n). """ # Validate inputs. if xp is None: From 579b3bcea3cfeb90fda99126622eb1fa42d9b6de Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Oct 2025 17:35:30 +0200 Subject: [PATCH 09/13] rewrite of the torch logic --- src/array_api_extra/_delegation.py | 87 ++++++++++++++++++------------ 1 file changed, 52 insertions(+), 35 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 470c6021..6069628a 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -346,6 +346,12 @@ def partition( The ordering of the elements in the two partitions on the either side of the k-th element in the output array is undefined. + Notes: + If `xp` implements `partition` or an equivalent method (e.g. topk for torch), + complexity will likely be O(n). + If not, this function simply calls `xp.sort` and complexity is O(n log n). + + Parameters ---------- a : Array Input array. @@ -361,11 +367,6 @@ def partition( ------- partitioned_array Array of the same type and shape as `a`. - - Notes: - If `xp` implements `partition` or an equivalent method (e.g. topk for torch), - complexity will likely be O(n). - If not, this function simply calls `xp.sort` and complexity is O(n log n). """ # Validate inputs. if xp is None: @@ -389,26 +390,32 @@ def partition( if not (axis == -1 or axis == a.ndim - 1): a = xp.transpose(a, axis, -1) - # Get smallest `kth` elements along axis - kth += 1 # HACK: we use a non-specified behavior of torch.topk: - # in `a_left`, the element in the last position is the max - a_left, indices = xp.topk(a, kth, dim=-1, largest=False, sorted=False) + out = xp.empty_like(a) + ranks = xp.arange(a.shape[-1]).expand_as(a) + + split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True) + del indices - # Build a mask to remove the selected elements - mask_right = xp.ones(a.shape, dtype=bool) - mask_right.scatter_(dim=-1, index=indices, value=False) + # fill the left-side of the partition + mask_src = a < split_value + n_left = mask_src.sum(dim=-1, keepdim=True) + mask_dest = ranks < n_left + out[mask_dest] = a[mask_src] - # Remaining elements along axis - a_right = a[mask_right] # 1-d array + # fill the middle of the partition + mask_src = a == split_value + n_left += mask_src.sum(dim=-1, keepdim=True) + mask_dest ^= ranks < n_left + out[mask_dest] = a[mask_src] - # Reshape. This is valid only because we work on the last axis - a_right = xp.reshape(a_right, shape=(*a.shape[:-1], -1)) + # fill the right-side of the partition + mask_src = a > split_value + mask_dest = ranks >= n_left + out[mask_dest] = a[mask_src] - # Concatenate the two parts along axis - partitioned_array = xp.cat((a_left, a_right), dim=-1) if not (axis == -1 or axis == a.ndim - 1): - partitioned_array = xp.transpose(partitioned_array, axis, -1) - return partitioned_array + out = xp.transpose(out, axis, -1) + return out # Note: dask topk/argtopk sort the return values, so it's # not much more efficient than sorting everything when @@ -427,9 +434,15 @@ def argpartition( ) -> Array: """ Perform an indirect partition along the given axis. + It returns an array of indices of the same shape as `a` that index data along the given axis in partitioned order. + Notes: + If `xp` implements `argpartition` or an equivalent method (e.g. topk for torch), + complexity will likely be O(n). + If not, this function simply calls `xp.argsort` and complexity is O(n log n). + Parameters ---------- a : Array @@ -446,11 +459,6 @@ def argpartition( ------- index_array Array of indices that partition `a` along the specified axis. - - Notes: - If `xp` implements `argpartition` or an equivalent method (e.g. topk for torch), - complexity will likely be O(n). - If not, this function simply calls `xp.argsort` and complexity is O(n log n). """ # Validate inputs. if xp is None: @@ -478,20 +486,29 @@ def argpartition( if not (axis == -1 or axis == a.ndim - 1): a = xp.transpose(a, axis, -1) - kth += 1 # HACK - _, indices_left = xp.topk(a, kth, dim=-1, largest=False, sorted=False) + ranks = xp.arange(a.shape[-1]).expand_as(a) + out = xp.empty_like(ranks) + + split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True) + del indices + + mask_src = a < split_value + n_left = mask_src.sum(dim=-1, keepdim=True) + mask_dest = ranks < n_left + out[mask_dest] = ranks[mask_src] - mask_right = xp.ones(a.shape, dtype=bool) - mask_right.scatter_(dim=-1, index=indices_left, value=False) + mask_src = a == split_value + n_left += mask_src.sum(dim=-1, keepdim=True) + mask_dest ^= ranks < n_left + out[mask_dest] = ranks[mask_src] - indices_right = xp.nonzero(mask_right)[-1] - indices_right = xp.reshape(indices_right, shape=(*a.shape[:-1], -1)) + mask_src = a > split_value + mask_dest = ranks >= n_left + out[mask_dest] = ranks[mask_src] - # Concatenate the two parts along axis - index_array = xp.cat((indices_left, indices_right), dim=-1) if not (axis == -1 or axis == a.ndim - 1): - index_array = xp.transpose(index_array, axis, -1) - return index_array + out = xp.transpose(out, axis, -1) + return out # Note: dask topk/argtopk sort the return values, so it's # not much more efficient than sorting everything when From c2827da81ee64d2a3474ea6486da98bd34bd9d4c Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Oct 2025 17:47:52 +0200 Subject: [PATCH 10/13] dask support in argpart tests --- tests/test_funcs.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index ca8b19fc..1ab146cf 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1316,7 +1316,7 @@ def _assert_valid_partition( axis: int | None = -1, ): """ - x : input array + x_np : input array k : int y : output array returned by the partition function to test """ @@ -1397,11 +1397,31 @@ def _partition(cls, x: np.ndarray, k: int, xp: ModuleType, axis: int | None = -1 indices = argpartition(arr, k, axis=axis) if axis is None: arr = xp.reshape(arr, shape=(-1,)) + return arr[indices] + if arr.ndim == 1: + return arr[indices] + return cls._take_along_axis(arr, indices, axis=axis, xp=xp) + + @classmethod + def _take_along_axis(cls, arr: Array, indices: Array, axis: int, xp: ModuleType): + if hasattr(xp, "take_along_axis"): + return xp.take_along_axis(arr, indices, axis=axis) if arr.ndim == 1: return arr[indices] - if not hasattr(xp, "take_along_axis"): - pytest.skip("TODO: find an alternative to take_along_axis") - return xp.take_along_axis(arr, indices, axis=axis) + if axis == 0: + assert isinstance(arr.shape[1], int) + arrs = [] + for i in range(arr.shape[1]): + arrs.append(cls._take_along_axis(arr[:, i, ...], indices[:, i, ...], + axis=0, xp=xp)) + return xp.stack(arrs, axis=1) + axis = axis - 1 if axis != -1 else -1 + assert isinstance(arr.shape[0], int) + arrs = [] + for i in range(arr.shape[0]): + arrs.append(cls._take_along_axis(arr[i, ...], indices[i, ...], + axis=axis, xp=xp)) + return xp.stack(arrs, axis=0) @override def test_1d(self, xp: ModuleType): From b2a567e6fc0525b9d7fdb158f9d8dbcdc278fcaf Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Fri, 3 Oct 2025 16:36:42 +0200 Subject: [PATCH 11/13] Fix docstring format Co-authored-by: Lucas Colley --- src/array_api_extra/_delegation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 807ecf35..66022e5c 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -529,9 +529,9 @@ def partition( the k-th element in the output array is undefined. Notes: - If `xp` implements `partition` or an equivalent method (e.g. topk for torch), + If `xp` implements ``partition`` or an equivalent function (e.g. ``topk`` for torch), complexity will likely be O(n). - If not, this function simply calls `xp.sort` and complexity is O(n log n). + If not, this function simply calls ``xp.sort`` and complexity is O(n log n). Parameters ---------- From ac19a2393e94462f0f0209d85e272efe18fbbb38 Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 3 Oct 2025 16:50:24 +0200 Subject: [PATCH 12/13] numpy doc: notes format --- src/array_api_extra/_delegation.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 66022e5c..dc333cc7 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -528,11 +528,6 @@ def partition( The ordering of the elements in the two partitions on the either side of the k-th element in the output array is undefined. - Notes: - If `xp` implements ``partition`` or an equivalent function (e.g. ``topk`` for torch), - complexity will likely be O(n). - If not, this function simply calls ``xp.sort`` and complexity is O(n log n). - Parameters ---------- a : Array @@ -549,6 +544,12 @@ def partition( ------- partitioned_array Array of the same type and shape as `a`. + + Notes + ----- + If `xp` implements ``partition`` or an equivalent function + (e.g. ``topk`` for torch), complexity will likely be O(n). + If not, this function simply calls ``xp.sort`` and complexity is O(n log n). """ # Validate inputs. if xp is None: @@ -620,11 +621,6 @@ def argpartition( It returns an array of indices of the same shape as `a` that index data along the given axis in partitioned order. - Notes: - If `xp` implements `argpartition` or an equivalent method (e.g. topk for torch), - complexity will likely be O(n). - If not, this function simply calls `xp.argsort` and complexity is O(n log n). - Parameters ---------- a : Array @@ -641,6 +637,12 @@ def argpartition( ------- index_array Array of indices that partition `a` along the specified axis. + + Notes + ----- + If `xp` implements ``argpartition`` or an equivalent function + e.g. ``topk`` for torch), complexity will likely be O(n). + If not, this function simply calls ``xp.argsort`` and complexity is O(n log n). """ # Validate inputs. if xp is None: From 1084835a6d7ac9ceb1f5cb3b66a029b086c3daad Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 3 Oct 2025 16:57:57 +0200 Subject: [PATCH 13/13] comment about peak memory --- src/array_api_extra/_delegation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 5710ff18..7f467366 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -714,7 +714,7 @@ def partition( ranks = xp.arange(a.shape[-1]).expand_as(a) split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True) - del indices + del indices # indices won't be used => del ASAP to reduce peak memory usage # fill the left-side of the partition mask_src = a < split_value @@ -811,7 +811,7 @@ def argpartition( out = xp.empty_like(ranks) split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True) - del indices + del indices # indices won't be used => del ASAP to reduce peak memory usage mask_src = a < split_value n_left = mask_src.sum(dim=-1, keepdim=True)