Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random.binomial and random.multinomial #13327

Open
carlosgmartin opened this issue Nov 19, 2022 · 6 comments
Open

Add random.binomial and random.multinomial #13327

carlosgmartin opened this issue Nov 19, 2022 · 6 comments
Assignees
Labels
enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Comments

@carlosgmartin
Copy link
Contributor

carlosgmartin commented Nov 19, 2022

Add JAX counterparts of numpy.random.binomial and numpy.random.multinomial to jax.random package. See #480 (comment) for context. A current workaround is using the JAX substrate of TensorFlow Probability:

from jax import random, numpy as jnp
from tensorflow_probability.substrates import jax as tfp

def binomial(key, n, p, shape=()):
    return tfp.distributions.Binomial(n, probs=p).sample(
        seed=key,
        sample_shape=shape,
    )

def multinomial(key, n, p, shape=()):
    return tfp.distributions.Multinomial(n, probs=p).sample(
        seed=key,
        sample_shape=shape,
    )

key = random.PRNGKey(0)

key, subkey = random.split(key)
print(binomial(subkey, 9, .8, [2, 5]))

key, subkey = random.split(key)
print(multinomial(subkey, 9, jnp.array([.7, .2, .1]), [4]))

Output:

[[7. 8. 8. 7. 5.]
 [5. 8. 9. 8. 5.]]
[[7. 1. 1.]
 [4. 3. 2.]
 [5. 3. 1.]
 [4. 4. 1.]]
@carlosgmartin carlosgmartin added the enhancement New feature or request label Nov 19, 2022
@zhangqiaorjc
Copy link
Member

@sharadmv, is it possible to port tfp's implementation to jax.random?

@sharadmv
Copy link
Member

Possibly, yes. The implementation is fairly complex though, and makes some accelerator-specific tradeoffs IIRC. cc: @srvasude @brianwa84.

@jakevdp jakevdp added the P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR label Nov 28, 2022
@hylkedonker
Copy link

hylkedonker commented Dec 6, 2023

Any plans to add random.multinomial now that random.binomial has been merged?
The multinomial sampler could be implemented either as sequence of negative binomials or by repeated categorical draws.
See also the Wikipedia entry on multinomial variate generation).

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 6, 2023

Thanks for reaching out – I don't know of anyone working on this currently.

@andportnoy
Copy link
Contributor

andportnoy commented Jan 28, 2024

A workaround if you want pure JAX is to take the log of your probabilities vector (nonnegative, sums to 1):

jax.random.categorical(key, jnp.log(p))

(Based on the "contract" of jax.random.categorical:

logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
so that softmax(logits, axis) gives the corresponding probabilities.

)

That can't be the most efficient way to sample though...

@brianwa84
Copy link
Contributor

brianwa84 commented Jan 28, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Projects
None yet
Development

No branches or pull requests

7 participants