Skip to content

Commit

Permalink
Do not generate trivial gathers when indexing entire axis
Browse files Browse the repository at this point in the history
  • Loading branch information
lgeiger committed Apr 10, 2022
1 parent 5fdad0e commit aac41ab
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
14 changes: 9 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -3706,11 +3706,15 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
# Normalize the slice to use None when possible
start, stop, step = i.start, i.stop, i.step
try:
if ((step is None or core.symbolic_equal_dim(step, 1)) and
stop is not None and core.symbolic_equal_dim(stop, x_shape[x_axis])):
# The following is a useful special case with shape polymorphism
stop = None
except TypeError:
if step is None or core.symbolic_equal_dim(step, 1):
step = None
if step is None:
if start is None or core.symbolic_equal_dim(start, 0):
start = None
if stop is None or (not isinstance(stop, core.Tracer) and
core.greater_equal_dim(stop, x_shape[x_axis])):
stop = None
except (TypeError, core.InconclusiveDimensionOperation):
pass

# Handle slice(None)
Expand Down
5 changes: 5 additions & 0 deletions tests/lax_numpy_indexing_test.py
Expand Up @@ -849,6 +849,11 @@ def testTrivialGatherIsntGenerated(self):
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertNotIn('gather', str(jaxpr))

jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)

def testIndexingEmptyDimension(self):
# Issue 2671: XLA error when indexing into dimension of size 0
x = jnp.ones((2, 0))
Expand Down

0 comments on commit aac41ab

Please sign in to comment.