Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-assorted.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@
setdiff1d
sinc
union1d
unravel_index
```
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
setdiff1d,
sinc,
union1d,
unravel_index,
)
from ._lib._at import at
from ._lib._funcs import (
Expand Down Expand Up @@ -58,4 +59,5 @@
"sinc",
"testing",
"union1d",
"unravel_index",
]
54 changes: 54 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"pad",
"searchsorted",
"sinc",
"unravel_index",
]


Expand Down Expand Up @@ -1307,3 +1308,56 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
return xp.union1d(a, b)

return _funcs.union1d(a, b, xp=xp)


def unravel_index(
ind: Array,
shape: tuple[int, ...],
/,
*,
xp: ModuleType | None = None,
) -> tuple[Array, ...]:
"""
Convert a flat index or array of flat indices into a tuple of coordinate arrays.

Parameters
----------
ind : array
An integer array whose elements are indices into the flattened version
of an array of dimensions `shape`.

shape : tuple of ints
The shape to use for unraveling `indices`.

xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
tuple of array
A tuple of unraveled indices. Each array in the tuple has the same shape
as the `indices` array.

Examples
--------
>>> import array_api_extra as xpx
>>> import array_api_strict as xp
>>> xpx.unravel_index(xp.asarray([1, 2, 3, 4, 5]), (4, 3))
(
Array([0, 0, 1, 1, 1], dtype=array_api_strict.int64),
Array([1, 2, 0, 1, 2], dtype=array_api_strict.int64),
)
"""
if xp is None:
xp = array_namespace(ind)

if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_dask_namespace(xp)
or is_jax_namespace(xp)
or is_torch_namespace(xp)
):
return xp.unravel_index(ind, shape)

return _funcs.unravel_index(ind, shape)
10 changes: 10 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,3 +757,13 @@ def angle(z: Array, /, *, deg: bool = False, xp: ModuleType | None = None) -> Ar
if deg:
a = a * 180 / xp.pi
return a


def unravel_index(ind: Array, shape: tuple[int, ...], /) -> tuple[Array, ...]:
# numpydoc ignore=PR01,RT01
"""See docstring in `array_api_extra._delegation.py`."""
coords: list[Array] = []
for dim in reversed(shape):
coords.append(ind % dim)
ind = ind // dim
return tuple(reversed(coords))
49 changes: 49 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
setdiff1d,
sinc,
union1d,
unravel_index,
)
from array_api_extra import (
searchsorted as xpx_searchsorted,
Expand Down Expand Up @@ -1981,3 +1982,51 @@ def test_2d(self, xp: ModuleType):
def test_device(self, xp: ModuleType, device: Device):
a = xp.asarray([1 + 1j], device=device)
assert get_device(angle(a)) == device


class TestUnravelIndex:
def test_simple(self, xp: ModuleType):
ind = xp.asarray([22, 41, 37])
shape = (7, 6)
expected = (xp.asarray([3, 6, 6]), xp.asarray([4, 5, 1]))
res = unravel_index(ind, shape)
for res_arr, exp_arr in zip(res, expected, strict=True):
assert_equal(res_arr, exp_arr)

ind = xp.asarray([0, 1, 2, 3, 4, 5])
shape = (3, 2)
expected = (
xp.asarray([0, 0, 1, 1, 2, 2]),
xp.asarray([0, 1, 0, 1, 0, 1]),
)
res = unravel_index(ind, shape)
for res_arr, exp_arr in zip(res, expected, strict=True):
assert_equal(res_arr, exp_arr)

def test_indices_scalar(self, xp: ModuleType):
ind = xp.asarray(1621)
shape = (6, 7, 8, 9)
expected = (xp.asarray(3), xp.asarray(1), xp.asarray(4), xp.asarray(1))
res = unravel_index(ind, shape)
# a tuple of integers is expected
assert res == expected

def test_indices_2d(self, xp: ModuleType):
ind = xp.asarray([[1234], [5678]])
shape = (10, 10, 10, 10)
expected = (
xp.asarray([[1], [5]]),
xp.asarray([[2], [6]]),
xp.asarray([[3], [7]]),
xp.asarray([[4], [8]]),
)
res = unravel_index(ind, shape)
for res_arr, exp_arr in zip(res, expected, strict=True):
assert_equal(res_arr, exp_arr)

def test_device(self, xp: ModuleType, device: Device):
ind = xp.asarray([4, 1], device=device)
shape = (3, 2)
res = unravel_index(ind, shape)
for res_arr in res:
assert get_device(res_arr) == device