Skip to content

Commit

Permalink
Merge pull request #404 from hawkinsp/numpy
Browse files Browse the repository at this point in the history
Implement np.roll (#70).
  • Loading branch information
hawkinsp committed Feb 18, 2019
2 parents 70b13ce + f392920 commit 5180549
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
35 changes: 35 additions & 0 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,6 +1763,41 @@ def argsort(a, axis=-1, kind='quicksort', order=None):
return perm


@_wraps(onp.roll)
def roll(a, shift, axis=None):
a = asarray(a)
a_shape = shape(a)
if axis is None:
return lax.reshape(roll(ravel(a), shift, axis=0), a_shape)

a_ndim = len(a_shape)
if isinstance(shift, tuple):
if isinstance(axis, tuple):
if len(axis) != len(shift):
msg = "Mismatched lengths between shift ({}) and axis ({}) for np.roll."
raise ValueError(msg.format(len(shift), len(axis)))
axis = tuple(a for a in axis)
else:
axis = (axis,) * len(shift)
elif isinstance(axis, tuple):
shift = (shift,) * len(axis)
else:
shift = (shift,)
axis = (axis,)

for offset, i in zip(shift, axis):
i = _canonicalize_axis(i, a_ndim)
offset = offset % (a_shape[i] or 1)
slices = [slice(None)] * a_ndim
slices[i] = slice(None, -offset)
before = a[tuple(slices)]
slices[i] = slice(-offset, None)
after = a[tuple(slices)]
a = lax.concatenate((after, before), i)

return a


@_wraps(onp.take)
def take(a, indices, axis=None, out=None, mode=None):
if out:
Expand Down
23 changes: 23 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,29 @@ def testArgsortManually(self):
expected = onp.argsort(x)
self.assertAllClose(expected, ans, check_dtypes=False)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_shifts={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype),
shifts, axis),
"rng": rng, "shape": shape, "dtype": dtype, "shifts": shifts,
"axis": axis}
for dtype in all_dtypes
for shape in [(3, 4), (3, 4, 5), (7, 4, 0)]
for shifts, axis in [
(3, None),
(1, 1),
((3,), (0,)),
((-2,), (-2,)),
((1, 2), (0, -1))
]
for rng in [jtu.rand_default()]))
def testRoll(self, shape, dtype, shifts, axis, rng):
args_maker = lambda: [rng(shape, dtype)]
lnp_op = lambda x: lnp.roll(x, shifts, axis=axis)
onp_op = lambda x: onp.roll(x, shifts, axis=axis)
self._CheckAgainstNumpy(lnp_op, onp_op, args_maker, check_dtypes=True)
self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_index={}_axis={}_mode={}".format(
jtu.format_shape_dtype_string(shape, dtype),
Expand Down

0 comments on commit 5180549

Please sign in to comment.