diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 7c05552a..935a6e9b 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -12,6 +12,7 @@ pad, partition, sinc, + union1d, ) from ._lib._at import at from ._lib._funcs import ( @@ -50,4 +51,5 @@ "partition", "setdiff1d", "sinc", + "union1d", ] diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index e9a943c0..f78b7980 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -895,3 +895,37 @@ def isin( return xp.isin(a, b, assume_unique=assume_unique, invert=invert) return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp) + + +def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: + """ + Find the union of two arrays. + + Return the unique, sorted array of values that are in either of the two + input arrays. + + Parameters + ---------- + a, b : Array + Input arrays. They are flattened internally if they are not already 1D. + + xp : array_namespace, optional + The standard-compatible namespace for `a` and `b`. Default: infer. + + Returns + ------- + Array + Unique, sorted union of the input arrays. + """ + if xp is None: + xp = array_namespace(a, b) + + if ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_dask_namespace(xp) + or is_jax_namespace(xp) + ): + return xp.union1d(a, b) + + return _funcs.union1d(a, b, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index aed38f8b..34df9bc7 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -820,3 +820,11 @@ def isin( # numpydoc ignore=PR01,RT01 _helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp), original_a_shape, ) + + +def union1d(a: Array, b: Array, /, *, xp: ModuleType) -> Array: + # numpydoc ignore=PR01,RT01 + """See docstring in `array_api_extra._delegation.py`.""" + a = xp.reshape(a, (-1,)) + b = xp.reshape(b, (-1,)) + return xp.unique_values(xp.concat([a, b])) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 92e794ed..624cd55e 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -31,6 +31,7 @@ partition, setdiff1d, sinc, + union1d, ) from array_api_extra._lib._backends import NUMPY_VERSION, Backend from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal @@ -1529,3 +1530,37 @@ def test_kind(self, xp: ModuleType, library: Backend): expected = xp.asarray([False, True, False, True]) res = isin(a, b, kind="sort") xp_assert_equal(res, expected) + + +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="unique_values returns arrays") +@pytest.mark.skip_xp_backend( + Backend.ARRAY_API_STRICTEST, + reason="data_dependent_shapes flag for unique_values is disabled", +) +class TestUnion1d: + def test_simple(self, xp: ModuleType): + a = xp.asarray([-1, 1, 0]) + b = xp.asarray([2, -2, 0]) + expected = xp.asarray([-2, -1, 0, 1, 2]) + res = union1d(a, b) + xp_assert_equal(res, expected) + + def test_2d(self, xp: ModuleType): + a = xp.asarray([[-1, 1, 0], [1, 2, 0]]) + b = xp.asarray([[1, 0, 1], [-2, -1, 0]]) + expected = xp.asarray([-2, -1, 0, 1, 2]) + res = union1d(a, b) + xp_assert_equal(res, expected) + + def test_3d(self, xp: ModuleType): + a = xp.asarray([[[-1, 0], [1, 2]], [[-1, 0], [1, 2]]]) + b = xp.asarray([[[0, 1], [-1, 2]], [[1, -2], [0, 2]]]) + expected = xp.asarray([-2, -1, 0, 1, 2]) + res = union1d(a, b) + xp_assert_equal(res, expected) + + @pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device") + def test_device(self, xp: ModuleType, device: Device): + a = xp.asarray([-1, 1, 0], device=device) + b = xp.asarray([2, -2, 0], device=device) + assert get_device(union1d(a, b)) == device