Skip to content

Commit

Permalink
implement jnp.expand_dims and jnp.stack for PRNGKeyArrays
Browse files Browse the repository at this point in the history
Also:
* fix `jnp.concatenate` and `jnp.append` for PRNGKeyArrays
* add `ndim` property to PRNGKeyArrays
* minor fix to `lax.expand_dims` with duplicate dimensions
  • Loading branch information
froystig committed Feb 17, 2022
1 parent c49fb9c commit 0f7904f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 10 deletions.
2 changes: 1 addition & 1 deletion jax/_src/lax/lax.py
Expand Up @@ -1188,7 +1188,7 @@ def squeeze(array: Array, dimensions: Sequence[int]) -> Array:

def expand_dims(array: Array, dimensions: Sequence[int]) -> Array:
"""Insert any number of size 1 dimensions into an array."""
ndim_out = np.ndim(array) + len(dimensions)
ndim_out = np.ndim(array) + len(set(dimensions))
dims_set = frozenset(canonicalize_axis(i, ndim_out) for i in dimensions)
result_shape = list(np.shape(array))
for i in sorted(dims_set):
Expand Down
9 changes: 6 additions & 3 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -1865,8 +1865,11 @@ def _squeeze(a, axis):

@_wraps(np.expand_dims)
def expand_dims(a, axis: Union[int, Sequence[int]]):
_check_arraylike("expand_dims", a)
return lax.expand_dims(a, _ensure_index_tuple(axis))
_stackable(a) or _check_arraylike("expand_dims", a)
axis = _ensure_index_tuple(axis)
if hasattr(a, "expand_dims"):
return a.expand_dims(axis)
return lax.expand_dims(a, axis)


@_wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC)
Expand Down Expand Up @@ -3393,7 +3396,7 @@ def stack(arrays, axis: int = 0, out=None):
axis = _canonicalize_axis(axis, arrays.ndim)
return concatenate(expand_dims(arrays, axis + 1), axis=axis)
else:
_check_arraylike("stack", *arrays)
_stackable(*arrays) or _check_arraylike("stack", *arrays)
shape0 = shape(arrays[0])
axis = _canonicalize_axis(axis, len(shape0) + 1)
new_arrays = []
Expand Down
18 changes: 14 additions & 4 deletions jax/_src/prng.py
Expand Up @@ -34,7 +34,7 @@
_canonicalize_tuple_index, _eliminate_deprecated_list_indexing,
_expand_bool_indices, _register_stackable)
import jax._src.pretty_printer as pp
from jax._src.util import prod
from jax._src.util import canonicalize_axis, prod


UINT_DTYPES = {
Expand Down Expand Up @@ -154,6 +154,10 @@ def _shape(self):
base_ndim = len(self.impl.key_shape)
return self._keys.shape[:-base_ndim]

@property
def ndim(self):
return len(self.shape)

def _is_scalar(self):
base_ndim = len(self.impl.key_shape)
return self._keys.ndim == base_ndim
Expand Down Expand Up @@ -191,14 +195,20 @@ def reshape(self, newshape, order=None):
return PRNGKeyArray(self.impl, reshaped_keys)

def concatenate(self, key_arrs, axis):
axis = axis % len(self.shape)
axis = canonicalize_axis(axis, self.ndim)
arrs = [self._keys, *[k._keys for k in key_arrs]]
return PRNGKeyArray(self.impl, jnp.stack(arrs, axis))
return PRNGKeyArray(self.impl, jnp.concatenate(arrs, axis))

def broadcast_to(self, shape):
new_shape = tuple(shape)+(self._keys.shape[-1],)
new_shape = (*shape, *self.impl.key_shape)
return PRNGKeyArray(self.impl, jnp.broadcast_to(self._keys, new_shape))

def expand_dims(self, dimensions: Sequence[int]):
# follows lax.expand_dims, not jnp.expand_dims, so dimensions is a sequence
ndim_out = self.ndim + len(set(dimensions))
dimensions = [canonicalize_axis(d, ndim_out) for d in dimensions]
return PRNGKeyArray(self.impl, lax.expand_dims(self._keys, dimensions))

def __repr__(self):
arr_shape = self._shape
pp_keys = pp.text('shape = ') + pp.text(str(arr_shape))
Expand Down
20 changes: 18 additions & 2 deletions tests/random_test.py
Expand Up @@ -1335,13 +1335,20 @@ def test_concatenate(self):
key = random.PRNGKey(123)
keys = random.split(key, 2)
keys = jnp.concatenate([keys, keys, keys], axis=0)
self.assertEqual(keys.shape, (3, 2))
self.assertEqual(keys.shape, (6,))

def test_broadcast_to(self):
key = random.PRNGKey(123)
keys = jnp.broadcast_to(key, (3,))
self.assertEqual(keys.shape, (3,))

def test_expand_dims(self):
key = random.PRNGKey(123)
keys = random.split(key, 6)
keys = jnp.reshape(keys, (2, 3))
keys = jnp.expand_dims(keys, 1)
self.assertEqual(keys.shape, (2, 1, 3))

def test_broadcast_arrays(self):
key = random.PRNGKey(123)
keys = jax.random.split(key, 3)
Expand All @@ -1351,7 +1358,9 @@ def test_broadcast_arrays(self):
def test_append(self):
key = random.PRNGKey(123)
keys = jnp.append(key, key)
self.assertEqual(keys.shape, (2, 1))
self.assertEqual(keys.shape, (2,))
keys = jnp.append(keys, keys)
self.assertEqual(keys.shape, (4,))

def test_ravel(self):
key = random.PRNGKey(123)
Expand All @@ -1360,6 +1369,13 @@ def test_ravel(self):
keys = jnp.ravel(keys)
self.assertEqual(keys.shape, (4,))

def test_stack(self):
key = random.PRNGKey(123)
keys = jax.random.split(key, 2)
keys = jnp.stack([keys, keys, keys], axis=0)
self.assertEqual(keys.shape, (3, 2))


def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
raise SkipTest('sampler only implemented for default RNG')

Expand Down

0 comments on commit 0f7904f

Please sign in to comment.