In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# What do we gain doing with compositionality?

To help us get a handle over what kind of gains we get, I'm going to do a comparison between composed `lax.scan` and `vmaps` against a program that we might write in pure Python versus our compiled version.

## Writing a Gaussian random walk in pure Python

Let's start with a pure Python implementation of a Gaussian random walk, leveraging vanilla NumPy's random module for API convenience only (and not for performance).

In [None]:
import numpy as onp
def gaussian_random_walk_python(num_realizations, num_timesteps):
    rws = []
    for i in range(num_realizations):
        rw = []
        prev_draw = 0
        for t in range(num_timesteps):
            prev_draw = onp.random.normal(loc=prev_draw)
            rw.append(prev_draw)
        rws.append(rw)
    return rws

In [None]:
from time import time


N_REALIZATIONS = 500
N_TIMESTEPS = 10_000
start = time()
trajectories_python = gaussian_random_walk_python(N_REALIZATIONS, N_TIMESTEPS)
end = time()
print(f"{end - start:.2f} seconds")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

for trajectory in trajectories_python[:20]:
    plt.plot(trajectory)
sns.despine()

## Comparison against our JAX program

Let's now compare the program against the version we wrote above.

In [None]:
from dl_workshop.jax_idioms import make_gaussian_random_walk_func
from jax import vmap
def gaussian_random_walk_jax(num_realizations, num_timesteps):
    keys = random.split(key, num_realizations)
    grw_k_steps = make_gaussian_random_walk_func(num_timesteps)
    final, trajectories = vmap(grw_k_steps)(keys)
    return final, trajectories

In [None]:
from jax import random
key = random.PRNGKey(42)
start = time()
final_jax, trajectories_jax = gaussian_random_walk_jax(N_REALIZATIONS, N_TIMESTEPS)
trajectories_jax.block_until_ready()
end = time()
print(f"{end - start:.2f} seconds")

In [None]:
for trajectory in trajectories_jax[:20]:
    plt.plot(trajectory)
sns.despine()

### Compare against a JIT-compiled version of our JAX program

Now we're going to JIT-compile our Gaussian Random Walk function and see how long it takes for the program to run.

In [None]:
from jax import jit

def gaussian_random_walk_jit(num_realizations, num_timesteps):
    keys = random.split(key, num_realizations)
    grw_k_steps = make_gaussian_random_walk_func(num_timesteps)
    grw_k_steps = jit(grw_k_steps)
    final, trajectories = vmap(grw_k_steps)(keys)
    return final, trajectories

In [None]:
start = time()
final_jit, trajectories_jit = gaussian_random_walk_jit(N_REALIZATIONS, N_TIMESTEPS)
trajectories_jit.block_until_ready()
end = time()
print(f"{end - start:.2f} seconds")

In [None]:
for trajectory in trajectories_jit[:20]:
    plt.plot(trajectory)
sns.despine()

JIT-compilation gave us about a 1-2X speedup over non-JIT compiled code, and was about 20X faster than the pure Python version. That shouldn't surprise you one bit :).

## A few pointers on syntax

Firstly, if we subscribe to the Zen of Python's notion that "flat is better than nested", then following the idioms listed here -- closures/partials, `vmap` and `lax.scan`, then we'll likely only ever go one closure deep into our programs. Notice how we basically never wrote any for-loops in our array code; they were handled elegantly by the looping constructs `vmap` and `lax.scan`. 

Secondly, using `jit`, we get further optimizations on our code for free. A pre-requisite of `jit` is that the _every_ function call made in the program function being `jit`-ed is required to be written in a "pure functional" style, i.e. there are no side effects, no mutation of global state. If you write a program using the idioms used here (closures to wrap state, `vmap`/`lax.scan` in lieu of loops, explicit random number generation using PRNGKeys), then you will be able to JIT compile the program with ease.