# Fashion MNIST using Linear Layers with JAX

- toc: true
- badges: true
- comments: true
- categories: [jupyter]
- image: images/chart-preview.png

## Introduction

In this article, I'm going to start implementing a JAX-based neural network library, and use it to synthesize a Fashion-MNIST classifier. Features will be added over the coming weeks and months as I tackle more advanced topics and architectures. I don't intend to write a full-featured framework. My goal is to engineer some components that I can mix-and-match (and easily customize) to solve problems that interest me.  Except for basic linear algebra and automatic differentiation, I'll try to implement everything from scratch and use Tensorflow/Keras to validate my implementation.  The reason I'm choosing to test against Tensorflow is that I have more experience working with it (plus it's very easy to throw something together in Keras).  I'll be sure to give credit to any design ideas and techniques I use in the process of building out functionality.  


## JAX 

JAX is an array-processing library that uses Google's XLA (Accelerated Linear Algebra) compiler to generate highly performance code that can run on a variety of hardware platforms.  It feels a lot like numpy, with a number of advantages including built in automatic differentiation, parallelization, and just-in-time compilation.  It's built in array type is called a `DeviceArray`.  Unlike numpy's `ndarray` type though, elements of `DeviceArray`s cannot be directly mutated.  

The other fundamental datatype in JAX is the *pytree*, which (not surprisingly given the name) is a tree of python objects.  A lot of the magic and simplicity of JAX comes from working with pytrees.  More on this later.


## The Goal

Here is an example of how you could go about building a basic Fashion MNIST classifier in Keras.  The API design is very elegant and easy to 
understand.  I even like the logging information provided by the call to `model.fit`, so much in fact that I'm going to replicate the style in
my training loop.

```python
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data() 

model = tf.keras.Sequential([
  tf.keras.layers.Rescaling(1/255),
  tf.keras.layers.Flatten(input_shape=(28,28)),
  tf.keras.layers.Dense(128, activation=tf.keras.activations.relu),
  tf.keras.layers.Dense(10, activation=tf.keras.activations.softmax)                          
])

model.compile(
    loss=tf.keras.losses.sparse_categorical_crossentropy, 
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

history = model.fit(X_train, y_train, epochs=5)
```

```
Epoch 1/5
1875/1875 [==============================] - 5s 2ms/step - loss: 0.4955 - accuracy: 0.8250
Epoch 2/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3726 - accuracy: 0.8643
Epoch 3/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3354 - accuracy: 0.8781
Epoch 4/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3136 - accuracy: 0.8845
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2937 - accuracy: 0.8927
```


Here's what I'll be able to do by the end of this post.  Okay, it's nowhere near as nice
as the Keras API at this point.  Part of that comes from the fact that Keras tries to hide
some of the details from you, while I'm definitely not.  

```python
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data() 
train_dataset = Dataset(X_train, y_train)
valid_dataset = Dataset(X_test, y_test)
train_datasource = Dataloader(train_dataset, batchsize=32)
valid_datasource = Dataloader(valid_dataset, batchsize=64)

network = Sequential(
    rescale_image,
    flatten,
    Linear(28*28, 128),
    relu,
    Linear(128, 10),
    softmax  
)

grad_fn = jax.jit(jax.grad(lambda model, X, y: fashion_mnist_loss(model(X), y)))

history = train(
    num_epochs=5, 
    train_datasource=train_datasource, 
    valid_datasource=None,
    optimizer=Adam(model, lr=1e-3), 
    loss_fn=fashion_mnist_loss, 
    model=network,
    grad_fn=grad_fn
)
```

```
Epoch 1/5
1874/1875  [===============================] - 3s  1.49ms/batch  -  loss: 0.4958   -  accuracy: 0.8245      
Epoch 2/5
1874/1875  [===============================] - 2s  1.40ms/batch  -  loss: 0.3725   -  accuracy: 0.8654    
Epoch 3/5
1874/1875  [===============================] - 2s  1.53ms/batch  -  loss: 0.3339   -  accuracy: 0.8778    
Epoch 4/5
1874/1875  [===============================] - 2s  1.55ms/batch  -  loss: 0.3080   -  accuracy: 0.8869    
Epoch 5/5
1874/1875  [===============================] - 2s  1.46ms/batch  -  loss: 0.2900   -  accuracy: 0.8934
```


## Let's Start

In [433]:
import jax 
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import time 

from typing import Tuple, List, Any, Dict, Callable


## Loading the Data

I've looked at enough Pytorch and fast.ai code to realize that the `Dataset` and `Dataloader` approach to batching data are really convenient to work with.  In Pytorch, a `Dataset` is much more general that what I have here.  In fact, because the data returned by `tf.keras.datasets.fashion_mnist.load_data()` is just a few numpy arrays, the `Dataset` abstraction is pretty useless right now.

In [434]:
class Dataset:
    def __init__(self, X, y):
        self.X, self.y = X, y
    def __len__(self):
        return jnp.shape(self.X)[0]
    def __getitem__(self, i):
        return self.X[i,:], self.y[i]

The `Dataloader` is responsible for slicing the passed in dataset into a collection of batches that can be used for training, validating, testing, or anything else you can think of.  

In [573]:
class Dataloader:
    def __init__(self, dataset: Dataset, batchsize=32, shuffle=False):
        self.dataset = dataset
        self.batchsize = batchsize
        self.shuffle = shuffle
    def __iter__(self):
        for i in range(0, len(self.dataset), self.batchsize): 
            yield self.dataset[i:i+self.batchsize]
    def __len__(self):
        return len(self.dataset) // self.batchsize
        

By design, `Dataloader`s are lazy, meaning that batches are only constructed and returned when you ask for them.

In [575]:
(X_train, y_train), _ = tf.keras.datasets.fashion_mnist.load_data()
dataset = Dataset(X_train, y_train)
dataloader = Dataloader(dataset)

(array([[[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        ...,
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],

## The model

In this section, I'll describe each component of the `Sequential` model being 

```python
model = Sequential(
    rescale_image,
    flatten,
    Linear(784, 128),
    relu,
    Linear(128, 10),
    softmax  
)
```

### Linear Layers

Like the `Treex` and `Equinox` neural network libraries, I interpret each layer as a parametric function.  A parametric function is just like an ordinary function, except that it's behavior is governed by a set of parameters.  In the case of a linear layer in a neural network, these parameters correspond to the weights and biases that are updated in the training loop.

In [439]:
class Parametric: pass

In [440]:
@jax.tree_util.register_pytree_node_class
class Linear(Parametric):
    w: jnp.ndarray 
    b: jnp.ndarray
    ni: int 
    no: int 

    def __init__(self, num_inputs, num_outputs, build=True, seed=1234):
        self.ni = num_inputs 
        self.no = num_outputs 
        # want to add seed as internal object
        if build:
            key = jax.random.PRNGKey(seed)
            self.w = jax.random.normal(key, (num_inputs, num_outputs)) * jnp.sqrt(2.0 / num_inputs)
            self.b = jnp.zeros(num_outputs)

    def __repr__(self):
        return f'Linear(num_inputs={self.ni}, num_outputs={self.no})'
        
    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b
        
    def tree_flatten(self):
        return (self.w, self.b), (self.ni, self.no)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layer = cls(*aux_data, build=False)
        layer.w, layer.b = children
        return layer

### Functional Layers



In [None]:
A 

In [441]:
@jax.tree_util.register_pytree_node_class
class Function:
    def __init__(self, fn):
        self.fn = fn 
        
    def __call__(self, x):
        return self.fn(x)
    
    def __repr__(self):
        return f'Function({self.fn.__name__})'
    
    def tree_flatten(self):
        return [None], self.fn

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(aux_data)
    

### Flatten

In [442]:
def flatten(x):
    shape = jnp.shape(x)
    if len(shape) == 2:
        # flatten a single 2D image
        return jnp.reshape(x, -1)
    elif len(shape) == 3:
        # x is a batch of 2D images, flatten each image 
        batch_size = jnp.shape(x)[0]
        return jnp.reshape(x, (batch_size, -1)) 
    else:
        raise Exception(
            f'At the moment you can only pass 2D or 3D arrays to flatten, you passed a {len(shape)}D array' 
        )

### RELU

In [443]:
def relu(x): return jnp.clip(x, a_min=0)    

In [444]:
x = np.random.randn(10,10)
assert np.all(np.isclose(relu(x), tf.nn.relu(x))), 'test failed'

### Softmax

In [445]:
def softmax(x):
    ex = jnp.exp(x)
    return ex / jnp.sum(ex, axis=-1, keepdims=True)    

In [446]:
x = np.random.randn(5, 10)
assert np.all(np.isclose(softmax(x), tf.nn.softmax(x))), 'test failed'
assert np.isclose(jnp.sum(softmax(x)), jnp.shape(x)[0])

### Rescaling

In [447]:
callable(Linear(1,1))

True

In [448]:
def rescale_image(x): return x / 255.0

## Sequential

A `Sequential` model is a list of function-like each objects.  Looking at the `__init__` method below, each item in the list of layers must be a subclass of `Parametric` or `Function`, or be `callable`.  Note that the order is important here because any object that implements a `__call__` method is by definition `callable`.  The other methods defined in the `Sequential` class definition are pretty straight-forward and consist of simply looping over the list of layers and calling that particular method (and possibly appending results).

In [449]:
@jax.tree_util.register_pytree_node_class
class Sequential(Parametric):

    def __init__(self, *layers):
        self.layers = []
        for layer in layers:
            if isinstance(layer, Parametric) or isinstance(layer, Function):
                self.layers.append(layer)
            elif callable(layer):
                self.layers.append(Function(layer))
    
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    def __repr__(self):
        string = ''
        for layer in self.layers:
            string += (repr(layer) + '\n')
        return string
    
    def tree_flatten(self):
        aux_data, children = [], []
        for layer in self.layers:
            params, extra_stuff = jax.tree_flatten(layer)
            aux_data.append(extra_stuff)
            children.append(params)
        return children, aux_data
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layers = []
        for params, spec in zip(children, aux_data):
            layers.append(jax.tree_unflatten(spec, params))
        return Sequential(*layers)

In [450]:
def fashion_mnist_mlp():
    return Sequential(
        rescale_image,
        flatten,
        Linear(784, 128),
        relu,
        Linear(128, 10),
        softmax  
    )

model = fashion_mnist_mlp()
print(model)


Function(rescale_image)
Function(flatten)
Linear(num_inputs=784, num_outputs=128)
Function(relu)
Linear(num_inputs=128, num_outputs=10)
Function(softmax)



## The Training Loss

In [430]:
def cross_entropy_loss(y_true, probs):
    batch_size, _ = jnp.shape(probs)
    return -jnp.sum(jnp.log(probs + 1.0e-16) * y_true) / batch_size

In [431]:
y_true = np.array([[0, 1, 0], [0, 0, 1]])
y_pred = np.array([[0.05, 0.95, 0.0], [0.1, 0.8, 0.1]])
keras_cross_entropy = tf.keras.losses.CategoricalCrossentropy()
assert np.all(np.isclose(cross_entropy_loss(y_true, y_pred), keras_cross_entropy(y_true, y_pred))), 'Not close'

## The Optimizers


In [77]:
class Optimizer: pass

In [78]:
class SGD(Optimizer):
    def __init__(self, model, lr=1e-3):
        self.lr = lr 
    def step(self, model, grads):
        return jax.tree_map(lambda p, g: p - self.lr*g, model, grads)

In [79]:
class Adam(Optimizer):
    def __init__(self, model, lr=1e-3, v_decay=0.9, s_decay=0.999, eps=1e-7):
        self.lr, self.v_decay, self.s_decay, self.eps = lr, v_decay, s_decay, eps
        self.v = jax.tree_map(lambda x: jnp.zeros_like(x), model) 
        self.s = jax.tree_map(lambda x: jnp.zeros_like(x), model)
        self.k = 0 
    def step(self, model, grads):
        lr, v_decay, s_decay, eps = self.lr, self.v_decay, self.s_decay, self.eps
        v, s = self.v, self.s
        k = self.k = self.k+1
        self.v = jax.tree_map(lambda v, g: v_decay*v +(1-v_decay)*g, v, grads)
        self.s = jax.tree_map(lambda s, g: s_decay*s +(1-s_decay)*g*g, s, grads)
        v_hat = jax.tree_map(lambda v: v / (1-v_decay**k), self.v)
        s_hat = jax.tree_map(lambda s: s / (1-s_decay**k), self.s)
        new_model = jax.tree_map(lambda params, v_hat, s_hat: params - (lr*v_hat)/(jnp.sqrt(s_hat) + eps), model, v_hat, s_hat)
        return new_model


## The Logging

In [492]:
def progress_bar(percentage, total=30):
    x = int(percentage*total)
    if x < total:
        r = '[' + ''.join(['=']*x) + '>' + ''.join(['.']*(total-x)) + ']' 
    else:
        r = '[' + ''.join(['=']*(total+1)) + ']' 
    return r

## The Training Loop  

In [542]:
def train(num_epochs, train_datasource, valid_datasource, optimizer, loss_fn, grad_fn, model):
    history = {'loss':[], 'accuracy':[]}
    
    if valid_datasource is not None:
        history = {**history, 'valid_loss': [], 'valid_accuracy': []}

    train_num_batches = len(train_datasource)

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')

        # TRAINING PHASE
        train_loss_accum, train_accuracy_accum, train_batch_size = 0, 0, 0
        
        num_steps = 0 
        
        epoch_duration = 0.0
        
        # we know how many batches there are ... keep track
        #train_loss_per_batch = []
        #train_accuracy_per_batch = []
        for i, (X_train, y_train) in enumerate(train_datasource):
            
            # logging
            batch_start = time.time()

            num_steps += 1
            
            # training loss and gradients for this particular batch
            probs = model(X_train)
            loss = loss_fn(probs, y_train)
            
            grads = grad_fn(model, X_train, y_train)
            model = optimizer.step(model, grads)
            
            
            # Results aggregation
            num_correct = jnp.sum(jnp.argmax(probs, axis=-1) == y_train)
            train_loss_accum += loss 
            train_batch_size += len(y_train)
            train_accuracy_accum += num_correct
            train_accuracy = train_accuracy_accum / train_batch_size
            train_loss = train_loss_accum / train_num_batches # average loss per batch

            # Logging ....
            batch_duration = time.time() - batch_start
            epoch_duration += batch_duration 
            log_batch_count = f'{i}/{train_num_batches}'
            log_epoch_time = f'{int(epoch_duration)}s'
            log_batch_time = f'{1_000*batch_duration:.2f}ms/batch'
            log_batch_loss = f'loss: {train_loss:.4f}'
            log_batch_accuracy = f'accuracy: {train_accuracy:.4f}'
            log_string =  f'{log_batch_count:<10s} {progress_bar((i+1)/train_num_batches)} - {log_epoch_time:<3s} {log_batch_time:<5s} - {log_batch_loss:<13s} - {log_batch_accuracy:<20s}'
            print(log_string, end='\r') 

        # 
        history['loss'].append(train_loss)
        history['accuracy'].append(train_accuracy)      

        # VALIDATION PHASE
        if valid_datasource is not None:
            valid_loss_accum, valid_accuracy_accum, valid_batch_size = 0, 0, 0 

            # Run validation step ...
            for i, (X_valid, y_valid) in enumerate(valid_datasource):
                num_steps += 1
                probs = model(X_valid)
                loss = loss_fn(probs, y_valid)
                
                valid_accuracy_accum += jnp.sum(jnp.argmax(probs, axis=-1) == y_valid)

                valid_loss_accum += loss
                valid_batch_size += len(y_valid)

            epoch_valid_loss = valid_loss_accum / valid_batch_size 
            epoch_valid_accuracy = valid_accuracy_accum / valid_batch_size

            history['loss'].append(epoch_valid_loss)
            history['accuracy'].append(epoch_valid_accuracy)
        
        # this log_string should include validation results
        print(log_string, end='\n')
    return history
    

In [539]:
@jax.jit
def fashion_mnist_loss(probs, y_true, num_classes=10):
    # average cross entropy, batch
    y_one_hot = jax.nn.one_hot(y_true, num_classes)
    return -jnp.sum(jnp.log(probs) * y_one_hot) / len(y_true)

In [541]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data() 
train_dataset = Dataset(X_train, y_train)
valid_dataset = Dataset(X_test, y_test)
train_datasource = Dataloader(train_dataset, batchsize=32)
valid_datasource = Dataloader(valid_dataset, batchsize=64)

network = Sequential(
    rescale_image,
    flatten,
    Linear(28*28, 128),
    relu,
    Linear(128, 10),
    softmax  
)

grad_fn = jax.jit(jax.grad(lambda model, X, y: fashion_mnist_loss(model(X), y)))

history = train(
    num_epochs=5, 
    train_datasource=train_datasource, 
    valid_datasource=None,
    optimizer=Adam(model, lr=1e-3), 
    loss_fn=fashion_mnist_loss, 
    model=network,
    grad_fn=grad_fn
)



Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


## Keras Run

In [463]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data() 


model = tf.keras.Sequential([
    tf.keras.layers.Rescaling(1/255.0),
    tf.keras.layers.Flatten(input_shape=(28,28)),
    tf.keras.layers.Dense(128, activation=tf.keras.activations.relu),
    tf.keras.layers.Dense(10, activation=tf.keras.activations.softmax)                          
])

model.compile(
    loss=tf.keras.losses.sparse_categorical_crossentropy, 
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

history = model.fit(X_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


## Next Time

1. Implement convolutional layers and train a convolutional neural network
2. Add a callback system to simplify the training loop
3. Show validation results
4. Evaluate performance on test set 