diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 87ef6986..7c05552a 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -6,6 +6,7 @@ cov, expand_dims, isclose, + isin, nan_to_num, one_hot, pad, @@ -39,6 +40,7 @@ "default_dtype", "expand_dims", "isclose", + "isin", "kron", "lazy_apply", "nan_to_num", diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 7f467366..e9a943c0 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -836,3 +836,62 @@ def argpartition( # kth is not small compared to x.size return _funcs.argpartition(a, kth, axis=axis, xp=xp) + + +def isin( + a: Array, + b: Array, + /, + *, + assume_unique: bool = False, + invert: bool = False, + kind: str | None = None, + xp: ModuleType | None = None, +) -> Array: + """ + Determine whether each element in `a` is present in `b`. + + Return a boolean array of the same shape as `a` that is True for elements + that are in `b` and False otherwise. + + Parameters + ---------- + a : array + Input elements. + b : array + The elements against which to test each element of `a`. + assume_unique : bool, optional + If True, the input arrays are both assumed to be unique which can speed + up the calculation. Default: False. + invert : bool, optional + If True, the values in the returned array are inverted. Default: False. + kind : str | None, optional + The algorithm or method to use. This will not affect the final result, + but will affect the speed and memory use. + For NumPy the options are {None, "sort", "table"}. + For Jax the mapped parameter is instead `method` and the options are + {"compare_all", "binary_search", "sort", and "auto" (default)} + For CuPy, Dask, Torch and the default case this parameter is not present and + thus ignored. Default: None. + xp : array_namespace, optional + The standard-compatible namespace for `a` and `b`. Default: infer. + + Returns + ------- + array + An array having the same shape as that of `a` that is True for elements + that are in `b` and False otherwise. + """ + if xp is None: + xp = array_namespace(a, b) + + if is_numpy_namespace(xp): + return xp.isin(a, b, assume_unique=assume_unique, invert=invert, kind=kind) + if is_jax_namespace(xp): + if kind is None: + kind = "auto" + return xp.isin(a, b, assume_unique=assume_unique, invert=invert, method=kind) + if is_cupy_namespace(xp) or is_torch_namespace(xp) or is_dask_namespace(xp): + return xp.isin(a, b, assume_unique=assume_unique, invert=invert) + + return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index fe52305f..aed38f8b 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -801,3 +801,22 @@ def argpartition( # numpydoc ignore=PR01,RT01 ) -> Array: """See docstring in `array_api_extra._delegation.py`.""" return xp.argsort(x, axis=axis, stable=False) + + +def isin( # numpydoc ignore=PR01,RT01 + a: Array, + b: Array, + /, + *, + assume_unique: bool = False, + invert: bool = False, + xp: ModuleType, +) -> Array: + """See docstring in `array_api_extra._delegation.py`.""" + original_a_shape = a.shape + a = xp.reshape(a, (-1,)) + b = xp.reshape(b, (-1,)) + return xp.reshape( + _helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp), + original_a_shape, + ) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index a80dd3e0..92e794ed 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -22,6 +22,7 @@ default_dtype, expand_dims, isclose, + isin, kron, nan_to_num, nunique, @@ -888,7 +889,7 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool): b = xp.asarray([1e-9, 1e-4, xp.nan], device=device) res = isclose(a, b, equal_nan=equal_nan) assert get_device(res) == device - + def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device): a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device) b = 1 @@ -1476,3 +1477,55 @@ def test_nd(self, xp: ModuleType, ndim: int): @override def test_input_validation(self, xp: ModuleType): self._test_input_validation(xp) + + +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse") +class TestIsIn: + def test_simple(self, xp: ModuleType, library: Backend): + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): + pytest.xfail("NumPy <1.24 has no kind kwarg in isin") + + b = xp.asarray([1, 2, 3, 4]) + + # `a` with 1 dimension + a = xp.asarray([1, 3, 6, 10]) + expected = xp.asarray([True, True, False, False]) + res = isin(a, b) + xp_assert_equal(res, expected) + + # `a` with 2 dimensions + a = xp.asarray([[0, 2], [4, 6]]) + expected = xp.asarray([[False, True], [True, False]]) + res = isin(a, b) + xp_assert_equal(res, expected) + + def test_device(self, xp: ModuleType, device: Device, library: Backend): + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): + pytest.xfail("NumPy <1.24 has no kind kwarg in isin") + + a = xp.asarray([1, 3, 6], device=device) + b = xp.asarray([1, 2, 3], device=device) + assert get_device(isin(a, b)) == device + + def test_assume_unique_and_invert( + self, xp: ModuleType, device: Device, library: Backend + ): + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): + pytest.xfail("NumPy <1.24 has no kind kwarg in isin") + + a = xp.asarray([0, 3, 6, 10], device=device) + b = xp.asarray([1, 2, 3, 10], device=device) + expected = xp.asarray([True, False, True, False]) + res = isin(a, b, assume_unique=True, invert=True) + assert get_device(res) == device + xp_assert_equal(res, expected) + + def test_kind(self, xp: ModuleType, library: Backend): + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): + pytest.xfail("NumPy <1.24 has no kind kwarg in isin") + + a = xp.asarray([0, 3, 6, 10]) + b = xp.asarray([1, 2, 3, 10]) + expected = xp.asarray([False, True, False, True]) + res = isin(a, b, kind="sort") + xp_assert_equal(res, expected)