diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b8882acf81a3..38fde42b5b6d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4938,9 +4938,15 @@ def _preprocess_slice( # See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding # "this is harder to get right than you may think" # (from https://github.com/python/cpython/blob/939fc6d6eab9b7ea8c244d513610dbdd556503a7/Objects/sliceobject.c#L275) + def convert_to_index(d: DimSize) -> DimSize: + # Convert np.array and jax.Array to int, leave symbolic dimensions alone + try: + return operator.index(d) + except: + return d # Must resolve statically if step is {<0, ==0, >0} - step = s.step if s.step is not None else 1 + step = convert_to_index(s.step) if s.step is not None else 1 try: if step == 0: raise ValueError("slice step cannot be zero") @@ -4975,12 +4981,12 @@ def clamp_index(i: DimSize, which: str): if s.start is None: start = 0 if step_gt_0 else axis_size - 1 else: - start = clamp_index(s.start, "start") + start = clamp_index(convert_to_index(s.start), "start") if s.stop is None: stop = axis_size if step_gt_0 else -1 else: - stop = clamp_index(s.stop, "stop") + stop = clamp_index(convert_to_index(s.stop), "stop") gap = step if step_gt_0 else - step distance = (stop - start) if step_gt_0 else (start - stop) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index d465a3befd90..5e2707f80d5b 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -443,6 +443,22 @@ def testStaticIndexing(self, name, shape, dtype, indexer): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + def testStaticIndexingWithJaxArray(self): + shape = (10,) + indexer = slice(jnp.array(2, dtype=np.int32), + np.array(11, dtype=np.int32), + jnp.array(1, dtype=np.int32)) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, np.int32)] + np_fun = lambda x: np.asarray(x)[indexer] + jnp_fun = lambda x: jnp.asarray(x)[indexer] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + # Tests x.at[...].get(...) as well. + jnp_fun = lambda x: jnp.asarray(x).at[indexer].get() + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( funcname=["negative", "sin", "cos", "square", "sqrt", "log", "exp"], )