Skip to content

Commit

Permalink
Broadcast arrays manually in categorical sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
amol-mandhane authored and Amol Mandhane committed Jul 7, 2021
1 parent 56087dc commit f945982
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
5 changes: 4 additions & 1 deletion jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,10 @@ def categorical(key: jnp.ndarray,
_check_shape("categorical", shape, batch_shape)

sample_shape = shape[:len(shape)-len(batch_shape)]
return jnp.argmax(gumbel(key, sample_shape + logits.shape, logits.dtype) + logits, axis=axis)
return jnp.argmax(
gumbel(key, sample_shape + logits.shape, logits.dtype) +
lax.expand_dims(logits, tuple(range(len(sample_shape)))),
axis=axis)


def laplace(key: jnp.ndarray,
Expand Down
1 change: 1 addition & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def testBernoulli(self, p, dtype):
]
for sample_shape in [(10000,), (5000, 2)]
for dtype in jtu.dtypes.floating))
@jtu.disable_implicit_rank_promotion
def testCategorical(self, p, axis, dtype, sample_shape):
key = random.PRNGKey(0)
p = np.array(p, dtype=dtype)
Expand Down

0 comments on commit f945982

Please sign in to comment.