In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In this notebook, I'll show the code for how to use JAX's `stax` submodule to write arbitrary models.

## Prerequisites

I'm assuming you have read through the `jax-programming.ipynb` notebook, as well as the `tutorial.ipynb` notebook.

The main `tutorial.ipynb` notebook gives you a general introduction to differential programming using `grad`, while the `jax-programming.ipynb` notebook gives you a flavour of the other four main JAX idioms: `vmap`, `lax.scan`, `random.PRNGKey`, and `jit`.

## What is `stax`?

Most deep learning libraries use objects as the data structure for a neural network layer.
As such, the tunable parameters of the layer, for example `w` and `b` for a linear ("dense") layer
are class attributes associated with the forward function.

In some sense, because a neural network layer is nothing more than a math function,
specifying the layer in terms of a function might also make sense.
`stax`, then, is a new take on writing neural network models using pure functions rather than objects.

## How does `stax` work?

The way that `stax` layers work is as follows.
Every neural network layer is nothing more than a math function with a "forward" pass.
Neural network models typically have their parameters 
initialized into the right shapes using random number generators.
Put these two together, and we have a _pair_ of functions that specify a layer:

- An `init_fun` function, that _initializes_ parameters into the correct shapes, and
- An `apply_fun` function, that _applies_ the specified math transformations onto incoming data, using parameters of the correct shape.

## Example: Linear layer

Let's see an example of this in action, by studying the implementation of the linear ("dense") layer in `stax`

In [None]:
from jax.experimental import stax

In [None]:
stax.Dense??

As you can see, the `apply_fun` specifies the linear transformation.
It accepts a parameter called `params`,
which gets tuple-unpacked into the appropriate `W` and `b`.

**Notice how the `params` argument matches up with the second output of `init_fun`!**
The `init_fun` always accepts an `rng` parameter, which is returned from JAX's `jax.random.PRNGKey()`.
It also accepts an `input_shape` parameter,
which specifies what the elementary shape of one sample of data is.
So if your entire dataset is of shape `(n_samples, n_columns)`,
then you would put in `(n_columns,)` inside there,
as you would want to ignore the sample dimension,
thus allowing us to take advantage of `vmap` to map our model function
over each and every i.i.d. sample in our dataset.
The `init_fun` also returns the `output_shape`,
which is used later when we chain layers together.

Let's see how we can use the Dense layer to specify a linear regression model.

### Create the initialization and application function pairs

Firstly, we create the `init_fun` and `apply_fun` pair:

In [None]:
init_fun, apply_fun = stax.Dense(1)

### Initialize the parameters

Now, let's initialize parameters using the `init_fun`.

Let's assume that we have data that is of 4 columns only.

In [None]:
from jax import random, numpy as np

key = random.PRNGKey(42)

output_shape, params_initial = init_fun(key, input_shape=(4,))

In [None]:
params_initial

### Apply parameters and data through function

We'll create some randomly generated data.

In [None]:
X = random.normal(key, shape=(200, 4))
X[0:5], X.shape

Here's some `y_true` values that I've snuck in.

In [None]:
y_true = np.dot(X, np.array([1, 2, 3, 4])) + 5
y_true = y_true.reshape(-1, 1)
y_true[0:5], y_true.shape

Now, we'll pass data through the linear model!

In [None]:
apply_fun??

In [None]:
from jax import vmap
from functools import partial

y_pred = vmap(partial(apply_fun, params_initial))(X)
y_pred[0:5], y_pred.shape

Voilà! We have a simple linear model implemented just like that.

## Optimization

Next question: how do we *optimize* the parameters using JAX?

Instead of writing a training loop on our own, we can take advantage of JAX's optimizers, which are also written in a functional paradigm!

JAX's optimizers are constructed as a "triplet" set of functions:

- `init`: Takes `params` and initializes them in as a `state`, which is structured in a fashion that `update` can operate on.
- `update`: Takes in `i`, `g`, and `state`, which respectively are:
    - `i`: The current loop iteration
    - `g`: Gradients calculated from `grad`!
    - `state`: The current state of the parameters.
- `get_params`: Takes in the `state` at a given point, and returns the parameters structured correctly.

In [None]:
from jax import jit, grad
from jax.experimental.optimizers import adam

init, update, get_params = adam(step_size=1e-1)
update = jit(update)
get_params = jit(get_params)

### Loss Function

We're still missing a piece here, that is the loss function.
For illustration purposes, let's use the mean squared error.

In [None]:
def mseloss(params, model, x, y_true):
    y_preds = vmap(partial(model, params))(x)
    return np.mean(np.power(y_preds - y_true, 2))

dmseloss = grad(mseloss)

### "Step" portion of update loop

Now, we're going to define the "step" portion of the update loop.

In [None]:
def step(i, state, dlossfunc, get_params, update, model, x, y_true):
    params = get_params(state)
    g = dlossfunc(params, model, x, y_true)
    state = update(i, g, state)
    return state

### JIT compilation

Because it takes so many parameters (in order to remain pure, and not rely on notebook state),
we're going to bind some of them using `functools.partial`.

I'm also going to show you what happens when we JIT-compile vs. don't JIT-compile the function.

In [None]:
step_partial = partial(step, get_params=get_params, dlossfunc=dmseloss, update=update, model=apply_fun, x=X, y_true=y_true)
step_partial_jit = jit(step_partial)

### Explicit loops

Firstly, let's see what kind of code we'd write if we _did_ write the loop explicitly.

In [None]:
from time import time
start = time()
state = init(params_initial)
for i in range(1000):
    params = get_params(state)
    g = dmseloss(params, apply_fun, X, y_true)
    state = update(i, g, state)
end = time()
print(end - start)

### Partialled out loop step

Now, let's run the loop with the partialled out function.

In [None]:
start = time()
state = init(params_initial)
for i in range(1000):
    state = step_partial(i, state)
end = time()
print(end - start)

### JIT-compiled loop!

This is much cleaner of a loop, but we did have to do some work up-front.

What happens if we now use the JIT-ed function?

In [None]:
start = time()
state = init(params_initial)
for i in range(1000):
    state = step_partial_jit(i, state)
end = time()
print(end - start)

Whoa, holy smokes, that's fast! At least 10X faster using JIT-compilation.

### `lax.scan` loop

Now we'll use some JAX trickery ot write a training loop without ever writing a for-loop.

In [None]:
from jax import lax

def make_scannable_step(stepfunc):
    def scannable_step(previous_state, iteration):
        new_state = stepfunc(iteration, previous_state)
        return new_state, previous_state
    return scannable_step

scannable_step = make_scannable_step(step_partial_jit)

start = time()
initial_state = init(params_initial)
final_state, states_history = lax.scan(scannable_step, initial_state, np.arange(1000))
end = time()
print(end - start)

In [None]:
get_params(final_state)

### `vmap`-ed training loop over multiple starting points

Now, we're going to do the ultimate: we'll create at least 100 different parameter initializations and run our training loop over each of them.

In [None]:
def make_training_start(params_initializer, state_initializer, scanfunc, n_steps):
    def train_one_start(key):
        output_shape, params = params_initializer(key)
        initial_state = state_initializer(params)
        final_state, states_history = lax.scan(scanfunc, initial_state, np.arange(n_steps))
        return final_state, states_history
    return train_one_start

train_linear = make_training_start(partial(init_fun, input_shape=(-1, 4)), init, scannable_step, 1000)

start = time()
N_INITIALIZATIONS = 100
initialization_keys = random.split(key, N_INITIALIZATIONS)
final_states, states_histories = vmap(train_linear)(initialization_keys)
end = time()
print(end - start)

In [None]:
w_final, b_final = vmap(get_params)(final_states)
w_final.squeeze()[0:5]

In [None]:
b_final.squeeze()[0:5]

Looks like we were also able to run the whole optimization pretty fast, _and_ recover the correct parameters over multiple training starts.

### JIT-compiled training loop

What happens if we JIT-compile the vmapped initialization?

In [None]:
start = time()
N_INITIALIZATIONS = 100
initialization_keys = random.split(key, N_INITIALIZATIONS)
train_linear_jit = jit(train_linear)
final_states, states_histories = vmap(train_linear_jit)(initialization_keys)
vmap(get_params)(final_states)  # this line exists to just block the computation until it completes.
end = time()
print(end - start)

HOOOOOLY SMOKES! Did you see that? With JIT-compilation, we essentially took the training time down to be identical to training on one starting point. Naturally, I don't expect this result to hold 100% of the time, but it's pretty darn rad to see that live. 

The craziest piece here is that we could `vmap` our training loop over multiple starting points and get massive speedups there.

## Neural Network Model: Redux

We're now going to try rewriting the neural network model that we had earlier on, now using `stax` syntax, and traing it using the syntax that we have learned above.

### Reconstruct model using `stax.serial`

Firstly, let's replicate the model using `stax.serial`. It's a serial composition of a Dense+Tanh layer, followed by a Dense+Sigmoid layer.

In [None]:
nn_init, nn_apply = stax.serial(
    stax.Dense(20),
    stax.Tanh,
    stax.Dense(1),
    stax.Sigmoid
)


def nn_init_wrapper(input_shape):
    def inner(key):
        return nn_init(key, input_shape)
    return inner

nn_initializer = nn_init_wrapper(input_shape=(-1, 41))
nn_initializer

Now, we initialize one instance of the parameters.

In [None]:
output_shape, params_init = nn_initializer(key)

We'll need a loss funciton to optimize as well.

In [None]:
import jax.numpy as np

def binary_cross_entropy(y_true, y_pred, tol=1e-6):
    return y_true * np.log(y_pred + tol) + (1 - y_true) * np.log(1 - y_pred + tol)

def logistic_loss(params, model, x, y):
    preds = vmap(partial(model, params))(x)
    bces = vmap(binary_cross_entropy)(y, preds)
    return -np.sum(bces)

dlogistic_loss = grad(logistic_loss)

### Load in data

Now, we load in the data.

In [None]:
import pandas as pd
from pyprojroot import here

X = pd.read_csv(here() / 'data/biodeg_X.csv', index_col=0)
y = pd.read_csv(here() / 'data/biodeg_y.csv', index_col=0)

### Test-drive functions to make sure they work

Always important. It'll reveal whether there's anything wrong with our code.

In [None]:
logistic_loss(params_init, nn_apply, X.values, y.values)


### Progressively construct our training functions

Firstly, we make sure the step function works with our logistic loss, model func, and actual data.

In [None]:
from jax.experimental.optimizers import adam

adam_init, update, get_params = adam(0.0005)

In [None]:
stepfunc_nn = partial(step, dlossfunc=dlogistic_loss, get_params=get_params, update=update, model=nn_apply, x=X.values, y_true=y.values)
scannable_step = make_scannable_step(stepfunc_nn)
train_nn = make_training_start(nn_initializer, adam_init, scannable_step, n_steps=3000)
start = time()
final_state, states_history = train_nn(key)
end = time()
print(end - start)

Friends, if you remember where we started in the `tutorial.ipynb` notebook, the original neural network took approximately a minute to train on a GPU (and longer if on a CPU).

Let's now start by ploting the loss over training iterations. We start first with a function that returns the loss from a given `state` object.

In [None]:
import matplotlib.pyplot as plt
def calculate_loss(state, get_params, model, lossfunc, x, y):
    params = get_params(state)
    return lossfunc(params, model, x, y)

calculate_loss(final_state, get_params, nn_apply, logistic_loss, X.values, y.values)

Now, we need to `vmap` it over all states in the states history, to get back the loss score.

In [None]:
calc_loss_vmap = partial(
    calculate_loss,
    get_params=get_params,
    model=nn_apply,
    lossfunc=logistic_loss,
    x=X.values,
    y=y.values
)
start = time()
losses = vmap(calc_loss_vmap)(states_history)
end = time()
print(end - start)

plt.plot(losses)

### Training with multiple starting points

Just as above, we can also train the neural network with multiple starting points, again by `vmap`-ing our training function across split PRNGKeys.

In [None]:
keys = random.split(key, 5)

start = time()
final_states, state_histories = vmap(train_nn)(keys)
end = time()
print(end - start)

In [None]:
get_params(final_states)[0][0].shape

Let's plot the losses over each of the state histories. Our last function `calc_loss_vmap` calculates loss score for one time point, which we then `vmap` over a single `states_history`, so we need another function that encapsulates this behaviour and `vmap`s over all state histories.

In [None]:
def state_history_loss(state_history):
    losses = vmap(calc_loss_vmap)(state_history)
    return losses

losses = vmap(state_history_loss)(state_histories)
losses.shape

In [None]:
losses

Correctly-shaped! And now plotting it...

In [None]:
plt.plot(losses.T)

Now that's pretty cool! We were able to see the loss from three independent runs. 

With sufficient memory, one would be able to do more runs; when I was writing this notebook early on, I saw that it was getting difficult to do on the order of tens of runs due to memory allocation issues.

## Summary

In this notebook, we saw a few things in action.

Firstly, we saw how to use the `stax` module on a linear model. Anytime we have a new framework for doing differential programming, it's super important to be able to explore it in the context of a linear model, which is _basically_ the foundation of all deep learning.

Secondly, we also explored how to leverage the JAX idioms to create fast parallelized training loops. We mixed-and-matched together `jit`, `vmap`, `lax.scan`, and `grad` into a performant training loop that was minimally nested.

A corollary of this programming style is that _every piece of the code can, in principle, be properly tested_, because they are properly isolated. Have you written training loops where you modify a little piece here and a little piece there, until you lost what your original working one looked like? With training functions that are minimally nested, we can control the behaviour explicitly using closures/partials easily. Even when doing experimenation, our code can run reliably and fast.

Thirdly, we saw how to apply the same lessons to training a neural network _really fast_ with multiple starting points. The essence of the solution was to properly structure our program in progressively higher level layers of abstraction. We carefully wrote the program to go from the inner most layer out until we hit our goal of allowing for a set of multiple starts. The key here is that each level of abstraction is very natural, and corresponds to a "unit computation" being applied consistently across an "array" of things. Once we identify that "unit computation", writing the `vmap`-able or `lax.scan`-able function becomes very easy.