# *Solving MNIST with JAX*

JAX is a cool library. Among other things, it:
- can JIT compile code for a CPU/GPU/TPU/etc...
- makes parallel execution easy, even on separate devices
- can transform a scalar-valued function into one that computes its gradient

What better way to explore this library than with a neural net?

Special thanks to [You Don't Know JAX](https://colinraffel.com/blog/you-don-t-know-jax.html) and [Neural Networks and Deep Learning](http://neuralnetworksanddeeplearning.com).

### First, let's make sure we're able to use the GPUs

In [1]:
import jax
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]

## Downloading the MNIST dataset

The MNIST dataset can be downloaded into a pandas dataframe via sklearn's OpenML interface. In order to save time in future runs, the pandas dataframe is cached.

In [2]:
from sklearn.datasets import fetch_openml
import pickle
import os

def load_mnist():
    pickle_path = '../data/mnist/data.pkl'
    if os.path.exists(pickle_path):
        with open(pickle_path, 'rb') as f:
            return pickle.load(f)
    mnist = fetch_openml(name='mnist_784', version=1, parser='auto')
    with open(pickle_path, 'wb') as f:
        pickle.dump(mnist, f)
    return mnist

mnist = load_mnist()
all_features, all_targets = mnist['data'], mnist['target']

## Note: Randomness in JAX

JAX offers a randomness utility very similar to that of numpy. The major difference is that you need to explicitly provide a seed. This is more tedious, but the upside is that you have more control over the process. For example if you use the same seed, you should end up with the same results.

All randomness in this notebook is derived from the same seed - so, if you choose to run it (with the provdied key) you should get the exact same results.

In [3]:
from random import randint
from jax import random as jrand

class DistSampler:
    def __init__(self, seed=None):
        seed = seed or randint(0, 10**6)
        self.key = jrand.PRNGKey(seed)
        print('Using seed:', seed)
        
    def normal(self, *shape):
        self.key, _ = jrand.split(self.key)
        return jrand.normal(self.key, shape=tuple(shape))
    
    def choice(self, n, k):
        self.key, _ = jrand.split(self.key)
        return list(map(int, jrand.choice(self.key, n, (k,), replace=False)))
    
    def random_rows(self, n, *arrs):
        idxs = self.choice(arrs[0].shape[0], n) # assume arrays have same outer dim
        return [jnp.take(x, jnp.asarray(idxs), axis=0) for x in arrs]
    
    def shuffle(self, arr):
        self.key, _ = jrand.split(self.key)
        return jrand.permutation(self.key, arr)
    
seed = 298504
dist_sampler = DistSampler(seed)

Using seed: 298504


## Preparing the data

Now that the dataset is available, it needs to be split into train and test sets and the feature vectors need to be normalized. For my own sanity, I also performed a quick spot check on the data.

In [4]:
from jax import numpy as jnp
import pandas as pd

"""
Split the data into train and test segments, then format it as JAX matrices
"""
bool_vec = [i < 60_000 for i in range(len(all_targets))]
bool_vec = [bool(x) for x in dist_sampler.shuffle(jnp.asarray(bool_vec))]
split_df = lambda df : (
    df[pd.Series(bool_vec).values],
    df[pd.Series([not b for b in bool_vec]).values]
)

train_features, test_features = split_df(all_features)
train_targets, test_targets = split_df(all_targets)

format_jnp = lambda *dfs : tuple([jnp.asarray(df.to_numpy(), dtype='float32') for df in dfs])

ftr_train, ftr_test, tgt_train, tgt_test = format_jnp(
    train_features,
    test_features,
    pd.get_dummies(train_targets),
    pd.get_dummies(test_targets)
)

"""
Spot check: print a rough sketch of a sample number, along with the expected answer.
"""
print('Sample image:\n')

idx = dist_sampler.choice(ftr_train.shape[0], 1)[0]
x, y = ftr_train[idx], tgt_train[idx]

img = jnp.reshape(x, (28,28))
for row in img:
    print(''.join(['@' if pix > 100 else ' ' for pix in row]))
    
print(f'\nSample answer: {jnp.argmax(y)}')

"""
Normalize the features
"""
normalize = lambda ftr_df : ftr_df / jnp.linalg.norm(ftr_df, axis=1, keepdims=True)
ftr_train, ftr_test = normalize(ftr_train), normalize(ftr_test)
Obs, Resp = ftr_train, tgt_train
TestObs, TestResp = ftr_test, tgt_test
# quick sanity check
assert all(jnp.isclose(jnp.linalg.norm(Obs, axis=1), jnp.full(Obs.shape[0], 1.0)))

Sample image:

                            
                            
                            
                            
                            
             @@@            
           @@@@@@@          
          @@@@@@@@@@@@@     
         @@@@@@  @@@@@@@    
         @@@@      @@@@@    
         @@@        @@@@    
         @@@       @@@@     
         @@@     @@@@@@     
         @@@   @@@@@@       
          @@  @@@@@@        
          @@@@@@@@@         
          @@@@@@@@          
          @@@@@@            
         @@@@@@@            
        @@@@@@@@            
        @@@@@ @@@           
        @@@@@@@@@@          
        @@@@@@@@@@          
         @@@@@@@@           
          @@@@@@            
                            
                            
                            

Sample answer: 8


## Constructing the Neural Network

To leverage the full power of JAX (it's JIT-compiler in particular), it's important to make sure our functions are "pure", e.g. stateless. Not doing so can result in weird errors, which the JAX docs do a good job of touching on. The `compose_layers` function makes it easy to assemble the core of the neural net in an expressive manner.

Note that the parameters are an explicit parameter of the neural network, rather than being "baked in". Personally, this is one of the things I like about working with JAX. The structure of the network (it's dimensions and activation functions) are clearly separated from the parameter values that are being learned, making the process of training more intuitive.

In [5]:
from jax import nn

def cross_entropy(prediction, truth):
    return 

"""
Given a list of activation functions [A0, A1, ..., An], construct a function that takes
an input X and a list of parameters [W0, B0, W1, B1, ..., Wn, Bn] and returns an output.

The function will essentially have the structure:

fn(x, [W0, B0, ..., Wn, Bn]):
    return An(An-1(...(A0(X * W0 + B0)...) * Wn-1 + Bn-1) * Wn + Bn)
"""
def compose_layers(activations, index=0):
    afn = activations[index]
    wi = 2*index
    bi = 2*index+1
    
    # x has dims (#In,)
    # params[wi] has dims (#Out, #In)
    # params[bi] has dims (#Out,)
    layer_fn = lambda x, params : afn(jnp.dot(params[wi], x) + params[bi])
    
    if len(activations) - 1 == index:
        return layer_fn
    else:
        nextfn = compose_layers(activations, index + 1)
        return lambda x, params : nextfn(layer_fn(x, params), params)

"""
Function to construct the neural net and randomly initialize its parameters
"""
def make_nn_and_params(*layers, in_dim=None, sampler=None):
    activations, dims = list(zip(*layers))
    
    # create the neural layer structure
    composed_layers = compose_layers(activations)
    
    # wrap the final output in a softmax
    nn_func = lambda x, params : nn.softmax(composed_layers(x, params))

    # sample params from a normal distribution given the dimensions
    param_list = []
    prev_dim = in_dim
    for d in dims:
        param_list.append(sampler(d, prev_dim)) # Wi
        param_list.append(sampler(d)) # Bi
        prev_dim = d

    return nn_func, param_list


"""
Define a three-layer neural net of widths 50, 100 and 10,
using the tanh function as an activation in each layer
"""
num_obs, num_ftrs = Obs.shape
neural_net, neural_net_params = make_nn_and_params(
    (jnp.tanh, 50),
    (jnp.tanh, 100),
    (jnp.tanh, 10),
    in_dim = num_ftrs,
    sampler = dist_sampler.normal
)

"""
Define the cross entropy loss as a function of the
neural net parameters, an input and the expected outpt

Note: by default, jax.grad will only differentiate 
      the first parameter of this function
"""
def loss_fn(params, x, y):
    # define cross-entropy
    cross_entropy = lambda prediction, truth : jnp.sum(
        -truth * jnp.log(prediction) - (1. - truth) * jnp.log(1. - prediction)
    )
    # pass input and params through neural network
    out = neural_net(x, params)
    # compute the output
    return cross_entropy(out, y)

## Defining the training process

Training a neural network is a simple (in this case) matter of gradient descent. That is, given a loss function, compute the average gradient G over a batch of inputs and then subtract G from the parameters (this is mini-batch stochastic gradient descent).


Using JAX's function transformation utilities, the training process can be implemented at nearly the same level of abstraction as the description above; the compiler should ensure that such abstraction doesn't come at the cost of performance.


The `make_loss_grad` function is where the JAX magic happens:
1. `jax.grad` is applied to the loss function, which returns a function that takes the same input as the loss function, but returns its gradient instead.
2. `functools.partial` is then used to "embed" the neural net parameters inside of the gradient function so that our gradient function will have the form `f : (input, expected_result) -> gradient`; this signature makes step 3 more straightforward...
3. `jax.vmap` is applied to parallelize the gradient function, for the sake of performance
4. `jax.jit` is applied so that the function can be jit-compiled, again for the sake of performance


The `sgd_vector_minibatch` function then implments the gradient descent process, taking in the neural net parameters as input and returning a "better" set of parameters.


Note that I'm uncertain about how effecient this implementation is relative to where it could be. The promise of JAX is abstraction without the loss of performance, but of course no compiler is perfect. This is something I look forward to exploring further.

In [6]:
from functools import partial
from jax import vmap, grad, jit

"""
Apply useful transformations to the loss function
"""
def make_loss_grad(params):
    # transform the loss function into a function that computes its gradient directly
    loss_grad = grad(loss_fn)
    # "embed" the parameters
    partial_loss_grad = partial(loss_grad, params)
    # parallelize the gradient function
    vector_loss_grad = vmap(partial_loss_grad, (0, 0), 0)
    # make the function jit-compilation friendly
    return jit(vector_loss_grad)


"""
Define the gradient-descent process
"""
def sgd_vector_minibatch(params, batch_size, learning_rate):
    # transform the loss function
    loss_grad = make_loss_grad(params)
    
    # get random subset of observations and responses
    batch_x, batch_y = dist_sampler.random_rows(batch_size, Obs, Resp)
    
    # compute the gradients of loss function relative to the training data
    param_gradients = loss_grad(batch_x, batch_y)
    
    # average the gradients (I suspect including this step in the jit-compiled
    # function could yield significant performance gains... tbd)
    average_grad_per_param = [
        jnp.average(gradients, 0)
        for gradients in param_gradients
    ]
    
    # subtract the gradients from the original params and return them
    updated_params = [
        param - (learning_rate * average_grad)
        for param, average_grad in zip(params, average_grad_per_param)
    ]
    return updated_params


"""
Return the average accuracy of the neural net with the current parameters
over a sample of the test data.
"""
jit_nn = jit(neural_net)    
def avg_prediction_accuracy(params, n=1000):
    check_guess = lambda x, y : jnp.argmax(jit_nn(x, params)) == jnp.argmax(y)
    vector_nn = vmap(check_guess, (0,0), 0)
    xs, ys = dist_sampler.random_rows(n, TestObs, TestResp)
    scores = vector_nn(xs, ys)
    return jnp.average(scores, 0)

## Training the neural net

In [None]:
import time

learning_rate = 0.5
iterations = 5000
batch_size = 100

fmt_acc = lambda acc : f'{(acc * 100):.2f}%\n'

accuracy = avg_prediction_accuracy(neural_net_params)
print(f'{"Initial Accuracy:":<20}{fmt_acc(accuracy):>10}')

starting_time = time.time()

for iter_num in range(1, iterations + 1):
    # perform gradient descent to get new params
    neural_net_params = sgd_vector_minibatch(neural_net_params, batch_size, learning_rate)
    
    # evaluate accuracy every 500 iterations
    if (iter_num) % 500 == 0:
        accuracy = avg_prediction_accuracy(neural_net_params)
        print(f'{f"({iter_num}) Accuracy:":<20}{fmt_acc(accuracy):>10}')
        
    # decrease the learning rate over time
    if (iter_num) % 2000 == 0:
        learning_rate *= 0.5
        
elapsed_seconds = time.time() - starting_time

print(f'{elapsed_seconds:.2f} seconds elapsed during training.')

Initial Accuracy:       8.60%

(500) Accuracy:        78.60%

(1000) Accuracy:       84.00%

(1500) Accuracy:       87.70%

(2000) Accuracy:       88.20%

(2500) Accuracy:       89.30%

(3000) Accuracy:       90.90%

(3500) Accuracy:       90.80%

(4000) Accuracy:       90.20%

(4500) Accuracy:       89.80%



### So there you have it - a simple neural network in JAX!