Skip to content

Commit

Permalink
[jax2tf] Fix for jnp.roll with shape polymorphism
Browse files Browse the repository at this point in the history
There was a partial fix before, in #13470, but it was incomplete
and the x64 mode was not handled properly.

There are no tests added here; this was discovered by running the
tests with --jax2tf_default_experimental_native_lowering, which
will become default soon.
  • Loading branch information
gnecula committed Dec 9, 2022
1 parent 942aa7a commit 86a70ab
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -3506,9 +3506,11 @@ def _roll(a, shift, axis):
for x, i in zip(broadcast_to(shift, b_shape),
np.broadcast_to(axis, b_shape)):
i = _canonicalize_axis(i, a_ndim)
x = remainder(x, (a_shape[i] or 1))
a_shape_i = array(a_shape[i], dtype=np.int32)
x = remainder(lax.convert_element_type(x, np.int32),
lax.max(a_shape_i, np.int32(1)))
a = lax.concatenate((a, a), i)
a = lax.dynamic_slice_in_dim(a, array(a_shape[i]) - x, a_shape[i], axis=i)
a = lax.dynamic_slice_in_dim(a, a_shape_i - x, a_shape[i], axis=i)
return a


Expand Down

0 comments on commit 86a70ab

Please sign in to comment.