In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

## Introduction

In this notebook, we'll introduce you to some idiomatic JAX tools that will help you write performant numerical array programs. The main takeaways you should get from this notebook are how you can:

1. Replace slow Python for loop constructs with fast, just-in-time compiled JAX loop constructs,
2. Create deterministic random numbers (and yes, this is not an oxymoron!) for reproducibility,
3. Freely mix-and-match through this idea of _composable transforms_.

Because contrasts to what we might be used to doing are the the most effective way to teach and learn, in each section, we'll be explicit about what exactly we're replacing when we write these numerical array programs. In doing so, my hope is that you'll see very clearly that structuring your array programs in a composable and atomic fashion will help you take advantage of JAX's composable function transforms to write really fast and compiled functions. And for good measure, we'll contrast this against pure Python programs, so you can witness for yourself how powerful JAX's ideas are... and appreciate how much effort has gone into making the whole thing NumPy compatible!

## Prerequisites

To get the most out of this notebook, you need only be familiar with the NumPy API, and writing functions. Having an appreciation of `functools.partial`, will help a bit, because we use it a lot in writing JAX programs. However, I know that not everybody has had prior experience with `partial`-ed functions, so we will introduce the idea mid-way, in a _just-in-time_ fashion as well.

If you've gone through `tutorial.ipynb`, which is the main tutorial notebook for this repository, then you'll have some appreciation of JAX's composable transforms. You'll also see how we wrote some loops in there, and hopefully have an appreciation of how much faster things will run when we use JAX's looping constructs instead.



## Replacing simple for-loops with `vmap`

The first JAX thing we will look at is the `vmap` function. What does `vmap` do? From the [JAX docs on `vmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.vmap):

> Vectorizing map. Creates a function which maps fun over argument axes.

What does that mean? Well, let's take a look at a few classic examples.

### Mapping an elementwise function over an array's leading axis

The first example is mapping a function over an array axis. The simplest example, which is a bit trivial, is doing elementwise application of a function. Say we have uniformly spaced numbers from 0 to 1 in an array:

In [None]:
import jax.numpy as np
from jax import vmap
from time import time

arr = np.linspace(0, 1, 10000)
arr

If we wanted to apply an exponential transform on every element, the "dumb", pure Python way to do so is to write a for-loop:

In [None]:
start = time()
new_arr = []
for element in arr:
    new_arr.append(np.exp(element))
new_arr = np.array(new_arr)
end = time()
print(f"{end - start:.2f} seconds")
new_arr

Because `np.exp` is a NumPy `ufunc` that operates on individual elements, we can call `np.exp` on `arr` directly:

In [None]:
start = time()
new_arr = np.exp(arr)
end = time()
print(f"{end - start:.4f} seconds")
new_arr

As you can see, this is much faster.

This, incidentally, is equivalent to using `vmap` to map the function across all elements in the array:

In [None]:
start = time()
new_arr = vmap(np.exp)(arr)
end = time()
print(f"{end - start:.4f} seconds")
new_arr

It's a bit slower, but one thing we gain from using `vmap` is the ability to ignore the leading (first) array axis of every element that is passed into the `vmap`-ed function. To see that, we're going to look at another example.

### Mapping a row-wise function across an array's leading axis

In this example let's say we have a matrix of values that we measured in an experiment. There were `n_samples` measured, and `3` unique properties that we collected, thereby giving us a matrix of shape `(n_samples, 3)`. If we needed to find their sum, we could do the following in pure NumPy:

In [None]:
def row_sum(data):
    """Given one dataset, calculate row-wise sum of data."""
    return np.sum(data, axis=1)

data = np.array([
    [1, 3, 1,],
    [3, 5, 1,],
    [1, 2, 5,],
    [7, 1, 3,],
    [11, 2, 3,],
])

start = time()
result = row_sum(data)
end = time()
print(f"{end - start:.4f} seconds")
result

This would give us the correct answer... but we had to worry about the "axis" argument, which is a bit irritating. Instead, we could use first transform `np.sum` into a vmapped function that is mapped across the leading axis of `data`:

In [None]:
def row_sum_one_data(data):
    """Given one dataset, calculate row-wise sum of data."""
    return vmap(np.sum)(data)

start = time()
result = row_sum_one_data(data)
end = time()
print(f"{end - start:.4f} seconds")
result

Thereby giving us the exact same result. While the syntax does take some time to get used to, it does more explicitly and clearly expresses the idea that _we don't really care about summing over the leading axis_.

Now, let's say we had multiple datasets for which we wanted to calculate the row-wise sum. How would we do this in pure NumPy?

Well, let's first create this dataset.

In [None]:
data2 = np.array([
    [1, 3, 7,],
    [3, 5, 11,],
    [3, 2, 5,],
    [7, 5, 3,],
    [11, 5, 3,],
])

combined_data = np.moveaxis(np.dstack([data, data2], ), 2, 0)
combined_data.shape

Our shapes tell us that we have 2 stacks of data, each with 5 rows and 3 columns.

Since we want row-wise summation, but want to preserve the 2 stacks of data, we have to now worry about which axes to collapse:

In [None]:
np.sum(combined_data, axis=2)

This is all cool, but we now have a "magic number" in our program. We can eliminate this magic number by instead doing vmapping `row_sum_over_data` across the `combined_data` array:

In [None]:
def row_sum_all_data(data):
    return vmap(row_sum_one_data)(data)
    
row_sum_all_data(combined_data)

And voilà, just like that, magic numbers were removed from our program, and the hierarchical structure of our functions are a bit more explicit:

- The elementary function, `np.sum`, operates on a per-row basis.
- We map the elementary function across all rows of a single dataset, giving us a higher-order function that calculates row-wise summation for a single dataset, `row_sum_one_data`.
- We then map the `row_sum_one_data` across all of the datasets that have been stacked together in a single 3D array.

### Mapping a function over two arrays simultaneously

Let's look at another example. Say we are given two arrays, and we wanted to elementwise multiply them together. For example:

In [None]:
a1 = np.array([1, 2, 3, 4,])
a2 = np.array([2, 3, 4, 5,])

As the NumPy-idiomatic option, we could do:

In [None]:
a1 * a2

Another option is that we can define a function called `multiply`, which multiplies two scalars together and gives us back another scalar, which we then apply over each element in a `zip` of the two arrays. This is the _extremely_ naive way of handling the problem:

In [None]:
result = []

def multiply(a, b):
    return a * b

for val1, val2 in zip(a1, a2):
    result.append(multiply(val1, val2))
np.array(result)

On the other hand, if we consider this to be the elementary operation of our function, we could instead multiply them pairwise:

In [None]:
vmap(multiply)(a1, a2)

As usual, we are able to not care about the leading array axis for each array. Once again, we also broke down the problem into its elementary components, and then leveraged `vmap` to build _out_ the program to do what we wanted it to do. (This general pattern will show up!)

In general, `vmap`-ing over the _leading_ array axis is the idiomatic thing to do with JAX. It's possibleto `vmap` over other axes, but those are not the defaults. The implication is that we are nudged towards writing programs that at their core begin with an "elementary" function that operate "elementwise", where the definition of an "element" is not necessarily an array element, but problem-dependent. We then progressively `vmap` them outwards on array data structures.

### Exercise 1: `vmap`-ing a dot product over square matrices

Let's try getting some practice with the following exercises.

The first one is to `vmap` a dot product of a square matrix against itself across a stack of square matrices.

An example square matrix called `sq_matrix` is provided for you to jog your memory on how dot products work if you need to.

In [None]:
from jax import random

key = random.PRNGKey(42)
data = random.normal(key, shape=(11, 5, 5))
sq_matrix = random.normal(key, shape=(5, 5))

vmap(np.dot)(data, data).shape

### Exercise 2: Constructing a more complex program

We're going to try our hand at constructing a program that first calculates a cumulative product vector for each row in each dataset, sums them up column-wise across each dataset, and applies this same operation across all datasets stacked together. This one is a bit more challenging!

To help you along here, the shape of the data are such:

- There are 11 stacks of data.
- Each stack of data has 31 rows, and 7 columns.

The result of this program still should have 11 stacks and 31 rows, but now each column is not the original data, but the cumulative product of the previous columns.

To get this answer write, no magic numbers are allows (e.g. for accessing particular axes). At least two `vmap`s are necessary here.

In [None]:
data = random.normal(key, shape=(11, 31, 7))

def row_wise_cumprod(row):
    return np.cumprod(row)

def dataset_wise_sum_cumprod(data):
    row_cumprods = vmap(row_wise_cumprod)(data)
    return vmap(np.sum)(row_cumprods)

vmap(dataset_wise_sum_cumprod)(data).shape

## Partially evaluating a function

We're going to take a quick detour and look at this idea of "partially evaluating a function". This is going to be important, as it'll allow us to construct functions that are compatible with the requirements of `vmap` and `lax.scan` and others in JAX, i.e. they have the correct function signature, but still allow us the flexibility to put in arbitrary things that might be needed for the function to work correctly.

There are two ways to do this: you can either use `functools.partial`, or you can use function closures. Let's see how to do this.

### Partially evaluating a function using `functools.partial`

For simplicity's sake, let's explore the idea using a function that adds two numbers:

In [None]:
def add(a, b):
    return a + b

Now, let's say we wanted to fix `b` to the value `3`, thus generating an `add_three` function. We can do this two ways. The first is by `functools.partial`:

In [None]:
from functools import partial

add_three = partial(add, b=3)

We can now call `add_three` on any value of `a`:

In [None]:
add_three(20)

If we inspect the function `add_three`:

In [None]:
add_three?

We see that `add_three` accepts one _positional_ argument, `a`, and its value of `b` has been set to a default of `3`.

What if we wanted to fix `a` to `3` instead?

In [None]:
add_three_v2 = partial(add, a=3)
add_three_v2?

Notice how now the function signature has changed, such that `b` is not set while `a` has been. This has implications for how we use the function.

Calling the function this way will error out:

```python
>>> add_three_v2(3)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-109-e78f540eb25e> in <module>
----> 1 add_three_v2(3)

TypeError: add() got multiple values for argument 'a'
```

That is because when we pass in the argument with no keyword specified, it is interpreted as the first positional argument, which as you can see, has already been set.

On the other hand, calling the function this way will not:

In [None]:
add_three_v2(b=3)

### Partially evaluating a function using closures

Another pattern that we can use is to use closures. Closures are functions that return a closed function that contains information from the closing function. Confused? Let me illustrate:

In [None]:
def closing_function(a):
    def closed_function(b):
        return a + b
    return closed_function

Using this pattern, we can rewrite `add_three` using closures:

In [None]:
def make_add_something(value):
    def closed_function(b):
        return b + value
    return closed_function

add_three_v3 = make_add_something(3)
add_three_v3(5)

In [None]:
add_three_v3?

Now, you'll notice that the signature of `add_three_v3` follows that exactly of the closed function. 

When writing array programs using JAX, this is the key design pattern you'll want to implement: Always return a function that has the function signature that you need.

Naming things is the hardest activity in programming, because we are giving categorical names to things, and sometimes their category of thing isn't always clear. Fret not: the pattern I'll give you is the following:

```python
def SOME_FUNCTION_generator(argument1, argument2, keyword_arugment1=default_value1):
    """To simplify things, just give the name of the closing function <some_function>_generator."""
    def inner(arg1, arg2, kwarg1=default_value1):
        """This function should follow the API that is neeed."""
        return something
    return inner
```


## Eliminating for-loops that have carry-over using `lax.scan`

We are now going to see how we can eliminate for-loops that have carry-over using `lax.scan`.

From the JAX docs, `lax.scan` replaces a for-loop with carry-over:

> Scan a function over leading array axes while carrying along state.
> 
> ...
> 
> ```python
> def scan(f, init, xs, length=None):
    if xs is None:
         xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)
> ```

A key requirement of the function `f` is that it must have only two positional arguments in there, one for `carry` and one for `x`. You'll see how we can thus apply `functools.partial` to construct functions that have this signature from other functions that have other 

Let's see some concrete examples of this in action.

### Updating a variable with new info on each loop iteration

One classic case where we might use a for-loop is in the cumulative sum or product. Here, we need the current loop information to update the information from the previous loop. Let's see it in action for the cumulative sum:

In [None]:
a = np.array([1, 2, 3, 5, 7, 11, 13, 17])

result = []
res = 0
for el in a:
    res += el
    result.append(res)
np.array(result)

This is identical to the cumulative sum:

In [None]:
np.cumsum(a)

Now, let's write it using `lax.scan`, so we can see the pattern in action:

In [None]:
from jax import lax
def scanfunc(res, el):
    res = res + el
    return res, res  # ("carryover", "accumulated")

result_init = 0
final, result = lax.scan(scanfunc, result_init, a)
result

As you can see, scanned function has to return two things:

- One object that gets carried over to the next loop (`carryover`), and
- Another object that gets "accumulated" into an array (`accumulated`).

The starting initial value, `result_init`, is passed into the `scanfunc` as `res` on the first call of the `scanfunc`. On subsequent calls, the first `res` is passed back into the `scanfunc` as the new `res`.

### Exercise 1: Simulating compound interest

We can use `lax.scan` to generate data that simulates the generation of wealth by compound interest. Here's an implementation using a plain vanilla for-loop:

In [None]:
wealth_record = []
starting_wealth = 100.
interest_factor = 1.01

prev_wealth = starting_wealth
for t in range(100):
    new_wealth = prev_wealth * interest_factor
    wealth_record.append(prev_wealth)
    prev_wealth = new_wealth

np.array(wealth_record)

Now try implementing it in a `lax.scan` form:

In [None]:
from functools import partial

starting_wealth = 100.
interest_factor = 1.01

timesteps = np.arange(100)

def make_wealth_at_time_func(interest_factor):
    def wealth_at_time(prev_wealth, time):
        new_wealth = prev_wealth * interest_factor
        return new_wealth, prev_wealth
    return wealth_at_time

wealth_func = make_wealth_at_time_func(interest_factor)

final, result = lax.scan(wealth_func, init=starting_wealth, xs=timesteps)
result

The two are equivalent, so we know we have the `lax.scan` implementation right.

### Exercise 2: Compose `vmap` and `lax.scan` together

That was one simulation of wealth generation by compound interest for one individual. Now, let's simulate the wealth generation for different starting wealth levels (you may choose any 300 starting points however you'd like). To do so, you'll likely want to start with a function that accepts a scalar starting wealth and generates the simulated time series from there, and then `vmap` that function across multiple starting points (which is an array itself).

In [None]:
def make_simulation_func(timesteps):
    def inner(starting_wealth):
        final, result = lax.scan(wealth_func, init=starting_wealth, xs=timesteps)
        return final, result
    return inner

simulation_func = make_simulation_func(timesteps=np.arange(200))
starting_wealth = np.arange(300).astype(float)

final, growth = vmap(simulation_func)(starting_wealth)
growth

## Fully Reproducible Random Number Generation

In this section, we'll explore how to create programs that use random number generation in a fashion that is fully deterministic conditioned on a single starting random number generation key.

But first, let's explore what happens when we use NumPy's vanilla random number generation protocol to generate numbers.

In [None]:
import numpy as onp  # original numpy

Let's draw a random number from a Gaussian in NumPy.

In [None]:
onp.random.seed(42)
a = onp.random.normal()
a

And for good measure, let's draw another one.

In [None]:
b = onp.random.normal()
b

This is intuitive behaviour, because we expect that each time we call on a random number generator, we should get back a different number from before.

However, this behaviour is problematic when we are trying to debug programs, which essentially are deterministic. This is because _stochastically_, we might hit a setting where we encounter an error in our program, and we are unable to reproduce it because we are relying on a random number generator that relies on global state, and hence that doesn't behave in a _fully_ controllable fashion.

How then can we get "the best of both worlds": random number generation that is controllable?

The way that JAX's developers went about doing this is to use pseudo-random number generators that require explicit passing in of a pseudo-random number generation key, rather than relying on a global state being set. Each unique key will deterministically give a unique drawn value explicitly. Let's see that in action:

In [None]:
key = random.PRNGKey(42)

a = random.normal(key=key)
a

To show you that passing in the same key gives us the same values as before:

In [None]:
b = random.normal(key=key)
b

That should already be a stark difference from what you're used to with vanilla NumPy, and this is one key crucial difference between JAX's random module and NumPy's random module. Everything else is very similar, but this is a key difference, and for good reason -- this should hint to you the idea that we can have explicity reproducibility, rather than merely implicit, over our stochastic programs within the same session.

How do we get a new draw? Well, we can either create a new key manually, or we can programmatically split the key into two, and use one of the newly split keys to generate a new random number. Let's see that in action:

In [None]:
k1, k2 = random.split(key)
c = random.normal(key=k2)
c

In [None]:
k3, k4, k5 = random.split(k2, num=3)
d = random.normal(key=k3)
d

By splitting the key into two, three, or even 1000 parts, we can get new keys that are derived from a parent key that generate different random numbers from the same random number generating function.

Let's explore how we can use this in the generation of a Gaussian random walk.

### Simulating a Gaussian random walk

A Gaussian random walk is one where we start at a point that is drawn from a Gaussian, and then we draw another point from a Gausian using the first point as the starting Gaussian point.

Does that loop structure sound familiar? Well... yeah, it sounds like a classic `lax.scan` setup!

Here's how we might set it up.

Firstly, JAX's `random.normal` function doesn't allow us to specify the location and scale, and only gives us a draw from a unit Gaussian. We can work around this, because any unit Gaussian draw can be shifted and scaled to a $N(\mu, \sigma)$ by multiplying the draw by $\sigma$ and adding $\mu$. 

To get a length 1000 random draw, we can split the key 1000 ways, and use `lax.scan` to scan a new Gaussian generator across the keys, thereby giving us 1000 unique draws. We then add the old value of the Gaussian to the new draw.

We return the tuple (`new_gaussian, old_gaussian`), as we want to have the new gaussian passed into the next iteration, and accumulate the history of the old gaussians.

In [None]:
def generate_new_gaussian(old_gaussian, key):
    new_gaussian = random.normal(key) + old_gaussian
    return new_gaussian, old_gaussian

keys = random.split(key, num=1000)
final, result = lax.scan(generate_new_gaussian, 0., keys)
result

In [None]:
import matplotlib.pyplot as plt

plt.plot(result)

Looks like we did it! Definitely looks like a proper Gaussian random walk to me. Let's encapsulate this inside a funciton generator, because the next thing we're going to do is to generate multiple realizations of the Gaussian random walk.

In [None]:
def make_gaussian_random_walk_func(num_steps):
    def gaussian_random_walk(key):
        keys = random.split(key, num=num_steps)
        final, result = lax.scan(generate_new_gaussian, 0., keys)
        return final, result
    return gaussian_random_walk

Now, what if we wanted to generate multiple realizations of the Gaussian random walk? Does this sound familiar? If so... yeah, it's a vanilla for-loop, which directly brings us to `vmap`!

In [None]:
num_realizations = 200
keys = random.split(key, num_realizations)
grw_1000_steps = make_gaussian_random_walk_func(1000)
final, trajectories = vmap(grw_1000_steps)(keys)

In [None]:
trajectories.shape

We did it! We have 200 trajectories of a 1000-step Gaussian random walk. Notice also how the program is structured very nicely: Each layer of abstraction in the program corresponds to a new axis dimension along which we are working. The onion layering of the program has very _natural_ structure for the problem at hand.

Enough prosyletizing from me, let's visualize the Gaussian random walk to make sure it genuinely is a GRW.

In [None]:
import seaborn as sns

fig, ax = plt.subplots()

for trajectory in trajectories[0:20]:
    ax.plot(trajectory)
sns.despine()

Now, note how if you were to re-run the entire program from top-to-bottom again, you would get _exactly the same plot_. This is what we mean by "reproducible". Traditional array programs are not fully reproducible, they are only "kind of" reproducible in the limit of many runs of the same program. With JAX's random number generation paradigm, any random number generation program is 100% reproducible, down to the level of the exact sequence of random number draws, as long as the seed(s) controlling the program are 100% identical. When an error shows up in a program, as long as its stochastic components are controlled by hand-set seeds, that error is 100% reproducible. For those who have tried working with stochastic programs before, this is an extremely desirable property, as it means we gain the ability to reliably debug our program -- absolutely crucial especially when it comes to working with probabilistic models.

Also notice how we finally wrote our first productive for-loop -- but it was only to plot something, not for some form of calculations :).

## What do we gain doing this kind of composition?

To help us get a handle over what kind of gains we get, I'm going to do a comparison between composed `lax.scan` and `vmaps` against a program that we might write in pure Python versus our compiled version.

### Writing a Gaussian random walk in pure Python

Let's start with a pure Python implementation of a Gaussian random walk, leveraging vanilla NumPy's random module for API convenience only (and not for performance).

In [None]:
def gaussian_random_walk_python(num_realizations, num_timesteps):
    rws = []
    for i in range(num_realizations):
        rw = []
        prev_draw = 0
        for t in range(num_timesteps):
            prev_draw = onp.random.normal(loc=prev_draw)
            rw.append(prev_draw)
        rws.append(rw)
    return rws

In [None]:
N_REALIZATIONS = 500
N_TIMESTEPS = 10_000
start = time()
trajectories_python = gaussian_random_walk_python(N_REALIZATIONS, N_TIMESTEPS)
end = time()
print(f"{end - start:.2f} seconds")

In [None]:
for trajectory in trajectories_python[:20]:
    plt.plot(trajectory)
sns.despine()

### Comparison against our JAX program

Let's now compare the program against the version we wrote above.

In [None]:
def gaussian_random_walk_jax(num_realizations, num_timesteps):
    keys = random.split(key, num_realizations)
    grw_k_steps = make_gaussian_random_walk_func(num_timesteps)
    final, trajectories = vmap(grw_k_steps)(keys)
    return final, trajectories

In [None]:
start = time()
final_jax, trajectories_jax = gaussian_random_walk_jax(N_REALIZATIONS, N_TIMESTEPS)
trajectories_jax.block_until_ready()
end = time()
print(f"{end - start:.2f} seconds")

In [None]:
for trajectory in trajectories_jax[:20]:
    plt.plot(trajectory)
sns.despine()

### Compare against a JIT-compiled version of our JAX program

Now we're going to JIT-compile our Gaussian Random Walk function and see how long it takes for the program to run.

In [None]:
from jax import jit

def gaussian_random_walk_jit(num_realizations, num_timesteps):
    keys = random.split(key, num_realizations)
    grw_k_steps = make_gaussian_random_walk_func(num_timesteps)
    grw_k_steps = jit(grw_k_steps)
    final, trajectories = vmap(grw_k_steps)(keys)
    return final, trajectories

In [None]:
start = time()
final_jit, trajectories_jit = gaussian_random_walk_jit(N_REALIZATIONS, N_TIMESTEPS)
trajectories_jit.block_until_ready()
end = time()
print(f"{end - start:.2f} seconds")

In [None]:
for trajectory in trajectories_jit[:20]:
    plt.plot(trajectory)
sns.despine()

JIT-compilation gave us about a 1-2X speedup over non-JIT compiled code, and was about 20X faster than the pure Python version. That shouldn't surprise you one bit :).

## A few pointers on syntax

Firstly, if we subscribe to the Zen of Python's notion that "flat is better than nested", then following the idioms listed here -- closures/partials, `vmap` and `lax.scan`, then we'll likely only ever go one closure deep into our programs. Notice how we basically never wrote any for-loops in our array code; they were handled elegantly by the looping constructs `vmap` and `lax.scan`. 

Secondly, using `jit`, we get further optimizations on our code for free. A pre-requisite of `jit` is that the _every_ function call made in the program function being `jit`-ed is required to be written in a "pure functional" style, i.e. there are no side effects, no mutation of global state. If you write a program using the idioms used here (closures to wrap state, `vmap`/`lax.scan` in lieu of loops, explicit random number generation using PRNGKeys), then you will be able to JIT compile the program with ease.