In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_probability as tfp

In [2]:
key = jax.random.PRNGKey(seed=0)
k = 5
p = np.ones((k,)) / k
n = 10_000



In [3]:
def sample_jax_random(key, p, n):
    return jax.random.categorical(key, logits=jnp.log(p), shape=(n,))


print(sample_jax_random(key, p, n))

%timeit sample_jax_random(key, p, n)

[2 1 2 ... 0 0 3]
12 ms ± 556 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [4]:
sample_jax_random_jit = jax.jit(sample_jax_random, static_argnums=(2,))

print(sample_jax_random_jit(key, p, n))

%timeit sample_jax_random_jit(key, p, n)

[2 1 2 ... 0 0 3]
3.07 ms ± 597 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
def sample_tfp_jax(key, p, n):
    cat_dist = tfp.experimental.substrates.jax.distributions.Categorical(probs=p)
    return cat_dist.sample(sample_shape=(n,), seed=key)


print(sample_tfp_jax(key, p, n))

%timeit sample_tfp_jax(key, p, n)

[2 2 0 ... 1 0 3]
76.2 ms ± 8.72 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
sample_tfp_jax_jit = jax.jit(sample_tfp_jax, static_argnums=(2,))

print(sample_tfp_jax_jit(key, p, n))

%timeit sample_tfp_jax_jit(key, p, n)

[2 2 0 ... 1 0 3]
3.91 ms ± 929 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
