# Fashion MNIST using Linear Layers with JAX

- toc: true
- badges: true
- comments: true
- categories: [jupyter]
- image: images/chart-preview.png
- hide: true



In [1]:
import jax 
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from typing import Tuple, List, Any, Dict, Callable


2022-07-18 15:40:07.488699: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/lib:/usr/local/bin:/usr/local/lib:


## PyTorch / fast.ai like Data API

In [2]:
class Dataset:
    def __init__(self, X, y):
        self.X, self.y = X, y
    def __len__(self):
        return jnp.shape(self.X)[0]
    def __getitem__(self, i):
        return self.X[i,:], self.y[i]

In [3]:
class Dataloader:
    def __init__(self, dataset: Dataset, batchsize=32, shuffle=False):
        self.dataset = dataset
        self.batchsize = batchsize
        self.shuffle = shuffle
    def __iter__(self):
        for i in range(0, len(self.dataset), self.batchsize): 
            yield self.dataset[i:i+self.batchsize]
        

In [4]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0

In [5]:
fashion_mnist.load_data??

[0;31mSignature:[0m [0mfashion_mnist[0m[0;34m.[0m[0mload_data[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;34m@[0m[0mkeras_export[0m[0;34m([0m[0;34m'keras.datasets.fashion_mnist.load_data'[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;32mdef[0m [0mload_data[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;34m"""Loads the Fashion-MNIST dataset.[0m
[0;34m[0m
[0;34m  This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,[0m
[0;34m  along with a test set of 10,000 images. This dataset can be used as[0m
[0;34m  a drop-in replacement for MNIST.[0m
[0;34m[0m
[0;34m  The classes are:[0m
[0;34m[0m
[0;34m  | Label | Description |[0m
[0;34m  |:-----:|-------------|[0m
[0;34m  |   0   | T-shirt/top |[0m
[0;34m  |   1   | Trouser     |[0m
[0;34m  |   2   | Pullover    |[0m
[0;34m  |   3   | Dress       |[0m
[0;34m  |   4   | Coat        |[0m
[0;34m  |   5   | Sandal      |[0m
[0;34

In [6]:
dataset = Dataset(X_train, y_train)

In [7]:
dataloader = Dataloader(dataset)

In [8]:
for X, y in dataloader:
    print(X.shape, y.shape)
    

(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28)

## Model API

In [9]:
class Module: pass


### Linear Layer

In [10]:
class Linear(Module):
    w: jnp.ndarray 
    b: jnp.ndarray
    ni: int 
    no: int 

    def __init__(self, num_inputs, num_outputs, seed=1234):
        self.ni = num_inputs 
        self.no = num_outputs 
        key = jax.random.PRNGKey(seed)
        self.w = jax.random.normal(key, (num_inputs, num_outputs)) * jnp.sqrt(2.0 / num_inputs)
        self.b = jnp.zeros(num_outputs)
        
    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b
        
    def params(self):
        return {'b': self.b, 'w': self.w}


In [11]:
l = Linear(2, 1)
x = np.random.randn(5,2)
y = l(x)

print(y)




[[-0.37713784]
 [ 0.5444933 ]
 [ 0.9541705 ]
 [ 0.3122406 ]
 [-0.03162232]]


In [12]:
def mse(model, X, y):
    preds = jax.vmap(model)(X)
    return jnp.mean((preds - y)**2)


In [13]:
print(mse(l, x, 2.0))
print(mse(l, np.random.randn(10, 2), np.random.randn(10)))

3.167813
1.0486954


In [14]:
mse_grad = jax.grad(mse)
mse_grad(l, x, 2.0)

TypeError: Argument '<__main__.Linear object at 0x7fd2a0592f70>' of type <class '__main__.Linear'> is not a valid JAX type.

To get this to work, the `Linear` class must be registered as a pytree.  

In [15]:
@jax.tree_util.register_pytree_node_class
class Linear(Module):
    w: jnp.ndarray 
    b: jnp.ndarray
    ni: int 
    no: int 

    def __init__(self, num_inputs, num_outputs, build=True, seed=1234):
        self.ni = num_inputs 
        self.no = num_outputs 
        # want to add seed as internal object
        if build:
            key = jax.random.PRNGKey(seed)
            self.w = jax.random.normal(key, (num_inputs, num_outputs)) * jnp.sqrt(2.0 / num_inputs)
            self.b = jnp.zeros(num_outputs)
    
    def merge(self, params):
        self.w, self.b = params

    def __repr__(self):
        return f'Linear(num_inputs={self.ni}, num_outputs={self.no})'
        
    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b
        
    def params(self):
        return {'b': self.b, 'w': self.w}

    def tree_flatten(self):
        return [self.w, self.b], [self.ni, self.no]

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layer = cls(*aux_data, build=False)
        layer.merge(params=children)
        return layer

In [16]:
lin = Linear(2, 1)

In [17]:
params, extra_stuff = lin.tree_flatten()

In [18]:
lin2 = Linear.tree_unflatten(extra_stuff, params)

In [19]:
print(lin.w)
print(lin2.w) 

[[ 0.43957582]
 [-0.26563603]]
[[ 0.43957582]
 [-0.26563603]]


In [20]:
@jax.jit
@jax.value_and_grad
def mse(model, X, y):
    preds = jax.vmap(model)(X)
    return jnp.mean((preds - y)**2)

In [21]:
X = np.random.randn(10, 2)
y = np.random.randn(10)

loss, g_loss = mse(lin, X, y)
print(loss, g_loss)

1.2169735 Linear(num_inputs=2, num_outputs=1)


In [22]:
g_loss.__dict__

{'ni': 2,
 'no': 1,
 'w': DeviceArray([[ 1.5454377 ],
              [-0.11858664]], dtype=float32),
 'b': DeviceArray([-0.7558269], dtype=float32)}

In [23]:
jax.tree_util.tree_flatten(g_loss)

([DeviceArray([[ 1.5454377 ],
               [-0.11858664]], dtype=float32),
  DeviceArray([-0.7558269], dtype=float32)],
 PyTreeDef(CustomNode(<class '__main__.Linear'>[[2, 1]], [*, *])))

In [24]:
locals()['Linear'].__class__

type

### Helper Functions

In [25]:
def flatten(x: jnp.ndarray): 
    shape = jnp.shape(x)
    new_shape = -1 if len(shape) == 1 else (shape[0], -1) 
    return jnp.reshape(x, new_shape) 


In [26]:
def relu(x: jnp.ndarray): 
    return jnp.clip(x, a_min=0)
   

In [27]:
x = np.random.randn(10,10)
assert jnp.all(jnp.isclose(relu(x), jax.nn.relu(x))), 'test failed'

In [28]:
def softmax(x: jnp.ndarray):
    ex = jnp.exp(x)
    return ex / jnp.sum(ex)

In [29]:
x = np.random.randn(10)
assert jnp.all(jnp.isclose(softmax(x), jax.nn.softmax(x))), 'test failed'

In [30]:
_registry = {
    flatten.__name__: flatten,
    softmax.__name__: softmax,
    relu.__name__: relu
}

### Sequential Module

In [31]:
@jax.tree_util.register_pytree_node_class
class Sequential(Module):
    layers: List
    def __init__(self, *layers):
        self.layers = layers
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    def tree_flatten(self):
        aux_data, children = [], []
        for layer in self.layers:
            if isinstance(layer, Module):
                params, extra_stuff = layer.tree_flatten()
                aux_data.append([layer.__class__.__name__] + extra_stuff)
                children.append(params) 
            elif callable(layer):
                # a layer function that doesn't have any paramerers ...
                aux_data.append(layer.__name__)
                children.append(None)    
        return children, aux_data
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layers = []
        # Want a more generic way to unflatten
        for params, spec in zip(children, aux_data):
            if isinstance(spec, list):
                layer_name, *args = spec
                if layer_name == 'Linear':
                    layers.append(Linear.tree_unflatten(args, params))   
            elif isinstance(spec, str) and spec in _registry:
                layers.append(_registry[spec])
        return Sequential(*layers)
    

In [32]:

def fashion_mnist_mlp():
    model = Sequential(
        flatten,
        Linear(784, 128),
        relu,
        Linear(128, 10),
        softmax   
    )

    return model


model = fashion_mnist_mlp()
params, extra_stuff = model.tree_flatten()
print(extra_stuff)
print(params)

['flatten', ['Linear', 784, 128], 'relu', ['Linear', 128, 10], 'softmax']
[None, [DeviceArray([[-0.00503162, -0.11710759,  0.05479915, ..., -0.07662067,
              -0.03762808,  0.037621  ],
             [-0.02311066,  0.00427538,  0.06703123, ...,  0.05820996,
              -0.03371886, -0.0653995 ],
             [-0.03936624,  0.08184296, -0.00103856, ..., -0.02543773,
               0.00404367,  0.10533019],
             ...,
             [-0.05674443,  0.01220774, -0.04277196, ...,  0.00793091,
              -0.03246848,  0.05214054],
             [-0.10229313, -0.04473471, -0.05902693, ..., -0.026743  ,
               0.01399903, -0.02305236],
             [ 0.02624378, -0.040582  ,  0.04346804, ..., -0.0069246 ,
               0.04329436,  0.07048796]], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

## Cross-entropy Loss

In [33]:
@jax.value_and_grad
def cross_entropy(model, X, y, num_cats=10):
    y_one_hot = jax.nn.one_hot(y, num_cats)
    log_softmax = jnp.log(jax.vmap(model)(X))
    return -jnp.mean(log_softmax * y_one_hot)
    

In [34]:

value, grads = cross_entropy(model, np.random.randn(5, 28, 28), [1,0,1,1,0])
print(value)

TypeError: Incompatible shapes for dot: got (28, 28) and (784, 128).

In [35]:
updated_model = jax.tree_util.tree_map(lambda p, g: p - 1e-3*g, model, grads)

NameError: name 'grads' is not defined

In [36]:
assert jnp.all(jnp.isclose(updated_model.layers[1].w, model.layers[1].w - 1e-3*grads.layers[1].w))

NameError: name 'updated_model' is not defined

## Stochastic Gradient Descent

In [37]:
class Optimizer: pass 

In [38]:
class SGD(Optimizer):
    def __init__(self, model, lr=1e-3):
        self.lr = lr 
    def step(self, model, grads):
        return jax.tree_map(lambda p, g: p - self.lr*g, model, grads)

In [45]:
class Adam(Optimizer):
    def __init__(self, model, lr=1e-3, v_decay=0.9, s_decay=0.999, eps=1e-7):
        self.lr, self.v_decay, self.s_decay, self.eps = lr, v_decay, s_decay, eps
        self.v = jax.tree_map(lambda x: jnp.zeros_like(x), model) 
        self.s = jax.tree_map(lambda x: jnp.zeros_like(x), model)
        self.k = 0 
    def step(self, model, grads):
        lr, v_decay, s_decay, eps = self.lr, self.v_decay, self.s_decay, self.eps
        k = self.k = self.k+1
        self.v = jax.tree_map(lambda v, g: v_decay*v +(1-v_decay)*g, self.v, grads)
        self.s = jax.tree_map(lambda s, g: s_decay*s +(1-s_decay)*g*g, self.s, grads)
        v_hat = jax.tree_map(lambda v: v / (1-v_decay**k), self.v)
        s_hat = jax.tree_map(lambda s: s / (1-s_decay**k), self.s)
        new_model = jax.tree_map(lambda params, v_hat, s_hat: params - (lr*v_hat)/(jnp.sqrt(s_hat) + eps), model, v_hat, s_hat)
        return new_model


## Training Loop  


## API Improvements

I can't claim the credit for the API implemented in this section; it's **heavily** inspired by the excellent fastai library.  I'm not lifting code from the fastai repository, but I'm definitely using some of the . 

In [46]:
from collections import OrderedDict 

def train(num_epochs, train_datasource, valid_datasource, optimizer, loss_fn, grad_fn, model):
    history = {'loss':[], 'accuracy':[]}
    
    if valid_datasource is not None:
        history = {**history, 'valid_loss': [], 'valid_accuracy': []}

    for epoch in range(num_epochs):

        # TRAINING PHASE
        loss_accum, accuracy_accum, sample_cnt = 0, 0, 0
        
        for X_train, y_train in train_datasource:

            # training loss and gradients for this particular batch
            probs = model(X_train)#jax.vmap(model)(X_train)
            loss = loss_fn(probs, y_train)
            
            grads = grad_fn(model, X_train, y_train)
            model = optimizer.step(model, grads)
            
            #loss, grad = loss_fn(net, X_train, y_train)

            # update the model using gradient descent
            #net = optimizer.step(net, grad)
                
            # update for metrics
            loss_accum += loss
            sample_cnt += len(y_train)
            accuracy_accum += jnp.sum(jnp.argmax(probs, axis=-1) == y_train)
            #accuracy_accum += jnp.sum(jnp.argmax(jax.vmap(net)(X_train), axis=-1) == y_train)


        epoch_train_loss = loss_accum / sample_cnt 
        epoch_train_accuracy = accuracy_accum / sample_cnt

        history['loss'].append(epoch_train_loss)
        history['accuracy'].append(epoch_train_accuracy)      

        # VALIDATION PHASE
        if valid_datasource is not None:
            loss_accum, accuracy_accum, sample_cnt = 0, 0, 0 

            # Run validation step ...
            for X_valid, y_valid in valid_datasource:
                probs = model(X_valid)
                loss = loss_fn(probs, y_valid)
                
                accuracy_accum += jnp.sum(jnp.argmax(probs, axis=-1) == y_valid)
                #loss, _ = loss_fn(net, X_valid, y_valid)

                #y_pred = predict(X_train)
                
                #valid_loss = loss_fn(predict(X_valid), y_valid)
                loss_accum += loss
                sample_cnt += len(y_valid)
                #accuracy_accum += jnp.sum(jnp.argmax(jax.vmap(net)(X_valid), axis=-1) == y_valid)

            epoch_valid_loss = loss_accum / sample_cnt 
            epoch_valid_accuracy = accuracy_accum / sample_cnt

            history['loss'].append(epoch_valid_loss)
            history['accuracy'].append(epoch_valid_accuracy)
  
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'train_loss: {epoch_train_loss:.6f} , train_accuracy: {100*epoch_train_accuracy:.2f}', end='')
        if valid_datasource:
            print(f' , ', end='')
        print(f'valid_loss: {epoch_valid_loss:.6f} , valid_accuracy: {100*epoch_valid_accuracy:.2f}')
    return history
        


In [47]:
def fashion_mnist_loss(probs, y_true):
    y_one_hot = jax.nn.one_hot(y_true, 10)
    return -jnp.mean(jnp.log(probs) * y_one_hot)

In [48]:
#@jax.value_and_grad
#def fashion_mnist_loss(model, X, y):
#    y_one_hot = jax.nn.one_hot(y, 10)
#    log_softmax = jnp.log(jax.vmap(model)(X))
#    return -jnp.mean(log_softmax * y_one_hot)

In [49]:
X_train_, y_train_ = X_train[:40_000,:,:], y_train[:40_000]
X_valid, y_valid = X_train[40_000:,:,:], y_train[40_000:]

train_dataset = Dataset(X_train_, y_train_)
valid_dataset = Dataset(X_valid, y_valid)

train_datasource = Dataloader(train_dataset, batchsize=64)
valid_datasource = Dataloader(valid_dataset, batchsize=64)
model = fashion_mnist_mlp()

#@jax.grad
#def grad_fn(model, X, y):
#    y_one_hot = jax.nn.one_hot(y, 10)
#    log_softmax = jnp.log(model(X))
#    return -jnp.mean(log_softmax * y_one_hot)

grad_fn = jax.grad(lambda model, X, y: fashion_mnist_loss(model(X), y))

history = train(
    num_epochs=5, 
    train_datasource=train_datasource, 
    valid_datasource=valid_datasource, 
    optimizer=Adam(model, lr=1e-3), 
    loss_fn=fashion_mnist_loss, 
    model=model,
    grad_fn=grad_fn
)



Epoch 1/5
train_loss: 0.007568 , train_accuracy: 80.38 , valid_loss: 0.007337 , valid_accuracy: 84.12
Epoch 2/5
train_loss: 0.007250 , train_accuracy: 85.44 , valid_loss: 0.007235 , valid_accuracy: 85.59
Epoch 3/5
train_loss: 0.007160 , train_accuracy: 86.83 , valid_loss: 0.007178 , valid_accuracy: 86.43
Epoch 4/5
train_loss: 0.007110 , train_accuracy: 87.64 , valid_loss: 0.007146 , valid_accuracy: 87.08
Epoch 5/5
train_loss: 0.007075 , train_accuracy: 88.32 , valid_loss: 0.007117 , valid_accuracy: 87.49


## Performance Curve

Let's see the trend in the loss function.

## Conclusion



In [50]:
tf.keras.layers.Dense

[0;31mInit signature:[0m [0mtf[0m[0;34m.[0m[0mkeras[0m[0;34m.[0m[0mlayers[0m[0;34m.[0m[0mDense[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m        
[0;32mclass[0m [0mDense[0m[0;34m([0m[0mLayer[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;34m"""Just your regular densely-connected NN layer.[0m
[0;34m[0m
[0;34m  `Dense` implements the operation:[0m
[0;34m  `output = activation(dot(input, kernel) + bias)`[0m
[0;34m  where `activation` is the element-wise activation function[0m
[0;34m  passed as the `activation` argument, `kernel` is a weights matrix[0m
[0;34m  created by the layer, and `bias` is a bias vector created by the layer[0m
[0;34m  (only applicable if `use_bias` is `True`). These are all attributes of[0m
[0;34m  `Dense`.[0m
[0;34m[0m
[0;34m  Note: If the input to the layer has a rank greater than 2, then `Dense`[0m
[0;34m  computes the 

In [254]:
print??

[0;31mDocstring:[0m
print(value, ..., sep=' ', end='\n', file=sys.stdout, flush=False)

Prints the values to a stream, or to sys.stdout by default.
Optional keyword arguments:
file:  a file-like object (stream); defaults to the current sys.stdout.
sep:   string inserted between values, default a space.
end:   string appended after the last value, default a newline.
flush: whether to forcibly flush the stream.
[0;31mType:[0m      builtin_function_or_method


In [333]:
model = fashion_mnist_mlp()

In [335]:
model(np.random.randn(2,28,28)).shape

(2, 10)

In [323]:
np.reshape(np.random.randn(2,28,28), (2, -1))

array([[-1.36403014, -0.35703367,  0.45756755, ...,  0.53845649,
        -0.44979458, -0.77233799],
       [-0.98455966, -0.17481828,  1.16870836, ...,  0.91413078,
         0.46064645, -1.13064741]])

In [324]:
jnp.shape(np.random.randn(10))

(10,)