diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 7c05552a..07f7d552 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -11,6 +11,7 @@ one_hot, pad, partition, + setdiff1d, sinc, ) from ._lib._at import at @@ -21,7 +22,6 @@ default_dtype, kron, nunique, - setdiff1d, ) from ._lib._lazy import lazy_apply diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index e9a943c0..e09f1f4a 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -19,6 +19,7 @@ from ._lib._utils._typing import Array, DType __all__ = [ + "atleast_nd", "cov", "expand_dims", "isclose", @@ -29,6 +30,55 @@ ] +def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: + """ + Recursively expand the dimension of an array to at least `ndim`. + + Parameters + ---------- + x : array + Input array. + ndim : int + The minimum number of dimensions for the result. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + array + An array with ``res.ndim`` >= `ndim`. + If ``x.ndim`` >= `ndim`, `x` is returned. + If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes + until ``res.ndim`` equals `ndim`. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> x = xp.asarray([1]) + >>> xpx.atleast_nd(x, ndim=3, xp=xp) + Array([[[1]]], dtype=array_api_strict.int64) + + >>> x = xp.asarray([[[1, 2], + ... [3, 4]]]) + >>> xpx.atleast_nd(x, ndim=1, xp=xp) is x + True + """ + if xp is None: + xp = array_namespace(x) + + if 1 <= ndim <= 3 and ( + is_numpy_namespace(xp) + or is_jax_namespace(xp) + or is_dask_namespace(xp) + or is_cupy_namespace(xp) + or is_torch_namespace(xp) + ): + return getattr(xp, f"atleast_{ndim}d")(x) + + return _funcs.atleast_nd(x, ndim=ndim, xp=xp) + + def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: """ Estimate a covariance matrix. @@ -197,55 +247,6 @@ def expand_dims( return _funcs.expand_dims(a, axis=axis, xp=xp) -def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: - """ - Recursively expand the dimension of an array to at least `ndim`. - - Parameters - ---------- - x : array - Input array. - ndim : int - The minimum number of dimensions for the result. - xp : array_namespace, optional - The standard-compatible namespace for `x`. Default: infer. - - Returns - ------- - array - An array with ``res.ndim`` >= `ndim`. - If ``x.ndim`` >= `ndim`, `x` is returned. - If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes - until ``res.ndim`` equals `ndim`. - - Examples - -------- - >>> import array_api_strict as xp - >>> import array_api_extra as xpx - >>> x = xp.asarray([1]) - >>> xpx.atleast_nd(x, ndim=3, xp=xp) - Array([[[1]]], dtype=array_api_strict.int64) - - >>> x = xp.asarray([[[1, 2], - ... [3, 4]]]) - >>> xpx.atleast_nd(x, ndim=1, xp=xp) is x - True - """ - if xp is None: - xp = array_namespace(x) - - if 1 <= ndim <= 3 and ( - is_numpy_namespace(xp) - or is_jax_namespace(xp) - or is_dask_namespace(xp) - or is_cupy_namespace(xp) - or is_torch_namespace(xp) - ): - return getattr(xp, f"atleast_{ndim}d")(x) - - return _funcs.atleast_nd(x, ndim=ndim, xp=xp) - - def isclose( a: Array | complex, b: Array | complex, @@ -553,6 +554,59 @@ def pad( return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) +def setdiff1d( + x1: Array | complex, + x2: Array | complex, + /, + *, + assume_unique: bool = False, + xp: ModuleType | None = None, +) -> Array: + """ + Find the set difference of two arrays. + + Return the unique values in `x1` that are not in `x2`. + + Parameters + ---------- + x1 : array | int | float | complex | bool + Input array. + x2 : array + Input comparison array. + assume_unique : bool + If ``True``, the input arrays are both assumed to be unique, which + can speed up the calculation. Default is ``False``. + xp : array_namespace, optional + The standard-compatible namespace for `x1` and `x2`. Default: infer. + + Returns + ------- + array + 1D array of values in `x1` that are not in `x2`. The result + is sorted when `assume_unique` is ``False``, but otherwise only sorted + if the input is sorted. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + + >>> x1 = xp.asarray([1, 2, 3, 2, 4, 1]) + >>> x2 = xp.asarray([3, 4, 5, 6]) + >>> xpx.setdiff1d(x1, x2, xp=xp) + Array([1, 2], dtype=array_api_strict.int64) + """ + + if xp is None: + xp = array_namespace(x1, x2) + + if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp): + x1, x2 = asarrays(x1, x2, xp=xp) + return xp.setdiff1d(x1, x2, assume_unique=assume_unique) + + return _funcs.setdiff1d(x1, x2, assume_unique=assume_unique, xp=xp) + + def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: r""" Return the normalized sinc function. diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index aed38f8b..fb124a6e 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -715,44 +715,10 @@ def setdiff1d( /, *, assume_unique: bool = False, - xp: ModuleType | None = None, -) -> Array: - """ - Find the set difference of two arrays. - - Return the unique values in `x1` that are not in `x2`. - - Parameters - ---------- - x1 : array | int | float | complex | bool - Input array. - x2 : array - Input comparison array. - assume_unique : bool - If ``True``, the input arrays are both assumed to be unique, which - can speed up the calculation. Default is ``False``. - xp : array_namespace, optional - The standard-compatible namespace for `x1` and `x2`. Default: infer. - - Returns - ------- - array - 1D array of values in `x1` that are not in `x2`. The result - is sorted when `assume_unique` is ``False``, but otherwise only sorted - if the input is sorted. - - Examples - -------- - >>> import array_api_strict as xp - >>> import array_api_extra as xpx + xp: ModuleType, +) -> Array: # numpydoc ignore=PR01,RT01 + """See docstring in `array_api_extra._delegation.py`.""" - >>> x1 = xp.asarray([1, 2, 3, 2, 4, 1]) - >>> x2 = xp.asarray([3, 4, 5, 6]) - >>> xpx.setdiff1d(x1, x2, xp=xp) - Array([1, 2], dtype=array_api_strict.int64) - """ - if xp is None: - xp = array_namespace(x1, x2) # https://github.com/microsoft/pyright/issues/10103 x1_, x2_ = asarrays(x1, x2, xp=xp) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 92e794ed..a120e559 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -33,10 +33,9 @@ sinc, ) from array_api_extra._lib._backends import NUMPY_VERSION, Backend -from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal -from array_api_extra._lib._utils._compat import ( - device as get_device, -) +from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal +from array_api_extra._lib._utils._compat import device as get_device +from array_api_extra._lib._utils._compat import is_jax_namespace from array_api_extra._lib._utils._helpers import eager_shape, ndindex from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function @@ -1264,6 +1263,7 @@ def test_assume_unique(self, xp: ModuleType): @pytest.mark.parametrize("shape2", [(), (1,), (1, 1)]) def test_shapes( self, + request: pytest.FixtureRequest, assume_unique: bool, shape1: tuple[int, ...], shape2: tuple[int, ...], @@ -1271,18 +1271,27 @@ def test_shapes( ): x1 = xp.zeros(shape1) x2 = xp.zeros(shape2) + + if is_jax_namespace(xp) and assume_unique and shape1 != (1,): + xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0") + actual = setdiff1d(x1, x2, assume_unique=assume_unique) xp_assert_equal(actual, xp.empty((0,))) @assume_unique @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp") - def test_python_scalar(self, xp: ModuleType, assume_unique: bool): + def test_python_scalar( + self, request: pytest.FixtureRequest, xp: ModuleType, assume_unique: bool + ): # Test no dtype promotion to xp.asarray(x2); use x1.dtype x1 = xp.asarray([3, 1, 2], dtype=xp.int16) x2 = 3 actual = setdiff1d(x1, x2, assume_unique=assume_unique) xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16)) + if is_jax_namespace(xp) and assume_unique: + xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0") + actual = setdiff1d(x2, x1, assume_unique=assume_unique) xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))