<a href="https://colab.research.google.com/github/deterministic-algorithms-lab/Jax-Journey/blob/main/flax_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install ml-collections & latest Flax version from Github.
!pip install -q ml-collections git+https://github.com/google/flax

[?25l[K     |███▊                            | 10kB 14.9MB/s eta 0:00:01[K     |███████▍                        | 20kB 19.8MB/s eta 0:00:01[K     |███████████                     | 30kB 22.6MB/s eta 0:00:01[K     |██████████████▉                 | 40kB 16.8MB/s eta 0:00:01[K     |██████████████████▌             | 51kB 10.7MB/s eta 0:00:01[K     |██████████████████████▏         | 61kB 12.0MB/s eta 0:00:01[K     |█████████████████████████▉      | 71kB 10.9MB/s eta 0:00:01[K     |█████████████████████████████▋  | 81kB 11.9MB/s eta 0:00:01[K     |████████████████████████████████| 92kB 6.3MB/s 
[?25h  Building wheel for flax (setup.py) ... [?25l[?25hdone


ML Collections is a library of collections(like normal python ```collections``` module) specialised for ML. The repo can be viewed [here](https://github.com/google/ml_collections).

In [None]:
import ml_collections

def get_config():
    config = ml_collections.ConfigDict()
    
    config.learning_rate = 0.1
    config.momentum = 0.9
    config.batch_size = 128
    config.num_epochs = 10

    return config

# Imports

In [None]:
from absl import logging
import flax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import tensorflow_datasets as tfds

logging.set_verbosity(logging.INFO)

In [None]:
# Helper functions for images.

def show_img(img, ax=None, title=None):
  """Shows a single image."""
  if ax is None:
    ax = plt.gca()
  ax.imshow(img[..., 0], cmap='gray')
  ax.set_xticks([])
  ax.set_yticks([])
  if title:
    ax.set_title(title)

def show_img_grid(imgs, titles):
  """Shows a grid of images."""
  n = int(np.ceil(len(imgs)**.5))
  _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
  for i, (img, title) in enumerate(zip(imgs, titles)):
    show_img(img, axs[i // n][i % n], title)

In [None]:
# Local imports from current directory will auto reload.
# Any changes you make to local files will appear automatically.
%load_ext autoreload
%autoreload 2

In [None]:
config = get_config()

* ```tfds.as_numpy()``` takes in a dataset to a python generator, that generates numpy matrices here. 

* ```tf.DatasetBuilder.as_dataset()``` builds an input pipeline(taking care of all batch size, device etc.) using ```tf.data.Dataset```(s). The ```tf.data.Dataset```(s) correspond to the ```nn.Dataset``` of PyTorch.

In [None]:
def get_datasets():
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    #print(test_ds)                                                             #Each dataset has data in different format, so do check.Here the structure is a dict {'image':np array of all images, 'label': all labels}
    #print(train_ds['image'][0])                                                #Prints first image. The values are 0 to 255..
    train_ds['image'] = jnp.float32(train_ds['image']) /255 
    test_ds['image'] = jnp.float32(test_ds['image']) / 255
    return train_ds, test_ds

In [None]:
train_ds, test_ds = get_datasets()

In [None]:
show_img_grid(
    [train_ds['image'][idx] for idx in range(25)],
    [f'label={train_ds["label"][idx]}' for idx in range(25)],
)

# Model

In [None]:
from flax import linen as nn
from flax import optim
from flax.metrics import tensorboard
import numpy as onp
from jax import random
import jax

In [None]:
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3,3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
        x = nn.Conv(features=64, kernel_size=(3,3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2), strides=(2,3))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x

In [None]:
key = random.PRNGKey(0)
key1, key2 = random.split(key)
x = random.normal(key1, (1, 28, 28, 1))

model = CNN()
params = model.init(key2, x)
print(params)                                                                   #To check dictionary structure.. whether variables are there, etc.

# Optimizer

In [None]:
optimizer_def = optim.Momentum(learning_rate=config.learning_rate, 
                               beta=config.momentum)
optimizer = optimizer_def.create(params)

# Training

##Loss Funtion

In [None]:
def cross_entropy_loss(labels,logits):
    return -jnp.mean(jnp.sum(labels*logits, axis=-1))

In [None]:
max_classes=10
def onehot(label):
    x = (label[...,None]==jnp.arange(0,max_classes)[None])
    return x.astype(jnp.float32)

In [None]:
def loss_fn(params, batch):                                                     #Can input any number of arguments.
    logits = CNN().apply(params, batch['image'])                                #We are not constrained to use the same model as before.
    loss = cross_entropy_loss(onehot(batch['label']), logits)
    return loss, logits                                                         #Can output at most two values

## Metric Calculation

In [None]:
def compute_metric(logits, labels):
    loss = cross_entropy_loss(logits, onehot(labels))
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
        'loss' : loss,
        'accuracy' : accuracy,
    }
    return metrics

## Single Step Training

The ```has_aux=True``` below is necessary to indicate that the ```loss_fn``` returns two values, first of which is output of mathematical operation and second is auxillary data. The inability to print any the abstractions adopted by ```JAX``` are very nicely explained [here]( https://github.com/google/jax/issues/196#issuecomment-451671635 ) .

In [None]:
@jax.jit
def train_step(optimizer, batch):
    grad_n_val_fn = jax.value_and_grad(loss_fn, has_aux=True)                   #By default, gradients will be calculated w.r.t the first argument of loss_fn only. 
    (loss, logits), grad = grad_n_val_fn(optimizer.target, batch)
    optimizer = optimizer.apply_gradient(grad)
    
    #print(loss)                                                                #Not able to get value of loss directly. 
                                                                                #Can't print values inside jit compiled functions and others nested,inside it, yet.
    return optimizer, compute_metric(logits, batch['label'])

## Epoch Training

### Setting up data loading

In [None]:
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size//config.batch_size

perms = random.permutation(key, len(train_ds['image']))
perms = perms[:steps_per_epoch*config.batch_size]
perms = perms.reshape((steps_per_epoch, config.batch_size))

### Training loop

In [None]:
metrics = []
for perm in perms:
    batch = {k: v[perm] for k,v in train_ds.items()}                            #batch is a dictionary/pytree here
    optimizer, metric = train_step(optimizer, batch)
    metrics.append(metric)

metrics = jax.device_get(metrics)                                               #Get metrics from device into CPU as numpy arrays
mean_metrics = {k : onp.mean([metric[k] for metric in metrics])                 #Averaging metrics of all batches, while
                    for k in metrics[0]}                                        #Looping over all types of metrics
print(mean_metrics)                                                             #Can print outside any jit-ted functions

{'accuracy': 0.9871461, 'loss': 0.04197854}
