Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
pad,
partition,
sinc,
union1d,
)
from ._lib._at import at
from ._lib._funcs import (
Expand Down Expand Up @@ -50,4 +51,5 @@
"partition",
"setdiff1d",
"sinc",
"union1d",
]
34 changes: 34 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,3 +895,37 @@ def isin(
return xp.isin(a, b, assume_unique=assume_unique, invert=invert)

return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp)


def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Find the union of two arrays.

Return the unique, sorted array of values that are in either of the two
input arrays.

Parameters
----------
a, b : Array
Input arrays. They are flattened internally if they are not already 1D.

xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.

Returns
-------
Array
Unique, sorted union of the input arrays.
"""
if xp is None:
xp = array_namespace(a, b)

if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_dask_namespace(xp)
or is_jax_namespace(xp)
):
return xp.union1d(a, b)

return _funcs.union1d(a, b, xp=xp)
8 changes: 8 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,3 +820,11 @@ def isin( # numpydoc ignore=PR01,RT01
_helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp),
original_a_shape,
)


def union1d(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
# numpydoc ignore=PR01,RT01
"""See docstring in `array_api_extra._delegation.py`."""
a = xp.reshape(a, (-1,))
b = xp.reshape(b, (-1,))
return xp.unique_values(xp.concat([a, b]))
35 changes: 35 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
partition,
setdiff1d,
sinc,
union1d,
)
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
Expand Down Expand Up @@ -1529,3 +1530,37 @@ def test_kind(self, xp: ModuleType, library: Backend):
expected = xp.asarray([False, True, False, True])
res = isin(a, b, kind="sort")
xp_assert_equal(res, expected)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="unique_values returns arrays")
@pytest.mark.skip_xp_backend(
Backend.ARRAY_API_STRICTEST,
reason="data_dependent_shapes flag for unique_values is disabled",
)
class TestUnion1d:
def test_simple(self, xp: ModuleType):
a = xp.asarray([-1, 1, 0])
b = xp.asarray([2, -2, 0])
expected = xp.asarray([-2, -1, 0, 1, 2])
res = union1d(a, b)
xp_assert_equal(res, expected)

def test_2d(self, xp: ModuleType):
a = xp.asarray([[-1, 1, 0], [1, 2, 0]])
b = xp.asarray([[1, 0, 1], [-2, -1, 0]])
expected = xp.asarray([-2, -1, 0, 1, 2])
res = union1d(a, b)
xp_assert_equal(res, expected)

def test_3d(self, xp: ModuleType):
a = xp.asarray([[[-1, 0], [1, 2]], [[-1, 0], [1, 2]]])
b = xp.asarray([[[0, 1], [-1, 2]], [[1, -2], [0, 2]]])
expected = xp.asarray([-2, -1, 0, 1, 2])
res = union1d(a, b)
xp_assert_equal(res, expected)

@pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device")
def test_device(self, xp: ModuleType, device: Device):
a = xp.asarray([-1, 1, 0], device=device)
b = xp.asarray([2, -2, 0], device=device)
assert get_device(union1d(a, b)) == device