Skip to content

Commit

Permalink
custom_prng: generalize indexing of PRNGKeyArray
Browse files Browse the repository at this point in the history
Co-authored-by: Roy Frostig <frostig@google.com>
  • Loading branch information
jakevdp and froystig committed Dec 20, 2021
1 parent 2a6147a commit 4d9e9b4
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
11 changes: 6 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6124,18 +6124,19 @@ def _is_scalar(x):
return np.isscalar(x) or (isinstance(x, (np.ndarray, ndarray))
and np.ndim(x) == 0)

def _canonicalize_tuple_index(arr_ndim, idx):
def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
"""Helper to remove Ellipsis and add in the implicit trailing slice(None)."""
len_without_none = _sum(1 for e in idx if e is not None and e is not Ellipsis)
if len_without_none > arr_ndim:
msg = "Too many indices for array: {} non-None/Ellipsis indices for dim {}."
raise IndexError(msg.format(len_without_none, arr_ndim))
raise IndexError(
f"Too many indices for {array_name}: {len_without_none} "
f"non-None/Ellipsis indices for dim {arr_ndim}.")
ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
ellipsis_index = next(ellipses, None)
if ellipsis_index is not None:
if next(ellipses, None) is not None:
msg = "Multiple ellipses (...) not supported: {}."
raise IndexError(msg.format(list(map(type, idx))))
raise IndexError(
f"Multiple ellipses (...) not supported: {list(map(type, idx))}.")
colons = (slice(None),) * (arr_ndim - len_without_none)
idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:]
elif len_without_none < arr_ndim:
Expand Down
18 changes: 7 additions & 11 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
from jax._src.api import jit, vmap
from jax._src.lib import xla_client
from jax._src.lib import cuda_prng
from jax._src.numpy.lax_numpy import _register_stackable
from jax._src.numpy.lax_numpy import (
_canonicalize_tuple_index, _eliminate_deprecated_list_indexing,
_expand_bool_indices, _register_stackable)
import jax._src.pretty_printer as pp
from jax._src.util import prod

Expand Down Expand Up @@ -167,18 +169,12 @@ def __iter__(self) -> Iterator['PRNGKeyArray']:
return (PRNGKeyArray(self.impl, k) for k in iter(self._keys))

def __getitem__(self, idx) -> 'PRNGKeyArray':
if not isinstance(idx, tuple):
idx = (idx,)
if any(type(i) is not int for i in idx):
raise NotImplementedError(
'PRNGKeyArray only supports indexing with integer indices. '
f'Cannot index at {idx}')
base_ndim = len(self.impl.key_shape)
ndim = self._keys.ndim - base_ndim
if len(idx) > ndim:
raise IndexError(
f'too many indices for PRNGKeyArray: array is {ndim}-dimensional '
f'but {len(idx)} were indexed')
indexable_shape = self.impl.key_shape[:ndim]
idx = _eliminate_deprecated_list_indexing(idx)
idx = _expand_bool_indices(idx, indexable_shape)
idx = _canonicalize_tuple_index(ndim, idx, array_name='PRNGKeyArray')
return PRNGKeyArray(self.impl, self._keys[idx])

def _fold_in(self, data: int) -> 'PRNGKeyArray':
Expand Down
32 changes: 32 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,38 @@ def test_explicit_unsafe_rbg_key(self):
key = random.unsafe_rbg_key(42)
self.assertIs(key.impl, prng.unsafe_rbg_prng_impl)

def test_key_array_indexing_0d(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.PRNGKey(1701)
self.assertEqual(key.shape, ())
self.assertEqual(key[None].shape, (1,))
self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*',
lambda: key[0])

def test_key_array_indexing_nd(self):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
keys = vmap(vmap(random.PRNGKey))(jnp.arange(6).reshape((2, 3)))
self.assertEqual(keys.shape, (2, 3))
self.assertEqual(keys[0, 0].shape, ())
self.assertEqual(keys[0, 1].shape, ())
self.assertEqual(keys[0].shape, (3,))
self.assertEqual(keys[1, :].shape, (3,))
self.assertEqual(keys[:, 1].shape, (2,))
self.assertEqual(keys[None].shape, (1, 2, 3))
self.assertEqual(keys[None, None].shape, (1, 1, 2, 3))
self.assertEqual(keys[None, :, None].shape, (1, 2, 1, 3))
self.assertEqual(keys[None, None, None, 0, None, None, None, 1].shape,
(1,) * 6)
self.assertEqual(keys[..., 1:, None].shape, (2, 2, 1))
self.assertEqual(keys[None, 0, ..., 1, None].shape, (1, 1))
self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*',
lambda: keys[0, 1, 2])
self.assertRaisesRegex(IndexError, 'Too many indices for PRNGKeyArray.*',
lambda: keys[0, 1, None, 2])


@jtu.with_config(jax_numpy_rank_promotion="raise")
class LaxRandomTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 4d9e9b4

Please sign in to comment.