# Fashion MNIST using Linear Layers with JAX

- toc: true
- badges: true
- comments: true
- categories: [jupyter]
- image: images/chart-preview.png
- hide: true



## Introduction

In this article, I'm going to start implementing a JAX-based neural network library, and use it to synthesize a Fashion-MNIST classifier.  Features will be added over the coming weeks and months as I try tp tackle more advanced topics.  I don't intend to write a full-featured library/framework like Tensorflow or Pytorch.  Instead, my goal is compile a set of loosely coupled components that I can mix-and-match to solve problems that interest me.  When there's some educational benefit, funtionality will be written from scratch and validated against other well-known libraries (including JAX itself).  

```python

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data() 

X_train, X_test = X_train / 255.0, X_test / 255.0

model = keras.Sequential([
  keras.layers.Flatten(input_shape=(28,28)),
  keras.layers.Dense(128, activation=keras.activations.relu),
  keras.layers.Dense(10, activation=keras.activations.softmax)                          
])

model.compile(
    loss=keras.losses.sparse_categorical_crossentropy, 
    optimizer=keras.optimizers.Adam(),
    metrics=['accuracy'])

model.fit(X_train, y_train, epochs=5)
```

```
Epoch 1/5
1875/1875 [==============================] - 5s 2ms/step - loss: 0.4955 - accuracy: 0.8250
Epoch 2/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3726 - accuracy: 0.8643
Epoch 3/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3354 - accuracy: 0.8781
Epoch 4/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3136 - accuracy: 0.8845
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2937 - accuracy: 0.8927
```

## What is JAX 

I've only been learning about JAX for a couple of weeks, and still have a lot to learn.  From what I've gathered so far, JAX is a high-performance numeric computing library that uses Google's XLA (Accelerated Linear Algebra) compiler.    

## The Fashion MNIST Dataset



In [5]:
import jax 
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import time 

from typing import Tuple, List, Any, Dict, Callable


## Data API

In [6]:
class Dataset:
    def __init__(self, X, y):
        self.X, self.y = X, y
    def __len__(self):
        return jnp.shape(self.X)[0]
    def __getitem__(self, i):
        return self.X[i,:], self.y[i]

In [7]:
class Dataloader:
    def __init__(self, dataset: Dataset, batchsize=32, shuffle=False):
        self.dataset = dataset
        self.batchsize = batchsize
        self.shuffle = shuffle
    def __iter__(self):
        for i in range(0, len(self.dataset), self.batchsize): 
            yield self.dataset[i:i+self.batchsize]
    def __len__(self):
        return len(self.dataset) // self.batchsize
        

In [8]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0

In [9]:
fashion_mnist.load_data??

[0;31mSignature:[0m [0mfashion_mnist[0m[0;34m.[0m[0mload_data[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;34m@[0m[0mkeras_export[0m[0;34m([0m[0;34m'keras.datasets.fashion_mnist.load_data'[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;32mdef[0m [0mload_data[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;34m"""Loads the Fashion-MNIST dataset.[0m
[0;34m[0m
[0;34m  This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,[0m
[0;34m  along with a test set of 10,000 images. This dataset can be used as[0m
[0;34m  a drop-in replacement for MNIST.[0m
[0;34m[0m
[0;34m  The classes are:[0m
[0;34m[0m
[0;34m  | Label | Description |[0m
[0;34m  |:-----:|-------------|[0m
[0;34m  |   0   | T-shirt/top |[0m
[0;34m  |   1   | Trouser     |[0m
[0;34m  |   2   | Pullover    |[0m
[0;34m  |   3   | Dress       |[0m
[0;34m  |   4   | Coat        |[0m
[0;34m  |   5   | Sandal      |[0m
[0;34

In [10]:
dataset = Dataset(X_train, y_train)

In [11]:
dataloader = Dataloader(dataset)

In [12]:
for X, y in dataloader:
    print(X.shape, y.shape)
    

(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28)

## Modules

In [13]:
class Module: pass


### Linear Layer

To get this to work, the `Linear` class must be registered as a pytree.  

In [14]:
@jax.tree_util.register_pytree_node_class
class Linear(Module):
    w: jnp.ndarray 
    b: jnp.ndarray
    ni: int 
    no: int 

    def __init__(self, num_inputs, num_outputs, build=True, seed=1234):
        self.ni = num_inputs 
        self.no = num_outputs 
        # want to add seed as internal object
        if build:
            key = jax.random.PRNGKey(seed)
            self.w = jax.random.normal(key, (num_inputs, num_outputs)) * jnp.sqrt(2.0 / num_inputs)
            self.b = jnp.zeros(num_outputs)
    
    def merge(self, params):
        self.w, self.b = params

    def __repr__(self):
        return f'Linear(num_inputs={self.ni}, num_outputs={self.no})'
        
    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b
        
    def params(self):
        return {'b': self.b, 'w': self.w}

    def tree_flatten(self):
        return [self.w, self.b], [self.ni, self.no]

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layer = cls(*aux_data, build=False)
        layer.merge(params=children)
        return layer

### Layer Functions

In [15]:
def flatten(x: jnp.ndarray): 
    shape = jnp.shape(x)
    new_shape = -1 if len(shape) == 1 else (shape[0], -1) 
    return jnp.reshape(x, new_shape) 


In [16]:
def relu(x: jnp.ndarray): 
    return jnp.clip(x, a_min=0)
   

In [17]:
x = np.random.randn(10,10)
assert jnp.all(jnp.isclose(relu(x), jax.nn.relu(x))), 'test failed'

In [18]:
def softmax(x: jnp.ndarray):
    ex = jnp.exp(x)
    return ex / jnp.sum(ex, axis=1, keepdims=True) 

In [19]:
x = np.random.randn(5, 10)
assert jnp.all(jnp.isclose(softmax(x), jax.nn.softmax(x))), 'test failed'
assert jnp.isclose(jnp.sum(softmax(x)), jnp.shape(x)[0])

In [20]:
_registry = {
    flatten.__name__: flatten,
    softmax.__name__: softmax,
    relu.__name__: relu
}

### Sequential Module

In [21]:
@jax.tree_util.register_pytree_node_class
class Sequential(Module):
    layers: List
    def __init__(self, *layers):
        self.layers = layers
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    def tree_flatten(self):
        aux_data, children = [], []
        for layer in self.layers:
            if isinstance(layer, Module):
                params, extra_stuff = layer.tree_flatten()
                aux_data.append([layer.__class__.__name__] + extra_stuff)
                children.append(params) 
            elif callable(layer):
                # a layer function that doesn't have any paramerers ...
                aux_data.append(layer.__name__)
                children.append(None)    
        return children, aux_data
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layers = []
        # Want a more generic way to unflatten
        for params, spec in zip(children, aux_data):
            if isinstance(spec, list):
                layer_name, *args = spec
                if layer_name == 'Linear':
                    layers.append(Linear.tree_unflatten(args, params))   
            elif isinstance(spec, str) and spec in _registry:
                layers.append(_registry[spec])
        return Sequential(*layers)
    

In [22]:

def fashion_mnist_mlp():
    model = Sequential(
        flatten,
        Linear(784, 128),
        relu,
        Linear(128, 10),
        softmax   
    )

    return model


model = fashion_mnist_mlp()
params, extra_stuff = model.tree_flatten()
print(extra_stuff)
print(params)

['flatten', ['Linear', 784, 128], 'relu', ['Linear', 128, 10], 'softmax']
[None, [DeviceArray([[-0.00503162, -0.11710759,  0.05479915, ..., -0.07662067,
              -0.03762808,  0.037621  ],
             [-0.02311066,  0.00427538,  0.06703123, ...,  0.05820996,
              -0.03371886, -0.0653995 ],
             [-0.03936624,  0.08184296, -0.00103856, ..., -0.02543773,
               0.00404367,  0.10533019],
             ...,
             [-0.05674443,  0.01220774, -0.04277196, ...,  0.00793091,
              -0.03246848,  0.05214054],
             [-0.10229313, -0.04473471, -0.05902693, ..., -0.026743  ,
               0.01399903, -0.02305236],
             [ 0.02624378, -0.040582  ,  0.04346804, ..., -0.0069246 ,
               0.04329436,  0.07048796]], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

## Cross-entropy Loss

## Optimizers

In [23]:
class Optimizer: pass 

In [24]:
class SGD(Optimizer):
    def __init__(self, model, lr=1e-3):
        self.lr = lr 
    def step(self, model, grads):
        return jax.tree_map(lambda p, g: p - self.lr*g, model, grads)

In [25]:
class Adam(Optimizer):
    def __init__(self, model, lr=1e-3, v_decay=0.9, s_decay=0.999, eps=1e-7):
        self.lr, self.v_decay, self.s_decay, self.eps = lr, v_decay, s_decay, eps
        self.v = jax.tree_map(lambda x: jnp.zeros_like(x), model) 
        self.s = jax.tree_map(lambda x: jnp.zeros_like(x), model)
        self.k = 0 
    def step(self, model, grads):
        lr, v_decay, s_decay, eps = self.lr, self.v_decay, self.s_decay, self.eps
        k = self.k = self.k+1
        self.v = jax.tree_map(lambda v, g: v_decay*v +(1-v_decay)*g, self.v, grads)
        self.s = jax.tree_map(lambda s, g: s_decay*s +(1-s_decay)*g*g, self.s, grads)
        v_hat = jax.tree_map(lambda v: v / (1-v_decay**k), self.v)
        s_hat = jax.tree_map(lambda s: s / (1-s_decay**k), self.s)
        new_model = jax.tree_map(lambda params, v_hat, s_hat: params - (lr*v_hat)/(jnp.sqrt(s_hat) + eps), model, v_hat, s_hat)
        return new_model


## Training Loop  


In [26]:
def train(num_epochs, train_datasource, valid_datasource, optimizer, loss_fn, grad_fn, model):
    history = {'loss':[], 'accuracy':[]}
    
    if valid_datasource is not None:
        history = {**history, 'valid_loss': [], 'valid_accuracy': []}

    train_num_batches = len(train_datasource)

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')

        # TRAINING PHASE
        train_loss_accum, train_accuracy_accum, train_batch_size = 0, 0, 0
        
        num_steps = 0 
        
        epoch_duration = 0.0
        for i, (X_train, y_train) in enumerate(train_datasource):
            
            # logging
            batch_start = time.time()

            num_steps += 1
            # training loss and gradients for this particular batch
            probs = model(X_train)
            loss = loss_fn(probs, y_train)
            
            grads = grad_fn(model, X_train, y_train)
            model = optimizer.step(model, grads)
            
                
            # update for metrics
            train_loss_accum += loss 
            train_batch_size += len(y_train)
            train_accuracy_accum += jnp.sum(jnp.argmax(probs, axis=-1) == y_train)
            train_accuracy = train_accuracy_accum / train_batch_size
            train_loss = train_loss_accum / train_num_batches # average loss per batch

            # Logging ....
            batch_duration = time.time() - batch_start
            epoch_duration += batch_duration 
            log_batch_count = f'{i}/{train_num_batches}'
            log_epoch_time = f'{epoch_duration:.2f}s'
            log_batch_time = f'{1_000*batch_duration:.2f}ms/batch'
            log_batch_loss = f'train_loss:  {train_loss:.2f}'
            log_batch_accuracy = f'train_accuracy:  {100*train_accuracy:.2f}'
            log_string =  f'{log_batch_count}  [           ]  {log_epoch_time} {log_batch_time}  ,  {log_batch_loss}  ,  {log_batch_accuracy} '
            print(log_string, end='\r') 



        history['loss'].append(train_loss)
        history['accuracy'].append(train_accuracy)      

        # VALIDATION PHASE
        if valid_datasource is not None:
            valid_loss_accum, valid_accuracy_accum, valid_batch_size = 0, 0, 0 

            # Run validation step ...
            for X_valid, y_valid in valid_datasource:
                num_steps += 1
                probs = model(X_valid)
                loss = loss_fn(probs, y_valid)
                
                valid_accuracy_accum += jnp.sum(jnp.argmax(probs, axis=-1) == y_valid)

                valid_loss_accum += loss
                valid_batch_size += len(y_valid)

            epoch_valid_loss = valid_loss_accum / valid_batch_size 
            epoch_valid_accuracy = valid_accuracy_accum / valid_batch_size

            history['loss'].append(epoch_valid_loss)
            history['accuracy'].append(epoch_valid_accuracy)
  
        print(log_string, end='\n')
    return history
    

In [27]:
jnp.sum(jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]))/2

DeviceArray(1., dtype=float32)

In [36]:
def fashion_mnist_loss(probs, y_true, num_classes=10):
    # average cross entropy, batch
    y_one_hot = jax.nn.one_hot(y_true, num_classes)
    return -jnp.sum(jnp.log(probs) * y_one_hot) / len(y_true)

In [37]:
y_true = np.array([1, 2])
y_pred = np.array([[0.05, 0.95, 0.01], [0.1, 0.8, 0.1]])
fashion_mnist_loss(y_pred, y_true, num_classes=3)

DeviceArray(1.1769392, dtype=float32)

In [30]:
X_train_, y_train_ = X_train[:40_000,:,:], y_train[:40_000]
X_valid, y_valid = X_train[40_000:,:,:], y_train[40_000:]

train_dataset = Dataset(X_train_, y_train_)
valid_dataset = Dataset(X_valid, y_valid)

train_datasource = Dataloader(train_dataset, batchsize=64)
valid_datasource = Dataloader(valid_dataset, batchsize=64)
model = fashion_mnist_mlp()



grad_fn = jax.grad(lambda model, X, y: fashion_mnist_loss(model(X), y))

history = train(
    num_epochs=5, 
    train_datasource=train_datasource, 
    valid_datasource=valid_datasource, 
    optimizer=Adam(model, lr=1e-3), 
    loss_fn=fashion_mnist_loss, 
    model=model,
    grad_fn=grad_fn
)



Epoch 1/5
624/625  [           ]  5.71s 10.72ms/batch  ,  train_loss:  0.56  ,  train_accuracy:  80.31 
Epoch 2/5
624/625  [           ]  6.13s 10.34ms/batch  ,  train_loss:  0.41  ,  train_accuracy:  85.45 
Epoch 3/5
624/625  [           ]  6.50s 7.56ms/batch  ,  train_loss:  0.37  ,  train_accuracy:  86.83  
Epoch 4/5
624/625  [           ]  6.16s 7.62ms/batch  ,  train_loss:  0.34  ,  train_accuracy:  87.76  
Epoch 5/5
624/625  [           ]  6.46s 9.58ms/batch  ,  train_loss:  0.32  ,  train_accuracy:  88.52  


## Performance Curve

## Conclusion



In [94]:
print('98.56', end='')
time.sleep(1)
print('\r64.34')

64.34
