In [1]:
import numpy as np

np.random.seed(0)


def print_truncated_random_state():
    """To avoid spamming the outputs, print only part of the state."""
    full_random_state = np.random.get_state()
    print(str(full_random_state)[:460], "...")


print_truncated_random_state()

('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...


In [2]:
np.random.seed(0)
print_truncated_random_state()

('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...


In [3]:
_ = np.random.uniform()
print_truncated_random_state()

('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...


In [4]:
np.random.seed(0)
print(np.random.uniform(size=3))

[0.5488135  0.71518937 0.60276338]


In [5]:
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))

individually: [0.5488135  0.71518937 0.60276338]
all at once:  [0.5488135  0.71518937 0.60276338]


## Random numbers in JAX

JAX’s random number generation differs from NumPy’s in important ways, because NumPy’s PRNG design makes it hard to simultaneously guarantee a number of desirable properties. Specifically, in JAX we want PRNG generation to be:

1. reproducible,

2. parallelizable,

3. vectorisable.

We will discuss why in the following. First, we will focus on the implications of a PRNG design based on a global state. Consider the code:

In [6]:
import numpy as np

np.random.seed(0)


def bar():
    return np.random.uniform()


def baz():
    return np.random.uniform()


def foo():
    return bar() + 2 * baz()


print(foo())

1.9791922366721637


The output of this code can only satisfy requirement #1 if we assume a predictable order of execution for `bar()` and `baz()`. This is not a problem in NumPy, which always evaluates code in the order defined by the Python interpreter. In JAX, however, this is more problematic: for efficient execution, we want the JIT compiler to be free to reorder, elide, and fuse various operations in the function we define. Further, when executing in multi-device environments, execution efficiency would be hampered by the need for each process to synchronize a global state.

To avoid this issue, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`:

In [7]:
from jax import random

key = random.key(42)
print(key)

Array((), dtype=key<fry>) overlaying:
[ 0 42]


In [8]:
print(random.normal(key))
print(random.normal(key))

-0.18471177
-0.18471177


In [9]:
for i in range(3):
    new_key, subkey = random.split(key)
    del key  # The old key is consumed by split() -- we must never use it again.

    val = random.normal(subkey)
    del subkey  # The subkey is consumed by normal().

    print(f"draw {i}: {val}")
    key = new_key  # new_key is safe to use in the next iteration.

draw 0: 1.369469404220581
draw 1: -0.19947023689746857
draw 2: -2.298278331756592


In [10]:
key, subkey = random.split(key)  # which discards the old key automatically.
key, *forty_two_subkeys = random.split(
    key, num=43
)  # create as many keys as you need, not just 2:

## Lack of sequential equivalence
Another difference between NumPy’s and JAX’s random modules relates to the sequential equivalence guarantee mentioned above.

As in NumPy, JAX’s random module also allows sampling of vectors of numbers. However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware (requirement #3 above).

In the example below, sampling 3 values out of a normal distribution individually using three subkeys gives a different result to using giving a single key and specifying shape=(3,):

In [11]:
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))

individually: [-0.04838832  0.10796154 -1.2226542 ]
all at once:  [ 0.18693547 -1.2806505  -1.5593132 ]


In [12]:
import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))

vectorized: [-0.04838832  0.10796154 -1.2226542 ]
