# MNIST using Linear Layers with JAX

- toc: true 
- badges: true
- comments: true
- categories: [jupyter]
- image: images/chart-preview.png

In [23]:
import jax 
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from typing import Callable, List

## Linear Layer API


In [28]:
class Linear(object):
    def __init__(self, input_dim: int, output_dim: int):

        self.output_dim, self.input_dim = output_dim, input_dim
        key = jax.random.PRNGKey(1234)
        key_w, key_b = jax.random.split(key, 2)
        
        # better initialization like xavier or kaiming
        self.w = jax.random.normal(key_w, (output_dim, input_dim))
        self.b = jax.random.normal(key_b, (output_dim,))

        # want to adopt fastAI's approach of registering parameters into some collection structure
    def __call__(self, x):
        return jnp.dot(self.w, x) + self.b

    def __repr__(self):
        return f'Linear({self.output_dim}, {self.input_dim})'

    @property
    def params(self):
        return {'weights': self.w, 'bias': self.b}


In [30]:
layer = Linear(3,2)

layer.params

{'weights': DeviceArray([[ 5.2179140e-01,  1.4660535e-03,  3.2984921e-01],
              [-3.9343223e-01, -1.9224546e+00, -1.3630803e-01]],            dtype=float32),
 'bias': DeviceArray([-1.7366334, -1.7102827], dtype=float32)}

In [18]:
result = jax.vmap(Linear(3, 2), in_axes=0)(np.random.randn(10, 3))
result

DeviceArray([[-2.3366485 , -2.3083436 ],
             [-1.4197106 , -1.2063366 ],
             [-1.7308872 ,  2.254344  ],
             [-0.02429783, -5.9196377 ],
             [-1.9397157 , -4.931116  ],
             [-1.9212854 , -0.96998787],
             [-1.9635744 ,  0.27272022],
             [-0.00939989, -4.705719  ],
             [-2.6962264 , -1.90379   ],
             [-2.1112676 , -1.3299353 ]], dtype=float32)

In [25]:
class Sequential(object):
    def __init__(self, layers: List[Callable]):
        self.layers = layers
        # want to get all the parameters in a list
    def __call__(self, x):
        pass 
    

## Loss Function

The loss function takes the current set of parameters, the model, and the data and calculates the distance between the model
predictions and the targets.  

In [102]:
@jax.value_and_grad
def mse_loss(params, model, X, y):
    l = jax.vmap(model, in_axes=(None, 0))(params, X) 
    return jnp.mean((l - y)**2) 

## Training Loop

The training loop tunes the model parameters for a specified number of epochs.  For now the parameters are adjusted once per epoch, after the entire dataset is consumed by the model.  

In [103]:
def train(data, model, lr=1e-2, num_epochs=50, loss=mse_loss):
    
    X, y = data
    loss_vals = np.zeros(num_epochs)


    # initialize the model parameters
    params = {'w': jnp.array([0,0,0], dtype=jnp.float32), 'b': 0.0}

    for i in range(num_epochs):
        loss_i, grad_params_i = loss(params, model, X, y)
        for key in params:
            params[key] -= lr*grad_params_i[key]
    
        loss_vals[i] = loss_i
    return loss_vals, params

## Performance Curve

Let's see the trend in the loss function.

## Conclusion

