Skip to content

Commit

Permalink
jnp.unravel_index: avoid overflow for large dimension sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 4, 2022
1 parent a8c6742 commit 58320e2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
22 changes: 11 additions & 11 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,22 +808,22 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'):

_UNRAVEL_INDEX_DOC = """\
Unlike numpy's implementation of unravel_index, negative indices are accepted
and out-of-bounds indices are clipped.
and out-of-bounds indices are clipped into the valid range.
"""

@_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
def unravel_index(indices, shape):
_check_arraylike("unravel_index", indices)
sizes = append(array(shape), 1)
cumulative_sizes = cumprod(sizes[::-1])[::-1]
total_size = cumulative_sizes[0]
# Clip so raveling and unraveling an oob index will not change the behavior
clipped_indices = clip(indices, -total_size, total_size - 1)
# Add enough trailing dims to avoid conflict with clipped_indices
cumulative_sizes = expand_dims(cumulative_sizes, range(1, 1 + _ndim(indices)))
clipped_indices = expand_dims(clipped_indices, axis=0)
idx = clipped_indices % cumulative_sizes[:-1] // cumulative_sizes[1:]
return tuple(idx)
shape = atleast_1d(shape)
if shape.ndim != 1:
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
out_indices = [None] * len(shape)
for i, s in reversed(list(enumerate(shape))):
indices, out_indices[i] = divmod(indices, s)
oob_pos = indices > 0
oob_neg = indices < -1
return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i))
for s, i in zip(shape, out_indices))

@_wraps(np.resize)
@partial(jit, static_argnames=('new_shape',))
Expand Down
5 changes: 4 additions & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4221,7 +4221,10 @@ def jnp_fun(a, c):
(5, (2, 1, 3)),
(0, ()),
(np.array([0, 1, 2]), (2, 2)),
(np.array([[[0, 1], [2, 3]]]), (2, 2)))
(np.array([[[0, 1], [2, 3]]]), (2, 2)),
# regression test for https://github.com/google/jax/issues/10540
(np.arange(5), (201_996, 201_996)), # prod(shape) overflows int32.
)
def testUnravelIndex(self, flat_index, shape):
args_maker = lambda: (flat_index, shape)
np_fun = jtu.with_jax_dtype_defaults(np.unravel_index, use_defaults=not hasattr(flat_index, 'dtype'))
Expand Down

0 comments on commit 58320e2

Please sign in to comment.