# *Solving MNIST with JAX*

JAX is a cool library. Among other things, it:
- can JIT compile code for a CPU/GPU/TPU/etc...
- makes parallel execution easy, even on separate devices
- can transform a scalar-valued function into one that computes its gradient

What better way to explore this library than with a neural net?

Special thanks to [You Don't Know JAX](https://colinraffel.com/blog/you-don-t-know-jax.html) and [Neural Networks and Deep Learning](http://neuralnetworksanddeeplearning.com).

### First, let's make sure we're able to use the GPUs

In [1]:
import jax
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]

## Downloading the MNIST dataset

The MNIST dataset can be downloaded into a pandas dataframe via sklearn's OpenML interface. In order to save time in future runs, the pandas dataframe is cached.

In [2]:
from sklearn.datasets import fetch_openml
import pickle
import os

def load_mnist():
    pickle_path = '../data/mnist/data.pkl'
    if os.path.exists(pickle_path):
        with open(pickle_path, 'rb') as f:
            return pickle.load(f)
    mnist = fetch_openml(name='mnist_784', version=1, parser='auto')
    with open(pickle_path, 'wb') as f:
        pickle.dump(mnist, f)
    return mnist

mnist = load_mnist()
all_features, all_targets = mnist['data'], mnist['target']

## Note: Randomness in JAX

JAX offers a randomness utility very similar to that of numpy. The major difference is that you need to explicitly provide a seed. This is more tedious, but the upside is that you have more control over the process. For example if you use the same seed, you should end up with the same results.

All randomness in this notebook is derived from the same seed - so, if you choose to run it (with the provdied key) you should get the exact same results.

In [3]:
from random import randint
from jax import random as jrand

class DistSampler:
    def __init__(self, seed=None):
        seed = seed or randint(0, 10**6)
        self.key = jrand.PRNGKey(seed)
        print('Using seed:', seed)
        
    def normal(self, *shape):
        self.key, _ = jrand.split(self.key)
        return jrand.normal(self.key, shape=tuple(shape))
    
    def choice(self, n, k):
        self.key, _ = jrand.split(self.key)
        return list(map(int, jrand.choice(self.key, n, (k,), replace=False)))
    
    def random_rows(self, n, *arrs):
        idxs = self.choice(arrs[0].shape[0], n) # assume arrays have same outer dim
        return [jnp.take(x, jnp.asarray(idxs), axis=0) for x in arrs]
    
    def shuffle(self, arr):
        self.key, _ = jrand.split(self.key)
        return jrand.permutation(self.key, arr)
    
seed = 298504
dist_sampler = DistSampler(seed)

Using seed: 298504


## Preparing the data

Now that the dataset is available, it needs to be split into train and test sets and the feature vectors need to be normalized. For my own sanity, I also performed a quick spot check on the data.

In [4]:
from jax import numpy as jnp
import pandas as pd

"""
Split the data into train and test segments, then format it as JAX matrices
"""
bool_vec = [i < 60_000 for i in range(len(all_targets))]
bool_vec = [bool(x) for x in dist_sampler.shuffle(jnp.asarray(bool_vec))]
split_df = lambda df : (
    df[pd.Series(bool_vec).values],
    df[pd.Series([not b for b in bool_vec]).values]
)

train_features, test_features = split_df(all_features)
train_targets, test_targets = split_df(all_targets)

format_jnp = lambda *dfs : tuple([jnp.asarray(df.to_numpy(), dtype='float32') for df in dfs])

ftr_train, ftr_test, tgt_train, tgt_test = format_jnp(
    train_features,
    test_features,
    pd.get_dummies(train_targets),
    pd.get_dummies(test_targets)
)

"""
Spot check: print a rough sketch of a sample number, along with the expected answer.
"""
print('Sample image:\n')

idx = dist_sampler.choice(ftr_train.shape[0], 1)[0]
x, y = ftr_train[idx], tgt_train[idx]

img = jnp.reshape(x, (28,28))
for row in img:
    print(''.join(['@' if pix > 100 else ' ' for pix in row]))
    
print(f'\nSample answer: {jnp.argmax(y)}')

"""
Normalize the features
"""
normalize = lambda ftr_df : ftr_df / jnp.linalg.norm(ftr_df, axis=1, keepdims=True)
ftr_train, ftr_test = normalize(ftr_train), normalize(ftr_test)

Obs, Resp = ftr_train, tgt_train
TestObs, TestResp = ftr_test, tgt_test

Sample image:

                            
                            
                            
                            
                            
             @@@            
           @@@@@@@          
          @@@@@@@@@@@@@     
         @@@@@@  @@@@@@@    
         @@@@      @@@@@    
         @@@        @@@@    
         @@@       @@@@     
         @@@     @@@@@@     
         @@@   @@@@@@       
          @@  @@@@@@        
          @@@@@@@@@         
          @@@@@@@@          
          @@@@@@            
         @@@@@@@            
        @@@@@@@@            
        @@@@@ @@@           
        @@@@@@@@@@          
        @@@@@@@@@@          
         @@@@@@@@           
          @@@@@@            
                            
                            
                            

Sample answer: 8


## Constructing the Neural Network

Here we'll define a simple neural network that maps a 784 dimensional input to a 10 dimemnsional input (which is then softmaxed). In addition, we'll define the loss function (cross-entropy), as well as a function to initialize the neural nets parameters.


Note: in order to leverage the full power of JAX's jit-compiler (more on that later), it's important to make sure our functions are "pure", e.g. stateless. Note that the parameters are an explicit parameter of the neural network, rather than being "baked in". Personally, this is one of the things I like about working with JAX. The structure of the network (it's dimensions and activation functions) are clearly separated from the parameter values that are being learned, making the process of training more intuitive.

In [5]:
from jax import nn, lax

def neural_net(x, params):
    w0, b0, w1, b1, w2, b2 = params
    x1 = jnp.tanh(jnp.dot(w0, x) + b0)
    x2 = jnp.tanh(jnp.dot(w1, x1) + b1)
    x3 = jnp.tanh(jnp.dot(w2, x2) + b2)
    return nn.softmax(x3)

def initialize_params():
    d0, d1, d2 = 50, 100, 10
    return [
        dist_sampler.normal(d0, Obs.shape[1]),
        dist_sampler.normal(d0),

        dist_sampler.normal(d1, d0),
        dist_sampler.normal(d1),

        dist_sampler.normal(d2, d1),
        dist_sampler.normal(d2),
    ]

def cross_entropy_loss(params, x, y):
    prediction = neural_net(x, params)
    return jnp.sum(
        -y * jnp.log(prediction) - (1. - y) * jnp.log(1. - prediction)
    )

## JAX MAGIC (GRAD, VMAP, JIT)

Seemingly magical function transformations are where JAX really comes into its own, and none exemplify this better than `jax.grad`, `jax.vmap` and `jax.jit`.

To briefly summarize, a neural network is just an "over-parameterized" differentiable function that maps inputs to outputs, which we can then interpret however we like. Training a neural net basically comes down to (1) passing data into this function, (2) computing a loss value based on the output, (3) computing the gradient of the loss function with respect to the neural net parameters, and then (4) updating the parameters in the opposite direction of the gradient.

### `jax.grad`
Step 3 of the training process is easily the trickiest part, which is where `jax.grad` comes in. Once we've defined a loss function around our neural net, we can then transform this loss function into one that takes the same arguments but instead returns the gradient with respect to the parameters of the neural net (note that `jax.grad` expects these parameters to be the first argument of the function being transformed). Once we have those gradients, updating the parameters becomes dead simple.

### `jax.vmap`
Neural nets are known to be "embarassingly parallel", which just means that a lot of the overall work of training can be done in parallel. However, actually taking advantage of this structural parallelism can be a delicate business. `jax.vmap` takes care of that for us by simply taking a function as input and returning a parallelized version.

When using `jax_vmap`, pay close attention to the `in_axes` and `out_axes` arguments. In the below example, note that:

- in_axes=(None, 0, 0)
- out_axes=0

These axes tell JAX that if our initial function `F` has a signature:

`F : (a, b, c) -> d`
    
then `vmap(F)` has a signature of:

`vmap(F) : (a, [b0, ..., bn], [c0, ..., cn]) -> [f(a, b0, c0), ..., f(a, bn, cn)]`

### `jax.jit`
Neural nets are also known to be very computationally intensive, so we'd like any speedup we can get. `jax.jit` enables us to mark a function as "jit-compilable", enabling it to significantly improve the efficiency of that functions execution.

For a function to be "jit-compilable", it must be pure. You'll likely encounter a number of errors related to this when you give JAX a go; just stay patient and use the docs!

In [6]:
from jax import jit, vmap, grad

# create the gradient function
loss_gradient = grad(cross_entropy_loss)

# parallelize the gradient function
vectorized_loss_gradient = vmap(loss_gradient, in_axes=(None, 0, 0), out_axes=0)

# define the "batch update" function and mark it for jit
@jit
def batch_update(params, X, Y, learning_rate):
    # (1) compute the gradients of loss function relative to the training data
    param_gradients = vectorized_loss_gradient(params, X, Y)
    # (2) average the gradients
    average_grad_per_param = [
        jnp.average(gradients, 0)
        for gradients in param_gradients
    ]
    # (3) subtract the gradients from the original params and return them
    return [
        param - (learning_rate * average_grad)
        for param, average_grad in zip(params, average_grad_per_param)
    ]

## Evaluation

In order to evaluate how well the neural network is actually doing during and after training, we need to evaluate its accuracy against the test set. These functions will let us do that. We can also take advantage of `jax.vmap` and `jax.jit` here, but we'll opt for the decorator syntax here to be more concise.

In [7]:
from functools import partial

@partial(vmap, in_axes=(None, 0, 0), out_axes=0)
def score_prediction(params, x, y):
    return jnp.argmax(neural_net(x, params)) == jnp.argmax(y)

@jit
def avg_prediction_accuracy(params, X, Y):
    scores = score_prediction(params, X, Y)
    return jnp.average(scores, 0)

## Training the neural net

The meat of the training process is defined in the `batch_update` function. Here we just run that function in a loop, using the updated parameters each time. Every once in awhile we'll log the current accuracy, and occassionally we'll shrink the learning_rate.

In [8]:
import time

# initialize the parameters
neural_net_params = initialize_params()

learning_rate = 0.5
iterations = 5000
batch_size = 100

fmt_acc = lambda acc : f'{(acc * 100):.2f}%\n'

test_xs, test_ys = dist_sampler.random_rows(1000, TestObs, TestResp)
accuracy = avg_prediction_accuracy(neural_net_params, test_xs, test_ys)
print(f'{"Initial Accuracy:":<20}{fmt_acc(accuracy):>10}')

starting_time = time.time()

for iter_num in range(1, iterations + 1):
    # perform gradient descent to get new params
    batch_x, batch_y = dist_sampler.random_rows(batch_size, Obs, Resp)
    neural_net_params = batch_update(neural_net_params, batch_x, batch_y, learning_rate)

    # evaluate accuracy every 500 iterations
    if (iter_num) % 500 == 0:
        test_xs, test_ys = dist_sampler.random_rows(1000, TestObs, TestResp)
        accuracy = avg_prediction_accuracy(neural_net_params, test_xs, test_ys)
        print(f'{f"({iter_num}) Accuracy:":<20}{fmt_acc(accuracy):>10}')

    # decrease the learning rate over time
    if (iter_num) % 2000 == 0:
        learning_rate *= 0.5
        
elapsed_seconds = time.time() - starting_time

print(f'{elapsed_seconds:.2f} seconds elapsed during training.\n')

accuracy = avg_prediction_accuracy(neural_net_params, TestObs, TestResp)
print(f'Final accuracy against test set: {fmt_acc(accuracy):>10}')

Initial Accuracy:       8.70%

(500) Accuracy:        79.70%

(1000) Accuracy:       85.50%

(1500) Accuracy:       88.80%

(2000) Accuracy:       89.00%

(2500) Accuracy:       89.50%

(3000) Accuracy:       91.80%

(3500) Accuracy:       91.80%

(4000) Accuracy:       91.20%

(4500) Accuracy:       91.40%

(5000) Accuracy:       91.80%

54.20 seconds elapsed during training.

Final accuracy against test set:    91.74%



### So there you have it - a simple neural network in JAX!

Despite this being a toy example, it should be clear that JAX's transformations can be used to make the whole process of training a neural network intuitive, without sacrificing performance.

I should note that JAX has a special library called `stax` which is designed for neural nets in particular. If you're interested in JAX for this use case, you should check that out!