Skip to content

Commit

Permalink
Avoid generating trivial gathers when reversing array
Browse files Browse the repository at this point in the history
  • Loading branch information
lgeiger committed May 11, 2022
1 parent 5bce808 commit f13b69c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
9 changes: 7 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3775,11 +3775,16 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
if stop is None or (not isinstance(stop, core.Tracer) and
core.greater_equal_dim(stop, x_shape[x_axis])):
stop = None
elif core.symbolic_equal_dim(step, -1):
step = -1
except (TypeError, core.InconclusiveDimensionOperation):
pass

# Handle slice(None)
if start is None and stop is None and step is None:
# Handle slice(None) and slice(None, None, -1)
if start is None and stop is None and (
step is None or isinstance(step, int) and step == -1):
if step == -1:
reversed_y_dims.append(collapsed_y_axis)
slice_shape.append(x_shape[x_axis])
gather_slice_shape.append(x_shape[x_axis])
offset_dims.append(collapsed_y_axis)
Expand Down
4 changes: 4 additions & 0 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ def testTrivialGatherIsntGenerated(self):
jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)

jaxpr = jax.make_jaxpr(lambda x: x[::-1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p)

def testIndexingEmptyDimension(self):
# Issue 2671: XLA error when indexing into dimension of size 0
x = jnp.ones((2, 0))
Expand Down

0 comments on commit f13b69c

Please sign in to comment.