# MNIST using Linear Layers with JAX

- toc: true
- badges: true
- comments: true
- categories: [jupyter]
- image: images/chart-preview.png

In [215]:
import jax 
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from typing import Callable, List, NamedTuple


## Model API

In [236]:
class Model(NamedTuple):
    name: str
    init: Callable 
    forward: Callable 
    def __repr__(self):
        return self.name 

In [237]:
def Linear(input_dim: int, output_dim: int):
    def init():    
        key = jax.random.PRNGKey(1234)
        params = {
            'weights': jax.random.normal(key, (output_dim, input_dim)) * jnp.sqrt(2.0 / input_dim),
            'bias': jnp.zeros(output_dim)
        }
        return params        
    def forward(params, x):
        W, b = params['weights'], params['bias']
        return jnp.dot(W, x) + b 

    return Model(init=init, forward=forward, name=f'Linear({input_dim}, {output_dim})')

In [238]:
def Sequential(*layers):
    def init():
        params = []
        for layer in layers:
            if isinstance(layer, Model):
                params.append(layer.init())
        return params    
    def forward(params, x):
        activation = x
        i = 0
        for layer in layers:
            if isinstance(layer, Model):
                activation = layer.forward(params[i], activation)
                i += 1
            else:
                activation = layer(activation)
        return activation
    return Model(init=init, forward=forward, name='Sequential')

The type of model we're looking to build can be represented in Keras, as follows:

```python
model = keras.Sequential([
  keras.layers.Flatten(input_shape=(28,28)),
  keras.layers.Dense(128, activation=keras.activations.relu),
  keras.layers.Dense(10, activation=keras.activations.softmax)                          
])
```

In [239]:
# maybe the parameters buried in the model are a problem, we need to differentiate with respect to them...
model = Sequential(
    Linear(28*28, 128),
    jax.nn.relu,
    Linear(128, 10)   
)
print(model)

Sequential


## Loss Function

The loss function takes the current set of parameters, the model, and the data and calculates the distance between the model
predictions and the targets.  

In [232]:
# parameters pulled from model, this seems really clunky ...
@jax.value_and_grad
def cross_entropy_loss(params, forward, X, y_one_hot):
    logits = jax.vmap(forward, in_axes=(None, 0))(params, X) 
    logsoftmax = logits - jax.nn.logsumexp(logits)
    loss = -(logsoftmax * y_one_hot).mean()
    return loss

In [233]:
params = model.init()
y = jax.nn.one_hot([0,2,1,3,4], 10)
a, b = cross_entropy_loss(params, model.forward, np.random.randn(5, 28*28), y)

## Training Loop

The training loop tunes the model parameters for a specified number of epochs.  For now the parameters are adjusted once per epoch, after the entire dataset is consumed by the model.  

In [103]:
def train(data, model, lr=1e-2, num_epochs=50, loss=mse_loss):
    
    X, y = data
    loss_vals = np.zeros(num_epochs)


    # initialize the model parameters
    params = {'w': jnp.array([0,0,0], dtype=jnp.float32), 'b': 0.0}

    for i in range(num_epochs):
        loss_i, grad_params_i = loss(params, model, X, y)
        for key in params:
            params[key] -= lr*grad_params_i[key]
    
        loss_vals[i] = loss_i
    return loss_vals, params

## Performance Curve

Let's see the trend in the loss function.

## Conclusion



In [41]:
X = np.random.randn(10, 3)
w = np.random.randn(5, 3)

In [42]:
np.dot(X, np.transpose(w))

array([[ 0.95615652, -0.60910943,  0.47719404,  0.50628421, -0.57886369],
       [ 3.09498684,  1.70024379, -1.01957485,  2.25813896, -0.37952626],
       [-3.67871561, -0.1351104 ,  0.04294664, -1.84479421,  0.98852387],
       [-1.00596024,  0.46315551,  0.67104569,  3.13116358, -1.09329311],
       [ 2.07878921,  1.88358723, -1.42939824,  0.84352964,  0.39912581],
       [ 2.91327168, -1.07647233,  0.73089571,  1.0445742 , -1.1687821 ],
       [ 1.03321149,  0.51671102,  0.70734573,  4.30933842, -1.72017219],
       [-2.73799462, -1.54710401,  1.83484434,  1.17423297, -1.08464385],
       [-1.6715637 , -2.74092988,  1.91974295, -1.18807747, -0.68628707],
       [ 1.21170605,  0.15311168, -0.1815407 ,  0.30359954, -0.13915325]])

In [44]:
def lin(x): return jnp.dot(w, x)

yy = jax.vmap(lin)(X)
print(yy)

[[ 0.9561565  -0.6091094   0.47719404  0.50628424 -0.5788637 ]
 [ 3.0949867   1.7002438  -1.0195749   2.258139   -0.37952614]
 [-3.6787155  -0.13511032  0.04294658 -1.8447943   0.9885239 ]
 [-1.0059603   0.46315545  0.67104566  3.1311636  -1.093293  ]
 [ 2.0787892   1.8835871  -1.4293982   0.8435297   0.39912578]
 [ 2.9132717  -1.0764723   0.7308957   1.044574   -1.1687821 ]
 [ 1.0332114   0.516711    0.7073457   4.309338   -1.7201722 ]
 [-2.7379947  -1.547104    1.8348444   1.174233   -1.0846438 ]
 [-1.6715636  -2.7409298   1.919743   -1.1880776  -0.686287  ]
 [ 1.2117062   0.15311167 -0.1815407   0.3035995  -0.13915324]]
