# ObJAX CIFAR10 example

It is recommended to run this notebook on GPU. In Google Colab this could be set through `Runtime -> Change runtime type` menu.

# Installation and Imports

In [1]:
%pip --quiet install objax

In [2]:
import math
import random

import jax
import jax.numpy as jn
from jax.lax import lax

import numpy as np
import tensorflow_datasets as tfds

import objax
from objax.zoo.wide_resnet import WideResNet

## Parameters

In [3]:
base_learning_rate = 0.1 # Learning rate
lr_decay_epochs = 30     # How often to decay learning rate
lr_decay_factor = 0.2    # By how much to decay learning rate
weight_decay =  0.0005   # Weight decay
batch_size = 128         # Batch size
num_train_epochs = 100   # Number of training epochs
wrn_width = 2            # Width of WideResNet
wrn_depth = 28           # Depth of WideResNet

# Setup dataset and model

In [None]:
# Augmentation function for input data
def augment(x):  # x is NCHW
  """Random flip and random shift augmentation of image batch."""
  if random.random() < .5:
    x = x[:, :, :, ::-1]  # Flip the batch images about the horizontal axis
  # Pixel-shift all images in the batch by up to 4 pixels in any direction.
  x_pad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'reflect')
  rx, ry = np.random.randint(0, 4), np.random.randint(0, 4)
  x = x_pad[:, :, rx:rx + 32, ry:ry + 32]
  return x

# Data
data = tfds.as_numpy(tfds.load(name='cifar10', batch_size=-1))
x_train = data['train']['image'].transpose(0, 3, 1, 2) / 255.0
y_train = data['train']['label']
x_test = data['test']['image'].transpose(0, 3, 1, 2) / 255.0
y_test = data['test']['label']
del data

# Model
model = WideResNet(nin=3, nclass=10, depth=wrn_depth, width=wrn_width)
model_vars = model.vars()
weight_decay_vars = [v for k, v in model_vars.items() if k.endswith('.w')]

# Optimizer
opt = objax.optimizer.Momentum(model_vars, nesterov=True)

# Prediction operation
predict_op = lambda x: objax.functional.softmax(model(x, training=False))
predict_op = objax.Jit(predict_op, model_vars)

# Loss and training op
def loss_fn(x, label):
  logit = model(x, training=True)
  xe_loss = objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()
  wd_loss = sum((v.value ** 2).sum() for v in weight_decay_vars)
  return xe_loss + weight_decay * wd_loss

loss_gv = objax.GradValues(loss_fn, model.vars())

def train_op(x, y, learning_rate):
    grads, loss = loss_gv(x, y)
    opt(learning_rate, grads)
    return loss

all_vars = model_vars + opt.vars()
train_op = objax.Jit(train_op, all_vars)

**Model parameters**

In [5]:
print(model_vars)

(WideResNet)[0](Conv2D).w                                        432 (3, 3, 3, 16)
(WideResNet)[1](WRNBlock).proj_conv(Conv2D).w                    512 (1, 1, 16, 32)
(WideResNet)[1](WRNBlock).norm_1(BatchNorm2D).running_mean        16 (1, 16, 1, 1)
(WideResNet)[1](WRNBlock).norm_1(BatchNorm2D).running_var         16 (1, 16, 1, 1)
(WideResNet)[1](WRNBlock).norm_1(BatchNorm2D).beta                16 (1, 16, 1, 1)
(WideResNet)[1](WRNBlock).norm_1(BatchNorm2D).gamma               16 (1, 16, 1, 1)
(WideResNet)[1](WRNBlock).conv_1(Conv2D).w                      4608 (3, 3, 16, 32)
(WideResNet)[1](WRNBlock).norm_2(BatchNorm2D).running_mean        32 (1, 32, 1, 1)
(WideResNet)[1](WRNBlock).norm_2(BatchNorm2D).running_var         32 (1, 32, 1, 1)
(WideResNet)[1](WRNBlock).norm_2(BatchNorm2D).beta                32 (1, 32, 1, 1)
(WideResNet)[1](WRNBlock).norm_2(BatchNorm2D).gamma               32 (1, 32, 1, 1)
(WideResNet)[1](WRNBlock).conv_2(Conv2D).w                      9216 (3, 3, 32, 32)
(

# Training loop

In [6]:
def lr_schedule(epoch):
  return base_learning_rate * math.pow(lr_decay_factor, epoch // lr_decay_epochs)

num_train_examples = x_train.shape[0]
num_test_examples = x_test.shape[0]
for epoch in range(num_train_epochs):
  # Training
  example_indices = np.arange(num_train_examples)
  np.random.shuffle(example_indices)
  for idx in range(0, num_train_examples, batch_size):
    x = x_train[example_indices[idx:idx + batch_size]]
    y = y_train[example_indices[idx:idx + batch_size]]
    loss = train_op(augment(x), y, lr_schedule(epoch))[0]

  # Eval
  accuracy = 0
  for idx in range(0, num_test_examples, batch_size):
    x = x_test[idx:idx + batch_size]
    y = y_test[idx:idx + batch_size]
    p = predict_op(x)
    accuracy += (np.argmax(p, axis=1) == y).sum()
  accuracy /= num_test_examples
  print(f'Epoch {epoch+1:3} -- train loss {loss:.3f}   test accuracy {accuracy*100:.1f}', flush=True)


Epoch   1 -- train loss 2.059   test accuracy 35.4
Epoch   2 -- train loss 1.291   test accuracy 70.2
Epoch   3 -- train loss 1.266   test accuracy 65.7
Epoch   4 -- train loss 1.005   test accuracy 67.2
Epoch   5 -- train loss 0.909   test accuracy 74.4
Epoch   6 -- train loss 0.808   test accuracy 76.3
Epoch   7 -- train loss 0.813   test accuracy 71.0
Epoch   8 -- train loss 0.770   test accuracy 78.8
Epoch   9 -- train loss 0.701   test accuracy 77.2
Epoch  10 -- train loss 0.788   test accuracy 74.7
Epoch  11 -- train loss 0.834   test accuracy 74.5
Epoch  12 -- train loss 1.031   test accuracy 77.7
Epoch  13 -- train loss 0.889   test accuracy 77.8
Epoch  14 -- train loss 0.697   test accuracy 75.3
Epoch  15 -- train loss 0.669   test accuracy 75.8
Epoch  16 -- train loss 0.643   test accuracy 80.4
Epoch  17 -- train loss 0.840   test accuracy 80.2
Epoch  18 -- train loss 0.779   test accuracy 78.7
Epoch  19 -- train loss 0.768   test accuracy 78.8
Epoch  20 -- train loss 0.947  