Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import (
argpartition,
atleast_nd,
cov,
expand_dims,
isclose,
nan_to_num,
one_hot,
pad,
partition,
sinc,
)
from ._lib._at import at
Expand All @@ -28,6 +30,7 @@
__all__ = [
"__version__",
"apply_where",
"argpartition",
"at",
"atleast_nd",
"broadcast_shapes",
Expand All @@ -42,6 +45,7 @@
"nunique",
"one_hot",
"pad",
"partition",
"setdiff1d",
"sinc",
]
193 changes: 192 additions & 1 deletion src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
is_torch_namespace,
)
from ._lib._utils._compat import device as get_device
from ._lib._utils._helpers import asarrays
from ._lib._utils._helpers import asarrays, eager_shape
from ._lib._utils._typing import Array, DType

__all__ = [
Expand Down Expand Up @@ -645,3 +645,194 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
return xp.sinc(x)

return _funcs.sinc(x, xp=xp)


def partition(
a: Array,
kth: int,
/,
axis: int | None = -1,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Return a partitioned copy of an array.

Creates a copy of the array and partially sorts it in such a way that the value
of the element in k-th position is in the position it would be in a sorted array.
In the output array, all elements smaller than the k-th element are located to
the left of this element and all equal or greater are located to its right.
The ordering of the elements in the two partitions on the either side of
the k-th element in the output array is undefined.

Parameters
----------
a : Array
Input array.
kth : int
Element index to partition by.
axis : int, optional
Axis along which to partition. The default is ``-1`` (the last axis).
If ``None``, the flattened array is used.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
partitioned_array
Array of the same type and shape as `a`.

Notes
-----
If `xp` implements ``partition`` or an equivalent function
(e.g. ``topk`` for torch), complexity will likely be O(n).
If not, this function simply calls ``xp.sort`` and complexity is O(n log n).
"""
# Validate inputs.
if xp is None:
xp = array_namespace(a)
if a.ndim < 1:
msg = "`a` must be at least 1-dimensional"
raise TypeError(msg)
if axis is None:
return partition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp)
(size,) = eager_shape(a, axis)
if not (0 <= kth < size):
msg = f"kth(={kth}) out of bounds [0 {size})"
raise ValueError(msg)

# Delegate where possible.
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
return xp.partition(a, kth, axis=axis)

# Use top-k when possible:
if is_torch_namespace(xp):
if not (axis == -1 or axis == a.ndim - 1):
a = xp.transpose(a, axis, -1)

out = xp.empty_like(a)
ranks = xp.arange(a.shape[-1]).expand_as(a)

split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True)
del indices # indices won't be used => del ASAP to reduce peak memory usage

# fill the left-side of the partition
mask_src = a < split_value
n_left = mask_src.sum(dim=-1, keepdim=True)
mask_dest = ranks < n_left
out[mask_dest] = a[mask_src]

# fill the middle of the partition
mask_src = a == split_value
n_left += mask_src.sum(dim=-1, keepdim=True)
mask_dest ^= ranks < n_left
out[mask_dest] = a[mask_src]

# fill the right-side of the partition
mask_src = a > split_value
mask_dest = ranks >= n_left
out[mask_dest] = a[mask_src]

if not (axis == -1 or axis == a.ndim - 1):
out = xp.transpose(out, axis, -1)
return out

# Note: dask topk/argtopk sort the return values, so it's
# not much more efficient than sorting everything when
# kth is not small compared to x.size

return _funcs.partition(a, kth, axis=axis, xp=xp)


def argpartition(
a: Array,
kth: int,
/,
axis: int | None = -1,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Perform an indirect partition along the given axis.

It returns an array of indices of the same shape as `a` that
index data along the given axis in partitioned order.

Parameters
----------
a : Array
Input array.
kth : int
Element index to partition by.
axis : int, optional
Axis along which to partition. The default is ``-1`` (the last axis).
If ``None``, the flattened array is used.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
index_array
Array of indices that partition `a` along the specified axis.

Notes
-----
If `xp` implements ``argpartition`` or an equivalent function
e.g. ``topk`` for torch), complexity will likely be O(n).
If not, this function simply calls ``xp.argsort`` and complexity is O(n log n).
"""
# Validate inputs.
if xp is None:
xp = array_namespace(a)
if is_pydata_sparse_namespace(xp):
msg = "Not implemented for sparse backend: no argsort"
raise NotImplementedError(msg)
if a.ndim < 1:
msg = "`a` must be at least 1-dimensional"
raise TypeError(msg)
if axis is None:
return argpartition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp)
(size,) = eager_shape(a, axis)
if not (0 <= kth < size):
msg = f"kth(={kth}) out of bounds [0 {size})"
raise ValueError(msg)

# Delegate where possible.
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
return xp.argpartition(a, kth, axis=axis)

# Use top-k when possible:
if is_torch_namespace(xp):
# see `partition` above for commented details of those steps:
if not (axis == -1 or axis == a.ndim - 1):
a = xp.transpose(a, axis, -1)

ranks = xp.arange(a.shape[-1]).expand_as(a)
out = xp.empty_like(ranks)

split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True)
del indices # indices won't be used => del ASAP to reduce peak memory usage

mask_src = a < split_value
n_left = mask_src.sum(dim=-1, keepdim=True)
mask_dest = ranks < n_left
out[mask_dest] = ranks[mask_src]

mask_src = a == split_value
n_left += mask_src.sum(dim=-1, keepdim=True)
mask_dest ^= ranks < n_left
out[mask_dest] = ranks[mask_src]

mask_src = a > split_value
mask_dest = ranks >= n_left
out[mask_dest] = ranks[mask_src]

if not (axis == -1 or axis == a.ndim - 1):
out = xp.transpose(out, axis, -1)
return out

# Note: dask topk/argtopk sort the return values, so it's
# not much more efficient than sorting everything when
# kth is not small compared to x.size

return _funcs.argpartition(a, kth, axis=axis, xp=xp)
24 changes: 24 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,27 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
)
return xp.sin(y) / y


def partition( # numpydoc ignore=PR01,RT01
x: Array,
kth: int, # noqa: ARG001
/,
axis: int = -1,
*,
xp: ModuleType,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
return xp.sort(x, axis=axis, stable=False)


def argpartition( # numpydoc ignore=PR01,RT01
x: Array,
kth: int, # noqa: ARG001
/,
axis: int = -1,
*,
xp: ModuleType,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
return xp.argsort(x, axis=axis, stable=False)
13 changes: 11 additions & 2 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,22 +250,31 @@ def ndindex(*x: int) -> Generator[tuple[int, ...]]:
yield *i, j


def eager_shape(x: Array, /) -> tuple[int, ...]:
def eager_shape(x: Array, /, axis: int | None = None) -> tuple[int, ...]:
"""
Return shape of an array. Raise if shape is not fully defined.

Parameters
----------
x : Array
Input array.
axis : int, optional
If provided, only returns the tuple (shape[axis],).

Returns
-------
tuple[int, ...]
Shape of the array.
"""
shape = x.shape
# Dask arrays uses non-standard NaN instead of None
if axis is not None:
s = shape[axis]
# Dask arrays uses non-standard NaN instead of None
if s is None or math.isnan(s):
msg = f"Unsupported lazy shape for axis {axis}"
raise TypeError(msg)
return (s,)

if any(s is None or math.isnan(s) for s in shape):
msg = "Unsupported lazy shape"
raise TypeError(msg)
Expand Down
Loading