#Implementation of two hidden neural network classifier from scratch in JAX.
* Two hidden layers here means (input - hidden1 - hidden2 - output).
* You must not use flax, optax, or any other library for this task.
* Use MNIST dataset with 80:20 train:test split.
* Manually optimize the number of neurons in hidden layers.
* Use gradient descent from scratch to optimize your network. You should use the Pytree
  concept of JAX to do this elegantly.
* Plot loss v/s iterations curve with matplotlib.
* Evaluate the model on test data with various classification metrics and briefly discuss
  their implications

# Introduction 

* What is Jax?
  - JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.
  - JAX can automatically differentiate native Python and NumPy code.
  - It can differentiate through a large subset of Python’s features, including loops, ifs, recursion, and closures, and it can     even take derivatives of derivatives of derivatives.
  
* What is Neural network?
  - Neural networks reflect the behavior of the human brain, allowing computer programs to recognize patterns and solve common     problems in the fields of AI, machine learning, and deep learning.
  - Neural networks, also knows as artificial neural networks, are a subset of machine learning and are at the heart of deep 
    deep learning algorithm.
    
* Details of the dataset used
  - Dataset name: MNIST
  - MNIST stands for Mixed National Institute of Standards and Technology, which has produced a handwritten digits dataset.         This is one of the most researched datasets in machine learning, and is used to classify handwritten digits. This dataset       is helpful for predictive analytics because of its sheer size, allowing deep learning to work its magic efficiently.
  - Information: * name : MNIST * length : 70000
    Input Summary: * shape : (28, 28, 1) * range : (0.0, 1.0)
    Target Summary: * shape : (10,) * range : (0.0, 1.0)
    
* Pipeline 
  - We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API     to load images and labels.

  - Of course, we can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-         play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our           model.

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

# Hyperparameters
* Let's get a few bookkeeping items out of the way.

In [None]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(512, activation=tf.nn.tanh),
  tf.keras.layers.Dense(512, activation=tf.nn.tanh),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

## Auto-batching predictions

Let us first define our prediction function. Note that we're defining this for a _single_ image example. We're going to use JAX's `vmap` function to automatically handle mini-batches, with no performance penalty.

In [None]:

from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

Let's check that our prediction function only works on single images.

In [None]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [None]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')

Invalid shapes!


In [None]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


At this point, we have all the ingredients we need to define our neural network and train it. We've built an auto-batched version of `predict`, which we should be able to use in a loss function. We should be able to use `grad` to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to use `jit` to speed up everything.

In [None]:
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

## Data Loading with `tensorflow/datasets`

JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader.

sklearn.model.selection may also be used to import MNIST Dataset. Any type of library can be used.

In [None]:
import tensorflow as tf
from sklearn.model_selection import train_test_split
from keras.datasets import mnist

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x = jnp.concatenate((x_train, x_test))
y = jnp.concatenate((y_train, y_test))
train_size = 0.8
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=train_size, random_state=42)
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = x_train, y_train
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = x_test, y_test
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

In [None]:
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

Train: (56000, 784) (56000, 10)
Test: (14000, 784) (14000, 10)


## Training Loop

In [None]:
import time

def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 10.03 sec
Training set accuracy 0.9253214001655579
Test set accuracy 0.9264999628067017
Epoch 1 in 8.86 sec
Training set accuracy 0.942339301109314
Test set accuracy 0.9432142972946167
Epoch 2 in 9.21 sec
Training set accuracy 0.9525356888771057
Test set accuracy 0.9540714025497437
Epoch 3 in 9.27 sec
Training set accuracy 0.9593035578727722
Test set accuracy 0.9597142934799194
Epoch 4 in 9.77 sec
Training set accuracy 0.9646071195602417
Test set accuracy 0.9639999866485596
Epoch 5 in 10.34 sec
Training set accuracy 0.9685892462730408
Test set accuracy 0.9672142863273621
Epoch 6 in 9.42 sec
Training set accuracy 0.9718571305274963
Test set accuracy 0.9700714349746704
Epoch 7 in 9.14 sec
Training set accuracy 0.9746785759925842
Test set accuracy 0.9722856879234314
Epoch 8 in 9.36 sec
Training set accuracy 0.9769642949104309
Test set accuracy 0.9750714302062988
Epoch 9 in 9.27 sec
Training set accuracy 0.9791249632835388
Test set accuracy 0.9772142767906189


We've now used the whole of the JAX API: grad for derivatives, jit for speedups and vmap for auto-vectorization. We used NumPy to specify all of our computation, and borrowed the great data loaders from tensorflow/datasets, and ran the whole thing on the GPU.