In [1]:
import jax
from jax import nn, lax, random, vmap, jit, value_and_grad
from jax.nn import initializers
from jax.experimental import optimizers
from jax import numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.data import Dataset
import sys
from tensorflow.keras.utils import Progbar

In [2]:
#Here we are loading the dataset from keras
(x_train, y_train), (x_test, y_test) = mnist.load_data()
#converting the values into float32
x_train = x_train.astype('float32')
#adding extra dimension because the dataset by default is greyscale so the num of channels is one,
#it looks like this (60000, 28, 28) as you can see there is no channel dimension
#so we add extra dimension in the last axis (-1) to look like this (60000, 28, 28, 1)
x_train = jnp.expand_dims(x_train, -1)

#do the same for the test dataset
x_test = x_test.astype('float32')
x_test = jnp.expand_dims(x_test, -1)

#specify the batch size
batch_size = 64

#add the dataset into TF Dataset
dataset = Dataset.from_tensor_slices((x_train, y_train))
#divide the dataset into batches and drop any remainders
train_dataset = dataset.batch(batch_size, drop_remainder=True)
test_dataset = Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.batch(batch_size, drop_remainder=True)

#get the number of batches
num_batches = x_train.shape[0] // batch_size

In [3]:
def get_weights(key=random.PRNGKey(100)):
    'This function returns weights of the convnet'
    
    #here we initialize the first layer of the convnet
    #the shape which is (3,3,1,64) == (H, W, I, O)
    #where H = height of the filter, W = width of the filter, I is the number of channels of the input image
    #in mnist we have only one channel because images are grey scale, and finally O = output filters we want
    #in the first layer we want 64 filter
    conv1_weights = initializers.glorot_uniform()(shape=(3,3,1,64), key=key)
    conv1_bias = initializers.zeros(shape=(1,1,1,64), key=key)
    #here we have input channels = 64 which are the output filters of the first layer
    #so we need them as an input to the next layer
    #and we have 32 filters
    conv2_weights = initializers.glorot_uniform()(shape=(3,3,64,32), key=key)
    conv2_bias = initializers.zeros(shape=(1,1,1,32), key=key)
    #same as above
    conv3_weights = initializers.glorot_uniform()(shape=(3,3,32,16), key=key)
    conv3_bias = initializers.zeros(shape=(1,1,1,16), key=key)
    #here we start the dense layers with their biases, the first layer is 22*22*16 which is the product
    #of the output dimensions of the above conv layer (22,22,16)
    dense1_weights = initializers.glorot_uniform()(shape=(22*22*16, 512), key=key)
    #dense bias with shape 128
    dense1_bias = initializers.zeros(shape=(512,), key=key)
    
    #same as above
    dense2_weights = initializers.glorot_uniform()(shape=(512, 64), key=key)
    dense2_bias = initializers.zeros(shape=(64,), key=key)
    dense3_weights = initializers.glorot_uniform()(shape=(64, 10), key=key)
    dense3_bias = initializers.zeros(shape=(10,), key=key)
    
    return [(conv1_weights, conv1_bias), (conv2_weights, conv2_bias), (conv3_weights, conv3_bias), (dense1_weights, dense1_bias), (dense2_weights, dense2_bias), (dense3_weights, dense3_bias)]

In [4]:
params = get_weights()

In [5]:
def forward_conv(params, inputs):
    #unpack the params
    kernel, bias = params
    #add extra dim to the inputs because we need the batch axis to do the conv operation
    #so we add batch axis
    inputs = jnp.expand_dims(inputs, 0)
    #Here we define the dimenstion specs
    #we pass the image shape, kernel shape and the dimension specs
    
    #the first dimension spec is the how the image dimensions should look like, here I specified the image should be represented with the order
    #of the following dimensions where N= Number of examples per batch, H=Height of image, W=Width of the image, C=number of channels in the image
    
    #the second dimension spec is the kernel dimensions spec which is H=Height of the filter (the kernel), W=Width of the filter, I=number of channel of the input image
    #in out example we only have 1 channel because MNIST is greyscale, O=number of filters
    
    #the final dimension spec is output spec which describes how the output dimensions should be represented
    dn = lax.conv_dimension_numbers(inputs.shape, kernel.shape, ('NHWC', 'HWIO', 'NHWC'))
    #then we pass the inputs, the kernel, strides, padding and the specs we defined above
    #you can think of the dimension number as if you telling the conv layer how the inputs dimensions look like and how the kernel shape looks like
    #and how the output dimensions should look like
    out = lax.conv_general_dilated(inputs, kernel, window_strides=(1,1), dimension_numbers=dn, padding='valid')
    out = jnp.add(out, bias)
    out = nn.relu(out)
    return jnp.squeeze(out)

In [6]:
#first vmap is very convenient for batch operations, vmap takes a batch and process it all at once
#it's equivalent to np.stack([x for x in x_train[i, :]]) as if we loop over every example and pass it to the function
#in_axes means what axes to use from the parameters that we pass to the function
#in the forward_conv() we pass conv layers params and the inputs.
#we don't need to use any axes from the params we need it as it is,
#but in the inputs we need to loop over the batch axes which is the first axes so we pass to the in_axes=(None, 0)
#which basically means don't do any thing the first parameter but in the second parameter loop over the first axis

#JIT = Just-in-time compilation, which makes the code runs faster
#check this video for more information https://www.youtube.com/watch?v=svJerixawV0

batch_conv = jit(vmap(forward_conv, in_axes=(None, 0), out_axes=0))

In [7]:
def forward_dense(params, inputs):
    #unpack the params
    weights, bias = params
    out = jnp.dot(inputs, weights) + bias
    return out

In [8]:
#same as the batch_conv() function
batch_dense = jit(vmap(forward_dense, in_axes=(None, 0), out_axes=0))

In [9]:
def forward_pass(params, inputs):
    #add batch dimension
    out = jnp.expand_dims(inputs, 0)
    #loop over the first conv layers to pass them to forward_conv function
    for param in params[:3]:
        out = batch_conv(param, out)
    #flatten layer
    out = out.reshape((out.shape[0], -1))
    
    #loop over the dense layers except the last one
    #because we need the last one to have different activation function
    dense_params = params[3:-1]
    for param in dense_params:
        out = forward_dense(param, out)
        out = nn.relu(out)
    out = forward_dense(params[-1], out)\
    #the last activation function will be softmax
    return nn.softmax(jnp.squeeze(out))

In [10]:
forward_batch = jit(vmap(forward_pass, in_axes=(None, 0), out_axes=0))

In [11]:
def to_categorical(y, num_classes):
    return (y[:, None] == np.arange(num_classes))

In [12]:
def to_numbers(y):
    #changes from categorical data back into normal data
    #example [0,0,1,0,0,0,0,0,0] => 3
    return jnp.argmax(y, axis=-1)

In [13]:
#check this link if you are interested in more information in the negative log likelihood loss
#https://ljvmiranda921.github.io/notebook/2017/08/13/softmax-and-the-negative-log-likelihood/
def NegativeLogLikelihood(params, inputs, targets):
    preds = forward_batch(params, inputs)
    return jnp.mean(-jnp.log(preds[targets]))

In [14]:
#Setting up the optimizer
learning_rate = 0.0001
#optimizers in JAX returns 3 functions
#the first one is the initialization function
#the second one is the update function
#the third one used to get the new params from the optimizer
init_fn, update_fn, get_params = optimizers.adam(learning_rate)
#set up the optimizer state by initializing the optimizer using the initialization function
#and passing the current model params to it
optimizer_state = init_fn(params)

#defining number of epochs and number of classes
num_epochs = 2
num_classes = 10

In [15]:
def step(params, x, y, optimizer_state):
    #here we update the params
    #first we get the gradients using the loss with respect to the current params
    #value_and_grad function do 2 things, first it evaluates the function normally
    #then it gets the gradients of the function
    #we need to pass the model params to the loss because we need JAX to keep track of what happens to the params
    #to get the gradients
    value, grads = value_and_grad(NegativeLogLikelihood)(params, x, y)
    #then we pass the grads and the current optimizer state to return new optimizer state
    optimizer_state = update_fn(0, grads, optimizer_state)
    #finally we get the params from the optimizer state and return it with the optimizer state and the loss value
    return get_params(optimizer_state), optimizer_state, value

In [16]:
#here I used progress bar from tensorflow to show the training progress
#also we add the metrics we want to track
prgbar = Progbar(num_batches, stateful_metrics=['loss', 'Remaining Epochs']) 

In [17]:
#start the training by getting the params from the current optimizer state
params = get_params(optimizer_state)
#list to store the losses
losses = []
for epoch in range(num_epochs):
    #sums up the epoch loss then resets to zero
    epoch_loss = 0
    #keeps track of the processed batches
    finished_batches = 0
    for x, y in train_dataset:
        #convert the tensors from Tensorflow Eager tensors into numpy tensors
        x = x.numpy()
        y = y.numpy()
        #convert targets into categorical
        y_cat = to_categorical(y, num_classes)
        #updates the params by passing the current params, x, y and current optimizer state
        params, optimizer_state, loss_result = step(params, x, y_cat, optimizer_state)
        #updates the epoch loss
        epoch_loss += loss_result
        #update the finished batches
        finished_batches += 1
        #Here we update the progress bar after every update
        #we add the loss name which was specified when we defined the progress bar
        #and the value for that metric
        prgbar.update(finished_batches, values=[('loss', loss_result), ('Remaining Epochs', num_epochs-epoch)])
    losses.append(epoch_loss)



In [18]:
#load sample from the test dataset
sample_test_x, sample_test_y = next(iter(test_dataset.take(1)))
#convert the tensors from tensorflow tensors into numpy
sample_test_x = sample_test_x.numpy()
sample_test_y = sample_test_y.numpy()

In [19]:
preds = forward_batch(params, sample_test_x)

In [20]:
preds_as_nums = to_numbers(preds)

In [21]:
accuracy = jnp.sum(preds_as_nums == sample_test_y) / sample_test_y.shape[0]

In [22]:
print(f'Batch 1 Accuracy: {accuracy}')

Batch 1 Accuracy: 0.984375
