diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 5e28e85b..6a11e059 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -39,10 +39,9 @@ ) from array_api_extra._lib._backends import NUMPY_VERSION, Backend from array_api_extra._lib._funcs import searchsorted as _funcs_searchsorted -from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal +from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal from array_api_extra._lib._utils._compat import ( array_namespace, - is_jax_namespace, is_torch_namespace, ) from array_api_extra._lib._utils._compat import device as get_device @@ -558,8 +557,6 @@ def test_complex(self, xp: ModuleType): expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128) xp_assert_close(actual, expect) - @pytest.mark.xfail_xp_backend(Backend.JAX_GPU, reason="jax#32296") - @pytest.mark.xfail_xp_backend(Backend.JAX, reason="jax#32296") def test_empty(self, xp: ModuleType): with warnings.catch_warnings(record=True): warnings.simplefilter("always", RuntimeWarning) @@ -1399,7 +1396,6 @@ def test_assume_unique(self, xp: ModuleType): @pytest.mark.parametrize("shape2", [(), (1,), (1, 1)]) def test_shapes( self, - request: pytest.FixtureRequest, assume_unique: bool, shape1: tuple[int, ...], shape2: tuple[int, ...], @@ -1408,26 +1404,18 @@ def test_shapes( x1 = xp.zeros(shape1) x2 = xp.zeros(shape2) - if is_jax_namespace(xp) and assume_unique and shape1 != (1,): - xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0") - actual = setdiff1d(x1, x2, assume_unique=assume_unique) xp_assert_equal(actual, xp.empty((0,))) @assume_unique @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp") - def test_python_scalar( - self, request: pytest.FixtureRequest, xp: ModuleType, assume_unique: bool - ): + def test_python_scalar(self, xp: ModuleType, assume_unique: bool): # Test no dtype promotion to xp.asarray(x2); use x1.dtype x1 = xp.asarray([3, 1, 2], dtype=xp.int16) x2 = 3 actual = setdiff1d(x1, x2, assume_unique=assume_unique) xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16)) - if is_jax_namespace(xp) and assume_unique: - xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0") - actual = setdiff1d(x2, x1, assume_unique=assume_unique) xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))