# Fully Controllable Random Number Generation

In this section, we'll explore how to create programs that use random number generation in a fashion that is fully deterministic conditioned on a single starting random number generation key.

But first, let's explore what happens when we use NumPy's vanilla random number generation protocol to generate numbers.

In [None]:
import numpy as onp  # original numpy

Let's draw a random number from a Gaussian in NumPy.

In [None]:
onp.random.seed(42)
a = onp.random.normal()
a

And for good measure, let's draw another one.

In [None]:
b = onp.random.normal()
b

This is intuitive behaviour, because we expect that each time we call on a random number generator, we should get back a different number from before.

However, this behaviour is problematic when we are trying to debug programs, which essentially are deterministic. This is because _stochastically_, we might hit a setting where we encounter an error in our program, and we are unable to reproduce it because we are relying on a random number generator that relies on global state, and hence that doesn't behave in a _fully_ controllable fashion.

How then can we get "the best of both worlds": random number generation that is controllable?

The way that JAX's developers went about doing this is to use pseudo-random number generators that require explicit passing in of a pseudo-random number generation key, rather than relying on a global state being set. Each unique key will deterministically give a unique drawn value explicitly. Let's see that in action:

In [None]:
from jax import random

key = random.PRNGKey(42)

a = random.normal(key=key)
a

To show you that passing in the same key gives us the same values as before:

In [None]:
b = random.normal(key=key)
b

That should already be a stark difference from what you're used to with vanilla NumPy, and this is one key crucial difference between JAX's random module and NumPy's random module. Everything else is very similar, but this is a key difference, and for good reason -- this should hint to you the idea that we can have explicity reproducibility, rather than merely implicit, over our stochastic programs within the same session.

How do we get a new draw? Well, we can either create a new key manually, or we can programmatically split the key into two, and use one of the newly split keys to generate a new random number. Let's see that in action:

In [None]:
k1, k2 = random.split(key)
c = random.normal(key=k2)
c

In [None]:
k3, k4, k5 = random.split(k2, num=3)
d = random.normal(key=k3)
d

By splitting the key into two, three, or even 1000 parts, we can get new keys that are derived from a parent key that generate different random numbers from the same random number generating function.

Let's explore how we can use this in the generation of a Gaussian random walk.

## Example: Simulating a Gaussian random walk

A Gaussian random walk is one where we start at a point that is drawn from a Gaussian, and then we draw another point from a Gausian using the first point as the starting Gaussian point.

Does that loop structure sound familiar? Well... yeah, it sounds like a classic `lax.scan` setup!

Here's how we might set it up.

Firstly, JAX's `random.normal` function doesn't allow us to specify the location and scale, and only gives us a draw from a unit Gaussian. We can work around this, because any unit Gaussian draw can be shifted and scaled to a $N(\mu, \sigma)$ by multiplying the draw by $\sigma$ and adding $\mu$. 

To get a length 1000 random draw, we can split the key 1000 ways, and use `lax.scan` to scan a new Gaussian generator across the keys, thereby giving us 1000 unique draws. We then add the old value of the Gaussian to the new draw.

We return the tuple (`new_gaussian, old_gaussian`), as we want to have the new gaussian passed into the next iteration, and accumulate the history of the old gaussians.

In [None]:
from dl_workshop.jax_idioms import generate_new_gaussian

generate_new_gaussian??

In [None]:
from jax import lax
keys = random.split(key, num=1000)
final, result = lax.scan(generate_new_gaussian, 0., keys)
result

In [None]:
import matplotlib.pyplot as plt

plt.plot(result)

Looks like we did it! Definitely looks like a proper Gaussian random walk to me. Let's encapsulate this inside a funciton generator, because the next thing we're going to do is to generate multiple realizations of the Gaussian random walk.

In [None]:
from dl_workshop.jax_idioms import make_gaussian_random_walk_func

make_gaussian_random_walk_func??

Now, what if we wanted to generate multiple realizations of the Gaussian random walk? Does this sound familiar? If so... yeah, it's a vanilla for-loop, which directly brings us to `vmap`!

In [None]:
from jax import vmap
num_realizations = 200
keys = random.split(key, num_realizations)
grw_1000_steps = make_gaussian_random_walk_func(1000)
final, trajectories = vmap(grw_1000_steps)(keys)

In [None]:
trajectories.shape

We did it! We have 200 trajectories of a 1000-step Gaussian random walk. Notice also how the program is structured very nicely: Each layer of abstraction in the program corresponds to a new axis dimension along which we are working. The onion layering of the program has very _natural_ structure for the problem at hand.

Enough prosyletizing from me, let's visualize the Gaussian random walk to make sure it genuinely is a GRW.

In [None]:
import seaborn as sns

fig, ax = plt.subplots()

for trajectory in trajectories[0:20]:
    ax.plot(trajectory)
sns.despine()

Now, note how if you were to re-run the entire program from top-to-bottom again, you would get _exactly the same plot_. This is what we mean by "reproducible". Traditional array programs are not fully reproducible, they are only "kind of" reproducible in the limit of many runs of the same program. With JAX's random number generation paradigm, any random number generation program is 100% reproducible, down to the level of the exact sequence of random number draws, as long as the seed(s) controlling the program are 100% identical. When an error shows up in a program, as long as its stochastic components are controlled by hand-set seeds, that error is 100% reproducible. For those who have tried working with stochastic programs before, this is an extremely desirable property, as it means we gain the ability to reliably debug our program -- absolutely crucial especially when it comes to working with probabilistic models.

Also notice how we finally wrote our first productive for-loop -- but it was only to plot something, not for some form of calculations :).