# Fashion MNIST using Linear Layers with JAX

- toc: true
- badges: true
- comments: true
- categories: [jupyter]
- image: images/chart-preview.png
- hide: true



In [67]:
import jax 
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from typing import Tuple, List, Any, Dict, Callable


## PyTorch / fast.ai like Data API

In [2]:
class Dataset:
    def __init__(self, X, y):
        self.X, self.y = X, y
    def __len__(self):
        return jnp.shape(self.X)[0]
    def __getitem__(self, i):
        return self.X[i,:], self.y[i]

In [71]:
class Dataloader:
    def __init__(self, dataset: Dataset, batchsize=32, shuffle=False):
        self.dataset = dataset
        self.batchsize = batchsize
        self.shuffle = shuffle
    def __iter__(self):
        for i in range(0, len(self.dataset), self.batchsize): 
            yield self.dataset[i:i+self.batchsize]
        

In [116]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0

In [124]:
fashion_mnist.load_data??

[0;31mSignature:[0m [0mfashion_mnist[0m[0;34m.[0m[0mload_data[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;34m@[0m[0mkeras_export[0m[0;34m([0m[0;34m'keras.datasets.fashion_mnist.load_data'[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;32mdef[0m [0mload_data[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;34m"""Loads the Fashion-MNIST dataset.[0m
[0;34m[0m
[0;34m  This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,[0m
[0;34m  along with a test set of 10,000 images. This dataset can be used as[0m
[0;34m  a drop-in replacement for MNIST.[0m
[0;34m[0m
[0;34m  The classes are:[0m
[0;34m[0m
[0;34m  | Label | Description |[0m
[0;34m  |:-----:|-------------|[0m
[0;34m  |   0   | T-shirt/top |[0m
[0;34m  |   1   | Trouser     |[0m
[0;34m  |   2   | Pullover    |[0m
[0;34m  |   3   | Dress       |[0m
[0;34m  |   4   | Coat        |[0m
[0;34m  |   5   | Sandal      |[0m
[0;34

In [117]:
dataset = Dataset(X_train, y_train)

In [118]:
dataloader = Dataloader(dataset)

In [119]:
for X, y in dataloader:
    print(X.shape, y.shape)
    

(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28) (32,)
(32, 28, 28)

## Model API

In [120]:
class Module: pass


### Linear Layer

In [121]:
class Linear(Module):
    w: jnp.ndarray 
    b: jnp.ndarray
    ni: int 
    no: int 

    def __init__(self, num_inputs, num_outputs, seed=1234):
        self.ni = num_inputs 
        self.no = num_outputs 
        key = jax.random.PRNGKey(seed)
        self.w = jax.random.normal(key, (num_inputs, num_outputs)) * jnp.sqrt(2.0 / num_inputs)
        self.b = jnp.zeros(num_outputs)
        
    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b
        
    def params(self):
        return {'b': self.b, 'w': self.w}


In [5]:
l = Linear(2, 1)
x = np.random.randn(2)
y = l(x)

print(y)




[-1.4597868]


In [6]:
def mse(model, X, y):
    preds = jax.vmap(model)(X)
    return jnp.mean((preds - y)**2)


In [7]:
print(mse(l, x, 2.0))
print(mse(l, np.random.randn(10, 2), np.random.randn(10)))

4.4841824
1.44302


In [8]:
mse_grad = jax.grad(mse)
mse_grad(l, x, 2.0)

TypeError: Argument '<__main__.Linear object at 0x7f6b801ebb80>' of type <class '__main__.Linear'> is not a valid JAX type.

To get this to work, the `Linear` class must be registered as a pytree.  

In [9]:
@jax.tree_util.register_pytree_node_class
class Linear(Module):
    w: jnp.ndarray 
    b: jnp.ndarray
    ni: int 
    no: int 

    def __init__(self, num_inputs, num_outputs, build=True, seed=1234):
        self.ni = num_inputs 
        self.no = num_outputs 
        # want to add seed as internal object
        if build:
            key = jax.random.PRNGKey(seed)
            self.w = jax.random.normal(key, (num_inputs, num_outputs)) * jnp.sqrt(2.0 / num_inputs)
            self.b = jnp.zeros(num_outputs)
    
    def merge(self, params):
        self.w, self.b = params

    def __repr__(self):
        return f'Linear(num_inputs={self.ni}, num_outputs={self.no})'
        
    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b
        
    def params(self):
        return {'b': self.b, 'w': self.w}

    def tree_flatten(self):
        return [self.w, self.b], [self.ni, self.no]

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layer = cls(*aux_data, build=False)
        layer.merge(params=children)
        return layer

In [10]:
lin = Linear(2, 1)

In [11]:
params, extra_stuff = lin.tree_flatten()

In [12]:
lin2 = Linear.tree_unflatten(extra_stuff, params)

In [13]:
print(lin.w)
print(lin2.w) 

[[ 0.43957582]
 [-0.26563603]]
[[ 0.43957582]
 [-0.26563603]]


In [14]:
@jax.jit
@jax.value_and_grad
def mse(model, X, y):
    preds = jax.vmap(model)(X)
    return jnp.mean((preds - y)**2)

In [15]:
X = np.random.randn(10, 2)
y = np.random.randn(10)

loss, g_loss = mse(lin, X, y)
print(loss, g_loss)

1.5986859 Linear(num_inputs=2, num_outputs=1)


In [16]:
g_loss.__dict__

{'ni': 2,
 'no': 1,
 'w': DeviceArray([[ 1.0373731 ],
              [-0.40601766]], dtype=float32),
 'b': DeviceArray([-0.7894485], dtype=float32)}

In [17]:
jax.tree_util.tree_flatten(g_loss)

([DeviceArray([[ 1.0373731 ],
               [-0.40601766]], dtype=float32),
  DeviceArray([-0.7894485], dtype=float32)],
 PyTreeDef(CustomNode(<class '__main__.Linear'>[[2, 1]], [*, *])))

In [18]:
locals()['Linear'].__class__

type

### Helper Functions

In [19]:
def flatten(x: jnp.ndarray): 
    return jnp.reshape(x, -1) 


In [20]:
def relu(x: jnp.ndarray): 
    return jnp.clip(x, a_min=0)

   

In [21]:
x = np.random.randn(10,10)
assert jnp.all(jnp.isclose(relu(x), jax.nn.relu(x))), 'test failed'

In [22]:
def softmax(x: jnp.ndarray):
    ex = jnp.exp(x)
    return ex / jnp.sum(ex)

In [122]:
x = np.random.randn(10)
assert jnp.all(jnp.isclose(softmax(x), jax.nn.softmax(x))), 'test failed'

In [24]:
_registry = {
    flatten.__name__: flatten,
    softmax.__name__: softmax,
    relu.__name__: relu
}

### Sequential Module

In [25]:
@jax.tree_util.register_pytree_node_class
class Sequential(Module):
    layers: List
    def __init__(self, *layers):
        self.layers = layers
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    def tree_flatten(self):
        aux_data, children = [], []
        for layer in self.layers:
            if isinstance(layer, Module):
                params, extra_stuff = layer.tree_flatten()
                aux_data.append([layer.__class__.__name__] + extra_stuff)
                children.append(params) 
            elif callable(layer):
                # a layer function that doesn't have any paramerers ...
                aux_data.append(layer.__name__)
                children.append(None)    
        return children, aux_data
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        layers = []
        # Need a bettern way to unflatten a sequential structure, more flexible at least...
        for params, spec in zip(children, aux_data):
            if isinstance(spec, list):
                layer_name, *args = spec
                if layer_name == 'Linear':
                    layers.append(Linear.tree_unflatten(args, params))   
            elif isinstance(spec, str) and spec in _registry:
                layers.append(_registry[spec])
        return Sequential(*layers)
    

In [99]:


model = Sequential(
    flatten,
    Linear(784, 128),
    relu,
    Linear(128, 10),
    softmax   
)

params, extra_stuff = model.tree_flatten()
print(extra_stuff)
print(params)

['flatten', ['Linear', 784, 128], 'relu', ['Linear', 128, 10], 'softmax']
[None, [DeviceArray([[-0.00503162, -0.11710759,  0.05479915, ..., -0.07662067,
              -0.03762808,  0.037621  ],
             [-0.02311066,  0.00427538,  0.06703123, ...,  0.05820996,
              -0.03371886, -0.0653995 ],
             [-0.03936624,  0.08184296, -0.00103856, ..., -0.02543773,
               0.00404367,  0.10533019],
             ...,
             [-0.05674443,  0.01220774, -0.04277196, ...,  0.00793091,
              -0.03246848,  0.05214054],
             [-0.10229313, -0.04473471, -0.05902693, ..., -0.026743  ,
               0.01399903, -0.02305236],
             [ 0.02624378, -0.040582  ,  0.04346804, ..., -0.0069246 ,
               0.04329436,  0.07048796]], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [100]:
model(np.random.randn(28, 28))

DeviceArray([0.24745142, 0.0232116 , 0.00846777, 0.05969231, 0.04158318,
             0.01553843, 0.33251768, 0.06073404, 0.20046304, 0.01034052],            dtype=float32)

In [101]:
model2 = Sequential.tree_unflatten(extra_stuff, params)

In [102]:
model2.tree_flatten()

([None,
  [DeviceArray([[-0.00503162, -0.11710759,  0.05479915, ..., -0.07662067,
                 -0.03762808,  0.037621  ],
                [-0.02311066,  0.00427538,  0.06703123, ...,  0.05820996,
                 -0.03371886, -0.0653995 ],
                [-0.03936624,  0.08184296, -0.00103856, ..., -0.02543773,
                  0.00404367,  0.10533019],
                ...,
                [-0.05674443,  0.01220774, -0.04277196, ...,  0.00793091,
                 -0.03246848,  0.05214054],
                [-0.10229313, -0.04473471, -0.05902693, ..., -0.026743  ,
                  0.01399903, -0.02305236],
                [ 0.02624378, -0.040582  ,  0.04346804, ..., -0.0069246 ,
                  0.04329436,  0.07048796]], dtype=float32),
   DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                0.

In [103]:
X = np.random.randn(10, 28, 28)
y = jax.vmap(model)(X)
print(y)

[[0.18219435 0.0288466  0.08972018 0.04580158 0.03094671 0.15161341
  0.13047464 0.1996091  0.04998811 0.09080528]
 [0.21781811 0.19041373 0.04586045 0.02944165 0.04131059 0.14847653
  0.14403759 0.01657318 0.10250735 0.06356081]
 [0.0913539  0.03204991 0.00621432 0.00281485 0.00478983 0.26314914
  0.28993967 0.01488724 0.20624228 0.08855885]
 [0.08637979 0.02319437 0.02073215 0.27249578 0.017279   0.12466064
  0.25296152 0.01484643 0.10506434 0.08238598]
 [0.01162764 0.01259067 0.05375895 0.4700068  0.0066564  0.05829709
  0.11621077 0.02655041 0.15631145 0.08798981]
 [0.04634016 0.22555009 0.08186731 0.01344233 0.03639027 0.3018414
  0.11826753 0.02965728 0.07825199 0.06839168]
 [0.22286822 0.07548939 0.02164324 0.03841386 0.01069496 0.2535701
  0.17003557 0.04174198 0.04416189 0.12138066]
 [0.30497798 0.01807475 0.02036336 0.18955673 0.05524043 0.17666912
  0.0173182  0.01110036 0.05980157 0.14689754]
 [0.01721036 0.05527912 0.00589779 0.02485557 0.06157751 0.12263539
  0.0881054  0

## Cross-entropy Loss

In [104]:
jnp.log(10)

DeviceArray(2.3025851, dtype=float32, weak_type=True)

In [105]:
@jax.value_and_grad
def cross_entropy(model, X, y, num_cats=10):
    y_one_hot = jax.nn.one_hot(y, num_cats)
    log_softmax = jnp.log(jax.vmap(model)(X))
    return -jnp.mean(log_softmax * y_one_hot)
    

In [106]:

value, grads = cross_entropy(model, np.random.randn(5, 28, 28), [1,0,1,1,0])
print(value)

0.3070485


In [107]:
jax.tree_map(lambda x: print(x.shape), grads)

(784, 128)
(128,)
(128, 10)
(10,)


<__main__.Sequential at 0x7f6b5c622610>

In [108]:
mm = jax.tree_util.tree_map(lambda x: print(x.shape), model)

(784, 128)
(128,)
(128, 10)
(10,)


In [109]:
updated_model = jax.tree_util.tree_map(lambda p, g: p - 1e-3*g, model, grads)

In [110]:
assert jnp.all(jnp.isclose(updated_model.layers[1].w, model.layers[1].w - 1e-3*grads.layers[1].w))

## Optimizer 

In [126]:
class Optimizer: pass 

In [128]:
class SGD(Optimizer):
    def __init__(self, lr=1e-3):
        self.lr = lr 
    def step(self, model, grads):
        return jax.tree_util.tree_map(lambda p, g: p - self.lr*g, model, grads)

## Metrics

In [None]:
x

## Training Loop  

In [129]:
dataset = Dataset(X_train, y_train)
dataloader = Dataloader(dataset, batchsize=64)
opt = SGD()
num_epochs, lr = 10, 1e-3
for i in range(num_epochs):
    epoch_correct_prediction_count = 0
    epoch_loss = 0.0
    num_training_examples = 0
    for X, y in dataloader:
        # evaluate the model
        loss, grad = cross_entropy(model, X, y)
        
        # update the model using gradient descent
        model = opt.step(model, grad)
        #model = jax.tree_util.tree_map(lambda p, g: p - lr*g, model, grad)

        # metrics
        y_preds = jnp.argmax(jax.vmap(model)(X), axis=1)
        correct = jnp.sum(y_preds == y)
        
        epoch_correct_prediction_count += correct
        epoch_loss += loss 
    
        minibatch_size = jnp.shape(X)[0]
        num_training_examples += minibatch_size

    epoch_accuracy = epoch_correct_prediction_count / num_training_examples
    epoch_loss = epoch_loss / num_training_examples

    print(f'Epoch {i}: {100*epoch_accuracy:.2f}')


AttributeError: 'Linear' object has no attribute '__name__'

In [334]:
X = np.random.randn(2, 28, 28)
y = np.array([0, 1])
o = jax.vmap(model)(X)
print(o)
y_preds = jnp.argmax(o, axis=1)
print(y_preds)
accuracy = jnp.sum(y_preds == y)
print(accuracy)

[[0.01349435 0.050846   0.00470727 0.0737321  0.01404345 0.08124171 0.32983273 0.09614405 0.11126529 0.22469307]
 [0.00822283 0.01901213 0.01239796 0.01414352 0.01985135 0.07420428 0.08907535 0.00788129 0.6057198  0.14949153]]
[6 8]
0


## Performance Curve

Let's see the trend in the loss function.

## Conclusion



In [41]:
X = np.random.randn(10, 3)
w = np.random.randn(5, 3)

In [42]:
np.dot(X, np.transpose(w))

array([[ 0.95615652, -0.60910943,  0.47719404,  0.50628421, -0.57886369],
       [ 3.09498684,  1.70024379, -1.01957485,  2.25813896, -0.37952626],
       [-3.67871561, -0.1351104 ,  0.04294664, -1.84479421,  0.98852387],
       [-1.00596024,  0.46315551,  0.67104569,  3.13116358, -1.09329311],
       [ 2.07878921,  1.88358723, -1.42939824,  0.84352964,  0.39912581],
       [ 2.91327168, -1.07647233,  0.73089571,  1.0445742 , -1.1687821 ],
       [ 1.03321149,  0.51671102,  0.70734573,  4.30933842, -1.72017219],
       [-2.73799462, -1.54710401,  1.83484434,  1.17423297, -1.08464385],
       [-1.6715637 , -2.74092988,  1.91974295, -1.18807747, -0.68628707],
       [ 1.21170605,  0.15311168, -0.1815407 ,  0.30359954, -0.13915325]])

In [44]:
def lin(x): return jnp.dot(w, x)

yy = jax.vmap(lin)(X)
print(yy)

[[ 0.9561565  -0.6091094   0.47719404  0.50628424 -0.5788637 ]
 [ 3.0949867   1.7002438  -1.0195749   2.258139   -0.37952614]
 [-3.6787155  -0.13511032  0.04294658 -1.8447943   0.9885239 ]
 [-1.0059603   0.46315545  0.67104566  3.1311636  -1.093293  ]
 [ 2.0787892   1.8835871  -1.4293982   0.8435297   0.39912578]
 [ 2.9132717  -1.0764723   0.7308957   1.044574   -1.1687821 ]
 [ 1.0332114   0.516711    0.7073457   4.309338   -1.7201722 ]
 [-2.7379947  -1.547104    1.8348444   1.174233   -1.0846438 ]
 [-1.6715636  -2.7409298   1.919743   -1.1880776  -0.686287  ]
 [ 1.2117062   0.15311167 -0.1815407   0.3035995  -0.13915324]]
