From 68fd52116c0c2f18165d5422a7e02ea5575d393d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Gauthier-Clerc?= Date: Sun, 7 Dec 2025 21:39:18 +0100 Subject: [PATCH 1/2] fix bug with delegate create_diagonals. --- src/array_api_extra/_delegation.py | 11 ++++++----- tests/test_funcs.py | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index afd7c8c5..bfefda80 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -226,13 +226,14 @@ def create_diagonal( if is_torch_namespace(xp): return xp.diag_embed(x, offset=offset, dim1=-2, dim2=-1) - if (is_dask_namespace(xp) or is_cupy_namespace(xp)) and x.ndim < 2: + if ( + is_dask_namespace(xp) + or is_cupy_namespace(xp) + or is_numpy_namespace(xp) + or is_jax_namespace(xp) + ) and (x.ndim < 2): return xp.diag(x, k=offset) - if (is_jax_namespace(xp) or is_numpy_namespace(xp)) and x.ndim < 3: - batch_dim, n = eager_shape(x)[:-1], eager_shape(x, -1)[0] + abs(offset) - return xp.reshape(xp.diag(x, k=offset), (*batch_dim, n, n)) - return _funcs.create_diagonal(x, offset=offset, xp=xp) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index bcb94037..2d089f1b 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -711,6 +711,7 @@ def test_0d_raises(self, xp: ModuleType): (0, 1), (1, 0), (0, 0), + (2, 3), (4, 2, 1), (1, 1, 7), (0, 0, 1), From 2880e20104df42c36f965067540971c88865ba05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Gauthier-Clerc?= Date: Sun, 7 Dec 2025 21:40:39 +0100 Subject: [PATCH 2/2] remove unused type. --- tests/test_funcs.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 2d089f1b..cc8203e6 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -55,8 +55,6 @@ lazy_xp_function(setdiff1d, jax_jit=False) lazy_xp_function(sinc) -NestedFloatList = list[float] | list["NestedFloatList"] - class TestApplyWhere: @staticmethod