# ObJAX CIFAR10 example

This example is based on [cifar10_simple.py](https://github.com/google/objax/blob/master/examples/classify/img/cifar10_simple.py) with few minor changes:

* it demonstrates how to do weight decay,
* it uses Momentum optimizer with learning rate schedule,
* it uses `tensorflow_datasets` instead of `Keras` dataset.

It's recommended to run this notebook on GPU. In Google Colab this could be set through `Runtime -> Change runtime type` menu.

# Installation and Imports

In [1]:
%pip --quiet install objax

In [2]:
import math
import random

import numpy as np
import tensorflow_datasets as tfds

import objax
from objax.zoo.wide_resnet import WideResNet

## Parameters

In [3]:
base_learning_rate = 0.1 # Learning rate
lr_decay_epochs = 30     # How often to decay learning rate
lr_decay_factor = 0.2    # By how much to decay learning rate
weight_decay =  0.0005   # Weight decay
batch_size = 128         # Batch size
num_train_epochs = 100   # Number of training epochs
wrn_width = 2            # Width of WideResNet
wrn_depth = 28           # Depth of WideResNet

# Setup dataset and model

In [None]:
# Augmentation function for input data
def augment(x):  # x is NCHW
  """Random flip and random shift augmentation of image batch."""
  if random.random() < .5:
    x = x[:, :, :, ::-1]  # Flip the batch images about the horizontal axis
  # Pixel-shift all images in the batch by up to 4 pixels in any direction.
  x_pad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'reflect')
  rx, ry = np.random.randint(0, 4), np.random.randint(0, 4)
  x = x_pad[:, :, rx:rx + 32, ry:ry + 32]
  return x

# Data
data = tfds.as_numpy(tfds.load(name='cifar10', batch_size=-1))
x_train = data['train']['image'].transpose(0, 3, 1, 2) / 255.0
y_train = data['train']['label']
x_test = data['test']['image'].transpose(0, 3, 1, 2) / 255.0
y_test = data['test']['label']
del data

# Model
model = WideResNet(nin=3, nclass=10, depth=wrn_depth, width=wrn_width)
weight_decay_vars = [v for k, v in model.vars().items() if k.endswith('.w')]

# Optimizer
opt = objax.optimizer.Momentum(model.vars(), nesterov=True)

# Prediction operation
predict_op = objax.nn.Sequential([objax.ForceArgs(model, training=False), objax.functional.softmax])
predict_op = objax.Jit(predict_op)

# Loss and training op
@objax.Function.with_vars(model.vars())
def loss_fn(x, label):
  logit = model(x, training=True)
  xe_loss = objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()
  wd_loss = sum((v ** 2).sum() for v in weight_decay_vars)
  return xe_loss + weight_decay * wd_loss

loss_gv = objax.GradValues(loss_fn, model.vars())

@objax.Function.with_vars(model.vars() + loss_gv.vars() + opt.vars())
def train_op(x, y, learning_rate):
    grads, loss = loss_gv(x, y)
    opt(learning_rate, grads)
    return loss

train_op = objax.Jit(train_op)

**Model parameters**

In [5]:
print(model.vars())

(WideResNet)[0](Conv2D).w                                        432 (3, 3, 3, 16)
(WideResNet)[1](WRNBlock).proj_conv(Conv2D).w                    512 (1, 1, 16, 32)
(WideResNet)[1](WRNBlock).norm_1(BatchNorm2D).running_mean        16 (1, 16, 1, 1)
(WideResNet)[1](WRNBlock).norm_1(BatchNorm2D).running_var         16 (1, 16, 1, 1)
(WideResNet)[1](WRNBlock).norm_1(BatchNorm2D).beta                16 (1, 16, 1, 1)
(WideResNet)[1](WRNBlock).norm_1(BatchNorm2D).gamma               16 (1, 16, 1, 1)
(WideResNet)[1](WRNBlock).conv_1(Conv2D).w                      4608 (3, 3, 16, 32)
(WideResNet)[1](WRNBlock).norm_2(BatchNorm2D).running_mean        32 (1, 32, 1, 1)
(WideResNet)[1](WRNBlock).norm_2(BatchNorm2D).running_var         32 (1, 32, 1, 1)
(WideResNet)[1](WRNBlock).norm_2(BatchNorm2D).beta                32 (1, 32, 1, 1)
(WideResNet)[1](WRNBlock).norm_2(BatchNorm2D).gamma               32 (1, 32, 1, 1)
(WideResNet)[1](WRNBlock).conv_2(Conv2D).w                      9216 (3, 3, 32, 32)
(

# Training loop

In [6]:
def lr_schedule(epoch):
  return base_learning_rate * math.pow(lr_decay_factor, epoch // lr_decay_epochs)

num_train_examples = x_train.shape[0]
num_test_examples = x_test.shape[0]
for epoch in range(num_train_epochs):
  # Training
  example_indices = np.arange(num_train_examples)
  np.random.shuffle(example_indices)
  for idx in range(0, num_train_examples, batch_size):
    x = x_train[example_indices[idx:idx + batch_size]]
    y = y_train[example_indices[idx:idx + batch_size]]
    loss = train_op(augment(x), y, lr_schedule(epoch))[0]

  # Eval
  accuracy = 0
  for idx in range(0, num_test_examples, batch_size):
    x = x_test[idx:idx + batch_size]
    y = y_test[idx:idx + batch_size]
    p = predict_op(x)
    accuracy += (np.argmax(p, axis=1) == y).sum()
  accuracy /= num_test_examples
  print(f'Epoch {epoch+1:3} -- train loss {loss:.3f}   test accuracy {accuracy*100:.1f}', flush=True)


Epoch   1 -- train loss 2.183   test accuracy 57.6
Epoch   2 -- train loss 1.348   test accuracy 64.7
Epoch   3 -- train loss 1.186   test accuracy 64.6
Epoch   4 -- train loss 0.975   test accuracy 69.7
Epoch   5 -- train loss 1.114   test accuracy 68.5
Epoch   6 -- train loss 0.837   test accuracy 68.8
Epoch   7 -- train loss 0.871   test accuracy 74.4
Epoch   8 -- train loss 0.996   test accuracy 74.8
Epoch   9 -- train loss 0.764   test accuracy 79.1
Epoch  10 -- train loss 0.813   test accuracy 77.6
Epoch  11 -- train loss 0.933   test accuracy 70.4
Epoch  12 -- train loss 0.937   test accuracy 65.5
Epoch  13 -- train loss 0.814   test accuracy 79.5
Epoch  14 -- train loss 0.638   test accuracy 76.2
Epoch  15 -- train loss 0.984   test accuracy 77.6
Epoch  16 -- train loss 0.844   test accuracy 79.8
Epoch  17 -- train loss 0.690   test accuracy 77.9
Epoch  18 -- train loss 0.692   test accuracy 77.4
Epoch  19 -- train loss 0.706   test accuracy 76.2
Epoch  20 -- train loss 0.786  