diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index afd7c8c5..bfefda80 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -226,13 +226,14 @@ def create_diagonal( if is_torch_namespace(xp): return xp.diag_embed(x, offset=offset, dim1=-2, dim2=-1) - if (is_dask_namespace(xp) or is_cupy_namespace(xp)) and x.ndim < 2: + if ( + is_dask_namespace(xp) + or is_cupy_namespace(xp) + or is_numpy_namespace(xp) + or is_jax_namespace(xp) + ) and (x.ndim < 2): return xp.diag(x, k=offset) - if (is_jax_namespace(xp) or is_numpy_namespace(xp)) and x.ndim < 3: - batch_dim, n = eager_shape(x)[:-1], eager_shape(x, -1)[0] + abs(offset) - return xp.reshape(xp.diag(x, k=offset), (*batch_dim, n, n)) - return _funcs.create_diagonal(x, offset=offset, xp=xp) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index bcb94037..cc8203e6 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -55,8 +55,6 @@ lazy_xp_function(setdiff1d, jax_jit=False) lazy_xp_function(sinc) -NestedFloatList = list[float] | list["NestedFloatList"] - class TestApplyWhere: @staticmethod @@ -711,6 +709,7 @@ def test_0d_raises(self, xp: ModuleType): (0, 1), (1, 0), (0, 0), + (2, 3), (4, 2, 1), (1, 1, 7), (0, 0, 1),