# Supervised learning tutorial

In supervised learning, we are given a set of (input, label) pairs `(x_i, y_i)`, and the goal is to learn a function that maps the inputs to the correct labels. The hope is that, once learnt, this function would generalise to samples that were not seen during training.

In this tutorial, we will focus on the task of image classification and we use Cifar10 dataset and Resnet18 neural architecture to learn the desired mapping (Part 1). 

Once the model is trained, we will perform an [adversarial attack in pixel space](https://arxiv.org/pdf/1312.6199.pdf) on this classifier to highlight some of the weaknesses of these models (Part 2).

Finally, we will use a self-attention mechanism ([Squeeze-and-Excitation](https://arxiv.org/pdf/1709.01507.pdf)) to improve the generalisation power of the model (Part 3). *Note*: Although the test accuracy is improved, the robustness to adversarial attacks is not necessarily improved.

**Key takeaways**

By the end of this tutorial, you will know:
* how to implement (using Jax and haiku libraries) a residual convolutional neural network for image classification and how to train it on an image dataset, using standard data augmentation, with weight decay regularisation and batch normalisation
* how to write new network modules
* how to use the same backpropagation algorithm that was initially used for training the network, to build adversarial examples that fool the network
* how to add new loss terms (to perform the adversarial attack)
* how to implement a simple, yet effective self-attention mechanism to improve classification accuracy.

**Homework**

Test the classifier's robustness to changes in the input distribution in the geometric space (e.g. by applying rotations to the inputs). What do you observe? How can the observed behaviour be prevented (at least partially)?

# Part 1: Resnet18 classifier on Cifar10


In [None]:
from typing import Iterable, Mapping, Tuple, Generator, Optional, Sequence, Text, Union, List

# We will use haiku on top of jax; it is not included by default, so let's install it  
!pip install -q dm-haiku
import haiku as hk

import jax
from jax.experimental import optix  # package for optimizer
import jax.numpy as jnp  # equivalent of numpy on GPU and TPU
import numpy as np  # original numpy

!pip install -q dm-tree
import tree
import enum
import time

# Dataset libraries
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

# Plotting library.
from matplotlib import pyplot as plt
import pylab as pl
from IPython import display

# Don't forget to select GPU runtime environment in Runtime -> Change runtime type
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

In [None]:
# Define useful types.
OptState = Tuple[optix.TraceState, optix.ScaleByScheduleState, optix.ScaleState]
Scalars = Mapping[str, jnp.ndarray]
Batch = Mapping[Text, np.ndarray]
ClassNames = Mapping[List, str]

In [None]:
# Dataset constants for cifar10 dataset, the "MNIST of real images":
# it contains low-res natural images (32x32x3) belonging to 10 classes.
dataset_name = 'cifar10'
class_cifar10 = [u'airplane', u'automobile', u'bird', u'cat', u'deer', u'dog', u'frog', u'horse', u'ship', u'truck'] 
train_split = 'train'
eval_split = 'test'
num_examples = {train_split: 50000,
                eval_split: 10000}
num_classes = 10

### Hyper-parameters for training and optimiser

In [None]:
train_batch_size = 128 #@param
eval_batch_size = 100  #@param
model_bn_decay = 0.9  #@param
train_weight_decay = 1e-4  #@param
optimizer_momentum = 0.9  #@param
optimizer_use_nesterov = True  #@param
train_eval_every = 1000  #@param
train_init_random_seed = 42  #@param
train_log_every = 100  #@param
num_train_steps = 400e3  #@param
num_eval_steps = (num_examples[eval_split]) // eval_batch_size

### Dataset loading and preprocessing

In [None]:
# We use tensorflow readers; JAX does not have support for input data reading
# and pre-processing.
def load(split: str,
         *,
         is_training: bool,
         batch_size: int) -> Generator[Batch, None, None]:
  """Loads the dataset as a generator of batches."""
  ds = tfds.load('cifar10', split=split).cache().repeat()
  
  if is_training:
    ds = ds.shuffle(10 * batch_size, seed=0)

  # Define the preprocessing for each train and test image
  def preprocess(example):
    image = _preprocess_image(example['image'], is_training)
    return {'image': image, 'label': example['label']}

  # Apply the preprocessing function to all samples in a batch using `map`
  ds = ds.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  # Get samples grouped in mini-batches to train using SGD
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

  return tfds.as_numpy(ds)  # return numpy array

def _preprocess_image(
    image: tf.Tensor,
    is_training: bool,
) -> tf.Tensor:
  """Returns processed and resized image."""
  # Images are stored as uint8; we convert to float for further processing.
  image = tf.cast(image, tf.float32)
  # Normalise pixel values between -1 and 1: original images are in [0, 255].
  # We normalise to [-1, 1] to have 0 mean and unit variance in the inputs,
  # as it makes the training more stable. Note that we do this normalisation 
  # over the activations of all the layers in the network by using batch 
  # normalisation layers.
  image = 2 * (image / 255.0) - 1.0

  # During training, we use data augmentation (left-right flips, random crops).
  # In this way, we are effectively increasing the size of the training dataset,
  # leading to improved generalisation.
  if is_training:
    image = tf.image.random_flip_left_right(image)
    # Pad images by reflecting the boundaries and randomly sample a 32x32 patch.
    image = tf.pad(image, [[4, 4], [4, 4], [0, 0]], mode='REFLECT')
    image = tf.image.random_crop(image, size=(32, 32, 3))
  return image

### Function to display images

In [None]:
MAX_IMAGES = 8
def gallery(images: np.ndarray,
            label: np.ndarray,
            class_names: ClassNames=class_cifar10,
            title: str='Input images'):  
  """Display a batch of images."""
  num_frames, h, w, num_channels = images.shape
  num_frames = min(num_frames, MAX_IMAGES)
  ff, axes = plt.subplots(1, num_frames,
                          figsize=(32, 32),
                          subplot_kw={'xticks': [], 'yticks': []})
  if images.min() < 0:
    images = (images + 1.) / 2.
  for i in range(0, num_frames):
    if num_channels == 3:
      axes[i].imshow(np.squeeze(images[i]))
    else:
      axes[i].imshow(np.squeeze(images[i]), cmap='gray')
    axes[i].set_title(class_names[label[i]], fontsize=28)
    plt.setp(axes[i].get_xticklabels(), visible=False)
    plt.setp(axes[i].get_yticklabels(), visible=False)
  ff.subplots_adjust(wspace=0.1)
  plt.show()

### Create a resnet block (coding exercise)

In a typical sequential model (no branching), the network as a whole is optimised to find the mapping between inputs and correct labels. In residual networks, each layer can learn an additive residual representation wrt to the representation already computed up to the previous layer, making the optimisation easier.

As opposed to [resnet-v1](https://arxiv.org/pdf/1512.03385.pdf) blocks (left), [resnet-v2](https://arxiv.org/pdf/1603.05027.pdf) blocks (right) use pre-activation modules, i.e. the batch normalisation (`BN`) and relu (`ReLU`) nonlinearity are applied within the resnet block, before the convolutional layer (`weight`). This allows the model to learn identity mappings over the shortcuts throughout the network, improving further the backpropagation of gradients.   

<img src="https://github.com/eemlcommunity/PracticalSessions2020/blob/master/assets/v1v2.png?raw=true" alt="resnet blocks" style="width: 80px;"/>

Figure from original [resnet-v2 paper](https://arxiv.org/pdf/1603.05027.pdf).

*Bottleneck blocks*: To reduce the number of parameters and memory footprint without sacrificing expressivity, bottleneck blocks can be applied. Instead of using 2 conv layers (`weight` in the figure above) with 3x3 filters, empirically it is shown that projecting in a lower dimensional space (using 1x1 conv layers), applying 3x3 convolutions, and then reprojecting back into the original dimension space, does not affect accuracy.    

*1x1 conv shortcuts*: when the input and the output of a resnet block have different numbers of channels, 1x1 convolutional layers are used on the shortcut to project the representation to the desired output feature dimension.

In [None]:
def check_length(length, value, name):
  if len(value) != length:
    raise ValueError(f"`{name}` must be of length {length} not {len(value)}")

class BlockV2(hk.Module):
  """ResNet V2 block with optional bottleneck."""

  def __init__(
      self,
      channels: int,
      stride: Union[int, Sequence[int]],
      use_projection: bool,
      bn_config: Mapping[str, float],
      bottleneck: bool,
      name: Optional[str] = None,
  ):
    super().__init__(name=name)
    self.use_projection = use_projection

    # Define batch norm parameters: the batch_norm layer normalises the inputs 
    # to have zero mean and unit variance. To not affect the expressivity
    # of the network, e.g. in cases where it would be better for the activations
    # to not be 0-centred or to have larger variance, batch_norm can optionally
    # learn a scale and an offset parameters. 
    bn_config = dict(bn_config)
    bn_config.setdefault("create_scale", True)
    bn_config.setdefault("create_offset", True)

    # See comment above about 1x1 conv shortcut 
    if self.use_projection:
      self.proj_conv = hk.Conv2D(
          output_channels=channels,
          kernel_shape=1,
          stride=stride,
          with_bias=False,
          padding="SAME",
          name="shortcut_conv")

    # If we use bottleneck blocks (see comment above), inside the resnet block 
    # we first project the activations into a lower dimensional space, 
    # which has number of channels divided by `channel_div` compared to the 
    # desired number of channels in the output.
    channel_div = 4 if bottleneck else 1
    conv_0 = hk.Conv2D(
        output_channels=channels // channel_div,
        kernel_shape=1 if bottleneck else 3,
        stride=1,
        with_bias=False,
        padding="SAME",
        name="conv_0")

    bn_0 = hk.BatchNorm(name="batchnorm_0", **bn_config)
    # Then we apply the 3x3 conv layer
    conv_1 = hk.Conv2D(
        output_channels=channels // channel_div,
        kernel_shape=3,
        stride=stride,
        with_bias=False,
        padding="SAME",
        name="conv_1")

    bn_1 = hk.BatchNorm(name="batchnorm_1", **bn_config)
    layers = ((conv_0, bn_0), (conv_1, bn_1))

    # When using bottleneck, we have also a 3rd 1x1 convolutional layer
    # within the resnet block (see comment above about bottleneck blocks)
    if bottleneck:
      conv_2 = hk.Conv2D(
          output_channels=channels,
          kernel_shape=1,
          stride=1,
          with_bias=False,
          padding="SAME",
          name="conv_2")

      bn_2 = hk.BatchNorm(name="batchnorm_2", **bn_config)
      layers = layers + ((conv_2, bn_2),)

    self.layers = layers

  def __call__(self, inputs, is_training, test_local_stats):
    x = shortcut = inputs
    ######################
    ### YOUR CODE HERE ###
    ######################
    for i, (conv_i, bn_i) in enumerate(self.layers):
      # Apply pre-activation: batch_norm + relu
      x = bn_i(x, is_training, test_local_stats)
      x = jax.nn.relu(x)
      # If using 1x1 conv projection on the shortcut, apply proj_conv once 
      if i == 0 and self.use_projection:
        shortcut = self.proj_conv(x)
      # Apply convolution
      x = conv_i(x)

    return x + shortcut

In [None]:
#@title Stack resnet blocks
class BlockGroup(hk.Module):
  """Group of blocks for ResNet implementation."""

  def __init__(
      self,
      channels: int,
      num_blocks: int,
      stride: Union[int, Sequence[int]],
      bn_config: Mapping[str, float],
      bottleneck: bool,
      use_projection: bool,
      name: Optional[str] = None,
  ):
    super().__init__(name=name)

    self.blocks = []
    for i in range(num_blocks):
      self.blocks.append(
          BlockV2(channels=channels,
                  stride=(1 if i else stride),
                  use_projection=(i == 0 and use_projection),
                  bottleneck=bottleneck,
                  bn_config=bn_config,
                  name="block_%d" % (i)))

  def __call__(self, inputs, is_training, test_local_stats):
    out = inputs
    for block in self.blocks:
      out = block(out, is_training, test_local_stats)
    return out

In [None]:
#@title Define a generic resnet architecture
# Note: This class is generic, it can be used to instantiate any Resnet 
# model, e.g. Resnet-50, Resnet-101, etc. by substituting the correct block
# parameters 
class ResNet(hk.Module):
  """ResNet model."""

  def __init__(
      self,
      blocks_per_group: Sequence[int],
      num_classes: int,
      bn_config: Optional[Mapping[str, float]] = None,
      bottleneck: bool = True,
      channels_per_group: Sequence[int] = (256, 512, 1024, 2048),
      use_projection: Sequence[bool] = (True, True, True, True),
      name: Optional[str] = None,
  ):
    """Constructs a ResNet model.
    Args:
      blocks_per_group: A sequence of length 4 that indicates the number of
        blocks created in each group.
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers. By default the `decay_rate` is
        `0.9` and `eps` is `1e-5`.
      bottleneck: Whether the block should bottleneck or not. Defaults to True.
      channels_per_group: A sequence of length 4 that indicates the number
        of channels used for each block in each group.
      use_projection: A sequence of length 4 that indicates whether each
        residual block should use projection.
      name: Name of the module.
    """
    super().__init__(name=name)
    bn_config = dict(bn_config or {})
    bn_config.setdefault("decay_rate", 0.9)
    bn_config.setdefault("eps", 1e-5)
    bn_config.setdefault("create_scale", True)
    bn_config.setdefault("create_offset", True)

    # Number of blocks in each group for ResNet.
    check_length(4, blocks_per_group, "blocks_per_group")
    check_length(4, channels_per_group, "channels_per_group")

    # We first convolve the image with 7x7 filters, to be able to better extract
    # low-level features such as contours. Using conv with stride=2 halves the
    # resolution of the input, reducing considerably the computation cost, and
    # increasing the receptive field.  
    self.initial_conv = hk.Conv2D(
        output_channels=64,
        kernel_shape=7,
        stride=2,
        with_bias=False,
        padding="SAME",
        name="initial_conv")

    self.block_groups = []
    strides = (1, 2, 2, 2)
    for i in range(4):
      self.block_groups.append(
          BlockGroup(channels=channels_per_group[i],
                     num_blocks=blocks_per_group[i],
                     stride=strides[i],
                     bn_config=bn_config,
                     bottleneck=bottleneck,
                     use_projection=use_projection[i],
                     name="block_group_%d" % (i)))

    self.final_batchnorm = hk.BatchNorm(name="final_batchnorm", **bn_config)
    self.logits = hk.Linear(num_classes, w_init=jnp.zeros, name="logits")

  def __call__(self, inputs, is_training, test_local_stats=False):
    out = inputs
    out = self.initial_conv(out)
    # Reduce the spatial resolution of the activations by a factor of 2. This
    # increases the receptive field and reduces the computation cost. Note that
    # compared to a strided conv which has the same effects, the pooling layers 
    # does not have trainable parameters.
    out = hk.max_pool(out,
                      window_shape=(1, 3, 3, 1),
                      strides=(1, 2, 2, 1),
                      padding="SAME")

    for block_group in self.block_groups:
      out = block_group(out, is_training, test_local_stats)

    out = self.final_batchnorm(out, is_training, test_local_stats)
    out = jax.nn.relu(out)

    # Pool over spatial dimensions to obtain the final vector embedding
    # of the image. Use jnp.mean and not hk.avg_pool, to make sure that the
    # network can be applied to inputs with any resolution without modification
    # of the model.
    out = jnp.mean(out, axis=[1, 2])
    return self.logits(out)

In [None]:
#@title Instantiate Resnet18
class ResNet18(ResNet):
  """ResNet18."""

  def __init__(self,
               num_classes: int,
               bn_config: Optional[Mapping[str, float]] = None,
               name: Optional[str] = None):
    """Constructs a ResNet model.
    Args:
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers.
      name: Name of the module.
    """
    super().__init__(blocks_per_group=(2, 2, 2, 2),
                     num_classes=num_classes,
                     bn_config=bn_config,
                     bottleneck=False,
                     channels_per_group=(64, 128, 256, 512),
                     use_projection=(False, True, True, True),
                     name=name)

### Create the forward pass of the model

In [None]:
def net_fn(
    batch: Batch,
    is_training: bool,
) -> jnp.ndarray:
  """Forward pass of the resnet."""
  model = ResNet18(num_classes, bn_config={'decay_rate': model_bn_decay})
  return model(batch['image'], is_training=is_training)

# Transform the forward function into a pair of pure functions.
# We use transform with state because we need to keep the state of the network,
# e.g. for batch norm statistics.
net = hk.transform_with_state(net_fn)

### Define learning rate schedule and optimizer

In [None]:
# We use learning rate annealing during training. We start with a larger
# learning rate `lr_init` which allows exploring faster the space of solutions
# and we reduce it by a factor of 10 `lr_factor` after a predefined number of
# steps. Smaller learning rate at the end of the training allows the model to
# explore a local neighbourhood and settle on a good local minimum. 
def lr_schedule(step: jnp.ndarray) -> jnp.ndarray:
  """Define learning rate annealing schedule."""
  # After how many steps to apply the learning rate reduction
  boundaries = jnp.array((200e3, 300e3, 350e3))
  # Every time we hit a predefined number of steps, we apply the reduction
  # of the learning rate by `lr_factor`
  lr_decay_exponent = jnp.sum(step >= boundaries)
  lr_init = 0.1
  lr_factor = 0.1
  return lr_init * lr_factor**lr_decay_exponent

# Define the optimiser, we use SGD with momentum
def make_optimizer():
  """SGD with nesterov momentum and a custom lr schedule."""
  return optix.chain(optix.trace(decay=optimizer_momentum,
                                 nesterov=optimizer_use_nesterov),
                     optix.scale_by_schedule(lr_schedule),
                     optix.scale(-1))


### Define the loss function: cross-entropy for classification and weight decay for regularization

In [None]:
# Function to compute l2 loss - useful for regularisation
def l2_loss(params: Iterable[jnp.ndarray]) -> jnp.ndarray:
  return 0.5 * sum(jnp.sum(jnp.square(p)) for p in params)

# Function to compute softmax cross entropy for classification
def softmax_cross_entropy(
    *,
    logits: jnp.ndarray,
    labels: jnp.ndarray,
) -> jnp.ndarray:
  return -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)

def loss_fn(
    params: hk.Params,
    state: hk.State,
    batch: Batch,
) -> Tuple[jnp.ndarray, hk.State]:
  """Computes a regularized loss for the given batch."""
  # The third parameter would be an rng key if one is needed in running
  # the model, e.g. for dropout. If not needed, pass `None`.
  logits, state = net.apply(params, state, None, batch, is_training=True)
  # The labels are given as class indices; convert to one_hot representation
  labels = jax.nn.one_hot(batch['label'], num_classes)
  # Compute classification loss
  cat_loss = jnp.mean(softmax_cross_entropy(logits=logits, labels=labels))
  # Get all the trainable parameters of the model, except batch_norm parameters
  # to apply weight decay regularisation , i.e. we penalise weights with
  # large magnitude
  l2_params = [p for ((mod_name, _), p) in tree.flatten_with_path(params)
               if 'batchnorm' not in mod_name]
  # We apply a weighting factor to the regularisation loss, so that it does
  # not dominate the total loss
  reg_loss = train_weight_decay * l2_loss(l2_params)
  # Compute the final loss
  loss = cat_loss + reg_loss
  return loss, state

### Define the training step and training dataset

In [None]:
@jax.jit
def train_step(
    params: hk.Params,
    state: hk.State,
    opt_state: OptState,
    batch: Batch, 
) -> Tuple[hk.Params, hk.State, OptState, Scalars]:
  """Applies an update to parameters and returns new state."""
  (loss, state), grads = (
      jax.value_and_grad(loss_fn, has_aux=True)(params, state, batch))

  # Compute and apply updates via our optimizer.
  updates, opt_state = make_optimizer().update(grads, opt_state)
  params = optix.apply_updates(params, updates)

  return params, state, opt_state, loss

# Get training dataset
train_dataset = load(train_split, is_training=True, batch_size=train_batch_size)

### Define the evaluation

In [None]:
@jax.jit
def eval_batch(
    params: hk.Params,
    state: hk.State,
    batch: Batch,
) -> jnp.ndarray:
  """Evaluates a batch."""
  # The third parameter would be an rng key if one is needed in running the model,
  # e.g. for dropout. If not needed, pass `None`.
  logits, _ = net.apply(params, state, None, batch, is_training=False)
  predicted_label = jnp.argmax(logits, axis=-1)
  correct = jnp.sum(jnp.equal(predicted_label, batch['label']))
  return correct.astype(jnp.float32)

def evaluate(
    split: str,
    params: hk.Params,
    state: hk.State,
) -> Scalars:
  """Evaluates the model at the given params/state."""
  test_dataset = load(split, is_training=False, batch_size=eval_batch_size)
  correct = jnp.array(0)
  total = 0
  for eval_iter in range(num_eval_steps):
    correct += eval_batch(params, state, next(test_dataset))
    total += eval_batch_size
  return correct.item() / total

### Initialise the model and the optimiser

In [None]:
def make_initial_state(
    rng: jnp.ndarray,
    batch: Batch,
) -> Tuple[hk.Params, hk.State, OptState]:
  """Computes the initial network state."""
  params, state = net.init(rng, batch, is_training=True)
  opt_state = make_optimizer().init(params)
  return params, state, opt_state

# We need a random key for initialization
rng = jax.random.PRNGKey(train_init_random_seed)

# Initialization requires an example input to calculate shapes of parameters.
batch = next(train_dataset)
params, state, opt_state = make_initial_state(rng, batch)

### How many parameters in your model?

In [None]:
def get_num_params(params: hk.Params):
  num_params = 0
  for p in jax.tree_leaves(params): 
    # print(p.shape)
    num_params = num_params + jnp.prod(p.shape)
  return num_params
print('Total number of parameters %d' % get_num_params(params))

### Display input images and shapes

In [None]:
print (batch['image'].shape)
print (batch['label'].shape)
gallery(batch['image'], batch['label'])

### Run training loop and evaluation; full training gives accuracy ~89.1%

In [None]:
eval_every = train_eval_every
log_every = train_log_every

for step_num in range(int(num_train_steps)):
  # Take a training step.
  params, state, opt_state, train_loss = (
      train_step(params, state, opt_state, next(train_dataset)))

  # We run evaluation during training to see the progress.
  if eval_every > 0 and step_num % eval_every == 0:
    eval_acc = evaluate(eval_split, params, state)
    print('[Eval acc %s/%s] %s'%(step_num, int(num_train_steps), eval_acc))

  # Log progress at fixed intervals.
  if step_num % log_every == 0:
    print('[Train loss %s/%s] %s'%(step_num, int(num_train_steps), train_loss))

# Once training has finished we run eval one more time to get final results.
eval_acc = evaluate(eval_split, params, state)
print('[Eval acc FINAL]: %s'%(eval_acc))