From 5d17a66c8332ef53c6b42fa833da7fd083fe4f33 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Mon, 27 Oct 2025 18:41:21 +0500 Subject: [PATCH] ENH: new function fill_diagonal --- src/array_api_extra/__init__.py | 2 ++ src/array_api_extra/_delegation.py | 53 +++++++++++++++++++++++++++ src/array_api_extra/_lib/_funcs.py | 40 ++++++++++++++++++++- tests/test_funcs.py | 58 ++++++++++++++++++++++++++++++ 4 files changed, 152 insertions(+), 1 deletion(-) diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 7c05552a..8df5793e 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -5,6 +5,7 @@ atleast_nd, cov, expand_dims, + fill_diagonal, isclose, isin, nan_to_num, @@ -39,6 +40,7 @@ "create_diagonal", "default_dtype", "expand_dims", + "fill_diagonal", "isclose", "isin", "kron", diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index e9a943c0..dec4d5cc 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -895,3 +895,56 @@ def isin( return xp.isin(a, b, assume_unique=assume_unique, invert=invert) return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp) + + +def fill_diagonal( + a: Array, + val: Array | int | float, + *, + wrap: bool = False, + xp: ModuleType | None = None, +) -> None | Array: + """ + Fill the main diagonal of the given array `a` of any dimensionality >= 2. + + For an array `a` with ``a.ndim >= 2``, the diagonal is the list of + values ``a[i, ..., i]`` with indices ``i`` all identical. This function + modifies the input array in-place without returning a value. However + specifically for JAX, a copy of the array `a` with the diagonal elements + overwritten is returned. This is because it is not possible to modify JAX's + immutable arrays in-place. + + Parameters + ---------- + a : Array + Input array whose diagonal is to be filled. It should be at least 2-D. + + val : Array | int | float + Value(s) to write on the diagonal. If `val` is a scalar, the value is + written along the diagonal. If `val` is an Array, the flattened `val` + is written along the diagonal. + + wrap : bool, optional + Only applicable for NumPy and Cupy. For tall matrices the + diagonal is "wrapped" after N columns. Default: False. + + xp : array_namespace, optional + The standard-compatible namespace for `a` and `val`. Default: infer. + + Returns + ------- + Array | None + For JAX a copy of the original array `a` is returned. For all other + cases the array `a` is modified in-place so None is returned. + """ + if xp is None: + xp = array_namespace(a, val) + + if is_jax_namespace(xp): + return xp.fill_diagonal(a, val, inplace=False) + + if is_numpy_namespace(xp) or is_cupy_namespace(xp): + xp.fill_diagonal(a, val, wrap=wrap) + + _funcs.fill_diagonal(a, val, xp=xp) + return None diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index aed38f8b..ad6a6724 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -8,7 +8,12 @@ from ._at import at from ._utils import _compat, _helpers -from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array +from ._utils._compat import ( + array_namespace, + is_dask_namespace, + is_jax_array, + is_torch_namespace, +) from ._utils._helpers import ( asarrays, capabilities, @@ -820,3 +825,36 @@ def isin( # numpydoc ignore=PR01,RT01 _helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp), original_a_shape, ) + + +def fill_diagonal( # numpydoc ignore=PR01,RT01 + a: Array, + val: Array | int | float, + *, + xp: ModuleType, +) -> None: + """See docstring in `array_api_extra._delegation.py`.""" + if a.ndim < 2: + msg = f"array `a` must be at least 2-d. Got array with shape {tuple(a.shape)}" + raise ValueError(msg) + + a, val = asarrays(a, val, xp=xp) + min_rows_columns = min(x or 0 for x in a.shape) + if is_torch_namespace(xp): + val_size = math.prod(x or 0 for x in val.shape) + else: + val_size = val.size or 0 + if val.ndim > 0 and val_size != min_rows_columns: + msg = ( + "`val` needs to be a scalar or an array of the same size as the " + f"diagonal of `a` ({min_rows_columns}). Got {val.shape[0]}" + ) + raise ValueError(msg) + + if val.ndim == 0: + for i in range(min_rows_columns): + a[i, i] = val + else: + val = cast(Array, xp.reshape(val, (-1,))) + for i in range(min_rows_columns): + a[i, i] = val[i] diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 92e794ed..309d9b58 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -21,6 +21,7 @@ create_diagonal, default_dtype, expand_dims, + fill_diagonal, isclose, isin, kron, @@ -37,6 +38,9 @@ from array_api_extra._lib._utils._compat import ( device as get_device, ) +from array_api_extra._lib._utils._compat import ( + is_jax_namespace, +) from array_api_extra._lib._utils._helpers import eager_shape, ndindex from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function @@ -1529,3 +1533,57 @@ def test_kind(self, xp: ModuleType, library: Backend): expected = xp.asarray([False, True, False, True]) res = isin(a, b, kind="sort") xp_assert_equal(res, expected) + + +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="item assignment not supported") +@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="numpy read only arrays") +class TestFillDiagonal: + def test_simple(self, xp: ModuleType): + a = xp.zeros((3, 3), dtype=xp.int64) + val = 5 + expected = xp.asarray([[5, 0, 0], [0, 5, 0], [0, 0, 5]], dtype=xp.int64) + if is_jax_namespace(xp): + a = cast(Array, fill_diagonal(a, val)) + else: + _ = fill_diagonal(a, val) + xp_assert_equal(a, expected) + + def test_val_1d(self, xp: ModuleType): + a = xp.zeros((3, 3), dtype=xp.int64) + val = xp.asarray([1, 2, 3], dtype=xp.int64) + expected = xp.asarray([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=xp.int64) + if is_jax_namespace(xp): + a = cast(Array, fill_diagonal(a, val)) + else: + _ = fill_diagonal(a, val) + xp_assert_equal(a, expected) + + @pytest.mark.parametrize( + ("a_shape", "expected"), + [ + ((4, 4), [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]), + ( + (5, 4), + [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4], [0, 0, 0, 0]], + ), + ], + ) + def test_val_2d(self, xp: ModuleType, a_shape: tuple[int, int], expected: Array): + a = xp.zeros(a_shape, dtype=xp.int64) + val = xp.asarray([[1, 2], [3, 4]], dtype=xp.int64) + expected = xp.asarray(expected, dtype=xp.int64) + if is_jax_namespace(xp): + a = cast(Array, fill_diagonal(a, val)) + else: + _ = fill_diagonal(a, val) + xp_assert_equal(a, expected) + + @pytest.mark.parametrize("val_scalar", [True, False]) + def test_device(self, xp: ModuleType, device: Device, val_scalar: bool): + a = xp.zeros((3, 3), dtype=xp.int64, device=device) + val = 5 if val_scalar else xp.asarray([1, 2, 3], dtype=xp.int64, device=device) + if is_jax_namespace(xp): + a = cast(Array, fill_diagonal(a, val)) + else: + _ = fill_diagonal(a, val) + assert get_device(a) == device