<a href="https://colab.research.google.com/github/gauravjain14/All-about-JAX/blob/main/Training_Simple_Neural_Network_Tensorflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction
This is the third notebook in our series of "All-about JAX".

In this notebook, we'll be combining what we learnt in [jax101.ipynb](https://github.com/gauravjain14/All-about-JAX/blob/main/jax101.ipynb) and [Flax_Basics.ipynb](https://github.com/gauravjain14/All-about-JAX/blob/main/Flax_Basics.ipynb) and use them to train a simple neural network with our beloved tensorflow/datasets Data Loading.

This notebook follows the JAX tutorial [Training a Simple Neural Network, with tensorflow/datasets Data Loading](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html).

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

## Hyperparameters

Getting a few bookkeeping items out of the way, for once and hopefully forever.

In [2]:
## Initialize weights and biases for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

## Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))
for p in params:
  print(p[0].shape, p[1].shape)



(512, 784) (512,)
(512, 512) (512,)
(10, 512) (10,)


## Auto-batching predictions

Using JAX's `vmap` function here. The idea behind this is that we can use the following prediction function, defined for a single image example, for handling mini-batches without paying any **performance penalty**

In [3]:
from jax.scipy.special import logsumexp

def relu(x):
  # We are defining relu here. I'm sure FLAX should already have one.
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

In [4]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [5]:
# Let's see if this works for multiple batches
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')

Invalid shapes!


### How to make batches work?

We saw that our `predict` function didn't directly work when we tried to feed **batched** inputs. To change that, we use `jax.vmap` to map the `predict` function such that it can work with batched inputs

In [6]:
batched_predict = vmap(predict, in_axes=(None, 0))

batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


## Utility and Loss functions

In [28]:
def one_hot(x, k, dtype=jnp.float32):
  """ Create a one-hot encoding of x of size k. """
  return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
            for (w, b), (dw, db) in zip(params, grads)]

## Data Loading with tensorflow/datasets

JAX finally decided to not reinvent the wheel here and rather choose one of the many amazing data loading frameworks available, viz a viz., the `tensorflow/datasets` data loader

In [14]:
import tensorflow as tf
# Ensure TF does not see GPU and grab all the GPU memory
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

## Fetch full datasets for evaluation
# tfds.load return tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with
# tfds.dataset_as_numpy
mnist_data, info = tfds.load(name='mnist', batch_size=-1, data_dir=data_dir, 
                             with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

In [15]:
print(info)

tfds.core.DatasetInfo(
    name='mnist',
    full_name='mnist/3.0.1',
    description="""
    The MNIST database of handwritten digits.
    """,
    homepage='http://yann.lecun.com/exdb/mnist/',
    data_path='/tmp/tfds/mnist/3.0.1',
    file_format=tfrecord,
    download_size=11.06 MiB,
    dataset_size=21.00 MiB,
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=60000, num_shards=1>,
    },
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
)


#### Extract images and labels from the dataset

In [22]:
## Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

## Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

In [24]:
print('Train: ', train_images.shape, train_labels.shape)
print('Test: ', test_images.shape, test_labels.shape)

Train:  (60000, 784) (60000, 10)
Test:  (10000, 784) (10000, 10)


## Training Loop - it's here, Alas!

A lot of the code in the following section is self-explanatory. I'll include comments where required and deemed necessary

In [29]:
import time

def get_train_batches():
  """ This functions look sort of redundant due to all the cells we executed
  in the Data Loading with tensorflow/datasets subsection.
  """
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, \
                 data_dir=data_dir)
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpyu converts the tf.data.Dataset into an iterable of
  # NumPy arrays
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

 
  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 20.61 sec
Training set accuracy 0.9821500182151794
Test set accuracy 0.9704999923706055
Epoch 1 in 10.33 sec
Training set accuracy 0.9837166666984558
Test set accuracy 0.9714999794960022
Epoch 2 in 20.58 sec
Training set accuracy 0.9849166870117188
Test set accuracy 0.972000002861023
Epoch 3 in 10.33 sec
Training set accuracy 0.9860166907310486
Test set accuracy 0.9731999635696411
Epoch 4 in 10.68 sec
Training set accuracy 0.9870833158493042
Test set accuracy 0.9739999771118164
Epoch 5 in 20.57 sec
Training set accuracy 0.9879500269889832
Test set accuracy 0.9747999906539917
Epoch 6 in 13.16 sec
Training set accuracy 0.9890666604042053
Test set accuracy 0.9751999974250793
Epoch 7 in 20.61 sec
Training set accuracy 0.9900833368301392
Test set accuracy 0.9752999544143677
Epoch 8 in 9.26 sec
Training set accuracy 0.9908000230789185
Test set accuracy 0.9756999611854553
Epoch 9 in 10.33 sec
Training set accuracy 0.9916000366210938
Test set accuracy 0.9763999581336975


## One additional experiment

In the training loop above, we have manually defined a loss function. However, to quickly understand the impact of loss function on the overall training efficiency, I am just playing around with one more pre-defined loss function ([by jaxopt](https://jaxopt.github.io/stable/objective_and_loss.html))

In [33]:
!pip install jaxopt
import jaxopt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
# redefine the `update` function from above

def update(params, x, y):
  print(y.shape)
  print(batched_predict(params, x).shape)
  grads = grad(jaxopt.loss.huber_loss)(y, batched_predict(params, x))

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))