Skip to content

Commit

Permalink
Fix indexing with slices when the slice elements are jax.Array.
Browse files Browse the repository at this point in the history
This fixes a bug introduced in #18679, for the case when some
elements of the slice are `jax.Array`. We add a new test also.
  • Loading branch information
gnecula committed Dec 5, 2023
1 parent 7a3e214 commit ec46058
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
12 changes: 9 additions & 3 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -4938,9 +4938,15 @@ def _preprocess_slice(
# See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
# "this is harder to get right than you may think"
# (from https://github.com/python/cpython/blob/939fc6d6eab9b7ea8c244d513610dbdd556503a7/Objects/sliceobject.c#L275)
def convert_to_index(d: DimSize) -> DimSize:
# Convert np.array and jax.Array to int, leave symbolic dimensions alone
try:
return operator.index(d)
except:
return d

# Must resolve statically if step is {<0, ==0, >0}
step = s.step if s.step is not None else 1
step = convert_to_index(s.step) if s.step is not None else 1
try:
if step == 0:
raise ValueError("slice step cannot be zero")
Expand Down Expand Up @@ -4975,12 +4981,12 @@ def clamp_index(i: DimSize, which: str):
if s.start is None:
start = 0 if step_gt_0 else axis_size - 1
else:
start = clamp_index(s.start, "start")
start = clamp_index(convert_to_index(s.start), "start")

if s.stop is None:
stop = axis_size if step_gt_0 else -1
else:
stop = clamp_index(s.stop, "stop")
stop = clamp_index(convert_to_index(s.stop), "stop")

gap = step if step_gt_0 else - step
distance = (stop - start) if step_gt_0 else (start - stop)
Expand Down
16 changes: 16 additions & 0 deletions tests/lax_numpy_indexing_test.py
Expand Up @@ -443,6 +443,22 @@ def testStaticIndexing(self, name, shape, dtype, indexer):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def testStaticIndexingWithJaxArray(self):
shape = (10,)
indexer = slice(jnp.array(2, dtype=np.int32),
np.array(11, dtype=np.int32),
jnp.array(1, dtype=np.int32))
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, np.int32)]
np_fun = lambda x: np.asarray(x)[indexer]
jnp_fun = lambda x: jnp.asarray(x)[indexer]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
# Tests x.at[...].get(...) as well.
jnp_fun = lambda x: jnp.asarray(x).at[indexer].get()
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
funcname=["negative", "sin", "cos", "square", "sqrt", "log", "exp"],
)
Expand Down

0 comments on commit ec46058

Please sign in to comment.