From 7e6f10b7fa073c7a8a91b76809a0cf266651de20 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 23 Oct 2025 18:06:36 +0500 Subject: [PATCH 1/3] ENH Add union1d --- src/array_api_extra/__init__.py | 2 ++ src/array_api_extra/_delegation.py | 34 ++++++++++++++++++++++++++++++ src/array_api_extra/_lib/_funcs.py | 8 +++++++ tests/test_funcs.py | 34 ++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+) diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 7c05552a..935a6e9b 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -12,6 +12,7 @@ pad, partition, sinc, + union1d, ) from ._lib._at import at from ._lib._funcs import ( @@ -50,4 +51,5 @@ "partition", "setdiff1d", "sinc", + "union1d", ] diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index e9a943c0..f78b7980 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -895,3 +895,37 @@ def isin( return xp.isin(a, b, assume_unique=assume_unique, invert=invert) return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp) + + +def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: + """ + Find the union of two arrays. + + Return the unique, sorted array of values that are in either of the two + input arrays. + + Parameters + ---------- + a, b : Array + Input arrays. They are flattened internally if they are not already 1D. + + xp : array_namespace, optional + The standard-compatible namespace for `a` and `b`. Default: infer. + + Returns + ------- + Array + Unique, sorted union of the input arrays. + """ + if xp is None: + xp = array_namespace(a, b) + + if ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_dask_namespace(xp) + or is_jax_namespace(xp) + ): + return xp.union1d(a, b) + + return _funcs.union1d(a, b, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index aed38f8b..34df9bc7 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -820,3 +820,11 @@ def isin( # numpydoc ignore=PR01,RT01 _helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp), original_a_shape, ) + + +def union1d(a: Array, b: Array, /, *, xp: ModuleType) -> Array: + # numpydoc ignore=PR01,RT01 + """See docstring in `array_api_extra._delegation.py`.""" + a = xp.reshape(a, (-1,)) + b = xp.reshape(b, (-1,)) + return xp.unique_values(xp.concat([a, b])) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 92e794ed..46e83bc3 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -31,6 +31,7 @@ partition, setdiff1d, sinc, + union1d, ) from array_api_extra._lib._backends import NUMPY_VERSION, Backend from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal @@ -1529,3 +1530,36 @@ def test_kind(self, xp: ModuleType, library: Backend): expected = xp.asarray([False, True, False, True]) res = isin(a, b, kind="sort") xp_assert_equal(res, expected) + + +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="unique_values returns arrays") +@pytest.mark.skip_xp_backend( + Backend.ARRAY_API_STRICTEST, + reason="data_dependent_shapes flag for unique_values is disabled", +) +class TestUnion1d: + def test_simple(self, xp: ModuleType): + a = xp.asarray([-1, 1, 0]) + b = xp.asarray([2, -2, 0]) + expected = xp.asarray([-2, -1, 0, 1, 2]) + res = union1d(a, b) + xp_assert_equal(res, expected) + + def test_2d(self, xp: ModuleType): + a = xp.asarray([[-1, 1, 0], [1, 2, 0]]) + b = xp.asarray([[1, 0, 1], [-2, -1, 0]]) + expected = xp.asarray([-2, -1, 0, 1, 2]) + res = union1d(a, b) + xp_assert_equal(res, expected) + + def test_3d(self, xp: ModuleType): + a = xp.asarray([[[-1, 0], [1, 2]], [[-1, 0], [1, 2]]]) + b = xp.asarray([[[0, 1], [-1, 2]], [[1, -2], [0, 2]]]) + expected = xp.asarray([-2, -1, 0, 1, 2]) + res = union1d(a, b) + xp_assert_equal(res, expected) + + def test_device(self, xp: ModuleType, device: Device): + a = xp.asarray([-1, 1, 0]) + b = xp.asarray([2, -2, 0]) + assert get_device(union1d(a, b)) == device From 2a17b8bffbd7e3be99d282aa1bb49b2e3f6bd08f Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 23 Oct 2025 18:12:10 +0500 Subject: [PATCH 2/3] Fix device test --- tests/test_funcs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 46e83bc3..15461b5e 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1532,11 +1532,11 @@ def test_kind(self, xp: ModuleType, library: Backend): xp_assert_equal(res, expected) -@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="unique_values returns arrays") @pytest.mark.skip_xp_backend( Backend.ARRAY_API_STRICTEST, reason="data_dependent_shapes flag for unique_values is disabled", ) +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="unique_values returns arrays") class TestUnion1d: def test_simple(self, xp: ModuleType): a = xp.asarray([-1, 1, 0]) @@ -1560,6 +1560,6 @@ def test_3d(self, xp: ModuleType): xp_assert_equal(res, expected) def test_device(self, xp: ModuleType, device: Device): - a = xp.asarray([-1, 1, 0]) - b = xp.asarray([2, -2, 0]) + a = xp.asarray([-1, 1, 0], device=device) + b = xp.asarray([2, -2, 0], device=device) assert get_device(union1d(a, b)) == device From 1c5da0c08921e28ebcc2f3960450635c556b6b0b Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 23 Oct 2025 18:25:15 +0500 Subject: [PATCH 3/3] Fix tests --- tests/test_funcs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 15461b5e..624cd55e 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1532,11 +1532,11 @@ def test_kind(self, xp: ModuleType, library: Backend): xp_assert_equal(res, expected) +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="unique_values returns arrays") @pytest.mark.skip_xp_backend( Backend.ARRAY_API_STRICTEST, reason="data_dependent_shapes flag for unique_values is disabled", ) -@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="unique_values returns arrays") class TestUnion1d: def test_simple(self, xp: ModuleType): a = xp.asarray([-1, 1, 0]) @@ -1559,6 +1559,7 @@ def test_3d(self, xp: ModuleType): res = union1d(a, b) xp_assert_equal(res, expected) + @pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device") def test_device(self, xp: ModuleType, device: Device): a = xp.asarray([-1, 1, 0], device=device) b = xp.asarray([2, -2, 0], device=device)