diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 06837eee..b459499c 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -87,14 +87,17 @@ def test_argmin(x, data): ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected) +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=())) +def test_nonzero_zerodim_error(x): + with pytest.raises(Exception): + xp.nonzero(x) + + @pytest.mark.data_dependent_shapes -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1, min_side=1))) def test_nonzero(x): out = xp.nonzero(x) - if x.ndim == 0: - assert len(out) == 1, f"{len(out)=}, but should be 1 for 0-dimensional arrays" - else: - assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}" + assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}" out_size = math.prod(out[0].shape) for i in range(len(out)): assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1" diff --git a/numpy-skips.txt b/numpy-skips.txt index 6ade0e7f..0c6f39ae 100644 --- a/numpy-skips.txt +++ b/numpy-skips.txt @@ -21,6 +21,9 @@ array_api_tests/test_array_object.py::test_getitem # missing copy arg array_api_tests/test_signatures.py::test_func_signature[reshape] +# does not (yet) raise an exception for zero-dimensional inputs to nonzero +array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error + # https://github.com/numpy/numpy/issues/21211 array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] # https://github.com/numpy/numpy/issues/21213