Skip to content

Latest commit

 

History

History
54 lines (38 loc) · 1.59 KB

jax.ops.rst

File metadata and controls

54 lines (38 loc) · 1.59 KB

jax.ops package

.. currentmodule:: jax.ops

.. automodule:: jax.ops


Indexed update operators

JAX is intended to be used with a functional style of programming, and hence does not support NumPy-style indexed assignment directly. Instead, JAX provides pure alternatives, namely :func:`jax.ops.index_update` and its relatives.

.. autosummary::
  :toctree: _autosummary

    index
    index_update
    index_add
    index_mul
    index_min
    index_max


Syntactic sugar for indexed update operators

JAX also provides an alternate syntax for these indexed update operators. Specifically, JAX ndarray types have a property at, which can be used as follows (where idx can be an arbitrary index expression).

Alternate syntax Equivalent expression
x.at[idx].set(y) jax.ops.index_update(x, jax.ops.index[idx], y)
x.at[idx].add(y) jax.ops.index_add(x, jax.ops.index[idx], y)
x.at[idx].mul(y) jax.ops.index_mul(x, jax.ops.index[idx], y)
x.at[idx].min(y) jax.ops.index_min(x, jax.ops.index[idx], y)
x.at[idx].max(y) jax.ops.index_max(x, jax.ops.index[idx], y)

Note that none of these expressions modify the original x; instead they return a modified copy of x.

Other operators

.. autosummary::
  :toctree: _autosummary

    segment_sum