# Fashion MNIST using Linear Layers with JAX

- toc: true
- badges: true
- comments: true
- categories: [jupyter]
- image: images/chart-preview.png
- hide: true



In this article, I'm going to start implementing a JAX-based neural network library, and use it to synthesize a Fashion-MNIST classifier. Features will be added over the coming weeks and months as I try tackle more advanced topics and architectures. I don't intend to write a full-featured library/framework like Tensorflow or Pytorch. Instead, my goal is compile a set of loosely coupled components that I can mix-and-match to solve problems that interest me. When there's some educational benefit, funtionality will be written from scratch and validated against other well-known libraries.  

Why do I want to do this: 


# Why JAX

JAX is a python-based numerical-computing library that uses Google's XLA (Accelerated Linear Algebra) compiler to generate high performance code that can run on a variety of hardware platforms.  It feels a lot like numpy, with a few quirks.  One of it's major selling points, particularly for machine learning and scientific computing, is it's ability to calculate derivatives.  It also heavily embraces the functional-programming-paradigm.

There are several JAX-based neural network libraries in the open-source space 

```python

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data() 

X_train, X_test = X_train / 255.0, X_test / 255.0

model = keras.Sequential([
  keras.layers.Flatten(input_shape=(28,28)),
  keras.layers.Dense(128, activation=keras.activations.relu),
  keras.layers.Dense(10, activation=keras.activations.softmax)                          
])

model.compile(
    loss=keras.losses.sparse_categorical_crossentropy, 
    optimizer=keras.optimizers.Adam(),
    metrics=['accuracy'])

model.fit(X_train, y_train, epochs=5)
```

```
Epoch 1/5
1875/1875 [==============================] - 5s 2ms/step - loss: 0.4955 - accuracy: 0.8250
Epoch 2/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3726 - accuracy: 0.8643
Epoch 3/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3354 - accuracy: 0.8781
Epoch 4/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3136 - accuracy: 0.8845
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2937 - accuracy: 0.8927
```

## The Fashion MNIST Dataset



In [1]:
import jax 
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import time 

from typing import Tuple, List, Any, Dict, Callable


2022-07-19 21:08:16.001764: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/lib:/usr/local/bin:/usr/local/lib:


## Datasource API

In [2]:
class Dataset:
    def __init__(self, X, y):
        self.X, self.y = X, y
    def __len__(self):
        return jnp.shape(self.X)[0]
    def __getitem__(self, i):
        return self.X[i,:], self.y[i]

In [3]:
class Dataloader:
    def __init__(self, dataset: Dataset, batchsize=32, shuffle=False):
        self.dataset = dataset
        self.batchsize = batchsize
        self.shuffle = shuffle
    def __iter__(self):
        for i in range(0, len(self.dataset), self.batchsize): 
            yield self.dataset[i:i+self.batchsize]
    def __len__(self):
        return len(self.dataset) // self.batchsize
        

In [139]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
#X_train, X_test = X_train / 255.0, X_test / 255.0

In [140]:
dataset = Dataset(X_train, y_train)

In [141]:
dataloader = Dataloader(dataset)

In [8]:
for X, y in dataloader:
    print(X.shape, y.shape)
    break
    

(32, 28, 28) (32,)


## Parametric Equations



In [68]:
class Parametric: pass


### Linear

To get this to work, the `Linear` class must be registered as a pytree.  

In [149]:
@jax.tree_util.register_pytree_node_class
class Linear(Parametric):
    w: jnp.ndarray 
    b: jnp.ndarray
    ni: int 
    no: int 

    def __init__(self, num_inputs, num_outputs, build=True, seed=1234):
        self.ni = num_inputs 
        self.no = num_outputs 
        # want to add seed as internal object
        if build:
            key = jax.random.PRNGKey(seed)
            self.w = jax.random.normal(key, (num_inputs, num_outputs)) * jnp.sqrt(2.0 / num_inputs)
            self.b = jnp.zeros(num_outputs)
    
    def merge(self, params):
        self.w, self.b = params

    def __repr__(self):
        return f'Linear(num_inputs={self.ni}, num_outputs={self.no})'
        
    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b
        
    def tree_flatten(self):
        return (self.w, self.b), (self.ni, self.no)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layer = cls(*aux_data, build=False)
        layer.w, layer.b = children
        #layer.merge(params=children)
        return layer

In [188]:
relu.__name__

'relu'

### Flatten

The `Flatten` layer converts a two dimension array to a single dimensional array or a batch of two-dimensional arrays, to a batch of one-dimensional arrays.  Each two dimensional array is flattened row-wise.

In [220]:
@jax.tree_util.register_pytree_node_class
class Function:
    def __init__(self, fn):
        self.fn = fn 
        
    def __call__(self, x):
        return self.fn(x)
    
    def __repr__(self):
        return f'Function({self.fn.__name__})'
    
    def tree_flatten(self):
        return [None], self.fn

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(aux_data)
    

In [245]:
a, b = jax.tree_flatten(Function(relu))
print(a, b)
jax.tree_unflatten(b,a)

[] PyTreeDef(CustomNode(<class '__main__.Function'>[<function relu at 0x7f7e65272dc0>], [None]))


Function(relu)

In [207]:
x = lambda a : a+1
x.__name__

'<lambda>'

In [143]:
def flatten(x):
    shape = jnp.shape(x)
        
    if len(shape) == 1:
        return x 
    elif len(shape) == 2:
        # flatten a single 2D image
        return jnp.reshape(x, -1)
    elif len(shape) == 3:
        # x is a batch of 2D images, flatten each image 
        batch_size = jnp.shape(x)[0]
        return jnp.reshape(x, (batch_size, -1))    

In [174]:
@jax.tree_util.register_pytree_node_class
class Flatten(Parametric):
    def __call__(self, x):
        shape = jnp.shape(x)
        
        if len(shape) == 1:
            return x 
        elif len(shape) == 2:
            # flatten a single 2D image
            return jnp.reshape(x, -1)
        elif len(shape) == 3:
            # x is a batch of 2D images, flatten each image 
            batch_size = jnp.shape(x)[0]
            return jnp.reshape(x, (batch_size, -1))
    
    def __repr__(self):
        return f'Flatten()'
    
    def tree_flatten(self):
        return [None], self.__class__.__name__

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls()

Here's an example of `Flatten` operating on a two-element batch of $3\times 3$ arrays.

In [175]:
x = np.array([[[1,2,3], [4,5,6]], [[7,8,9], [10,11,12]]])
y = Flatten()(x)
print(y)

[[ 1  2  3  4  5  6]
 [ 7  8  9 10 11 12]]


### RELU

In [177]:
def relu(x):
    return jnp.clip(x, a_min=0)    

In [178]:
@jax.tree_util.register_pytree_node_class
class Relu(Parametric):
    def __call__(self, x):
        return jnp.clip(x, a_min=0)

    def __repr__(self):
        return f'Relu()'
    
    def tree_flatten(self):
        return [None], self.__class__.__name__

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls()

In [179]:
x = np.random.randn(10,10)
assert jnp.all(jnp.isclose(relu(x), jax.nn.relu(x))), 'test failed'

### Softmax

In [180]:
def softmax(x):
    ex = jnp.exp(x)
    return ex / jnp.sum(ex, axis=1, keepdims=True)    

In [181]:
@jax.tree_util.register_pytree_node_class
class Softmax(Parametric):
    def __call__(self, x):
        ex = jnp.exp(x)
        return ex / jnp.sum(ex, axis=1, keepdims=True) 

    def __repr__(self):
        return f'Softmax()'
    
    def tree_flatten(self):
        return [None], self.__class__.__name__

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls()
    

In [182]:
x = np.random.randn(5, 10)
assert jnp.all(jnp.isclose(softmax(x), jax.nn.softmax(x))), 'test failed'
assert jnp.isclose(jnp.sum(softmax(x)), jnp.shape(x)[0])

### ScaleImage

In [253]:
def normalize_image(x):
    return x / 255.0

In [184]:
@jax.tree_util.register_pytree_node_class
class NormalizeImage(Parametric):
    
    def __call__(self, x):
        return x / 255.0 

    def __repr__(self):
        return f'NormalizeImage()'
    
    def tree_flatten(self):
        return [None], self.__class__.__name__

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls()

## Sequential

In [248]:
@jax.tree_util.register_pytree_node_class
class Sequential(Parametric):

    def __init__(self, *layers):
        self.layers = []
        for layer in layers:
            if isinstance(layer, Parametric) or isinstance(layer, Function):
                self.layers.append(layer)
            elif callable(layer):
                self.layers.append(Function(layer))
    
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    def __repr__(self):
        string = ''
        for layer in self.layers:
            string += (repr(layer) + '\n')
        return string
    
    def tree_flatten(self):
        aux_data, children = [], []
        for layer in self.layers:
            params, extra_stuff = jax.tree_flatten(layer)
            aux_data.append(extra_stuff)
            children.append(params)
        return children, aux_data
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layers = []
        for params, spec in zip(children, aux_data):
            layers.append(jax.tree_unflatten(spec, params))
        return Sequential(*layers)

In [254]:

def fashion_mnist_mlp():
    model = Sequential(
        normalize_image,
        flatten,
        Linear(784, 128),
        relu,
        Linear(128, 10),
        softmax  
    )

    return model
model = fashion_mnist_mlp()
print(model)
a, b = jax.tree_flatten(model)
jax.tree_unflatten(b, a)


Function(normalize_image)
Function(flatten)
Linear(num_inputs=784, num_outputs=128)
Function(relu)
Linear(num_inputs=128, num_outputs=10)
Function(softmax)



Function(normalize_image)
Function(flatten)
Linear(num_inputs=784, num_outputs=128)
Function(relu)
Linear(num_inputs=128, num_outputs=10)
Function(softmax)

## Cross-entropy Loss

## Optimizers

In [77]:
class Optimizer: pass

In [78]:
class SGD(Optimizer):
    def __init__(self, model, lr=1e-3):
        self.lr = lr 
    def step(self, model, grads):
        return jax.tree_map(lambda p, g: p - self.lr*g, model, grads)

In [79]:
class Adam(Optimizer):
    def __init__(self, model, lr=1e-3, v_decay=0.9, s_decay=0.999, eps=1e-7):
        self.lr, self.v_decay, self.s_decay, self.eps = lr, v_decay, s_decay, eps
        self.v = jax.tree_map(lambda x: jnp.zeros_like(x), model) 
        self.s = jax.tree_map(lambda x: jnp.zeros_like(x), model)
        self.k = 0 
    def step(self, model, grads):
        lr, v_decay, s_decay, eps = self.lr, self.v_decay, self.s_decay, self.eps
        v, s = self.v, self.s
        k = self.k = self.k+1
        self.v = jax.tree_map(lambda v, g: v_decay*v +(1-v_decay)*g, v, grads)
        self.s = jax.tree_map(lambda s, g: s_decay*s +(1-s_decay)*g*g, s, grads)
        v_hat = jax.tree_map(lambda v: v / (1-v_decay**k), self.v)
        s_hat = jax.tree_map(lambda s: s / (1-s_decay**k), self.s)
        new_model = jax.tree_map(lambda params, v_hat, s_hat: params - (lr*v_hat)/(jnp.sqrt(s_hat) + eps), model, v_hat, s_hat)
        return new_model


## Training Loop  


In [136]:
def train(num_epochs, train_datasource, valid_datasource, optimizer, loss_fn, grad_fn, model):
    history = {'loss':[], 'accuracy':[]}
    
    if valid_datasource is not None:
        history = {**history, 'valid_loss': [], 'valid_accuracy': []}

    train_num_batches = len(train_datasource)

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')

        # TRAINING PHASE
        train_loss_accum, train_accuracy_accum, train_batch_size = 0, 0, 0
        
        num_steps = 0 
        
        epoch_duration = 0.0
        for i, (X_train, y_train) in enumerate(train_datasource):
            
            # logging
            batch_start = time.time()

            num_steps += 1
            
            # training loss and gradients for this particular batch
            probs = model(X_train)
            loss = loss_fn(probs, y_train)
            
            grads = grad_fn(model, X_train, y_train)
            model = optimizer.step(model, grads)
                  
            # update for metrics
            train_loss_accum += loss 
            train_batch_size += len(y_train)
            train_accuracy_accum += jnp.sum(jnp.argmax(probs, axis=-1) == y_train)
            train_accuracy = train_accuracy_accum / train_batch_size
            train_loss = train_loss_accum / train_num_batches # average loss per batch

            # Logging ....
            batch_duration = time.time() - batch_start
            epoch_duration += batch_duration 
            log_batch_count = f'{i}/{train_num_batches}'
            log_epoch_time = f'{epoch_duration:>5.2f}s'
            log_batch_time = f'{1_000*batch_duration:>5.2f}ms/batch'
            log_batch_loss = f'train_loss:  {train_loss:>5.2f}'
            log_batch_accuracy = f'train_accuracy:  {100*train_accuracy:<5.2f}'
            log_string =  f'{log_batch_count}  [           ]  {log_epoch_time} {log_batch_time}  ,  {log_batch_loss}  ,  {log_batch_accuracy} '
            print(log_string, end='\r') 

        # 
        history['loss'].append(train_loss)
        history['accuracy'].append(train_accuracy)      

        # VALIDATION PHASE
        if valid_datasource is not None:
            valid_loss_accum, valid_accuracy_accum, valid_batch_size = 0, 0, 0 

            # Run validation step ...
            for X_valid, y_valid in valid_datasource:
                num_steps += 1
                probs = model(X_valid)
                loss = loss_fn(probs, y_valid)
                
                valid_accuracy_accum += jnp.sum(jnp.argmax(probs, axis=-1) == y_valid)

                valid_loss_accum += loss
                valid_batch_size += len(y_valid)

            epoch_valid_loss = valid_loss_accum / valid_batch_size 
            epoch_valid_accuracy = valid_accuracy_accum / valid_batch_size

            history['loss'].append(epoch_valid_loss)
            history['accuracy'].append(epoch_valid_accuracy)
        
        # this log_string should include validation results
        print(log_string, end='\n')
    return history
    

In [137]:
def fashion_mnist_loss(probs, y_true, num_classes=10):
    # average cross entropy, batch
    y_one_hot = jax.nn.one_hot(y_true, num_classes)
    return -jnp.sum(jnp.log(probs) * y_one_hot) / len(y_true)

In [252]:
X_train_, y_train_ = X_train[:40_000,:,:], y_train[:40_000]
X_valid, y_valid = X_train[40_000:,:,:], y_train[40_000:]

train_dataset = Dataset(X_train_, y_train_)
valid_dataset = Dataset(X_valid, y_valid)

train_datasource = Dataloader(train_dataset, batchsize=64)
valid_datasource = Dataloader(valid_dataset, batchsize=64)
model = fashion_mnist_mlp()



grad_fn = jax.grad(lambda model, X, y: fashion_mnist_loss(model(X), y))

history = train(
    num_epochs=5, 
    train_datasource=train_datasource, 
    valid_datasource=valid_datasource, 
    optimizer=Adam(model, lr=1e-3), 
    loss_fn=fashion_mnist_loss, 
    model=model,
    grad_fn=grad_fn
)



Epoch 1/5
624/625  [           ]   4.78s  7.43ms/batch  ,  train_loss:   0.56  ,  train_accuracy:  80.31 
Epoch 2/5
624/625  [           ]   4.79s  7.07ms/batch  ,  train_loss:   0.41  ,  train_accuracy:  85.45 
Epoch 3/5
624/625  [           ]   4.85s  8.14ms/batch  ,  train_loss:   0.37  ,  train_accuracy:  86.83 
Epoch 4/5
624/625  [           ]   4.94s  7.94ms/batch  ,  train_loss:   0.34  ,  train_accuracy:  87.76 
Epoch 5/5
624/625  [           ]   5.61s  8.44ms/batch  ,  train_loss:   0.32  ,  train_accuracy:  88.52 


## Performance Curve

## Conclusion



In [94]:
print('98.56', end='')
time.sleep(1)
print('\r64.34')

64.34


In [48]:
x = 10
f'Hello {x:>10.1f}'

'Hello       10.0'

In [43]:
??format

[0;31mSignature:[0m [0mformat[0m[0;34m([0m[0mvalue[0m[0;34m,[0m [0mformat_spec[0m[0;34m=[0m[0;34m''[0m[0;34m,[0m [0;34m/[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Return value.__format__(format_spec)

format_spec defaults to the empty string.
See the Format Specification Mini-Language section of help('FORMATTING') for
details.
[0;31mType:[0m      builtin_function_or_method


In [44]:
?jax.tree_unflatten??

[0;31mSignature:[0m [0mjax[0m[0;34m.[0m[0mtree_unflatten[0m[0;34m([0m[0mtreedef[0m[0;34m,[0m [0mleaves[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mtree_unflatten[0m[0;34m([0m[0mtreedef[0m[0;34m,[0m [0mleaves[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;34m"""Reconstructs a pytree from the treedef and the leaves.[0m
[0;34m[0m
[0;34m  The inverse of :func:`tree_flatten`.[0m
[0;34m[0m
[0;34m  Args:[0m
[0;34m    treedef: the treedef to reconstruct[0m
[0;34m    leaves: the list of leaves to use for reconstruction. The list must match[0m
[0;34m      the leaves of the treedef.[0m
[0;34m[0m
[0;34m  Returns:[0m
[0;34m    The reconstructed pytree, containing the ``leaves`` placed in the structure[0m
[0;34m    described by ``treedef``.[0m
[0;34m  """[0m[0;34m[0m
[0;34m[0m  [0;32mreturn[0m [0mtreedef[0m[0;34m.[0m[0munflatten[0m[0;34m([0m[0mleaves[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0

In [None]:
?jax.tree_unflatten

In [45]:
model

<__main__.Sequential at 0x7f7f78ee1f70>

In [48]:
a, b = model.tree_flatten()

In [49]:
Sequential.tree_unflatten(b, a)

<__main__.Sequential at 0x7f7ef447b760>