# Task 3 - Implement two hidden layers neural network classifier from scratch in JAX

* Two hidden layers here means (input - hidden1 - hidden2 - output).
* You must not use flax, optax, or any other library for this task.
* Use MNIST dataset with 80:20 train:test split.
* Manually optimize the number of neurons in hidden layers.
* Use gradient descent from scratch to optimize your network. You should use the Pytree concept of JAX to do this elegantly.
* Plot loss v/s iterations curve with matplotlib.
* Evaluate the model on test data with various classification metrics and briefly discuss their implications.




### Importing Libraries

In [13]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import tensorflow as tf
from sklearn.model_selection import train_test_split
from keras.datasets import mnist
tf.config.set_visible_devices([], device_type='GPU')
from jax.scipy.special import logsumexp
import tensorflow_datasets as tfds
import time
import matplotlib.pyplot as plt

### Dataset Processing

In [3]:
def one_hot(x, k, dtype=jnp.float32):
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
data_dir = '/tmp/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x = jnp.concatenate((x_train, x_test))
y = jnp.concatenate((y_train, y_test))
train_size = 0.8
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=train_size, random_state=42)
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c
train_images, train_labels = x_train, y_train
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)
test_images, test_labels = x_test, y_test
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /tmp/tfds/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]


[1mDataset mnist downloaded and prepared to /tmp/tfds/mnist/3.0.1. Subsequent calls will reuse this data.[0m
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz




In [4]:
print('Train Shape:', train_images.shape, train_labels.shape)
print('Test Shape:', test_images.shape, test_labels.shape)

Train Shape: (56000, 784) (56000, 10)
Test Shape: (14000, 784) (14000, 10)


### Model Building and Evaluation

In [5]:
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(512, activation=tf.nn.tanh),
  tf.keras.layers.Dense(512, activation=tf.nn.tanh),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [9]:
def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)
batched_predict = vmap(predict, in_axes=(None, 0))

In [10]:
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [12]:
def get_train_batches():
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  ds = ds.batch(batch_size).prefetch(1)
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 20.60 sec
Training set accuracy 0.9253214001655579
Test set accuracy 0.9264999628067017
Epoch 1 in 9.63 sec
Training set accuracy 0.942339301109314
Test set accuracy 0.9432142972946167
Epoch 2 in 9.37 sec
Training set accuracy 0.9525356888771057
Test set accuracy 0.9540714025497437
Epoch 3 in 10.34 sec
Training set accuracy 0.9593035578727722
Test set accuracy 0.9597142934799194
Epoch 4 in 9.47 sec
Training set accuracy 0.9646071195602417
Test set accuracy 0.9639999866485596
Epoch 5 in 10.34 sec
Training set accuracy 0.9685892462730408
Test set accuracy 0.9672142863273621
Epoch 6 in 10.36 sec
Training set accuracy 0.9718571305274963
Test set accuracy 0.9700714349746704
Epoch 7 in 9.33 sec
Training set accuracy 0.9746785759925842
Test set accuracy 0.9722856879234314
Epoch 8 in 10.34 sec
Training set accuracy 0.9769642949104309
Test set accuracy 0.9750714302062988
Epoch 9 in 10.34 sec
Training set accuracy 0.9791249632835388
Test set accuracy 0.9772142767906189
