Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
atleast_nd,
cov,
expand_dims,
fill_diagonal,
isclose,
isin,
nan_to_num,
Expand Down Expand Up @@ -39,6 +40,7 @@
"create_diagonal",
"default_dtype",
"expand_dims",
"fill_diagonal",
"isclose",
"isin",
"kron",
Expand Down
53 changes: 53 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,3 +895,56 @@ def isin(
return xp.isin(a, b, assume_unique=assume_unique, invert=invert)

return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp)


def fill_diagonal(
a: Array,
val: Array | int | float,
*,
wrap: bool = False,
xp: ModuleType | None = None,
) -> None | Array:
"""
Fill the main diagonal of the given array `a` of any dimensionality >= 2.

For an array `a` with ``a.ndim >= 2``, the diagonal is the list of
values ``a[i, ..., i]`` with indices ``i`` all identical. This function
modifies the input array in-place without returning a value. However
specifically for JAX, a copy of the array `a` with the diagonal elements
overwritten is returned. This is because it is not possible to modify JAX's
immutable arrays in-place.

Parameters
----------
a : Array
Input array whose diagonal is to be filled. It should be at least 2-D.

val : Array | int | float
Value(s) to write on the diagonal. If `val` is a scalar, the value is
written along the diagonal. If `val` is an Array, the flattened `val`
is written along the diagonal.

wrap : bool, optional
Only applicable for NumPy and Cupy. For tall matrices the
diagonal is "wrapped" after N columns. Default: False.

xp : array_namespace, optional
The standard-compatible namespace for `a` and `val`. Default: infer.

Returns
-------
Array | None
For JAX a copy of the original array `a` is returned. For all other
cases the array `a` is modified in-place so None is returned.
"""
if xp is None:
xp = array_namespace(a, val)

if is_jax_namespace(xp):
return xp.fill_diagonal(a, val, inplace=False)

if is_numpy_namespace(xp) or is_cupy_namespace(xp):
xp.fill_diagonal(a, val, wrap=wrap)

_funcs.fill_diagonal(a, val, xp=xp)
return None
40 changes: 39 additions & 1 deletion src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
from ._utils._compat import (
array_namespace,
is_dask_namespace,
is_jax_array,
is_torch_namespace,
)
from ._utils._helpers import (
asarrays,
capabilities,
Expand Down Expand Up @@ -820,3 +825,36 @@ def isin( # numpydoc ignore=PR01,RT01
_helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp),
original_a_shape,
)


def fill_diagonal( # numpydoc ignore=PR01,RT01
a: Array,
val: Array | int | float,
*,
xp: ModuleType,
) -> None:
"""See docstring in `array_api_extra._delegation.py`."""
if a.ndim < 2:
msg = f"array `a` must be at least 2-d. Got array with shape {tuple(a.shape)}"
raise ValueError(msg)

a, val = asarrays(a, val, xp=xp)
min_rows_columns = min(x or 0 for x in a.shape)
if is_torch_namespace(xp):
val_size = math.prod(x or 0 for x in val.shape)
else:
val_size = val.size or 0
if val.ndim > 0 and val_size != min_rows_columns:
msg = (
"`val` needs to be a scalar or an array of the same size as the "
f"diagonal of `a` ({min_rows_columns}). Got {val.shape[0]}"
)
raise ValueError(msg)

if val.ndim == 0:
for i in range(min_rows_columns):
a[i, i] = val
else:
val = cast(Array, xp.reshape(val, (-1,)))
for i in range(min_rows_columns):
a[i, i] = val[i]
58 changes: 58 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
create_diagonal,
default_dtype,
expand_dims,
fill_diagonal,
isclose,
isin,
kron,
Expand All @@ -37,6 +38,9 @@
from array_api_extra._lib._utils._compat import (
device as get_device,
)
from array_api_extra._lib._utils._compat import (
is_jax_namespace,
)
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
from array_api_extra._lib._utils._typing import Array, Device
from array_api_extra.testing import lazy_xp_function
Expand Down Expand Up @@ -1529,3 +1533,57 @@ def test_kind(self, xp: ModuleType, library: Backend):
expected = xp.asarray([False, True, False, True])
res = isin(a, b, kind="sort")
xp_assert_equal(res, expected)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="item assignment not supported")
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="numpy read only arrays")
class TestFillDiagonal:
def test_simple(self, xp: ModuleType):
a = xp.zeros((3, 3), dtype=xp.int64)
val = 5
expected = xp.asarray([[5, 0, 0], [0, 5, 0], [0, 0, 5]], dtype=xp.int64)
if is_jax_namespace(xp):
a = cast(Array, fill_diagonal(a, val))
else:
_ = fill_diagonal(a, val)
xp_assert_equal(a, expected)

def test_val_1d(self, xp: ModuleType):
a = xp.zeros((3, 3), dtype=xp.int64)
val = xp.asarray([1, 2, 3], dtype=xp.int64)
expected = xp.asarray([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=xp.int64)
if is_jax_namespace(xp):
a = cast(Array, fill_diagonal(a, val))
else:
_ = fill_diagonal(a, val)
xp_assert_equal(a, expected)

@pytest.mark.parametrize(
("a_shape", "expected"),
[
((4, 4), [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]),
(
(5, 4),
[[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4], [0, 0, 0, 0]],
),
],
)
def test_val_2d(self, xp: ModuleType, a_shape: tuple[int, int], expected: Array):
a = xp.zeros(a_shape, dtype=xp.int64)
val = xp.asarray([[1, 2], [3, 4]], dtype=xp.int64)
expected = xp.asarray(expected, dtype=xp.int64)
if is_jax_namespace(xp):
a = cast(Array, fill_diagonal(a, val))
else:
_ = fill_diagonal(a, val)
xp_assert_equal(a, expected)

@pytest.mark.parametrize("val_scalar", [True, False])
def test_device(self, xp: ModuleType, device: Device, val_scalar: bool):
a = xp.zeros((3, 3), dtype=xp.int64, device=device)
val = 5 if val_scalar else xp.asarray([1, 2, 3], dtype=xp.int64, device=device)
if is_jax_namespace(xp):
a = cast(Array, fill_diagonal(a, val))
else:
_ = fill_diagonal(a, val)
assert get_device(a) == device