Skip to content

Commit

Permalink
Implement jax.numpy.argpartition
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 8, 2023
1 parent f71a55c commit 4fbaee5
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 22 deletions.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Expand Up @@ -68,6 +68,7 @@ namespace; they are listed below.
arctanh
argmax
argmin
argpartition
argsort
argwhere
around
Expand Down
61 changes: 40 additions & 21 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -3529,11 +3529,17 @@ def msort(a):


@_wraps(np.partition, lax_description="""
The jax version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`.
The JAX version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
you're only accessing the top or bottom k values of the output, it may be more
efficient to call :func:`jax.lax.top_k` directly.
The JAX version differs from the NumPy version in the treatment of NaN entries;
NaNs which have the negative bit set are sorted to the beginning of the array.
""")
@partial(jit, static_argnames=['kth', 'axis'])
def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
# TODO(jakevdp): handle NaN values like numpy.
_check_arraylike("partition", a)
arr = asarray(a)
if issubdtype(arr.dtype, np.complexfloating):
Expand All @@ -3548,6 +3554,38 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
return swapaxes(out, -1, axis)


@_wraps(np.argpartition, lax_description="""
The JAX version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
you're only accessing the top or bottom k values of the output, it may be more
efficient to call :func:`jax.lax.top_k` directly.
The JAX version differs from the NumPy version in the treatment of NaN entries;
NaNs which have the negative bit set are sorted to the beginning of the array.
""")
@partial(jit, static_argnames=['kth', 'axis'])
def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
# TODO(jakevdp): handle NaN values like numpy.
_check_arraylike("partition", a)
arr = asarray(a)
if issubdtype(arr.dtype, np.complexfloating):
raise NotImplementedError("jnp.argpartition for complex dtype is not implemented.")
axis = _canonicalize_axis(axis, arr.ndim)
kth = _canonicalize_axis(kth, arr.shape[axis])

arr = swapaxes(arr, axis, -1)
bottom_ind = lax.top_k(-arr, kth + 1)[1]

# To avoid issues with duplicate values, we compute the top indices via a proxy
set_to_zero = lambda a, i: a.at[i].set(0)
for _ in range(arr.ndim - 1):
set_to_zero = jax.vmap(set_to_zero)
proxy = set_to_zero(ones(arr.shape), bottom_ind)
top_ind = lax.top_k(proxy, arr.shape[-1] - kth - 1)[1]
out = lax.concatenate([bottom_ind, top_ind], dimension=arr.ndim - 1)
return swapaxes(out, -1, axis)


@partial(jit, static_argnums=(2,))
def _roll(a, shift, axis):
a_shape = shape(a)
Expand Down Expand Up @@ -4947,19 +4985,6 @@ def _notimplemented_flat(self):
raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: "
"consider arr.flatten() instead.")

### track unimplemented functions

_NOT_IMPLEMENTED_DESC = """
*** This function is not yet implemented by jax.numpy, and will raise NotImplementedError ***
"""

def _not_implemented(fun, module=None):
@_wraps(fun, module=module, update_doc=False, lax_description=_NOT_IMPLEMENTED_DESC)
def wrapped(*args, **kwargs):
msg = "Numpy function {} not yet implemented"
raise NotImplementedError(msg.format(fun))
return wrapped


@_wraps(np.place, lax_description="""
Numpy function :func:`numpy.place` is not available in JAX and will raise a
Expand Down Expand Up @@ -5086,12 +5111,6 @@ def _deepcopy(self: Array, memo: Any) -> Array:
"ravel", "repeat", "sort", "squeeze", "std", "sum",
"swapaxes", "take", "trace", "var"]

# These methods are mentioned explicitly by nondiff_methods, so we create
# _not_implemented implementations of them here rather than in __init__.py.
# TODO(phawkins): implement these.
argpartition = _not_implemented(np.argpartition)


# Experimental support for NumPy's module dispatch with NEP-37.
# Currently requires https://github.com/seberg/numpy-dispatch
_JAX_ARRAY_TYPES = (device_array.DeviceArray, core.Tracer, ArrayImpl)
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Expand Up @@ -174,6 +174,7 @@
nan_to_num as nan_to_num,
nanargmax as nanargmax,
nanargmin as nanargmin,
argpartition as argpartition,
nanmedian as nanmedian,
nanpercentile as nanpercentile,
nanquantile as nanquantile,
Expand Down
44 changes: 43 additions & 1 deletion tests/lax_numpy_test.py
Expand Up @@ -3599,17 +3599,58 @@ def testPartition(self, shape, dtype, axis, kth):
jnp_output = jnp.partition(arg, axis=axis, kth=kth)
np_output = np.partition(arg, axis=axis, kth=kth)

# Assert that pivot point is equal
# Assert that pivot point is equal:
self.assertArraysEqual(
lax.index_in_dim(jnp_output, axis=axis, index=kth),
lax.index_in_dim(np_output, axis=axis, index=kth))

# Assert remaining values are correctly partitioned:
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis))
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))

@jtu.sample_product(
[{'shape': shape, 'axis': axis, 'kth': kth}
for shape in nonzerodim_shapes
for axis in range(-len(shape), len(shape))
for kth in range(-shape[axis], shape[axis])],
dtype=default_dtypes,
)
def testArgpartition(self, shape, dtype, axis, kth):
rng = jtu.rand_default(self.rng())
arg = rng(shape, dtype)

jnp_output = jnp.argpartition(arg, axis=axis, kth=kth)
np_output = np.argpartition(arg, axis=axis, kth=kth)

# Assert that all indices are present
self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False)

# Because JAX & numpy may treat duplicates differently, we must compare values
# rather than indices.
getvals = lambda x, ind: x[ind]
for ax in range(arg.ndim):
if ax != range(arg.ndim)[axis]:
getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax)
jnp_values = getvals(arg, jnp_output)
np_values = getvals(arg, np_output)

# Assert that pivot point is equal:
self.assertArraysEqual(
lax.index_in_dim(jnp_values, axis=axis, index=kth),
lax.index_in_dim(np_values, axis=axis, index=kth))

# Assert remaining values are correctly partitioned:
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_values, start_index=0, limit_index=kth, axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_values, start_index=0, limit_index=kth, axis=axis), dimension=axis))
self.assertArraysEqual(
lax.sort(lax.slice_in_dim(jnp_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis),
lax.sort(lax.slice_in_dim(np_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis))

@jtu.sample_product(
[dict(shifts=shifts, axis=axis)
for shifts, axis in [
Expand Down Expand Up @@ -5183,6 +5224,7 @@ def testWrappedSignaturesMatch(self):

# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
unsupported_params = {
'argpartition': ['kind', 'order'],
'asarray': ['like'],
'broadcast_to': ['subok'],
'clip': ['kwargs'],
Expand Down

0 comments on commit 4fbaee5

Please sign in to comment.