Pseudo Random Numbers in JAX
https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html 

In [None]:
import jax
import jax.numpy as jnp

In [None]:
import numpy as np
np.random.seed(0)
np.random.get_state()[1][:10]

In [None]:
# In numpy state is updated after each call
np.random.seed(0)
_ = np.random.uniform()
print(f'{np.random.get_state()[1][:3]}')
np.random.seed(0)
_a = np.random.uniform()
print(f'{np.random.get_state()[1][:3]}')
_b = np.random.uniform()
print(f'{np.random.get_state()[1][:3]}')


In [None]:
# JAX does not use a global state.
# Instead, random functions explicitly consume the state
# which is referred to as a key .

key = jax.random.PRNGKey(0)
print(key)

In [None]:
print(f'{jax.random.normal(key)}')
print(f'{jax.random.normal(key)}')
print(f'{jax.random.normal(key)}')

In [None]:
# Use chain of keys and subkeys to ensure duplicate behavior
print('\n---\n')
key = jax.random.PRNGKey(0)
print(f'{jax.random.normal(key)}')
key, subkey = jax.random.split(key)
del key
print(f'{jax.random.normal(subkey)}')
# ...

# OR
print('\n---\n')
key = jax.random.PRNGKey(0)
print(f'{jax.random.normal(key)}')
_, key = jax.random.split(key)
print(f'{jax.random.normal(key)}')
_, key = jax.random.split(key)
print(f'{jax.random.normal(key)}')

# OR
print('\n---\n')
key = jax.random.PRNGKey(0)
print(f'{jax.random.normal(key)}')
_, *keys = jax.random.split(key, num=3)
print(f'{jax.random.normal(keys[0])}')
print(f'{jax.random.normal(keys[1])}')