From e5c5ac27561cbed3d0a188fc1ea2802b127415d1 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Wed, 15 Oct 2025 12:59:49 +0500 Subject: [PATCH 1/6] ENH: Add support for "isin" --- src/array_api_extra/__init__.py | 2 + src/array_api_extra/_delegation.py | 59 ++++++++++++++++++++++++++++++ src/array_api_extra/_lib/_funcs.py | 22 +++++++++++ tests/test_funcs.py | 13 ++++++- 4 files changed, 95 insertions(+), 1 deletion(-) diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 87ef6986..7c05552a 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -6,6 +6,7 @@ cov, expand_dims, isclose, + isin, nan_to_num, one_hot, pad, @@ -39,6 +40,7 @@ "default_dtype", "expand_dims", "isclose", + "isin", "kron", "lazy_apply", "nan_to_num", diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 7f467366..e9168b12 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -836,3 +836,62 @@ def argpartition( # kth is not small compared to x.size return _funcs.argpartition(a, kth, axis=axis, xp=xp) + + +def isin( + a: Array, + b: Array, + /, + *, + assume_unique: bool = False, + invert: bool = False, + kind: str | None = None, + xp: ModuleType | None = None, +) -> Array: + """ + Determine whether each element in `a` is present in `b`. + + Return a boolean array of the same shape as `a` that is True for elements + that are in `b` and False otherwise. + + Parameters + ---------- + a : array_like + Input elements. + b : array_like + The elements against which to test each element of `a`. + assume_unique : bool, optional + If True, the input arrays are both assumed to be unique which can speed + up the calculation. Default: False. + invert : bool, optional + If True, the values in the returned array are inverted. Default: False. + kind : str | None, optional + The algorithm or method to use. This will not affect the final result, + but will affect the speed and memory use. + For Numpy the options are {None, "sort", "table"}. + For Jax the mapped parameter is instead `method` and the options are + {"compare_all", "binary_search", "sort", and "auto" (default)} + For Cupy, Dask, Torch and the default case this parameter is not present and + thus ignored. Default: None. + xp : array_namespace, optional + The standard-compatible namespace for `a` and `b`. Default: infer. + + Returns + ------- + array + An array having the same shape as that of `a` that is True for elements + that are in `b` and False otherwise. + """ + if xp is None: + xp = array_namespace(a, b) + + if is_numpy_namespace(xp): + return xp.isin(a, b, assume_unique=assume_unique, invert=invert, kind=kind) + if is_jax_namespace(xp): + if kind is None: + kind = "auto" + return xp.isin(a, b, assume_unique=assume_unique, invert=invert, method=kind) + if is_cupy_namespace(xp) or is_torch_namespace(xp) or is_dask_namespace(xp): + return xp.isin(a, b, assume_unique=assume_unique, invert=invert) + + return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index fe52305f..11720111 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -801,3 +801,25 @@ def argpartition( # numpydoc ignore=PR01,RT01 ) -> Array: """See docstring in `array_api_extra._delegation.py`.""" return xp.argsort(x, axis=axis, stable=False) + + +def isin( # numpydoc ignore=PR01,RT01 + a: Array, + b: Array, + /, + *, + assume_unique: bool = False, + invert: bool = False, + xp: ModuleType | None = None, +) -> Array: + """See docstring in `array_api_extra._delegation.py`.""" + if xp is None: + xp = array_namespace(a, b) + + original_a_shape = a.shape + a = xp.reshape(a, (-1,)) + b = xp.reshape(b, (-1,)) + return xp.reshape( + _helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp), + original_a_shape, + ) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index a80dd3e0..6ac508fa 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -22,6 +22,7 @@ default_dtype, expand_dims, isclose, + isin, kron, nan_to_num, nunique, @@ -888,7 +889,7 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool): b = xp.asarray([1e-9, 1e-4, xp.nan], device=device) res = isclose(a, b, equal_nan=equal_nan) assert get_device(res) == device - + def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device): a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device) b = 1 @@ -1476,3 +1477,13 @@ def test_nd(self, xp: ModuleType, ndim: int): @override def test_input_validation(self, xp: ModuleType): self._test_input_validation(xp) + + +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse") +class TestIsIn: + def test_simple(self, xp: ModuleType): + a = xp.asarray([[0, 2], [4, 6]]) + b = xp.asarray([1, 2, 3, 4]) + expected = xp.asarray([[False, True], [True, False]]) + res = isin(a, b) + xp_assert_equal(res, expected) From f95f6bd3dc39362c3bfa9ba582b7eba368e392b7 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 16 Oct 2025 11:43:49 +0500 Subject: [PATCH 2/6] PR suggestions and more tests --- tests/test_funcs.py | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 6ac508fa..04bc5046 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1481,9 +1481,40 @@ def test_input_validation(self, xp: ModuleType): @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse") class TestIsIn: - def test_simple(self, xp: ModuleType): - a = xp.asarray([[0, 2], [4, 6]]) + def test_simple(self, xp: ModuleType, library: Backend): + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): + pytest.xfail("NumPy <1.24 has no kind kwarg in isin") + b = xp.asarray([1, 2, 3, 4]) + + # `a` with 1 dimension + a = xp.asarray([1, 3, 6, 10]) + expected = xp.asarray([True, True, False, False]) + res = isin(a, b) + xp_assert_equal(res, expected) + + # `a` with 2 dimensions + a = xp.asarray([[0, 2], [4, 6]]) expected = xp.asarray([[False, True], [True, False]]) res = isin(a, b) xp_assert_equal(res, expected) + + def test_device(self, xp: ModuleType, device: Device): + a = xp.asarray([1, 3, 6], device=device) + b = xp.asarray([1, 2, 3], device=device) + assert get_device(isin(a, b)) == device + + def test_assume_unique_and_invert(self, xp: ModuleType, device: Device): + a = xp.asarray([0, 3, 6, 10], device=device) + b = xp.asarray([1, 2, 3, 10], device=device) + expected = xp.asarray([True, False, True, False]) + res = isin(a, b, assume_unique=True, invert=True) + assert get_device(res) == device + xp_assert_equal(res, expected) + + def test_kind(self, xp: ModuleType): + a = xp.asarray([0, 3, 6, 10]) + b = xp.asarray([1, 2, 3, 10]) + expected = xp.asarray([False, True, False, True]) + res = isin(a, b, kind="sort") + xp_assert_equal(res, expected) From 63f9dfde1b0ca812f2e1d53b143c1b77cce8b6be Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 16 Oct 2025 11:48:05 +0500 Subject: [PATCH 3/6] Skip numpy < 1.24 in other tests as well --- tests/test_funcs.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 04bc5046..92e794ed 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1486,7 +1486,7 @@ def test_simple(self, xp: ModuleType, library: Backend): pytest.xfail("NumPy <1.24 has no kind kwarg in isin") b = xp.asarray([1, 2, 3, 4]) - + # `a` with 1 dimension a = xp.asarray([1, 3, 6, 10]) expected = xp.asarray([True, True, False, False]) @@ -1499,20 +1499,31 @@ def test_simple(self, xp: ModuleType, library: Backend): res = isin(a, b) xp_assert_equal(res, expected) - def test_device(self, xp: ModuleType, device: Device): + def test_device(self, xp: ModuleType, device: Device, library: Backend): + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): + pytest.xfail("NumPy <1.24 has no kind kwarg in isin") + a = xp.asarray([1, 3, 6], device=device) b = xp.asarray([1, 2, 3], device=device) assert get_device(isin(a, b)) == device - def test_assume_unique_and_invert(self, xp: ModuleType, device: Device): + def test_assume_unique_and_invert( + self, xp: ModuleType, device: Device, library: Backend + ): + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): + pytest.xfail("NumPy <1.24 has no kind kwarg in isin") + a = xp.asarray([0, 3, 6, 10], device=device) b = xp.asarray([1, 2, 3, 10], device=device) expected = xp.asarray([True, False, True, False]) res = isin(a, b, assume_unique=True, invert=True) assert get_device(res) == device xp_assert_equal(res, expected) - - def test_kind(self, xp: ModuleType): + + def test_kind(self, xp: ModuleType, library: Backend): + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): + pytest.xfail("NumPy <1.24 has no kind kwarg in isin") + a = xp.asarray([0, 3, 6, 10]) b = xp.asarray([1, 2, 3, 10]) expected = xp.asarray([False, True, False, True]) From eb8b30cc3a31165894aed9f821fcf8947c76d994 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 16 Oct 2025 11:50:42 +0500 Subject: [PATCH 4/6] Remove unneeded condition --- src/array_api_extra/_lib/_funcs.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 11720111..aed38f8b 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -803,19 +803,16 @@ def argpartition( # numpydoc ignore=PR01,RT01 return xp.argsort(x, axis=axis, stable=False) -def isin( # numpydoc ignore=PR01,RT01 +def isin( # numpydoc ignore=PR01,RT01 a: Array, b: Array, /, *, assume_unique: bool = False, invert: bool = False, - xp: ModuleType | None = None, + xp: ModuleType, ) -> Array: """See docstring in `array_api_extra._delegation.py`.""" - if xp is None: - xp = array_namespace(a, b) - original_a_shape = a.shape a = xp.reshape(a, (-1,)) b = xp.reshape(b, (-1,)) From 3fff8812c4dec791a76b0d7ce913e0f1c4fc6199 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 16 Oct 2025 11:51:23 +0500 Subject: [PATCH 5/6] Update --- src/array_api_extra/_delegation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index e9168b12..65d5e9be 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -856,9 +856,9 @@ def isin( Parameters ---------- - a : array_like + a : array Input elements. - b : array_like + b : array The elements against which to test each element of `a`. assume_unique : bool, optional If True, the input arrays are both assumed to be unique which can speed @@ -868,7 +868,7 @@ def isin( kind : str | None, optional The algorithm or method to use. This will not affect the final result, but will affect the speed and memory use. - For Numpy the options are {None, "sort", "table"}. + For NumPy the options are {None, "sort", "table"}. For Jax the mapped parameter is instead `method` and the options are {"compare_all", "binary_search", "sort", and "auto" (default)} For Cupy, Dask, Torch and the default case this parameter is not present and From 80922b763e8ec798763c69162f03d61946ed2645 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 16 Oct 2025 11:53:48 +0500 Subject: [PATCH 6/6] Minor docstring update --- src/array_api_extra/_delegation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 65d5e9be..e9a943c0 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -871,7 +871,7 @@ def isin( For NumPy the options are {None, "sort", "table"}. For Jax the mapped parameter is instead `method` and the options are {"compare_all", "binary_search", "sort", and "auto" (default)} - For Cupy, Dask, Torch and the default case this parameter is not present and + For CuPy, Dask, Torch and the default case this parameter is not present and thus ignored. Default: None. xp : array_namespace, optional The standard-compatible namespace for `a` and `b`. Default: infer.