Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@
)
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
from array_api_extra._lib._funcs import searchsorted as _funcs_searchsorted
from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
from array_api_extra._lib._utils._compat import (
array_namespace,
is_jax_namespace,
is_torch_namespace,
)
from array_api_extra._lib._utils._compat import device as get_device
Expand Down Expand Up @@ -558,8 +557,6 @@ def test_complex(self, xp: ModuleType):
expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128)
xp_assert_close(actual, expect)

@pytest.mark.xfail_xp_backend(Backend.JAX_GPU, reason="jax#32296")
@pytest.mark.xfail_xp_backend(Backend.JAX, reason="jax#32296")
def test_empty(self, xp: ModuleType):
with warnings.catch_warnings(record=True):
warnings.simplefilter("always", RuntimeWarning)
Expand Down Expand Up @@ -1399,7 +1396,6 @@ def test_assume_unique(self, xp: ModuleType):
@pytest.mark.parametrize("shape2", [(), (1,), (1, 1)])
def test_shapes(
self,
request: pytest.FixtureRequest,
assume_unique: bool,
shape1: tuple[int, ...],
shape2: tuple[int, ...],
Expand All @@ -1408,26 +1404,18 @@ def test_shapes(
x1 = xp.zeros(shape1)
x2 = xp.zeros(shape2)

if is_jax_namespace(xp) and assume_unique and shape1 != (1,):
xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0")

actual = setdiff1d(x1, x2, assume_unique=assume_unique)
xp_assert_equal(actual, xp.empty((0,)))

@assume_unique
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
def test_python_scalar(
self, request: pytest.FixtureRequest, xp: ModuleType, assume_unique: bool
):
def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
# Test no dtype promotion to xp.asarray(x2); use x1.dtype
x1 = xp.asarray([3, 1, 2], dtype=xp.int16)
x2 = 3
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))

if is_jax_namespace(xp) and assume_unique:
xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0")

actual = setdiff1d(x2, x1, assume_unique=assume_unique)
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))

Expand Down
Loading