## Part III: [Distilling the knowledge](https://arxiv.org/pdf/1503.02531.pdf) from a (larger) teacher model

- import an already trained baseline model
- add KL distillation loss between teacher and student
- train Mobilenet classifier with this joint loss


### Train Mobilenet with distillation loss, we use Kullback-Leibler (KL) divergence 

Define loss as
\begin{equation}
\mathcal{L} = \lambda \mathcal{L}_{\text{distill}} + \mathcal{L}_{\text{classif}}
\end{equation}

where
\begin{equation}
\mathcal{L}_{\text{distill}} = \text{KL}(\text{p}_{\text{teacher}}, \text{p}_{\text{student}}).
\end{equation}


Recall the definition of
$$\text{KL}(p||q) = \sum_{i=1}^{N}p(x_i) \cdot \log \frac{p(x_i)}{q(x_i)} . $$

The outputs of the networks are logits, which we interpret as probabilities when passed through softmax:

$$p_i^{(T)} =\frac{\exp{(\text{logits}_i / T) }}{\sum_j \exp{(\text{logits}_j / T) }} $$

where $T$ is a temperature and usually we set it to $1$. Setting it to a higher number smoothens the probability distribution. To be fully precise, we will use

\begin{equation}
\mathcal{L}_{\text{distill}} = \text{KL}(\text{p}_{\text{teacher}}^{(T)}, \text{p}_{\text{student}}^{(T)}),
\end{equation}

$$\lambda = T^2.$$






### Imports

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import time

import tensorflow as tf

# Don't forget to select GPU runtime environment in Runtime -> Change runtime type
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

# we will use Sonnet on top of TF 
!pip install -q dm-sonnet
import sonnet as snt

import numpy as np

# Plotting library.
from matplotlib import pyplot as plt
import pylab as pl
from IPython import display

Found GPU at: /device:GPU:0


In [0]:
# Reset graph
tf.reset_default_graph()

### Copy the pretrained weights of baseline model on the virtual machine
- you need to load all three files in the *baseline_weights* folder

In [0]:
from google.colab import files

uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

Saving baseline.ckpt.data-00000-of-00001 to baseline.ckpt (1).data-00000-of-00001


### Download dataset to be used for training and testing
- Cifar-10 equivalent of MNIST for natural RGB images
- 60000 32x32 colour images in 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
- train: 50000; test: 10000

In [0]:
cifar10 = tf.keras.datasets.cifar10
# (down)load dataset
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

### Prepare the data for training and testing
- for training, we use stochastic optimizers (e.g. SGD, Adam), so we need to sample at random mini-batches from the training dataset
- for testing, we iterate sequentially through the test set

In [0]:
# define dimension of the batches to sample from the datasets
BATCH_SIZE_TRAIN = 64 #@param
BATCH_SIZE_TEST = 100 #@param

# create Dataset objects using the data previously downloaded
dataset_train = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
# we shuffle the data and sample repeatedly batches for training
batched_dataset_train = dataset_train.shuffle(100000).repeat().batch(BATCH_SIZE_TRAIN)
# create iterator to retrieve batches
iterator_train = batched_dataset_train.make_one_shot_iterator()
# get a training batch of images and labels
(batch_train_images, batch_train_labels) = iterator_train.get_next()

In [0]:
# we do the same for test dataset
dataset_test = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
batched_dataset_test = dataset_test.repeat().batch(BATCH_SIZE_TEST)
iterator_test = batched_dataset_test.make_one_shot_iterator() 
(batch_test_images, batch_test_labels) = iterator_test.get_next()

In [0]:
# preprocess input for training and testing
def random_flip_left_right(image, flip_index, seed=None):
  shape = image.get_shape()
  if shape.ndims == 3 or shape.ndims is None:
    uniform_random = tf.random_uniform([], 0, 1.0, seed=seed)
    mirror_cond = tf.less(uniform_random, .5)
    result = tf.cond(
        mirror_cond,
        lambda: tf.reverse(image, [flip_index]),
        lambda: image
    )
    return fix_image_flip_shape(image, result)
  elif shape.ndims == 4:
    uniform_random = tf.random_uniform(
        [tf.shape(image)[0]], 0, 1.0, seed=seed
    )
    mirror_cond = tf.less(uniform_random, .5)
    return tf.where(
        mirror_cond,
        image,
        tf.map_fn(lambda x: tf.reverse(x, [flip_index]), image, dtype=image.dtype)
    )
  else:
    raise ValueError("\'image\' must have either 3 or 4 dimensions.")
    
def train_image_preprocess(h, w, random_flip=True):
  """Image processing required for training the model."""

  def fn(image):
    batch_size = image.get_shape().as_list()[0]
    # Ensure the data is in range [-1, 1].
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = image * 2.0 - 1.0
    # Randomly choose a (24, 24, 3) patch to be used for training.
    image = tf.random_crop(image, size=(BATCH_SIZE_TRAIN, h, w, 3))
    # Randomly flip the image.
    image = random_flip_left_right(image, 2)
    return image

  return fn

def test_image_preprocess():
  def fn(image):
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = image * 2.0 - 1.0
    return image
  return fn

### Teacher model is baseline

In [0]:
class Baseline(snt.AbstractModule):
  
  def __init__(self, num_classes, name="baseline"):
    super(Baseline, self).__init__(name=name)
    self._num_classes = num_classes
    self._output_channels = [
        64, 64, 128, 128, 128, 256, 256, 256, 512, 512, 512
        ]
    self._num_layers = len(self._output_channels)

    self._kernel_shapes = [[3, 3]] * self._num_layers  # All kernels are 3x3.
    self._strides = [1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1]
    self._paddings = [snt.SAME] * self._num_layers
   
  def _build(self, inputs, is_training=None, test_local_stats=False):
    net = inputs
    # instantiate all the convolutional layers 
    layers = [snt.Conv2D(name="conv_2d_{}".format(i),
                         output_channels=self._output_channels[i],
                         kernel_shape=self._kernel_shapes[i],
                         stride=self._strides[i],
                         padding=self._paddings[i],
                         use_bias=True) for i in xrange(self._num_layers)]
    # connect them to the graph, adding batch norm and non-linearity
    for i, layer in enumerate(layers):
      net = layer(net)
      bn = snt.BatchNorm(name="batch_norm_{}".format(i))
      net = bn(net, is_training=is_training, test_local_stats=test_local_stats)
      net = tf.nn.relu(net)

    net = tf.reduce_mean(net, reduction_indices=[1, 2], keepdims=False,
                         name="avg_pool")

    logits = snt.Linear(self._num_classes)(net)

    return logits

### Student model is Mobilenet

In [0]:
class Mobilenet(snt.AbstractModule):
  
  def __init__(self, num_classes, name="mobilenet"):
    super(Mobilenet, self).__init__(name=name)
    self._num_classes = num_classes
    self._channel_multipliers = [
        0, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1
    ]
    self._output_channels = [
        64, 64, 128, 128, 128, 256, 256, 256, 512, 512, 512
    ]
    self._num_layers = len(self._output_channels)

    self._kernel_shapes = [[3, 3]] * self._num_layers  # All kernels are 3x3.
    self._strides = [1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1]
    self._paddings = [snt.SAME] * self._num_layers
   
  def _build(self, inputs, is_training=None, test_local_stats=False):
    net = inputs
    # instantiate all the convolutional layers
    first_conv = snt.Conv2D(name="conv_2d_0",
                            output_channels=self._output_channels[0],
                            kernel_shape=self._kernel_shapes[0],
                            stride=self._strides[0],
                            padding=self._paddings[0],
                            use_bias=True)
    
    # instantiate depthwise conv layers
    conv_layers_dw = [snt.DepthwiseConv2D(name="conv_dw_2d_{}".format(i),
                                          channel_multiplier=self._channel_multipliers[i],
                                          kernel_shape=self._kernel_shapes[i],
                                          stride=self._strides[i],
                                          padding=self._paddings[i],
                                          use_bias=True)
                      for i in xrange(1, self._num_layers)]
    
    # instantiate 1x1 conv layers
    conv_layers_1x1 = [snt.Conv2D(name="conv_1x1_2d_{}".format(i),
                                  output_channels=self._output_channels[i],
                                  kernel_shape=(1, 1),
                                  stride=self._strides[i],
                                  padding=self._paddings[i],
                                  use_bias=True)
                       for i in xrange(1, self._num_layers)]
    # connect first layer to the graph, adding batch norm and non-linearity
    net = first_conv(net)
    bn = snt.BatchNorm(name="batch_norm_0")
    net = bn(net, is_training=is_training, test_local_stats=test_local_stats)
    net = tf.nn.relu(net)
    
    # connect the rest of the layers
    for i, (layer_dw, layer_1x1) in enumerate(zip(conv_layers_dw, conv_layers_1x1)):
      net = layer_dw(net)
      bn = snt.BatchNorm(name="batch_norm_{}_0".format(i))
      net = bn(net, is_training=is_training, test_local_stats=test_local_stats)
      net = tf.nn.relu(net)
      net = layer_1x1(net)
      bn = snt.BatchNorm(name="batch_norm_{}_1".format(i))
      net = bn(net, is_training=is_training, test_local_stats=test_local_stats)
      net = tf.nn.relu(net)      

    net = tf.reduce_mean(net, reduction_indices=[1, 2], keepdims=False,
                         name="avg_pool")

    logits = snt.Linear(self._num_classes)(net)

    return logits

In [0]:
# First define the preprocessing ops for the train/test data
crop_height = 24 #@param
cropt_width = 24 #@param
preprocess_fn_train = train_image_preprocess(crop_height, cropt_width)
preprocess_fn_test = test_image_preprocess()

num_classes = 10 #@param

In [0]:
# for evaluation, we look at top_k_accuracy since it's easier to interpret; normally k=1 or k=5
def top_k_accuracy(k, labels, logits):
  in_top_k = tf.nn.in_top_k(predictions=tf.squeeze(logits), targets=tf.squeeze(tf.cast(labels, tf.int32)), k=k)
  return tf.reduce_mean(tf.cast(in_top_k, tf.float32))

In [0]:
# Define number of training iterations and reporting intervals
TRAIN_ITERS = 90e3 #@param
REPORT_TRAIN_EVERY = 100 #@param
PLOT_EVERY = 500 #@param
REPORT_TEST_EVERY = 1000 #@param
TEST_ITERS = 10 #@param
lr_init = 0.01 #@param
display_inputs = False #@param

class_mapping = [u'airplane', u'automobile', u'bird', u'cat', u'deer', u'dog', u'frog', u'horse', u'ship', u'truck']

### Instantiate teacher and load pre-trained weights


In [0]:
with tf.variable_scope("teacher"):
  teacher_model = Baseline(num_classes)
predictions_teacher = teacher_model(preprocess_fn_train(batch_train_images), is_training=False)


We do not want to alter the teacher weights

In [0]:
predictions_teacher = tf.stop_gradient(predictions_teacher)

In [0]:
var_list = snt.get_variables_in_scope("teacher", collection=tf.GraphKeys.GLOBAL_VARIABLES)  

var_map = {}
for i in range(0, len(var_list)):
  name = var_list[i].name[len("teacher/"):-2]
  var_map[name] = var_list[i]

saver = tf.train.Saver(var_map, reshape=True)

### Instantiate student

In [0]:
with tf.variable_scope("student"):
  student_model = Mobilenet(num_classes=num_classes)
# get predictions from the model
predictions_student = student_model(preprocess_fn_train(batch_train_images), is_training=True)
test_predictions_student = student_model(preprocess_fn_test(batch_test_images), is_training=False)

### For distillation, we use softmax with higher temperature. Normally T = 1; for distillation we use T>1.

\begin{equation}
q_i = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}
\end{equation}

In [0]:
# vizualise how the softmax temperature influences the output of the teacher
softmax_temp_distill = 5.0   # 
softmax_temp_normal = 1.0 # 
logits_high_temp = tf.nn.softmax(tf.div(predictions_teacher, softmax_temp_distill)) 
logits_low_temp = tf.nn.softmax(tf.div(predictions_teacher, softmax_temp_normal))

### Set up the training for Mobilenet, adding the distillation loss weighted by the square of temperature
- the gradient varies with the inverse of square of temperature

In [0]:
lambda_ = softmax_temp_distill * softmax_temp_distill

**Define the classification loss**

In [0]:

###################
#                 # 
# YOUR CODE       #
# train_loss = ...# 
#                 #
###################

**Define the distillation loss**

You may do this either with

* `tf.distributions.kl_divergence` between distributions
* `softmax_cross_entropy_with_logits`. Remember that the labels are expected to sum to 1, while the output of the teacher network is logits.

In [0]:

########################
#                      # 
# YOUR CODE            #
# distill_kl_loss = ...# 
#                      #
########################

**Define the joint training loss**

In [0]:

###################
#                 # 
# YOUR CODE       #
# train_loss = ...# 
#                 #
###################

### Create the training ops

Make sure Batch Norm moving averages get updated - run UPDATE_OPS.

In [0]:
def get_optimizer(step):
  """Get the optimizer used for training."""
  lr_schedule = (40e3, 60e3, 80e3)
  lr_schedule = tf.to_int64(lr_schedule)
  lr_factor = 0.1
  
  lr_init = 0.1
  num_epochs = tf.reduce_sum(tf.to_float(step >= lr_schedule))
  lr = lr_init * lr_factor**num_epochs

  return tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)

In [0]:
# Create a global step that is incremented during training; useful for e.g. learning rate annealing
global_step = tf.train.get_or_create_global_step()

# instantiate the optimizer
optimizer = get_optimizer(global_step)

In [0]:
# Get training and test ops
training_op = optimizer.minimize(train_loss, global_step)
update_ops = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS))
training_op = tf.group(training_op, update_ops)

### Teacher and student accuracy

In [0]:
test_acc = top_k_accuracy(1, batch_test_labels, test_predictions_student)
acc_teacher = top_k_accuracy(1, batch_train_labels, predictions_teacher) 

In [0]:
# Create the session and initialize variables
sess = tf.Session()
sess.run(tf.global_variables_initializer())

### Load pre-trained weights for teacher, and check accuracy to make sure the import was successful

In [0]:
saver.restore(sess, "baseline.ckpt")
 
test_batch_size = 100
num_batches = 100  # 100 batches * 100 samples per batch = 10000

avg_accuracy = 0.

###################
#                 # 
# YOUR CODE       #
#                 #
###################

# Check if import was done correctly by running eval on cifar train set
# expected_accuracy ~ 0.94
print ("Teacher accuracy {:.3f}".format(avg_accuracy))

INFO:tensorflow:Restoring parameters from baseline.ckpt


InternalError: ignored

### Visualize the impact of temperature on the logits

In [0]:
logits_ht, logits_lt, gt = sess.run([logits_high_temp, logits_low_temp, tf.one_hot(batch_train_labels, num_classes)])
# pick one sample and plot
idx = 33
plt.plot(logits_ht[idx], c='r', label='High Temp')
plt.plot(logits_lt[idx], c='g', label='Low Temp')
plt.plot(gt[idx,0], 'b--', label='GT')
plt.xlim([0,9])
plt.legend()
plt.show()

In [0]:
# Write a function that takes a list of losses and plots them.
def plot_losses(loss_list, steps):
  display.clear_output(wait=True)
  display.display(pl.gcf())
  pl.plot(steps, loss_list, c='b')
  time.sleep(1.0)

### Train the model

In [0]:
train_iter = 0
losses = []
steps = []
for train_iter in range(int(TRAIN_ITERS)):
  _, train_loss_np = sess.run([training_op, train_loss])
  
  if (train_iter % REPORT_TRAIN_EVERY) == 0:
    losses.append(train_loss_np)
    steps.append(train_iter)
  if (train_iter % PLOT_EVERY) == 0:
    plot_losses(losses, steps)    
    
  if (train_iter % REPORT_TEST_EVERY) == 0:
    avg_acc = 0.0
    for test_iter in range(TEST_ITERS):
      acc = sess.run(test_acc)
      avg_acc += acc
      
    avg_acc /= (TEST_ITERS)
    print ('Test acc at iter {0:5d} out of {1:5d} is {2:.2f}%'.format(int(train_iter), int(TRAIN_ITERS), avg_acc*100.0))