<a href="https://colab.research.google.com/github/deterministic-algorithms-lab/Jax-Journey/blob/main/jax_basic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

A notebook for this [blog](https://roberttlange.github.io/posts/2020/03/blog-post-10/) with additional notes. Implements MLP and CNN in ```JAX```. It is suggested to read that blog side-by-side. 

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as onp
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax import random

# Generate key which is used to generate random numbers
key = random.PRNGKey(1)                                                         #A key is always an nd-array of size (2,) 

In [None]:
# Generate a random matrix
x = random.uniform(key, (1000, 1000))
# Compare running times of 3 different matrix multiplications
%time y = onp.dot(x, x)
%time y = np.dot(x, x); print(y)
%time y = np.dot(x, x).block_until_ready()

The above is due to [Asyncronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html).

In [None]:
def ReLU(x):
    """ Rectified Linear Unit (ReLU) activation function """
    return np.maximum(0, x)

jit_ReLU = jit(ReLU)                                                            

JIT a simple python function using numpy to make it faster. Normally, each operation has its own kernel which are dispatched to GPU, one by one. If we have a sequence of operations, we can use the ```@jit decorator / jit()``` to compile multiple operations together using XLA.

In [None]:
%time out = ReLU(x).block_until_ready()

# Call jitted version to compile for evaluation time!
%time jit_ReLU(x).block_until_ready()                                           #First time call will cause compilation, and may take longer.
%time out = jit_ReLU(x).block_until_ready()

CPU times: user 60.6 ms, sys: 0 ns, total: 60.6 ms
Wall time: 61.1 ms
CPU times: user 26 ms, sys: 841 µs, total: 26.8 ms
Wall time: 25.8 ms
CPU times: user 1.21 ms, sys: 0 ns, total: 1.21 ms
Wall time: 696 µs


The ```grad()``` function takes as input a function ```f``` and returns the function ``` f' ``` . This ```f'``` should be ```jit()```-ted again.

In [None]:
def FiniteDiffGrad(x):
    """ Compute the finite difference derivative approx for the ReLU"""
    return np.array((ReLU(x + 1e-3) - ReLU(x - 1e-3)) / (2 * 1e-3))

# Compare the Jax gradient with a finite difference approximation
print("Jax Grad: ", jit(grad(jit(ReLU)))(2.))
print("FD Gradient:", FiniteDiffGrad(2.))

Jax Grad:  1.0
FD Gradient: 0.99998707


**vmap** -  makes batching as easy as never before. While in PyTorch one always has to be careful over which dimension you want to perform computations, vmap lets you simply write your computations for a single sample case and afterwards wrap it to make it batch compatible. 

In [None]:
batch_dim = 32
feature_dim = 100
hidden_dim = 512

# Generate a batch of vectors to process
X = random.normal(key, (batch_dim, feature_dim))

# Generate Gaussian weights and biases
params = [random.normal(key, (hidden_dim, feature_dim)),
          random.normal(key, (hidden_dim, ))]

def relu_layer(params, x):
    """ Simple ReLu layer for single sample """
    return ReLU(np.dot(params[0], x) + params[1])

def batch_version_relu_layer(params, x):
    """ Error prone batch version """
    return ReLU(np.dot(X, params[0].T) + params[1])

def vmap_relu_layer(params, x):
    """ vmap version of the ReLU layer """
    return jit(vmap(relu_layer, in_axes=(None, 0), out_axes=0))

out = np.stack([relu_layer(params, X[i, :]) for i in range(X.shape[0])])
out = batch_version_relu_layer(params, X)
out = vmap_relu_layer(params, X)

```vmap``` wraps the ```relu_layer``` function and takes as an input the axis over which to batch the inputs. In our case the first input to ```relu_layer``` are the parameters which are the same for the entire batch [```(None)```]. The second input is the feature vector, ```x```. We have stacked the vectors into a matrix such that our input has dimensions ```(batch_dim, feature_dim)```. We therefore need to provide ```vmap``` with batch dimension ```(0)``` in order to properly parallelize the computations. ```out_axes``` then specifies how to stack the individual samples' outputs. In order to keep things consistent, we choose the first dimension to remain the batch dimension.

## MLP

In [None]:
from jax.scipy.special import logsumexp
from jax.experimental import optimizers

import torch
from torchvision import datasets, transforms

import time

In [None]:
batch_size = 100

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

In [None]:
print(key)
split = random.split(key, 5)                                                    #Can be split into any number of parts. New keys, along new axis
print(split)                                                
print(random.split(split[0]))                                                   #Can only split "keys", i.e. , nd-array of size (2,)

[0 1]
[[3243370355 1344208528]
 [ 532076793 2354449600]
 [1813813011 1313272271]
 [3522235465 4107438537]
 [1531693580 2391939978]]
[[1467608531 2825924092]
 [ 757006082 1868645737]]


Since ```JAX``` offers only a functional programming interface, we can't write classes corresponding to modules, in ```JAX``` . We must write a function for initialization, and forward pass instead. 

In [None]:
def initialize_mlp(sizes, key):
    """ Initialize the weights of all layers of a linear layer network """

    keys = random.split(key, len(sizes))
    
    # Initialize a single layer with Gaussian weights -  helper function
    def initialize_layer(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
    
    return [initialize_layer(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]


layer_sizes = [784, 512, 512, 10]

# Return a list of tuples of layer weights
params = initialize_mlp(layer_sizes, key)


The forward passs functions should take as input all the parameters(```params```) of the model, and the input(```in_array```) to it. Usually, we make a dictionary of all the parameters, so that the function can access them easily.

In [None]:
def forward_pass(params, in_array):
    """ 
    Compute the forward pass for each example individually.
    Inputs :  params: List of tuples. Tuples must be as required by relu_layer.
              in_array: Input array as needed by relu_layer.
    """
    activations = in_array

    # Loop over the ReLU hidden layers
    for w, b in params[:-1]:
        activations = relu_layer([w, b], activations)

    # Perform final trafo to logits
    final_w, final_b = params[-1]
    logits = np.dot(final_w, activations) + final_b                             #Feel free to use any jit-numpy operations in your functions, anywhere.

    return logits - logsumexp(logits)                                           #Just simple softmax, it is. 

# Make a batched version of the `predict` function
batch_forward = vmap(forward_pass, in_axes=(None, 0), out_axes=0)


In [None]:
def one_hot(x, k, dtype=np.float32):
    """Create a one-hot encoding of x of size k """
    return np.array(x[:, None] == np.arange(k), dtype)

def loss(params, in_arrays, targets):
    """ 
    Compute the multi-class cross-entropy loss.

    Inputs : params: list of model parameters as accepted by forward_pass
             in_arrays: input_array as accepted by forward_pass
             targets: jit-numpy array containing one hot targets
    """
    preds = batch_forward(params, in_arrays)
    return -np.sum(preds * targets)                                             #Cross Entropy Loss. Divide by 784 to average.

def accuracy(params, data_loader):
    """ Compute the accuracy for a provided dataloader """
    acc_total = 0
    num_classes = 10

    for batch_idx, (data, target) in enumerate(data_loader):
        images = np.array(data).reshape(data.size(0), 28*28)                    #Need to make PyTorch tensors, into jit-numpy arrays
        targets = one_hot(np.array(target), num_classes)

        target_class = np.argmax(targets, axis=1)
        predicted_class = np.argmax(batch_forward(params, images), axis=1)
        acc_total += np.sum(predicted_class == target_class)
    return acc_total/len(data_loader.dataset)

In [None]:
x = np.arange(3)
print(x.shape)
print(x[None, :].shape)
print(x[:,None].shape)
print(x+x[None,:])
print(x[None,:]+x[:,None])
print(x+x[:,None])

(3,)
(1, 3)
(3, 1)
[[0 2 4]]
[[0 1 2]
 [1 2 3]
 [2 3 4]]
[[0 1 2]
 [1 2 3]
 [2 3 4]]


```value_and_grad(fn)``` returns a function that takes same arguments(```x```) as ```fn``` and returns both the return value(```fn(x)```) of ```fn``` and its gradient(```fn'(x)```), as a tuple. 

The optimizer below stores its data(parameters and hyperparameters) in ```opt_state``` and its functionality is defined in ```opt_update()```, ```opt_init()``` and ```get_params()``` . Notice how there is no class. It would be better to put all 4 things in a dicionary, hence.

In [None]:
@jit
def update(params, x, y, opt_state):
    """ 
    Compute the gradient for a batch and update the parameters

    Inputs :  params:     list of model parameters as accepted by loss function (in turn by forward_pass)
              x:          input as accepted by loss_function(in turn by forward_pass)
              y:          jit-numpy array containing one hot targets(as required by loss function)
              opt_state:  as required by opt_update
    Returns : 
              updated parameters, current optimizer state, computed value
    """
    value, grads = value_and_grad(loss)(params, x, y)
    opt_state = opt_update(0, grads, opt_state)                                 #opt_update is a function, not a variable, hence is available in this scope, although not defined here.
    return get_params(opt_state), opt_state, value                              #The first argument to the opt_update function is the optimizer step number.

# Defining an optimizer in Jax
step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)                                                    #All the updatable parameters. First opt_state needs to be obtained this way, always.

num_epochs = 10
num_classes = 10

Notice how in all the above code, each function tries to make sure that its input fits well with the functions that it is calling. And this leads to a hierarchical structure, in stark comparison to the step-wise structure of PyTorch code. 

In [None]:
def run_mnist_training_loop(num_epochs, opt_state, net_type="MLP"):
    """ Implements a learning loop over epochs. """

    # Initialize placeholder for logging
    log_acc_train, log_acc_test, train_loss = [], [], []

    # Get the initial set of parameters
    params = get_params(opt_state)                                              #Assumes all parameters are updatable. Otherwise send as argument in this function. 

    # Get initial accuracy after random init
    train_acc = accuracy(params, train_loader)
    test_acc = accuracy(params, test_loader)
    log_acc_train.append(train_acc)
    log_acc_test.append(test_acc)

    # Loop over the training epochs
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch_idx, (data, target) in enumerate(train_loader):
            if net_type == "MLP":
                # Flatten the image into 784-sized vectors for the MLP
                x = np.array(data).reshape(data.size(0), 28*28)
            elif net_type == "CNN":
                # No flattening of the input required for the CNN
                x = np.array(data)
            y = one_hot(np.array(target), num_classes)
            params, opt_state, loss = update(params, x, y, opt_state)
            train_loss.append(loss)

        epoch_time = time.time() - start_time
        train_acc = accuracy(params, train_loader)
        test_acc = accuracy(params, test_loader)
        log_acc_train.append(train_acc)
        log_acc_test.append(test_acc)
        print("Epoch {} | T: {:0.2f} | Train A: {:0.3f} | Test A: {:0.3f}".format(epoch+1, epoch_time,
                                                                    train_acc, test_acc))

    return train_loss, log_acc_train, log_acc_test


train_loss, train_log, test_log = run_mnist_training_loop(num_epochs,
                                                          opt_state,
                                                          net_type="MLP")


Epoch 1 | T: 16.56 | Train A: 0.973 | Test A: 0.968
Epoch 2 | T: 15.61 | Train A: 0.984 | Test A: 0.974
Epoch 3 | T: 15.55 | Train A: 0.990 | Test A: 0.979
Epoch 4 | T: 15.63 | Train A: 0.993 | Test A: 0.981
Epoch 5 | T: 15.41 | Train A: 0.992 | Test A: 0.978
Epoch 6 | T: 15.13 | Train A: 0.997 | Test A: 0.982
Epoch 7 | T: 15.27 | Train A: 0.996 | Test A: 0.980
Epoch 8 | T: 15.93 | Train A: 0.996 | Test A: 0.980
Epoch 9 | T: 15.80 | Train A: 0.995 | Test A: 0.981
Epoch 10 | T: 15.36 | Train A: 0.997 | Test A: 0.982


# CNN

In [None]:
from jax.experimental import stax
from jax.experimental.stax import (BatchNorm, Conv, Dense, Flatten,
                                   Relu, LogSoftmax)

The ```init_fun()``` below takes the ```key``` and the shape of input as its arguments. It returns the output shape and the randomly assigned parameters. 

The ```conv_net()``` function takes ```params``` and input of the shape specified in second argument of ```init_fun()``` and returns the result of the convolution operations specified in ```stax.serial()```. Note that if it is a function that returns ```f(x)``` , you can quickly make another one to get ```f'(x)``` .

In [None]:
init_fun, conv_net = stax.serial(Conv(32, (5, 5), (2, 2), padding="SAME"),      #First argument is number of out channels, second is filter shape, third stride.
                                 BatchNorm(), Relu,
                                 Conv(32, (5, 5), (2, 2), padding="SAME"),
                                 BatchNorm(), Relu,
                                 Conv(10, (3, 3), (2, 2), padding="SAME"),
                                 BatchNorm(), Relu,
                                 Conv(10, (3, 3), (2, 2), padding="SAME"), Relu,
                                 Flatten,
                                 Dense(num_classes),                            #Only final size needs to be specified !! 
                                 LogSoftmax)

output_shape, params = init_fun(key, (batch_size, 1, 28, 28))

Various types of initializations can also be specified for each layer. See [here](https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#serial) for default initializations of each layer. 

In [None]:
def accuracy(params, data_loader):
    """ Compute the accuracy for the CNN case (no flattening of input)"""
    acc_total = 0
    for batch_idx, (data, target) in enumerate(data_loader):
        images = np.array(data)
        targets = one_hot(np.array(target), num_classes)

        target_class = np.argmax(targets, axis=1)
        predicted_class = np.argmax(conv_net(params, images), axis=1)
        acc_total += np.sum(predicted_class == target_class)
    return acc_total/len(data_loader.dataset)

def loss(params, images, targets):
    preds = conv_net(params, images)
    return -np.sum(preds * targets)

In [None]:
step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)
num_epochs = 10

train_loss, train_log, test_log = run_mnist_training_loop(num_epochs,
                                                          opt_state,
                                                          net_type="CNN")