Skip to content

Commit

Permalink
Implement jnp.ravel_multi_index() (google#4313)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 17, 2020
1 parent 8a4ee3d commit e0af77f
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Expand Up @@ -292,6 +292,7 @@ Not every function in NumPy is implemented; contributions are welcome!
rad2deg
radians
ravel
ravel_multi_index
real
reciprocal
remainder
Expand Down
2 changes: 1 addition & 1 deletion jax/numpy/__init__.py
Expand Up @@ -53,7 +53,7 @@
object_, ones, ones_like, operator_name, outer, packbits, pad, percentile,
pi, piecewise, polyadd, polyder, polymul, polysub, polyval, positive, power,
prod, product, promote_types, ptp, quantile,
rad2deg, radians, ravel, real, reciprocal, remainder, repeat, reshape,
rad2deg, radians, ravel, ravel_multi_index, real, reciprocal, remainder, repeat, reshape,
result_type, right_shift, rint, roll, rollaxis, rot90, round, row_stack,
save, savez, searchsorted, select, set_printoptions, shape, sign, signbit,
signedinteger, sin, sinc, single, sinh, size, sometrue, sort, sort_complex, split, sqrt,
Expand Down
35 changes: 35 additions & 0 deletions jax/numpy/lax_numpy.py
Expand Up @@ -1176,6 +1176,41 @@ def ravel(a, order="C"):
return reshape(a, (size(a),), order)


@_wraps(np.ravel_multi_index)
def ravel_multi_index(multi_index, dims, mode='raise', order='C'):
assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}"
dims = tuple(core.concrete_or_error(int, d, "in `dims` argument of ravel_multi_index().") for d in dims)
for index in multi_index:
_check_arraylike("ravel_multi_index", index)
if mode == 'raise':
core.concrete_or_error(array, index,
"The error occurred because ravel_multi_index was jit-compiled"
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
if not issubdtype(_dtype(index), integer):
raise TypeError("only int indices permitted")
if mode == "raise":
if _any(any((i < 0) | (i >= d)) for i, d in zip(multi_index, dims)):
raise ValueError("invalid entry in coordinates array")
elif mode == "clip":
multi_index = [clip(i, 0, d - 1) for i, d in zip(multi_index, dims)]
elif mode == "wrap":
multi_index = [i % d for i, d in zip(multi_index, dims)]
else:
raise ValueError(f"invalid mode={mode!r}. Expected 'raise', 'wrap', or 'clip'")

if order == "F":
strides = np.cumprod((1,) + dims[:-1])
elif order == "C":
strides = np.cumprod((1,) + dims[1:][::-1])[::-1]
else:
raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'")

result = 0
for i, s in zip(multi_index, strides):
result = result + i * s
return result


_UNRAVEL_INDEX_DOC = """\
Unlike numpy's implementation of unravel_index, negative indices are accepted
and out-of-bounds indices are clipped.
Expand Down
43 changes: 43 additions & 0 deletions tests/lax_numpy_test.py
Expand Up @@ -2754,6 +2754,49 @@ def testRavel(self):
args_maker = lambda: [rng.randn(3, 4).astype("float32")]
self._CompileAndCheck(lambda x: x.ravel(), args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_order={}_mode={}".format(
shape, order, mode),
"shape": shape, "order": order, "mode": mode}
for shape in nonempty_nonscalar_array_shapes
for order in ['C', 'F']
for mode in ['wrap', 'clip', 'raise']))
def testRavelMultiIndex(self, shape, order, mode):
# generate indices in each dimension with a few out of bounds.
rngs = [jtu.rand_int(self.rng(), low=-1, high=dim + 1)
for dim in shape]
# generate multi_indices of different dimensions that broadcast.
args_maker = lambda: [tuple(rng(ndim * (3,), jnp.int_)
for ndim, rng in enumerate(rngs))]
def np_fun(x):
try:
return np.ravel_multi_index(x, shape, order=order, mode=mode)
except ValueError as err:
if str(err).startswith('invalid entry'):
# sentinel indicating expected error.
return -999
else:
raise

def jnp_fun(x):
try:
return jnp.ravel_multi_index(x, shape, order=order, mode=mode)
except ValueError as err:
if str(err).startswith('invalid entry'):
# sentinel indicating expected error.
return -999
else:
raise
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
if mode == 'raise':
msg = ("The error occurred because ravel_multi_index was jit-compiled "
"with mode='raise'. Use mode='wrap' or mode='clip' instead.")
with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg):
jax.jit(jnp_fun)(*args_maker())
else:
self._CompileAndCheck(jnp_fun, args_maker)


@parameterized.parameters(
(0, (2, 1, 3)),
(5, (2, 1, 3)),
Expand Down

0 comments on commit e0af77f

Please sign in to comment.