Skip to content

Commit 6ba1e87

Browse files
ENH: add partition and argpartition functions (#449)
Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
1 parent 9ac3e41 commit 6ba1e87

File tree

6 files changed

+376
-4
lines changed

6 files changed

+376
-4
lines changed

src/array_api_extra/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Extra array functions built on top of the array API standard."""
22

33
from ._delegation import (
4+
argpartition,
45
atleast_nd,
56
cov,
67
expand_dims,
78
isclose,
89
nan_to_num,
910
one_hot,
1011
pad,
12+
partition,
1113
sinc,
1214
)
1315
from ._lib._at import at
@@ -28,6 +30,7 @@
2830
__all__ = [
2931
"__version__",
3032
"apply_where",
33+
"argpartition",
3134
"at",
3235
"atleast_nd",
3336
"broadcast_shapes",
@@ -42,6 +45,7 @@
4245
"nunique",
4346
"one_hot",
4447
"pad",
48+
"partition",
4549
"setdiff1d",
4650
"sinc",
4751
]

src/array_api_extra/_delegation.py

Lines changed: 192 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
is_torch_namespace,
1616
)
1717
from ._lib._utils._compat import device as get_device
18-
from ._lib._utils._helpers import asarrays
18+
from ._lib._utils._helpers import asarrays, eager_shape
1919
from ._lib._utils._typing import Array, DType
2020

2121
__all__ = [
@@ -645,3 +645,194 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
645645
return xp.sinc(x)
646646

647647
return _funcs.sinc(x, xp=xp)
648+
649+
650+
def partition(
651+
a: Array,
652+
kth: int,
653+
/,
654+
axis: int | None = -1,
655+
*,
656+
xp: ModuleType | None = None,
657+
) -> Array:
658+
"""
659+
Return a partitioned copy of an array.
660+
661+
Creates a copy of the array and partially sorts it in such a way that the value
662+
of the element in k-th position is in the position it would be in a sorted array.
663+
In the output array, all elements smaller than the k-th element are located to
664+
the left of this element and all equal or greater are located to its right.
665+
The ordering of the elements in the two partitions on the either side of
666+
the k-th element in the output array is undefined.
667+
668+
Parameters
669+
----------
670+
a : Array
671+
Input array.
672+
kth : int
673+
Element index to partition by.
674+
axis : int, optional
675+
Axis along which to partition. The default is ``-1`` (the last axis).
676+
If ``None``, the flattened array is used.
677+
xp : array_namespace, optional
678+
The standard-compatible namespace for `x`. Default: infer.
679+
680+
Returns
681+
-------
682+
partitioned_array
683+
Array of the same type and shape as `a`.
684+
685+
Notes
686+
-----
687+
If `xp` implements ``partition`` or an equivalent function
688+
(e.g. ``topk`` for torch), complexity will likely be O(n).
689+
If not, this function simply calls ``xp.sort`` and complexity is O(n log n).
690+
"""
691+
# Validate inputs.
692+
if xp is None:
693+
xp = array_namespace(a)
694+
if a.ndim < 1:
695+
msg = "`a` must be at least 1-dimensional"
696+
raise TypeError(msg)
697+
if axis is None:
698+
return partition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp)
699+
(size,) = eager_shape(a, axis)
700+
if not (0 <= kth < size):
701+
msg = f"kth(={kth}) out of bounds [0 {size})"
702+
raise ValueError(msg)
703+
704+
# Delegate where possible.
705+
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
706+
return xp.partition(a, kth, axis=axis)
707+
708+
# Use top-k when possible:
709+
if is_torch_namespace(xp):
710+
if not (axis == -1 or axis == a.ndim - 1):
711+
a = xp.transpose(a, axis, -1)
712+
713+
out = xp.empty_like(a)
714+
ranks = xp.arange(a.shape[-1]).expand_as(a)
715+
716+
split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True)
717+
del indices # indices won't be used => del ASAP to reduce peak memory usage
718+
719+
# fill the left-side of the partition
720+
mask_src = a < split_value
721+
n_left = mask_src.sum(dim=-1, keepdim=True)
722+
mask_dest = ranks < n_left
723+
out[mask_dest] = a[mask_src]
724+
725+
# fill the middle of the partition
726+
mask_src = a == split_value
727+
n_left += mask_src.sum(dim=-1, keepdim=True)
728+
mask_dest ^= ranks < n_left
729+
out[mask_dest] = a[mask_src]
730+
731+
# fill the right-side of the partition
732+
mask_src = a > split_value
733+
mask_dest = ranks >= n_left
734+
out[mask_dest] = a[mask_src]
735+
736+
if not (axis == -1 or axis == a.ndim - 1):
737+
out = xp.transpose(out, axis, -1)
738+
return out
739+
740+
# Note: dask topk/argtopk sort the return values, so it's
741+
# not much more efficient than sorting everything when
742+
# kth is not small compared to x.size
743+
744+
return _funcs.partition(a, kth, axis=axis, xp=xp)
745+
746+
747+
def argpartition(
748+
a: Array,
749+
kth: int,
750+
/,
751+
axis: int | None = -1,
752+
*,
753+
xp: ModuleType | None = None,
754+
) -> Array:
755+
"""
756+
Perform an indirect partition along the given axis.
757+
758+
It returns an array of indices of the same shape as `a` that
759+
index data along the given axis in partitioned order.
760+
761+
Parameters
762+
----------
763+
a : Array
764+
Input array.
765+
kth : int
766+
Element index to partition by.
767+
axis : int, optional
768+
Axis along which to partition. The default is ``-1`` (the last axis).
769+
If ``None``, the flattened array is used.
770+
xp : array_namespace, optional
771+
The standard-compatible namespace for `x`. Default: infer.
772+
773+
Returns
774+
-------
775+
index_array
776+
Array of indices that partition `a` along the specified axis.
777+
778+
Notes
779+
-----
780+
If `xp` implements ``argpartition`` or an equivalent function
781+
e.g. ``topk`` for torch), complexity will likely be O(n).
782+
If not, this function simply calls ``xp.argsort`` and complexity is O(n log n).
783+
"""
784+
# Validate inputs.
785+
if xp is None:
786+
xp = array_namespace(a)
787+
if is_pydata_sparse_namespace(xp):
788+
msg = "Not implemented for sparse backend: no argsort"
789+
raise NotImplementedError(msg)
790+
if a.ndim < 1:
791+
msg = "`a` must be at least 1-dimensional"
792+
raise TypeError(msg)
793+
if axis is None:
794+
return argpartition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp)
795+
(size,) = eager_shape(a, axis)
796+
if not (0 <= kth < size):
797+
msg = f"kth(={kth}) out of bounds [0 {size})"
798+
raise ValueError(msg)
799+
800+
# Delegate where possible.
801+
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
802+
return xp.argpartition(a, kth, axis=axis)
803+
804+
# Use top-k when possible:
805+
if is_torch_namespace(xp):
806+
# see `partition` above for commented details of those steps:
807+
if not (axis == -1 or axis == a.ndim - 1):
808+
a = xp.transpose(a, axis, -1)
809+
810+
ranks = xp.arange(a.shape[-1]).expand_as(a)
811+
out = xp.empty_like(ranks)
812+
813+
split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True)
814+
del indices # indices won't be used => del ASAP to reduce peak memory usage
815+
816+
mask_src = a < split_value
817+
n_left = mask_src.sum(dim=-1, keepdim=True)
818+
mask_dest = ranks < n_left
819+
out[mask_dest] = ranks[mask_src]
820+
821+
mask_src = a == split_value
822+
n_left += mask_src.sum(dim=-1, keepdim=True)
823+
mask_dest ^= ranks < n_left
824+
out[mask_dest] = ranks[mask_src]
825+
826+
mask_src = a > split_value
827+
mask_dest = ranks >= n_left
828+
out[mask_dest] = ranks[mask_src]
829+
830+
if not (axis == -1 or axis == a.ndim - 1):
831+
out = xp.transpose(out, axis, -1)
832+
return out
833+
834+
# Note: dask topk/argtopk sort the return values, so it's
835+
# not much more efficient than sorting everything when
836+
# kth is not small compared to x.size
837+
838+
return _funcs.argpartition(a, kth, axis=axis, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,3 +777,27 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
777777
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
778778
)
779779
return xp.sin(y) / y
780+
781+
782+
def partition( # numpydoc ignore=PR01,RT01
783+
x: Array,
784+
kth: int, # noqa: ARG001
785+
/,
786+
axis: int = -1,
787+
*,
788+
xp: ModuleType,
789+
) -> Array:
790+
"""See docstring in `array_api_extra._delegation.py`."""
791+
return xp.sort(x, axis=axis, stable=False)
792+
793+
794+
def argpartition( # numpydoc ignore=PR01,RT01
795+
x: Array,
796+
kth: int, # noqa: ARG001
797+
/,
798+
axis: int = -1,
799+
*,
800+
xp: ModuleType,
801+
) -> Array:
802+
"""See docstring in `array_api_extra._delegation.py`."""
803+
return xp.argsort(x, axis=axis, stable=False)

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,22 +250,31 @@ def ndindex(*x: int) -> Generator[tuple[int, ...]]:
250250
yield *i, j
251251

252252

253-
def eager_shape(x: Array, /) -> tuple[int, ...]:
253+
def eager_shape(x: Array, /, axis: int | None = None) -> tuple[int, ...]:
254254
"""
255255
Return shape of an array. Raise if shape is not fully defined.
256256
257257
Parameters
258258
----------
259259
x : Array
260260
Input array.
261+
axis : int, optional
262+
If provided, only returns the tuple (shape[axis],).
261263
262264
Returns
263265
-------
264266
tuple[int, ...]
265267
Shape of the array.
266268
"""
267269
shape = x.shape
268-
# Dask arrays uses non-standard NaN instead of None
270+
if axis is not None:
271+
s = shape[axis]
272+
# Dask arrays uses non-standard NaN instead of None
273+
if s is None or math.isnan(s):
274+
msg = f"Unsupported lazy shape for axis {axis}"
275+
raise TypeError(msg)
276+
return (s,)
277+
269278
if any(s is None or math.isnan(s) for s in shape):
270279
msg = "Unsupported lazy shape"
271280
raise TypeError(msg)

0 commit comments

Comments
 (0)