# Image classification using a ConvNet with Flax, Linen and Optax

Flax is a deep learning library for JAX that is more flexible. Linen and Optax are used for training the Flax model.

- Perform a quick setup
- Build a convolutional neural network model with the Linen API that classifies images
- Define a loss and accuracy metrics function
- Create a dataset function with TensorFlow Datasets
- Define training and evaluation functions
- Load the MNIST dataset
- Initialize the parameters with PRNGs and instantiate the optimizer with Optax
- Train the network and evaluate it

### Setup

In [1]:
!pip install --upgrade -q pip jax jaxlib flax optax tensorflow-datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.6/72.6 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m204.3/204.3 KB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.9/154.9 KB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m239.0/239.0 KB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.9/73.9 KB[0m [31m4.6 MB/s[0m eta

In [2]:
import jax
import jax.numpy as jnp 

from flax import linen as nn 
from flax.training import train_state
import optax

import numpy as np
import tensorflow_datasets as tfds

### Build a ConvNet

In [3]:
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x) # 10 classes 
    return x

### Metrics

In [4]:
# optax has built-in softmax cross-entropy
# labels are one-hot encoded

def compute_metrics(logits, labels):
  loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10)))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }
  return metrics

### Dataset function

In [5]:
def get_datasets():
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  # Split into training/test sets
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  # Convert to floating-points
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
  return train_ds, test_ds

### Training and eval

In [6]:
@jax.jit
def train_step(state, batch):
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = jnp.mean(optax.softmax_cross_entropy(
        logits=logits, 
        labels=jax.nn.one_hot(batch['label'], num_classes=10)))
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits, batch['label'])
  return state, metrics


@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits, batch['label'])

### Define training and eval loop

In [7]:
# shuffle data
# run optimization
# get metrics
# update params

def train_epoch(state, train_ds, batch_size, epoch, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # Skip an incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  batch_metrics = []

  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)

  training_batch_metrics = jax.device_get(batch_metrics)
  training_epoch_metrics = {
      k: np.mean([metrics[k] for metrics in training_batch_metrics])
      for k in training_batch_metrics[0]}

  print('Training - epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, training_epoch_metrics['loss'], training_epoch_metrics['accuracy'] * 100))

  return state, training_epoch_metrics

# eval model on test data
# get metrics
# copy metrics
# print test loss acc
def eval_model(model, test_ds):
  metrics = eval_step(model, test_ds)
  metrics = jax.device_get(metrics)
  eval_summary = jax.tree_map(lambda x: x.item(), metrics)
  return eval_summary['loss'], eval_summary['accuracy']

### Load dataset


In [8]:
train_ds, test_ds = get_datasets()

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.




### Initialize ConvNet

In [9]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

cnn = CNN()
# random init
params = cnn.init(init_rng, jnp.ones([1, 28, 28, 1]))['params']

### Setup optimizer and state for updating params

In [10]:
nesterov_momentum = 0.9
learning_rate = 0.001
tx = optax.sgd(learning_rate=learning_rate, nesterov=nesterov_momentum)

# create a TrainState data class that applies the gradients and updates the optimizer state and parameters.
state = train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

### Train and eval ConvNet

In [None]:
num_epochs = 10
batch_size = 128

for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state, train_metrics = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print('Testing - epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))

Training - epoch: 1, loss: 2.2437, accuracy: 32.41
Testing - epoch: 1, loss: 2.16, accuracy: 60.28
Training - epoch: 2, loss: 2.0582, accuracy: 66.50
Testing - epoch: 2, loss: 1.91, accuracy: 73.00
Training - epoch: 3, loss: 1.6968, accuracy: 72.72
Testing - epoch: 3, loss: 1.42, accuracy: 77.04
Training - epoch: 4, loss: 1.1834, accuracy: 77.76
Testing - epoch: 4, loss: 0.93, accuracy: 82.32
Training - epoch: 5, loss: 0.8114, accuracy: 82.19
Testing - epoch: 5, loss: 0.67, accuracy: 85.12
Training - epoch: 6, loss: 0.6258, accuracy: 84.82
Testing - epoch: 6, loss: 0.54, accuracy: 86.83
Training - epoch: 7, loss: 0.5316, accuracy: 86.22
Testing - epoch: 7, loss: 0.48, accuracy: 87.73
Training - epoch: 8, loss: 0.4776, accuracy: 87.25
Testing - epoch: 8, loss: 0.43, accuracy: 88.54


In [None]:
# todos
# - pytorch data loader
# - visualize predictions

### Acknowledgements
This notebook is built with code from [JAX-Flax-Tutorial-Image-Classification-with-Linen](https://github.com/8bitmp3/JAX-Flax-Tutorial-Image-Classification-with-Linen)