diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f47913ff03a2..56656b047c53 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3506,9 +3506,11 @@ def _roll(a, shift, axis): for x, i in zip(broadcast_to(shift, b_shape), np.broadcast_to(axis, b_shape)): i = _canonicalize_axis(i, a_ndim) - x = remainder(x, (a_shape[i] or 1)) + a_shape_i = array(a_shape[i], dtype=np.int32) + x = remainder(lax.convert_element_type(x, np.int32), + lax.max(a_shape_i, np.int32(1))) a = lax.concatenate((a, a), i) - a = lax.dynamic_slice_in_dim(a, array(a_shape[i]) - x, a_shape[i], axis=i) + a = lax.dynamic_slice_in_dim(a, a_shape_i - x, a_shape[i], axis=i) return a