Skip to content

Commit

Permalink
Implement np.intersect1d (#3726)
Browse files Browse the repository at this point in the history
* Implement np.intersect1d

* Add jitable helper to function

* Fix argsort failing tests

* Fix linter errors
  • Loading branch information
aldragan0 committed Jul 13, 2020
1 parent 9da9156 commit 0d81e98
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Expand Up @@ -153,6 +153,7 @@ Not every function in NumPy is implemented; contributions are welcome!
iscomplex
isfinite
isin
intersect1d
isinf
isnan
isneginf
Expand Down
2 changes: 1 addition & 1 deletion jax/numpy/__init__.py
Expand Up @@ -37,7 +37,7 @@
fmod, frexp, full, full_like, function, gcd, geomspace, gradient, greater,
greater_equal, hamming, hanning, heaviside, histogram, histogram_bin_edges,
hsplit, hstack, hypot, identity, iinfo, imag,
indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer,
indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer, intersect1d,
isclose, iscomplex, iscomplexobj, isfinite, isin, isinf, isnan, isneginf,
isposinf, isreal, isrealobj, isscalar, issubdtype, issubsctype, iterable,
ix_, kaiser, kron, lcm, ldexp, left_shift, less, less_equal, linspace,
Expand Down
51 changes: 51 additions & 0 deletions jax/numpy/lax_numpy.py
Expand Up @@ -1243,6 +1243,57 @@ def in1d(ar1, ar2, assume_unique=False, invert=False):
else:
return (ar1[:, None] == ar2).any(-1)

@partial(jit, static_argnums=2)
def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
"""
Helper function for intersect1d which is jit-able
"""
ar = concatenate((ar1, ar2))

if return_indices:
indices = argsort(ar)
aux = ar[indices]
else:
aux = sort(ar)

mask = aux[1:] == aux[:-1]
if return_indices:
return aux, mask, indices
else:
return aux, mask

@_wraps(np.intersect1d)
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):

if not assume_unique:
if return_indices:
ar1, ind1 = unique(ar1, return_index=True)
ar2, ind2 = unique(ar2, return_index=True)
else:
ar1 = unique(ar1)
ar2 = unique(ar2)
else:
ar1 = ravel(ar1)
ar2 = ravel(ar2)

if return_indices:
aux, mask, aux_sort_indices = _intersect1d_sorted_mask(ar1, ar2, return_indices)
else:
aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices)

int1d = aux[:-1][mask]

if return_indices:
ar1_indices = aux_sort_indices[:-1][mask]
ar2_indices = aux_sort_indices[1:][mask] - ar1.size
if not assume_unique:
ar1_indices = ind1[ar1_indices]
ar2_indices = ind2[ar2_indices]

return int1d, ar1_indices, ar2_indices
else:
return int1d


@_wraps(np.isin, lax_description="""
In the JAX version, the `assume_unique` argument is not referenced.
Expand Down
22 changes: 22 additions & 0 deletions tests/lax_numpy_test.py
Expand Up @@ -957,6 +957,28 @@ def testIn1d(self, element_shape, test_shape, dtype, invert):
self._CompileAndCheck(jnp_fun, args_maker)


@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}_assume_unique={}_return_indices={}".format(
jtu.format_shape_dtype_string(shape1, dtype1),
jtu.format_shape_dtype_string(shape2, dtype2),
assume_unique,
return_indices),
"shape1": shape1, "dtype1": dtype1, "shape2": shape2, "dtype2": dtype2,
"assume_unique": assume_unique, "return_indices": return_indices}
for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16]
for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16]
for shape1 in all_shapes
for shape2 in all_shapes
for assume_unique in [False, True]
for return_indices in [False, True]))
def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, return_indices):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)]
jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)


@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(
jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
Expand Down

0 comments on commit 0d81e98

Please sign in to comment.