Skip to content

Commit

Permalink
improve documentation for ix_
Browse files Browse the repository at this point in the history
  • Loading branch information
selamw1 committed May 6, 2024
1 parent bb6aa12 commit 9caf59d
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3158,9 +3158,39 @@ def _i0_jvp(primals, tangents):
primal_out, tangent_out = jax.jvp(i0.fun, primals, tangents)
return primal_out, where(primals[0] == 0, 0.0, tangent_out)


@util.implements(np.ix_)
def ix_(*args: ArrayLike) -> tuple[Array, ...]:
"""Return a multi-dimensional grid (open mesh) from N one-dimensional sequences.
JAX implementation of :func:`numpy.ix_`.
Args:
*args: N one-dimensional arrays
Returns:
Tuple of Jax arrays forming an open mesh, each with N dimensions.
See Also:
- :obj:`jax.numpy.ogrid`
- :obj:`jax.numpy.mgrid`
- :func:`jax.numpy.meshgrid`
Example:
>>> rows = jnp.array([0, 2])
>>> cols = jnp.array([1, 3])
>>> open_mesh = jnp.ix_(rows, cols)
>>> open_mesh
(Array([[0],
[2]], dtype=int32), Array([[1, 3]], dtype=int32))
>>> [grid.shape for grid in open_mesh]
[(2, 1), (1, 2)]
>>> x = jnp.array([[10, 20, 30, 40],
... [50, 60, 70, 80],
... [90, 100, 110, 120],
... [130, 140, 150, 160]])
>>> x[open_mesh]
Array([[ 20, 40],
[100, 120]], dtype=int32)
"""
util.check_arraylike("ix", *args)
n = len(args)
output = []
Expand Down

0 comments on commit 9caf59d

Please sign in to comment.