Skip to content

Commit

Permalink
[JAX] Update users of jax.ops.index... functions, which are depreca…
Browse files Browse the repository at this point in the history
…ted.

* replace uses of `jax.ops.index[...]` with `jax.numpy.index_exp[...]`, which is a standard NumPy function that does the same thing.
* remove some redundant uses of `jax.ops.index[...]`, where the expression is passed directly to an indexed accessor function like `.at[...]`.
* update some remaining users of `jax.ops.index_update(x, jax.ops.index[idx], y)` to use the `x.at[idx].set(y)` APIs.

PiperOrigin-RevId: 404395250
  • Loading branch information
hawkinsp authored and jax authors committed Oct 19, 2021
1 parent b09501f commit 9fee130
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
13 changes: 7 additions & 6 deletions jax/_src/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ class _Indexable(object):
"""Helper object for building indexes for indexed update functions.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`.
Prefer the use of :attr:`jax.numpy.ndarray.at`. If an explicit index
is needed, use :func:`jax.numpy.index_exp`.
This is a singleton object that overrides the :code:`__getitem__` method
to return the index it is passed.
Expand Down Expand Up @@ -171,7 +172,7 @@ def index_add(x: Array,
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_add(x, jax.ops.index[2:4, 3:], 6.)
>>> jax.ops.index_add(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 7., 7., 7.],
Expand Down Expand Up @@ -223,7 +224,7 @@ def index_mul(x: Array,
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_mul(x, jax.ops.index[2:4, 3:], 6.)
>>> jax.ops.index_mul(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
Expand Down Expand Up @@ -273,7 +274,7 @@ def index_min(x: Array,
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_min(x, jax.ops.index[2:4, 3:], 0.)
>>> jax.ops.index_min(x, jnp.index_exp[2:4, 3:], 0.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 0., 0., 0.],
Expand Down Expand Up @@ -322,7 +323,7 @@ def index_max(x: Array,
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_max(x, jax.ops.index[2:4, 3:], 6.)
>>> jax.ops.index_max(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
Expand Down Expand Up @@ -372,7 +373,7 @@ def index_update(x: Array,
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_update(x, jax.ops.index[::2, 3:], 6.)
>>> jax.ops.index_update(x, jnp.index_exp[::2, 3:], 6.)
DeviceArray([[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2917,7 +2917,7 @@ def testSearchsorted(self, ashape, vshape, side, dtype):
for dtype in default_dtypes
))
def testDigitize(self, xshape, binshape, right, reverse, dtype):
order = jax.ops.index[::-1] if reverse else jax.ops.index[:]
order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:]
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]]
np_fun = lambda x, bins: np.digitize(x, bins, right=right)
Expand Down

0 comments on commit 9fee130

Please sign in to comment.