# Fashion MNIST using Linear Layers with JAX

- toc: true
- badges: true
- comments: true
- categories: [jupyter]
- image: images/chart-preview.png
- hide: true



In [1]:
import jax 
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from typing import Tuple, List, Any, Dict, Callable


2022-07-16 21:52:53.905857: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/lib:/usr/local/bin:/usr/local/lib:


## PyTorch / fast.ai like Data API

In [2]:
class Dataset:
    def __init__(self, X, y):
        self.X, self.y = X, y
    def __len__(self):
        return jnp.shape(self.X)[0]
    def __getitem__(self, i):
        return self.X[i,:], self.y[i]

In [3]:
class Dataloader:
    def __init__(self, dataset: Dataset, batchsize=32, shuffle=False):
        self.dataset = dataset
        self.batchsize = batchsize
        self.shuffle = shuffle
    def __iter__(self):
        for i in range(0, len(self.dataset), self.batchsize): 
            yield self.dataset[i:i+self.batchsize]
        

In [4]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0

In [5]:
fashion_mnist.load_data??

[0;31mSignature:[0m [0mfashion_mnist[0m[0;34m.[0m[0mload_data[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;34m@[0m[0mkeras_export[0m[0;34m([0m[0;34m'keras.datasets.fashion_mnist.load_data'[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;32mdef[0m [0mload_data[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;34m"""Loads the Fashion-MNIST dataset.[0m
[0;34m[0m
[0;34m  This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,[0m
[0;34m  along with a test set of 10,000 images. This dataset can be used as[0m
[0;34m  a drop-in replacement for MNIST.[0m
[0;34m[0m
[0;34m  The classes are:[0m
[0;34m[0m
[0;34m  | Label | Description |[0m
[0;34m  |:-----:|-------------|[0m
[0;34m  |   0   | T-shirt/top |[0m
[0;34m  |   1   | Trouser     |[0m
[0;34m  |   2   | Pullover    |[0m
[0;34m  |   3   | Dress       |[0m
[0;34m  |   4   | Coat        |[0m
[0;34m  |   5   | Sandal      |[0m
[0;34

In [6]:
dataset = Dataset(X_train, y_train)

In [7]:
dataloader = Dataloader(dataset)

In [8]:
for X, y in dataloader:
    print(X.shape, y.shape)
    

(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28)

## Model API

In [9]:
class Module: pass


### Linear Layer

In [10]:
class Linear(Module):
    w: jnp.ndarray 
    b: jnp.ndarray
    ni: int 
    no: int 

    def __init__(self, num_inputs, num_outputs, seed=1234):
        self.ni = num_inputs 
        self.no = num_outputs 
        key = jax.random.PRNGKey(seed)
        self.w = jax.random.normal(key, (num_inputs, num_outputs)) * jnp.sqrt(2.0 / num_inputs)
        self.b = jnp.zeros(num_outputs)
        
    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b
        
    def params(self):
        return {'b': self.b, 'w': self.w}


In [11]:
l = Linear(2, 1)
x = np.random.randn(2)
y = l(x)

print(y)




[1.2276616]


In [12]:
def mse(model, X, y):
    preds = jax.vmap(model)(X)
    return jnp.mean((preds - y)**2)


In [13]:
print(mse(l, x, 2.0))
print(mse(l, np.random.randn(10, 2), np.random.randn(10)))

4.035294
0.523962


In [14]:
mse_grad = jax.grad(mse)
mse_grad(l, x, 2.0)

TypeError: Argument '<__main__.Linear object at 0x7f9b4ac74b50>' of type <class '__main__.Linear'> is not a valid JAX type.

To get this to work, the `Linear` class must be registered as a pytree.  

In [15]:
@jax.tree_util.register_pytree_node_class
class Linear(Module):
    w: jnp.ndarray 
    b: jnp.ndarray
    ni: int 
    no: int 

    def __init__(self, num_inputs, num_outputs, build=True, seed=1234):
        self.ni = num_inputs 
        self.no = num_outputs 
        # want to add seed as internal object
        if build:
            key = jax.random.PRNGKey(seed)
            self.w = jax.random.normal(key, (num_inputs, num_outputs)) * jnp.sqrt(2.0 / num_inputs)
            self.b = jnp.zeros(num_outputs)
    
    def merge(self, params):
        self.w, self.b = params

    def __repr__(self):
        return f'Linear(num_inputs={self.ni}, num_outputs={self.no})'
        
    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b
        
    def params(self):
        return {'b': self.b, 'w': self.w}

    def tree_flatten(self):
        return [self.w, self.b], [self.ni, self.no]

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layer = cls(*aux_data, build=False)
        layer.merge(params=children)
        return layer

In [16]:
lin = Linear(2, 1)

In [17]:
params, extra_stuff = lin.tree_flatten()

In [18]:
lin2 = Linear.tree_unflatten(extra_stuff, params)

In [19]:
print(lin.w)
print(lin2.w) 

[[ 0.43957582]
 [-0.26563603]]
[[ 0.43957582]
 [-0.26563603]]


In [20]:
@jax.jit
@jax.value_and_grad
def mse(model, X, y):
    preds = jax.vmap(model)(X)
    return jnp.mean((preds - y)**2)

In [21]:
X = np.random.randn(10, 2)
y = np.random.randn(10)

loss, g_loss = mse(lin, X, y)
print(loss, g_loss)

1.7428839 Linear(num_inputs=2, num_outputs=1)


In [22]:
g_loss.__dict__

{'ni': 2,
 'no': 1,
 'w': DeviceArray([[ 1.4666067],
              [-0.3670275]], dtype=float32),
 'b': DeviceArray([-0.55795026], dtype=float32)}

In [23]:
jax.tree_util.tree_flatten(g_loss)

([DeviceArray([[ 1.4666067],
               [-0.3670275]], dtype=float32),
  DeviceArray([-0.55795026], dtype=float32)],
 PyTreeDef(CustomNode(<class '__main__.Linear'>[[2, 1]], [*, *])))

In [24]:
locals()['Linear'].__class__

type

### Helper Functions

In [25]:
def flatten(x: jnp.ndarray): 
    return jnp.reshape(x, -1) 


In [26]:
def relu(x: jnp.ndarray): 
    return jnp.clip(x, a_min=0)

   

In [27]:
x = np.random.randn(10,10)
assert jnp.all(jnp.isclose(relu(x), jax.nn.relu(x))), 'test failed'

In [28]:
def softmax(x: jnp.ndarray):
    ex = jnp.exp(x)
    return ex / jnp.sum(ex)

In [29]:
x = np.random.randn(10)
assert jnp.all(jnp.isclose(softmax(x), jax.nn.softmax(x))), 'test failed'

In [30]:
_registry = {
    flatten.__name__: flatten,
    softmax.__name__: softmax,
    relu.__name__: relu
}

### Sequential Module

In [31]:
@jax.tree_util.register_pytree_node_class
class Sequential(Module):
    layers: List
    def __init__(self, *layers):
        self.layers = layers
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    def tree_flatten(self):
        aux_data, children = [], []
        for layer in self.layers:
            if isinstance(layer, Module):
                params, extra_stuff = layer.tree_flatten()
                aux_data.append([layer.__class__.__name__] + extra_stuff)
                children.append(params) 
            elif callable(layer):
                # a layer function that doesn't have any paramerers ...
                aux_data.append(layer.__name__)
                children.append(None)    
        return children, aux_data
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layers = []
        # Need a bettern way to unflatten a sequential structure, more flexible at least...
        for params, spec in zip(children, aux_data):
            if isinstance(spec, list):
                layer_name, *args = spec
                if layer_name == 'Linear':
                    layers.append(Linear.tree_unflatten(args, params))   
            elif isinstance(spec, str) and spec in _registry:
                layers.append(_registry[spec])
        return Sequential(*layers)
    

In [57]:

def fashion_mnist_mlp():
    model = Sequential(
        flatten,
        Linear(784, 128),
        relu,
        Linear(128, 10),
        softmax   
    )

    return model 


model = fashion_mnist_mlp()
params, extra_stuff = model.tree_flatten()
print(extra_stuff)
print(params)

['flatten', ['Linear', 784, 128], 'relu', ['Linear', 128, 10], 'softmax']
[None, [DeviceArray([[-0.00503162, -0.11710759,  0.05479915, ..., -0.07662067,
              -0.03762808,  0.037621  ],
             [-0.02311066,  0.00427538,  0.06703123, ...,  0.05820996,
              -0.03371886, -0.0653995 ],
             [-0.03936624,  0.08184296, -0.00103856, ..., -0.02543773,
               0.00404367,  0.10533019],
             ...,
             [-0.05674443,  0.01220774, -0.04277196, ...,  0.00793091,
              -0.03246848,  0.05214054],
             [-0.10229313, -0.04473471, -0.05902693, ..., -0.026743  ,
               0.01399903, -0.02305236],
             [ 0.02624378, -0.040582  ,  0.04346804, ..., -0.0069246 ,
               0.04329436,  0.07048796]], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

## Cross-entropy Loss

In [38]:
@jax.value_and_grad
def cross_entropy(model, X, y, num_cats=10):
    y_one_hot = jax.nn.one_hot(y, num_cats)
    log_softmax = jnp.log(jax.vmap(model)(X))
    return -jnp.mean(log_softmax * y_one_hot)
    

In [39]:

value, grads = cross_entropy(model, np.random.randn(5, 28, 28), [1,0,1,1,0])
print(value)

0.31954172


In [63]:
list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

jax.tree_map(lambda x: x*2, list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

In [42]:
updated_model = jax.tree_util.tree_map(lambda p, g: p - 1e-3*g, model, grads)

In [43]:
assert jnp.all(jnp.isclose(updated_model.layers[1].w, model.layers[1].w - 1e-3*grads.layers[1].w))

## Stochastic Gradient Descent

In [44]:
class Optimizer: pass 

In [45]:
class SGD(Optimizer):
    def __init__(self, model, lr=1e-3):
        self.lr = lr 
    def step(self, model, grads):
        return jax.tree_map(lambda p, g: p - self.lr*g, model, grads)

In [78]:
class Adam(Optimizer):
    def __init__(self, model, lr=1e-3, v_decay=0.9, s_decay=0.999, eps=1e-8):
        self.lr, self.v_decay, self.s_decay, self.eps = lr, v_decay, s_decay, eps
        self.v = jax.tree_map(lambda x: jnp.zeros_like(x), model) 
        self.s = jax.tree_map(lambda x: jnp.zeros_like(x), model)
        self.k = 0 
    def step(self, model, grads):
        lr, v_decay, s_decay, eps = self.lr, self.v_decay, self.s_decay, self.eps
        k = self.k = self.k+1
        self.v = jax.tree_map(lambda v, g: v_decay*v +(1-v_decay)*g, self.v, grads)
        self.s = jax.tree_map(lambda s, g: s_decay*s +(1-s_decay)*g*g, self.s, grads)
        v_hat = jax.tree_map(lambda v: v / (1-v_decay**k), self.v)
        s_hat = jax.tree_map(lambda s: s / (1-s_decay**k), self.s)
        result = jax.tree_map(lambda params, v_hat, s_hat: params - (lr*v_hat)/(jnp.sqrt(s_hat) + eps), model, v_hat, s_hat)
        return result


<__main__.Sequential at 0x7f9b47b31790>

## Training Loop  

In [46]:
dataset = Dataset(X_train, y_train)
dataloader = Dataloader(dataset, batchsize=64)
model = fashion_mnist_mlp()
num_epochs, lr = 10, 1e-2
opt = SGD(model, lr=lr)

for i in range(num_epochs):
    epoch_correct_prediction_count = 0
    epoch_loss = 0.0
    num_training_examples = 0
    for X, y in dataloader:
        # evaluate the model
        loss, grad = cross_entropy(model, X, y)
        
        # update the model using gradient descent
        model = opt.step(model, grad)

        # metrics
        y_preds = jnp.argmax(jax.vmap(model)(X), axis=1)
        correct = jnp.sum(y_preds == y)
        
        epoch_correct_prediction_count += correct
        epoch_loss += loss 
    
        minibatch_size = jnp.shape(X)[0]
        num_training_examples += minibatch_size

    epoch_accuracy = epoch_correct_prediction_count / num_training_examples
    epoch_loss = epoch_loss / num_training_examples

    print(f'Epoch {i}: {100*epoch_accuracy:.2f}')


Epoch 0: 47.87
Epoch 1: 66.56
Epoch 2: 69.68
Epoch 3: 72.04
Epoch 4: 73.66
Epoch 5: 75.04
Epoch 6: 76.17
Epoch 7: 77.09
Epoch 8: 77.82
Epoch 9: 78.55


In [81]:
dataset = Dataset(X_train, y_train)
dataloader = Dataloader(dataset, batchsize=64)
model = fashion_mnist_mlp()
num_epochs, lr = 10, 1e-3
opt = Adam(model)

for i in range(num_epochs):
    epoch_correct_prediction_count = 0
    epoch_loss = 0.0
    num_training_examples = 0
    for X, y in dataloader:
        # evaluate the model
        loss, grad = cross_entropy(model, X, y)
        
        # update the model using gradient descent
        model = opt.step(model, grad)

        # metrics
        y_preds = jnp.argmax(jax.vmap(model)(X), axis=1)
        correct = jnp.sum(y_preds == y)
        
        epoch_correct_prediction_count += correct
        epoch_loss += loss 
    
        minibatch_size = jnp.shape(X)[0]
        num_training_examples += minibatch_size

    epoch_accuracy = epoch_correct_prediction_count / num_training_examples
    epoch_loss = epoch_loss / num_training_examples

    print(f'Epoch {i}: {100*epoch_accuracy:.2f}')

Epoch 0: 82.50
Epoch 1: 86.81
Epoch 2: 88.04
Epoch 3: 88.97
Epoch 4: 89.62
Epoch 5: 90.18
Epoch 6: 90.64
Epoch 7: 91.10
Epoch 8: 91.48
Epoch 9: 91.78


## API Improvements

Heavily inspired by fast.ai.  The **Deep Learning for Coders with fastai and PyTorch** book is great and full of great ideas and software design techniques.

In [82]:
# This is a metric, but used in the callback system
class AccuracyTracker:
    def __init__(self):
        self.correct_prediction_count = 0 
        self.num_examples = 0 
    def on_minibatch_end(self, model, X, y):
        y_preds = jnp.argmax(jax.vmap(model)(X), axis=1)
        self.num_examples += jnp.shape(X)[0]
        self.correct_count += jnp.sum(y_preds == y)  
    def on_epoch_end(self):
        accuracy = self.correct_count / self.num_examples
        return accuracy 

In [None]:
dataset = Dataset(X_train, y_train)
dataloader = Dataloader(dataset, batchsize=64)
model = fashion_mnist_mlp()
num_epochs, lr = 10, 1e-3
opt = Adam(model)

for i in range(num_epochs):
    epoch_correct_prediction_count = 0
    epoch_loss = 0.0
    num_training_examples = 0
    for X, y in dataloader:
        # evaluate the model
        loss, grad = cross_entropy(model, X, y)
        
        # update the model using gradient descent
        model = opt.step(model, grad)

        for cb in callbacks:
            cb.on_minibatch_end(model, X, y)
            
        # metrics
        y_preds = jnp.argmax(jax.vmap(model)(X), axis=1)
        correct = jnp.sum(y_preds == y)
        
        epoch_correct_prediction_count += correct
        epoch_loss += loss 
    
        minibatch_size = jnp.shape(X)[0]
        num_training_examples += minibatch_size

    epoch_accuracy = epoch_correct_prediction_count / num_training_examples
    epoch_loss = epoch_loss / num_training_examples

    print(f'Epoch {i}: {100*epoch_accuracy:.2f}')

## Performance Curve

Let's see the trend in the loss function.

## Conclusion

