In [1]:
import jax
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]

# *Solving MNIST with JAX*

JAX is a cool library. Among other things, it:
- can JIT compile code for a CPU/GPU/TPU/etc...
- makes parallel execution easy, even on separate devices
- can transform a scalar-valued function into one that computes its gradient

What better way to explore this library than with a neural net?

Special thanks to [You Don't Know JAX](https://colinraffel.com/blog/you-don-t-know-jax.html) and [Neural Networks and Deep Learning](http://neuralnetworksanddeeplearning.com).

## Loading the MNIST dataset

We can load the MNIST dataset into a pandas dataframe via OpenML and sklearn's OpenML interface

In [2]:
from sklearn.datasets import fetch_openml
import pickle
import os

def load_mnist():
    pickle_path = '../data/mnist/data.pkl'
    if os.path.exists(pickle_path):
        with open(pickle_path, 'rb') as f:
            return pickle.load(f)
    mnist = fetch_openml(name='mnist_784', version=1, parser='auto')
    with open(pickle_path, 'wb') as f:
        pickle.dump(mnist, f)
    return mnist

mnist = load_mnist()
all_features, all_targets = mnist['data'], mnist['target']

In [3]:
from random import shuffle, randint
from jax import numpy as jnp
import pandas as pd


"""
Split the data into train and test segments, then format it as JAX matrices
"""
bool_vec = [i < 60_000 for i in range(len(all_targets))]
shuffle(bool_vec)
split_df = lambda df : (
    df[pd.Series(bool_vec).values],
    df[pd.Series([not b for b in bool_vec]).values]
)

train_features, test_features = split_df(all_features)
train_targets, test_targets = split_df(all_targets)

format_jnp = lambda *dfs : tuple([jnp.asarray(df.to_numpy(), dtype='float32') for df in dfs])

ftr_train, ftr_test, tgt_train, tgt_test = format_jnp(
    train_features,
    test_features,
    pd.get_dummies(train_targets),
    pd.get_dummies(test_targets)
)

"""
Spot check: print a rough sketch of a sample number, along with the expected answer.
"""
print('Sample image\n')

idx = randint(0, ftr_train.shape[0])
x, y = ftr_train[idx], tgt_train[idx]

img = jnp.reshape(x, (28,28))
for row in img:
    print(''.join(['@' if pix > 100 else ' ' for pix in row]))
    
print(f'\nSample answer: {jnp.argmax(y)}')

"""
Normalize the features
"""
normalize = lambda ftr_df : ftr_df / jnp.linalg.norm(ftr_df, axis=1, keepdims=True)
ftr_train, ftr_test = normalize(ftr_train), normalize(ftr_test)
Obs, Resp = ftr_train, tgt_train
TestObs, TestResp = ftr_test, tgt_test
# quick sanity check
assert all(jnp.isclose(jnp.linalg.norm(Obs, axis=1), jnp.full(Obs.shape[0], 1.0)))

Sample image

                            
                            
                            
                            
                            
            @@@@@@@         
          @@@@@@@@@@@       
         @@@@@@@@@@@@@      
         @@@@@@@@@@@@@      
        @@@@@@   @@@@       
         @@@@@   @@@@       
          @@@@@@@@@@        
          @@@@@@@@@@        
           @@@@@@@@         
            @@@@@@          
            @@@@@           
            @@@@@           
           @@@@@@           
          @@@@@@@@          
          @@@@@@@@          
         @@@@@@@@           
        @@@@@@@@@           
        @@@@@@@@            
        @@@@@@@@            
         @@@@               
                            
                            
                            

Sample answer: 8


In [4]:
from random import randint
from jax import random as jrand

class DistSampler:
    def __init__(self):
        self.key = jrand.PRNGKey(randint(0, 10**6))
        
    def normal(self, *shape):
        self.key, subkey = jrand.split(self.key)
        return jrand.normal(subkey, shape=tuple(shape))
    
    def uniform(self, *shape):
        self.key, subkey = jrand.split(self.key)
        return jrand.uniform(subkey, shape=tuple(shape))
    
dist_sampler = DistSampler()

In [5]:
from jax import grad, nn, jit

num_obs, num_ftrs = Obs.shape

def xlayer(x, w, b):
    # x has dims (#In,) and w has dims (#Out, #In) and b has dims (#Out,)
    return jnp.dot(w, x) + b

def compose_layers(activations, index=0):
    afn = activations[index]
    if len(activations) - 1 == index:
        return lambda x, params : afn(xlayer(x, params[2*index], params[2*index+1]))
    else:
        nextfn = compose_layers(activations, index + 1)
        return lambda x, params : nextfn(
            afn(xlayer(x, params[2*index], params[2*index+1])),
            params
        )

def make_nn_and_params(*layers, in_dim=None, sampler=None):
    activations, dims = list(zip(*layers))
    composed_layers = compose_layers(activations)
    nn_func = lambda x, params : nn.softmax(composed_layers(x, params))
    
    def make_params():
        param_list = []
        prev_dim = in_dim
        for d in dims:
            param_list.append(sampler(d, prev_dim)) # Wi
            param_list.append(sampler(d)) # Bi
            prev_dim = d
        return param_list

    return nn_func, make_params

neural_net, make_params = make_nn_and_params(
    (jnp.tanh, 30),
    (jnp.tanh, 50),
    (jnp.tanh, 10),
    in_dim = 784,
    sampler = dist_sampler.normal
)

# def neural_net(x, params):
#     w1, b1, w2, b2 = params #, w3, b3 = params
#     t0 = tanh_layer(x, w1, b1)
#     t1 = tanh_layer(t0, w2, b2)
#     #t2 = leaky_relu_layer(t1, w3, b3)
#     return nn.softmax(t1)

# params = [
#     dist_sampler.normal(3000, num_ftrs), # w1
#     dist_sampler.normal(3000), # b1
# #     dist_sampler.uniform(50, 784), # w2
# #     dist_sampler.uniform(50), # b2
#     dist_sampler.normal(10, 3000), # w3
#     dist_sampler.normal(10), # b3
# ]

def cross_entropy(prediction, truth):
    return -truth * jnp.log(prediction) - (1. - truth)*jnp.log(1. - prediction)

def loss_fn(params, x, y):
    out = neural_net(x, params)
    cross_entropy_vec = cross_entropy(out, y)
    ce = jnp.sum(cross_entropy_vec)
    return ce

# GOTCHA
# by default, only the first paramter of the input function
# will be differentiated against
loss_gradient = jit(grad(loss_fn))
jit_nn = jit(neural_net)
jit_loss = jit(loss_fn)

In [6]:
from functools import partial
from jax import vmap

def make_loss_grad(params):
    loss_grad = grad(loss_fn)
    partial_loss_grad = partial(loss_grad, params)
    vector_loss_grad = vmap(partial_loss_grad, (0, 0), 0)
    return jit(vector_loss_grad)

In [28]:
from random import sample

def random_rows(n, *arrs):
    idxs = sample(list(range(arrs[0].shape[0])), n)
    return [
        jnp.take(x, jnp.asarray(idxs), axis=0) for x in arrs
    ]

def sgd_minibatch(params, batch_size, learning_rate=0.01):
    new_params = [jnp.copy(p) for p in params]
    for _ in range(batch_size):
        idx = randint(0, num_obs - 1)
        x, y = Obs[idx], Resp[idx]
        grads = loss_gradient(params, x, y)
        new_params = [
            param_vec - ((learning_rate / batch_size) * grad)
            for param_vec, grad in zip(new_params, grads)
        ]
    return new_params

def sgd_vector_minibatch(params, batch_size, learning_rate=0.01):
    loss_grad = make_loss_grad(params)
    batch_x, batch_y = random_rows(batch_size, Obs, Resp)
    param_gradients = loss_grad(batch_x, batch_y)
    average_grad_per_param = [
        jnp.average(gradients, 0)
        for gradients in param_gradients
    ]
    updated_params = [
        param - (learning_rate * average_grad)
        for param, average_grad in zip(params, average_grad_per_param)
    ]
    return updated_params

def avg_training_loss(params, n=1000):
    partial_loss = partial(jit_loss, params)
    vector_loss = vmap(partial_loss, (0, 0), 0)
    xs, ys = random_rows(n, Obs, Resp)
    losses = vector_loss(xs, ys)
    return jnp.average(losses, 0)
    
def avg_prediction_accuracy(params, n=1000):
    check_guess = (
        lambda x, y : jnp.argmax(jit_nn(x, params)) == jnp.argmax(y)
    )
    vector_nn = vmap(check_guess, (0,0), 0)
    xs, ys = random_rows(n, TestObs, TestResp)
    scores = vector_nn(xs, ys)
    return jnp.average(scores, 0)

In [29]:
params = make_params()

In [None]:
for i, epoch_num in enumerate(range(10_000)):
    params = sgd_vector_minibatch(params, 100, .5)
    #params = sgd_minibatch(params, 30, .5)
    if (i+1) % 100 == 0:
        loss = avg_training_loss(params)
        print(f'({i+1})\taverage loss:\t\t{loss:.4f}')
        acc = avg_prediction_accuracy(params)
        print(f'\taverage accuracy:\t{(acc * 100):.2f}%\n')

(100)	average loss:		1.5602
	average accuracy:	91.10%

(200)	average loss:		1.5473
	average accuracy:	89.60%

(300)	average loss:		1.6171
	average accuracy:	91.30%

(400)	average loss:		1.5565
	average accuracy:	89.40%

(500)	average loss:		1.5793
	average accuracy:	90.50%

(600)	average loss:		1.5612
	average accuracy:	94.00%

(700)	average loss:		1.5776
	average accuracy:	91.20%

(800)	average loss:		1.5577
	average accuracy:	90.70%

(900)	average loss:		1.5253
	average accuracy:	90.30%

(1000)	average loss:		1.5317
	average accuracy:	91.20%

(1100)	average loss:		1.5566
	average accuracy:	91.40%

(1200)	average loss:		1.5342
	average accuracy:	92.00%

(1300)	average loss:		1.5443
	average accuracy:	91.30%

(1400)	average loss:		1.5386
	average accuracy:	93.10%

(1500)	average loss:		1.5417
	average accuracy:	92.30%

(1600)	average loss:		1.5423
	average accuracy:	91.30%

(1700)	average loss:		1.5205
	average accuracy:	91.90%

(1800)	average loss:		1.5299
	average accuracy:	94.00%

(

In [None]:
# import time

# class Clock():
#     def __init__(self):
#         self.start_time = time.time()
        
#     def start(self):
#         self.start_time = time.time()
    
#     def stop(self):
#         return time.time() - self.start_time
    
# clock = Clock()

# clock.start()
# for i, epoch_num in enumerate(range(100)):
#     params = sgd_minibatch(params, 50, .5)
#     if (i+1) % 20 == 0:
#         print(f'({i+1}) average loss: {avg_training_loss(params)}')
# print(f'minibatch time = {clock.stop()}')

# clock.start()
# for i, epoch_num in enumerate(range(100)):
#     params = sgd_vector_minibatch(params, 50, .5)
#     if (i+1) % 20 == 0:
#         print(f'({i+1}) average loss: {avg_training_loss(params)}')
# print(f'vector_minibatch time = {clock.stop()}')