# Fashion MNIST using Linear Layers with JAX

- toc: true
- badges: true
- comments: true
- categories: [jupyter]
- image: images/chart-preview.png

## Introduction

In this post, I'm going to implement a basic Fashion-MNIST classifier using JAX.  JAX is an array-processing library that uses Google's XLA (Accelerated Linear Algebra) compiler to generate high-performance code that can run on a variety of hardware platforms.  It feels a lot like numpy, with a number of advantages including built in automatic differentiation, vectorization and parallelization, and just-in-time compilation.  


## The Goal

To be a little more specific, here's what we're going to work toward:

```python
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data() 
train_dataset = Dataset(X_train, y_train)
valid_dataset = Dataset(X_test, y_test)
train_datasource = Dataloader(train_dataset, batchsize=32)
valid_datasource = Dataloader(valid_dataset, batchsize=64)

network = Sequential(
    rescale_image,
    flatten,
    Linear(28*28, 128),
    relu,
    Linear(128, 10),
    softmax  
)

grad_fn = jax.jit(jax.grad(lambda model, X, y: fashion_mnist_loss(model(X), y)))

history = train(
    num_epochs=5, 
    train_datasource=train_datasource, 
    valid_datasource=None,
    optimizer=Adam(model, lr=1e-3), 
    loss_fn=fashion_mnist_loss, 
    model=network,
    grad_fn=grad_fn
)
```

```
Epoch 1/5
1874/1875  [===============================] - 3s  1.49ms/batch  -  loss: 0.4958   -  accuracy: 0.8245      
Epoch 2/5
1874/1875  [===============================] - 2s  1.40ms/batch  -  loss: 0.3725   -  accuracy: 0.8654    
Epoch 3/5
1874/1875  [===============================] - 2s  1.53ms/batch  -  loss: 0.3339   -  accuracy: 0.8778    
Epoch 4/5
1874/1875  [===============================] - 2s  1.55ms/batch  -  loss: 0.3080   -  accuracy: 0.8869    
Epoch 5/5
1874/1875  [===============================] - 2s  1.46ms/batch  -  loss: 0.2900   -  accuracy: 0.8934
```


## Load Libraries

In [66]:
import jax 
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import time 

from typing import Tuple, List, Any, Dict, Callable

eps = jnp.finfo(jnp.float64).eps

## Loading the Data

I really like Pytorch's approach to handling data, and could have just imported it's builtin `Dataset` and `Dataloader` classes.

In [2]:
class Dataset:
    def __init__(self, X, y):
        self.X, self.y = X, y
    def __len__(self):
        return jnp.shape(self.X)[0]
    def __getitem__(self, i):
        return self.X[i,:], self.y[i]

In [3]:
class Dataloader:
    def __init__(self, dataset: Dataset, batchsize=32, shuffle=False):
        self.dataset = dataset
        self.batchsize = batchsize
        self.shuffle = shuffle
    def __iter__(self):
        for i in range(0, len(self.dataset), self.batchsize): 
            yield self.dataset[i:i+self.batchsize]
    def __len__(self):
        return len(self.dataset) // self.batchsize
        

In [4]:
(X_train, y_train), _ = tf.keras.datasets.fashion_mnist.load_data()
dataset = Dataset(X_train, y_train)
dataloader = Dataloader(dataset)

## The Sequential Model

A sequential model is a list of `Callable` objects that is evaluated by calling each members `__call__` method in order.  Each list element is a registered pytree, but I also wanted to have the flexibility to pass ordinary functions to the `Sequential` constructor and have everything just work.  As you'll soon see this feature was implemented by modifying the `__init__` method.

### `Linear` Layer


The `Linear` layer defined below is very similar to implementations (but less general) you'd find in other non-JAX neural network libraries.  

In [5]:
@jax.tree_util.register_pytree_node_class
class Linear:
    w: jnp.ndarray 
    b: jnp.ndarray
    ni: int 
    no: int 

    def __init__(self, num_inputs, num_outputs, build=True, seed=1234):
        self.ni = num_inputs 
        self.no = num_outputs 
        # want to add seed as internal object
        if build:
            key = jax.random.PRNGKey(seed)
            self.w = jax.random.normal(key, (num_inputs, num_outputs)) * jnp.sqrt(2.0 / num_inputs)
            self.b = jnp.zeros(num_outputs)

    def __repr__(self):
        return f'Linear(num_inputs={self.ni}, num_outputs={self.no})'
        
    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b
        
    def params(self):
        return {'w': self.w, 'b': self.b}
    
    def tree_flatten(self):
        return (self.w, self.b), (self.ni, self.no)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layer = cls(*aux_data, build=False)
        layer.w, layer.b = children
        return layer

The one glaring difference is the decorator and the two *tree* methods.  As the JAX documentation explains, the `register_pytree_node_class`, `tree_flatten` method, and `tree_unflatten` class methods are required to add a user-defined class to the JAX pytree registry.  Once added, JAX will know how to transform back and forth between objects that the rest of JAX system can efficiently operate on, and objects that are specific to your application.    

The `tree_flatten` method returns a two-element tuple consisting of the parameters you want to expose to JAX, and any meta-data that can help reconstruct the object.  Because JAX seems to really embrace the functional paradigm of immutable data structures, I thought it might be better to express the parameters as a tuple.  For `Linear`, the parameters are the weights and biases of the neural network.  For now, the only meta-data that seems helpful are the number of inputs and outputs (although this could be derived from the shape of the weights.  

Another thing to notice about `Linear` is the `build` attribute.  Most of the time, you want to initialize the weights and biases at creation time.  However, you don't want to do this when JAX reconstructs the object from it's flattened representation.  You probably just want to plop the parameters right into a freshly constructed object. The `build` attribute gives you some flexibility in that regard.

### `Function` Layer

The `Function` class fills the same need that `Lambda` layers do in Keras: being able to conveniently plug functions into models.  As the `tree_flatten` method shows, classes registered as pytrees can be parameter-free. 

In [6]:
@jax.tree_util.register_pytree_node_class
class Function:
    def __init__(self, fn):
        self.fn = fn 
        
    def __call__(self, x):
        return self.fn(x)
    
    def __repr__(self):
        return f'Function({self.fn.__name__})'
    
    def tree_flatten(self):
        return [], self.fn

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(aux_data)
    

### Helpful Functions

Here are a few functions that will be *lifted* to `Function` layer in the `Sequential` model.  One common approach for improving classification accuracy is to normalize your input data.  When working with gray-scale images, this typically means rescaling the pixels from $[0,255]$ to $[0,1]$.  This is what `rescale_image` does.

In [7]:
def rescale_image(x): return x / 255.0

The model built in this post operates on batches of two-dimensional gray-scale images.  Each batch is a three-dimensional array and can be interpreted as a vertical stack of 2D images, where the height of the vertical stack is the number of images.  The `flatten` goes through each slice of the vertical stack and transforms the 2D array into a one-dimensional array.  In the process, the 3D input becomes a 2D array.

In [8]:
def flatten(x):
    shape = jnp.shape(x)
    assert len(shape) == 3, 'x must represent a batch of two-dimensional gray-scale images' 
    batch_size = shape[0]
    return jnp.reshape(x, (batch_size, -1)) 
    

The last two functions we'll implement in this section are `relu` and `softmax`.  A couple of tests are also provided. 

In [10]:
def relu(x): return jnp.clip(x, a_min=0)    

x = np.random.randn(10,10)
assert np.all(np.isclose(relu(x), tf.nn.relu(x))), 'test failed'

In [11]:
def softmax(x):
    ex = jnp.exp(x)
    return ex / jnp.sum(ex, axis=-1, keepdims=True)    

x = np.random.randn(5, 10)
assert np.all(np.isclose(softmax(x), tf.nn.softmax(x))), 'test failed'
assert np.isclose(jnp.sum(softmax(x)), jnp.shape(x)[0])

### Sequential Model

With the supporting pieces implemented, we can define `Sequential`.

In [12]:
@jax.tree_util.register_pytree_node_class
class Sequential:

    def __init__(self, *layers):
        self.layers = []
        for layer in layers:
            if hasattr(layer, 'tree_flatten'):
                self.layers.append(layer)
            elif callable(layer):
                self.layers.append(Function(layer))
    
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    def __repr__(self):
        string = ''
        for layer in self.layers:
            string += (repr(layer) + '\n')
        return string
    
    def tree_flatten(self):
        aux_data, children = [], []
        for layer in self.layers:
            params, extra_stuff = jax.tree_flatten(layer)
            aux_data.append(extra_stuff)
            children.append(params)
        return children, aux_data
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layers = []
        for params, spec in zip(children, aux_data):
            layers.append(jax.tree_unflatten(spec, params))
        return Sequential(*layers)

Based on the previous high-level description, it should not be surprising that `Sequential` is essentially a wrapper around a list of layers.  The `__init__` method loops through the list of input objects and does one of the following things:

1. if it is a pytree, adds the layer to the list of layers
2. if it is a function, but not a pytree, wraps the function in a `Function` object and adds the pytree to the list of layers.

Either way, at the end of construction, each `Sequential` instance is a list of registered pytrees.  At this point, `Sequential` is added to the pytree registry by implementing `tree_flatten` and `tree_unflatten` by looping over the layers, calling each layer's flatten of unflatten method, and collecting the results.

### Some Testing ...

In [13]:
def fashion_mnist_mlp():
    return Sequential(
        rescale_image,
        flatten,
        Linear(784, 128),
        relu,
        Linear(128, 10),
        softmax  
    )

model = fashion_mnist_mlp()
print(model)


Function(rescale_image)
Function(flatten)
Linear(num_inputs=784, num_outputs=128)
Function(relu)
Linear(num_inputs=128, num_outputs=10)
Function(softmax)



In [36]:
x = jnp.zeros((5,5))
x[jnp.arange(5),[0,1,2,1,5]]

DeviceArray([0., 0., 0., 0., 0.], dtype=float32)

## Cross Entropy Loss

Cross Entropy is one of the most common loss functions for classification problems.  I'm going to spare you the long-winded mathematical justification of why it's a useful function.  Let me just say that it measures how close two probability distributions are. 

Here are two different versions of cross-entropy.  The first version (`cross_entropy`) assumes that `y_true` is one-hot encoded while the second version (`sparse_cross_entropy`) assumes that `y_true` is an array of indices.  I like the sparse version because it seems more efficient and less dependent on knowing the number of categories in the dataset.  

In [63]:
@jax.jit
def cross_entropy(y_true, probs):
    batch_size = jnp.shape(probs)[0]
    return -jnp.sum(jnp.log(probs + eps) * y_true) / batch_size

y_true = np.array([[0, 1, 0], [0, 0, 1]])
y_pred = np.array([[0.05, 0.95, 0.0], [0.1, 0.8, 0.1]])
keras_cross_entropy = tf.keras.losses.CategoricalCrossentropy()
assert np.all(np.isclose(cross_entropy(y_true, y_pred), keras_cross_entropy(y_true, y_pred))), 'Not close'

In [64]:
@jax.jit
def sparse_cross_entropy(y_true, probs):
    batch_size = len(y_true)
    X = jnp.log(probs + eps)[jnp.arange(batch_size), y_true]
    return -jnp.sum(X) / batch_size

y_true = jnp.array([1, 2])
y_pred = jnp.array([[0.05, 0.95, 0.0], [0.1, 0.8, 0.1]])
keras_cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy()
assert np.all(np.isclose(sparse_cross_entropy(y_true, y_pred), keras_cross_entropy(y_true, y_pred))), 'Not close'

## Optimizers

Optimizers update model parameters at each minibatch.  The simplest optimizer, stochastic gradient descent (aka SGD), is shown below and updates the parameters ($W$) as follows: 

$$
    W \leftarrow W - \alpha \frac{\partial \ell}{\partial W}
$$

where $\ell$ represents the loss-function used in model training and $\alpha$ is the learning rate.  

While certainly simple and fast, SGD is not the *go-to* optmizer these days.  That title seems to go to the Adam optimizer, as it tends to be the optimizer used in most examples I've seen.  From their original paper, Kingma and Ba describe Adam as 

> "an algorithm for first-order gradient-based optimization of stochastic objective functions, based on adaptive estimates of lower-order moments."

and from the same paper:

> "The method computes individual adaptive learning rates for different parameters from estimates of first and second moments of 
the gradients; the name Adam is derived from adaptive moment estimation."

For now, an optimizer has an `__init__` and `step` method.  The `__init__` method initializes the parameters used to calculate the step direction or step size, while the `step` method updates the model and any optimizer parameters.  

In [17]:
class Optimizer: pass

In [18]:
class SGD(Optimizer):
    def __init__(self, model, lr=1e-3):
        self.lr = lr 
    def step(self, model, grads):
        return jax.tree_map(lambda p, g: p - self.lr*g, model, grads)

In [19]:
class Adam(Optimizer):
    def __init__(self, model, lr=1e-3, v_decay=0.9, s_decay=0.999, eps=1e-7):
        self.lr, self.v_decay, self.s_decay, self.eps = lr, v_decay, s_decay, eps
        self.v = jax.tree_map(lambda x: jnp.zeros_like(x), model) 
        self.s = jax.tree_map(lambda x: jnp.zeros_like(x), model)
        self.k = 0 
    def step(self, model, grads):
        lr, v_decay, s_decay, eps = self.lr, self.v_decay, self.s_decay, self.eps
        v, s = self.v, self.s
        k = self.k = self.k+1
        self.v = jax.tree_map(lambda v, g: v_decay*v +(1-v_decay)*g, v, grads)
        self.s = jax.tree_map(lambda s, g: s_decay*s +(1-s_decay)*g*g, s, grads)
        v_hat = jax.tree_map(lambda v: v / (1-v_decay**k), self.v)
        s_hat = jax.tree_map(lambda s: s / (1-s_decay**k), self.s)
        new_model = jax.tree_map(lambda params, v_hat, s_hat: params - (lr*v_hat)/(jnp.sqrt(s_hat) + eps), model, v_hat, s_hat)
        return new_model


The `step` implementations illustrate why registering models as pytrees is so important.  In functional programming, you often *map* a function operating on a single item

In [74]:
jax.tree_map??

[0;31mSignature:[0m
[0mjax[0m[0;34m.[0m[0mtree_map[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mf[0m[0;34m:[0m [0mCallable[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m [0mAny[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtree[0m[0;34m:[0m [0mAny[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m*[0m[0mrest[0m[0;34m:[0m [0mAny[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mis_leaf[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mCallable[0m[0;34m[[0m[0;34m[[0m[0mAny[0m[0;34m][0m[0;34m,[0m [0mbool[0m[0;34m][0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0mAny[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mtree_map[0m[0;34m([0m[0mf[0m[0;34m:[0m [0mCallable[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m [0mAny[0m[0;34m][0m[0;34m,[0m [0mtree[0m[0;34m:[0m [0mAny[0m[0;34m,[0m [0;34m*[0m[0mrest[0m[0;34m:[0m [0mAny[0m

## Progress Bar

I really like how Keras logs information to the screen during model training, and decided to mimic the style.  Here's my version of the progress bar.  

In [20]:
def progress_bar(percentage, total=30):
    x = int(percentage*total)
    if x < total:
        r = '[' + ''.join(['=']*x) + '>' + ''.join(['.']*(total-x)) + ']' 
    else:
        r = '[' + ''.join(['=']*(total+1)) + ']' 
    return r

It has space for 31 characters sandwiched between an opening and closing bracket.  Examples at various completion percentages are shown below.

1. 0% progress

In [78]:
print(progress_bar(0))

[>..............................]


2. 10% progress

In [79]:
print(progress_bar(0.1))

[===>...........................]


3. 100% progress

In [81]:
print(progress_bar(1))



## The Training Loop  

The training loop is pretty basic, but looks cluttered because it's arranged as a single function.  Most frameworks split the training loop into several pieces, and incorporates a callback system that allows a user to customize it's functionality.  Otherwise, you'd likely have to resort to writing a new training loop for each problem.  Besides that, a callback system would lead to less cluttered (and therefore less buggy) code.



In [87]:
def train(num_epochs, train_datasource, valid_datasource, optimizer, loss_fn, grad_fn, model):
    history = {'loss':[], 'accuracy':[]}
    
    if valid_datasource is not None:
        history = {**history, 'valid_loss': [], 'valid_accuracy': []}

    train_num_batches = len(train_datasource)

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')

        # TRAINING PHASE
        train_loss_accum, train_accuracy_accum, train_batch_size = 0, 0, 0
        
        num_steps = 0 
        
        epoch_duration = 0.0
        
        # we know how many batches there are ... keep track
        for i, (X_train, y_train) in enumerate(train_datasource):
            
            # logging
            batch_start = time.time()

            num_steps += 1
            
            # training loss and gradients for this particular batch
            probs = model(X_train)
            loss = loss_fn(y_train, probs)
            grad = grad_fn(model, X_train, y_train)
            model = optimizer.step(model, grad)
            
            
            # Results aggregation
            num_correct = jnp.sum(jnp.argmax(probs, axis=-1) == y_train)
            train_loss_accum += loss 
            train_batch_size += len(y_train)
            train_accuracy_accum += num_correct
            train_accuracy = train_accuracy_accum / train_batch_size
            train_loss = train_loss_accum / train_num_batches # average loss per batch

            # Logging ....
            batch_duration = time.time() - batch_start
            epoch_duration += batch_duration 
            log_batch_count = f'{i}/{train_num_batches}'
            log_epoch_time = f'{int(epoch_duration)}s'
            log_batch_time = f'{1_000*batch_duration:.2f}ms/batch'
            log_batch_loss = f'loss: {train_loss:.4f}'
            log_batch_accuracy = f'accuracy: {train_accuracy:.4f}'
            log_string =  f'{log_batch_count:<10s} {progress_bar((i+1)/train_num_batches)} - {log_epoch_time:<3s} {log_batch_time:<5s} - {log_batch_loss:<13s} - {log_batch_accuracy:<20s}'
            print(log_string, end='\r') 

        # 
        history['loss'].append(train_loss)
        history['accuracy'].append(train_accuracy)      

        # VALIDATION PHASE
        if valid_datasource is not None:
            valid_loss_accum, valid_accuracy_accum, valid_batch_size = 0, 0, 0 

            # Run validation step ...
            for i, (X_valid, y_valid) in enumerate(valid_datasource):
                num_steps += 1
                probs = model(X_valid)
                loss = loss_fn(y_valid, probs)
                
                valid_accuracy_accum += jnp.sum(jnp.argmax(probs, axis=-1) == y_valid)

                valid_loss_accum += loss
                valid_batch_size += len(y_valid)

            epoch_valid_loss = valid_loss_accum / valid_batch_size 
            epoch_valid_accuracy = valid_accuracy_accum / valid_batch_size

            history['loss'].append(epoch_valid_loss)
            history['accuracy'].append(epoch_valid_accuracy)
        
        # this log_string should include validation results
        print(log_string, end='\n')
    return history
    

## Execution

In [89]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data() 
train_dataset = Dataset(X_train, y_train)
valid_dataset = Dataset(X_test, y_test)
train_datasource = Dataloader(train_dataset, batchsize=32)
valid_datasource = Dataloader(valid_dataset, batchsize=32)

network = Sequential(
    rescale_image,
    flatten,
    Linear(28*28, 128),
    relu,
    Linear(128, 10),
    softmax  
)

grad_fn = jax.jit(jax.grad(lambda model, X, y: sparse_cross_entropy(y, model(X))))

history = train(
    num_epochs=5, 
    train_datasource=train_datasource, 
    valid_datasource=valid_datasource,
    optimizer=Adam(model, lr=1e-3), 
    loss_fn=sparse_cross_entropy, 
    model=network,
    grad_fn=grad_fn
)



Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


## Keras Execution

For comparison, here's the training results for Keras on the same dataset.  While their are a few differences in the output, I'm pretty pleased with how well they match. I'm guessing that the reason there isn't an even better match comes from the fact that parameter initialization is slightly different.  I'm currently limited to Kaiming initialization, whereas Dense layers in Keras use Kaiming uniform by default.  

In [463]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data() 


model = tf.keras.Sequential([
    tf.keras.layers.Rescaling(1/255.0),
    tf.keras.layers.Flatten(input_shape=(28,28)),
    tf.keras.layers.Dense(128, activation=tf.keras.activations.relu),
    tf.keras.layers.Dense(10, activation=tf.keras.activations.softmax)                          
])

model.compile(
    loss=tf.keras.losses.sparse_categorical_crossentropy, 
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

history = model.fit(X_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


## Conclusion

1. Implement convolutional layers and train a convolutional neural network
2. Add a callback system to simplify the training loop
3. Show validation results
4. Evaluate performance on test set 

In [85]:
1e-3, 900e-6

(0.001, 0.0009)