## EEML2019: ConvNets and Computer Vision Tutorial (PART I)

### Supervised classification, overfitting and inductive biases in convnets, and how to improve models through self-supervision

* Exercise 1: Implement and train a Resnet-50 classifier using supervised learning; enable/disable batch norm updates to see the effect.
* Exercise 2: Inductive biases in convnets; comparison with MLP.
* Exercise 3: Overfitting and regularization using weight decay.
* Exercise 4: Enable self-supervised learning using data augmentation.

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import time

import tensorflow as tf

# Don't forget to select GPU runtime environment in Runtime -> Change runtime type
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

import numpy as np

# Plotting library.
from matplotlib import pyplot as plt
import pylab as pl
from IPython import display

import collections
import enum
import warnings
warnings.filterwarnings('ignore')

In [0]:
# Reset graph
tf.reset_default_graph()

## Download dataset to be used for training and testing
* Cifar-10 equivalent of MNIST for natural RGB images

* 60000 32x32 colour images in 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck

* train: 50000; test: 10000

In [0]:
cifar10 = tf.keras.datasets.cifar10
# (down)load dataset
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

# Check sizes of tensors
print ('Size of training images')
print (train_images.shape)
print ('Size of training labels')
print (train_labels.shape)
print ('Size of test images')
print (test_images.shape)
print ('Size of test labels')
print (test_labels.shape)

assert train_images.shape[0] == train_labels.shape[0]

## Display the images
The gallery function below shows sample images from the data, together with their labels.

In [0]:
MAX_IMAGES = 10
def gallery(images, label, title='Input images'):  
  class_dict = [u'airplane', u'automobile', u'bird', u'cat', u'deer', u'dog', u'frog', u'horse', u'ship', u'truck']
  num_frames, h, w, num_channels = images.shape
  num_frames = min(num_frames, MAX_IMAGES)
  ff, axes = plt.subplots(1, num_frames,
                          figsize=(num_frames, 1),
                          subplot_kw={'xticks': [], 'yticks': []})
  for i in range(0, num_frames):
    if num_channels == 3:
      axes[i].imshow(np.squeeze(images[i]))
    else:
      axes[i].imshow(np.squeeze(images[i]), cmap='gray')
    axes[i].set_title(class_dict[label[i][0]])
    plt.setp(axes[i].get_xticklabels(), visible=False)
    plt.setp(axes[i].get_yticklabels(), visible=False)
  ff.subplots_adjust(wspace=0.1)
  plt.show()

In [0]:
gallery(train_images, train_labels)

## Prepare the data for training and testing
* for training, we use stochastic optimizers (e.g. SGD, Adam), so we need to sample at random mini-batches from the training dataset
* for testing, we iterate sequentially through the test set

In [0]:
# define dimension of the batches to sample from the datasets
BATCH_SIZE_TRAIN = 100 #@param
BATCH_SIZE_TEST = 100 #@param

# create Dataset objects using the data previously downloaded
dataset_train = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
# we shuffle the data and sample repeatedly batches for training
batched_dataset_train = dataset_train.shuffle(100000).repeat().batch(BATCH_SIZE_TRAIN)
# create iterator to retrieve batches
iterator_train = batched_dataset_train.make_one_shot_iterator()
# get a training batch of images and labels
(batch_train_images, batch_train_labels) = iterator_train.get_next()

# check that the shape of the training batches is the expected one
print ('Shape of training images')
print (batch_train_images)
print ('Shape of training labels')
print (batch_train_labels)

In [0]:
# we do the same for test dataset
dataset_test = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
batched_dataset_test = dataset_test.repeat().batch(BATCH_SIZE_TEST)
iterator_test = batched_dataset_test.make_one_shot_iterator() 
(batch_test_images, batch_test_labels) = iterator_test.get_next()
print ('Shape of test images')
print (batch_test_images)
print ('Shape of test labels')
print (batch_test_labels)

In [0]:
# Squeeze labels and convert from uint8 to int32 - required below by the loss op
batch_test_labels = tf.cast(tf.squeeze(batch_test_labels), tf.int32)
batch_train_labels = tf.cast(tf.squeeze(batch_train_labels), tf.int32)

## General setting; use the options below to switch between exercises.

In [0]:
model = "mlp" #@param['resnet_v2','mlp']
flag_batch_norm = 'OFF' #@param['ON', 'OFF']
flag_permute = "False" #@param['True', 'False']
flag_regularize = "False" #@param['True', 'False']
flag_selfsup = "False" #@param['True', 'False']

## Preprocess input for training and testing

In [0]:
#@title 32x32 permutation list (for Exercise 3)
def get_permutation_cifar10():
  p = tf.constant([ 273,  746,  984,  197,  597,  519,  757,  113, 1009,  470,  321,
        552,  585,  246,  229,  569,  773,    6,  955,  379,  847,  548,
        148,  503,   27,  132, 1014,   82,  101,  260,  923,   53,  842,
        635,  203,  912,  439,  487,  162,  832,  395,  311,  593,   33,
        988,  856,  183,  215,  264,  699,  826,  692,  560,  124,  126,
        948,  708,   41,  368,  484,  467,  267,  731,   17,   73,   14,
        521,  240,  296,  846,   43,  779,  629,  640,  874,  268,  150,
        586,   56,  756,  769,  808,  382,  354,   95,  283,  900,  970,
        775,  432,  178,  998,  271,  118,  563,  445,  946,  261,  518,
        723,  725,  449,  595,  617,  935,  607,  400,  697,   96,  786,
        656,  138,  343,  653,  175,  993,  433,  681,  654,  574,  322,
        918,  831,  754,  381,   12,  797,  338,  182,  695,  829,  927,
        556, 1008,  491,  512,  717,  224,  303,  496, 1002,  693,  140,
        599,  332,  758,  465,   80,  501,  690, 1022,  422,  211,  331,
        926,  849,  313,  583,  128,  776,  168,  463,  201, 1018,  334,
        228,  643,  514,  934,  106,  799,  713,  507,  543,  703,  299,
         74,  263,  710,  622,  486,  344,  210,  687,  537,  362,  858,
        898,  688,  320,   91,  403,  778,  534,  674,  783,  284,  968,
        807,  724,  549,  184,  605,   15,    8,  535,   85,  495,  774,
        701,  848,  416,  642,  553,  897,  989,  370,  942,  571,  489,
        891,  388,  171,  761,  475,  844,  397,  227,  753,  278,  855,
        938,  794,  155,  748, 1017,  941,  745,  437,  414,  181,  759,
        234,  143,  554,  762,   46,  476,  417,  911, 1007,  882,  716,
        336,  117,   47,  977,  602,  837,  525,  880,  718,  660,  760,
        451,  142,  609,  405,  455,  315,  394,  987,   36,  389,  719,
        715,  386,  393,  446,  109,  658,  612,  685,  577, 1015,  967,
        641,  770,  510,  704,  793,  892,  275,  904,  335,  893,  259,
        307,  903,  985,  435,  712,  326,  232, 1021,  237,  827, 1005,
        172, 1000,  675,   84,  670,  963,  434,  485,   68,  677,  415,
        492,  947,  859,  732,  810,  366,  557,   22,  824,  765,  722,
        902,  950,  579,  288,  308,   48,  198, 1003,  481,  604,  139,
       1001,  647,  115,  618,  243,  466,  107,  795,  440,  152,  885,
        200,  230,   83,  821,  755,   10,  144,  749,  528,  494,  199,
        546,  282,  921,  223,  828,  962,  346,  925,  352,  421, 1023,
        763,  424,  894,  328,  290,   13,   62,  129,  156,  820,  436,
        871,  252,  359,  538,   35,  459,  226,  657,  401,  191,  483,
        187,  242,  680,  646,  473,  802,    4,  581,  130,  666,  709,
        889,    7,  864,  236,  991,  450,  532,  667,   70, 1011,  410,
        907,  266,  914,  189,  943,  796,  649,  990,  257,  937,  700,
        500,  188,  813,  809,  634,  789,   25,  517,  573,  104,  387,
        673,  966,  638,  845,  540, 1006,  910,  249,  610,  110,  480,
        663,    9,  225,  339,  398,  976,  131,  372,  628,  875,  174,
        488,  908,   79,  766,  310,  468,  691,  425,  289,  616,  309,
        915,  570,  636,  768,  591,  956,  464,  412,  120,  958,  939,
        782,  652,  541,  971,  100,  280,  721,  423,  430,  442,  506,
        160,  502,  333,  615,  399,   57,  250,  384,  959, 1012,   71,
        103,  429,  411,   59,  862,  887,  980,  529,  630,  444,  785,
        625,  916,  883,  901,  983,  852,  179,  747,  801,  218,  627,
        408,  443,  830,  305,  733,  509,  274,  682,  884,  584,  358,
        536,  739,  369,  933,  221,  247,  676,  982,  206,    1,  438,
        265,  954,  866,  672,  287,   26,   39,  606,  479,  102,  291,
         88,  205,   61,  931,  127,  351, 1004,  477,  655,  865,  355,
         67,   37,  735,  458,  454,  737,  873,  909,  173,  231,  158,
        555,  825,  945,  930,  337,  644,  505,  233,  730,  431, 1020,
        530,  580,  312,  720,  441,  550,  952,  367,  513,   50,  371,
         34,   45,  705,  153,  122,  209,   51,  870,  216,  185,  611,
        327,  815,  899,  603,  428, 1010,   42,  669, 1019,  601,  788,
        620,  771,  886,  116,  293,  986,  363,  834,  881,   81,   90,
        474,   94,  302,   31,  863,  317,  619,  471,   86,  869,   64,
        994,  683,   20,  330, 1013,  472,  650,  714,  380,  812,  853,
        196,  272,  736,  349,   75,  169,   28,  340,  163,  151,  979,
        798,  582,  559,  376,   24,  539,  818,  176,  207,  992,  600,
        975,  767,  867,  592,  978,  726,  277,  511,   98,  850,  498,
        668,  298,  292,  792,  523,  598,  742,  623,  426,  841,  361,
        121,  157,  964,  146,  490,  791,  780,  360,  679,   38,  222,
        419,  192,  587,   30,   77,  702,  235,  953,  997,  318,  751,
          2,  396,  542,  661,  499,   29,   69,  180,  621,  217,  588,
        972,   58,   60,  164,  840,  772,  545,  452,  170,  951,  752,
        281,  478,  711,  648,  575,  787,  213,  345,   19,  803,  190,
        527,  508,  149,  323,  624,  404,  817,  895,  420,  256,  413,
        626,  134,  390,  614,  342,  565,  238,  949,  241,  781,  590,
        533,  659,  365,  561,  112,  248,  357,  566,  407,  253,  913,
        461,  957,  932,  594,  255,  406,  784,  750,    3,  356,  141,
         97,   92,  919,  522,  734,  325,   54,  877,  738,  456,  133,
        917,  374,   66,  729,  835,  114,  833,  214,  504,  383,  631,
        347,  686,  905,  578,  613,  239,  806,  645,  790,  764,  427,
        651,  568,   87,  119,   63,   65,  202,  890,  940,  928,  286,
        409,  662,  551,   49,  251,  572,  632,    5,  524,  515,  888,
        608,  208,  329,   18,  516,  350,  295,  448,  385,  678,  936,
        896,  258,  204,  276,  177,  854,   72,  341,   16,  974,  836,
        851,  497,  316,  805,  262,  544,  981,  838,  843,  526,  707,
        348,  254,  447,  520,  453,  270,  304,  558,  462,  418,  279,
         99,  353,  314,  306,  564,  219,  167,  186,  297,  706,    0,
        804,   89,  878,   11,  816,  402,  868,  531,   78,  728,  373,
        562,  684,  944,  860,  876,  194,  195,  906,  973,  294,  960,
        567,  698,  378,  589,   40,  220,  493,  460,  929,  861,  823,
         76,  105, 1016,  839,  639,  324,  166,  740,   23,   52,  161,
        319,  996,  392,  135,  111,  391,  547,  145,  961,  999,  123,
        744,  364,  147,  469,  811,  125,  159,  664,  965,   93,  727,
        245,  814,  696,  377,   21,  665,  694,  920,  857,   55,  879,
        269,  285,  671,  165,  924,  193,  244,  969,  800,  457,  922,
        741,  375,  995,  482,  576,  108,  743,  689,  300,   44,   32,
        136,  872,  596,  637,  137,  819,  154,  633,  777,  301,  212,
        822])

  return p

In [0]:
# Data augmentation
# - scale image to [-1 , 1]
# - during training: apply horizontal flip randomly
# - random crop after padding
# - apply optional data augmentation (permutation, rotation)

def train_image_preprocess(h, w, num_transf=None):
  def fn(image):
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = image * 2 - 1
    image = tf.image.random_flip_left_right(image)
    # Data augmentation: pad images and randomly sample a (h, w) patch.
    image = tf.pad(image, [[0, 0], [4, 4], [4, 4], [0, 0]], mode='REFLECT')
    image = tf.random_crop(image, size=(BATCH_SIZE_TRAIN, h, w, 3))

    # Exercise 2: permuted Cifar10; scramble the images using a fixed permutation
    if flag_permute:
      ################
      # YOUR CODE HERE  reshape image 2D -> 1D array using tf.reshape
      p = get_permutation_cifar10()
      ################
      # YOUR CODE HERE  permute pixels according to permutation p using tf.gather(..., axis=1). We have batches 
      # YOUR CODE HERE  reshape to original shape

    # Exercise 4: data augmentation as self-supervision signal; for every image 
    # in the batch, sample uniformly at random a transformation (rotation), 
    # and apply it to the image while returning the id of the transformation
    label_transf = []
    if flag_selfsup and num_transf:
      list_img = []
      for i in xrange(BATCH_SIZE_TRAIN):
        ################
        # YOUR CODE HERE get a transformation label_ = tf.random.uniform...
        # YOUR CODE HERE apply transformation img = tf.image.rot90(image[i], k=label_)
        label_transf.append(label_)
        list_img.append(img)
      image = tf.stack(list_img, axis=0)
      label_transf = tf.stack(label_transf, axis=0)
    return image, label_transf
  return fn

def test_image_preprocess():
  def fn(image):
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = image * 2 - 1
    if flag_permute:
      ################
      # YOUR CODE HERE do the same as for training
      p = get_permutation_cifar10()
    else:
      sh = image.get_shape()
      image = tf.reshape(image, [BATCH_SIZE_TEST, sh[1], sh[2], sh[3]])
    return image
  return fn

## Define the model

In [0]:
# define parameters of resnet blocks for two resnet models
ResNetBlockParams = collections.namedtuple(
    "ResNetBlockParams", ["output_channels", "bottleneck_channels", "stride"])

BLOCKS_50 = (
    (ResNetBlockParams(256, 64, 1),) * 2 + (ResNetBlockParams(256, 64, 2),),
    (ResNetBlockParams(512, 128, 1),) * 3 + (ResNetBlockParams(512, 128, 2),),
    (ResNetBlockParams(1024, 256, 1),) * 5 + (ResNetBlockParams(1024, 256, 2),),
    (ResNetBlockParams(2048, 512, 1),) * 3)

In [0]:
#@title Utils

def _fixed_padding(inputs, kernel_size):
  """Pads the input along the spatial dimensions."""
  pad_total = kernel_size - 1
  pad_begin = pad_total // 2
  pad_end = pad_total - pad_begin
  padded_inputs = tf.pad(inputs, [[0, 0], [pad_begin, pad_end],
                                  [pad_begin, pad_end], [0, 0]])
  return padded_inputs

def _max_pool2d_same(inputs, kernel_size, stride, padding):
  """Strided 2-D max-pooling with fixed padding. 
  When padding='SAME' and stride > 1, we do fixed zero-padding followed by 
  max_pool2d with 'VALID' padding."""

  if padding == "SAME" and stride > 1:
    padding = "VALID"
    inputs = _fixed_padding(inputs, kernel_size)
  return tf.layers.MaxPooling2D(kernel_size, strides=stride, padding=padding)(inputs)

def _conv2d_same(inputs, num_outputs, kernel_size, stride, use_bias=False,
                 name="conv_2d_same"):
  """Strided 2-D convolution with 'SAME' padding. If stride > 1, we do fixed
  zero-padding, followed by conv2d with 'VALID' padding."""
  if stride == 1:
    padding = "SAME"
  else:
    padding = "VALID"
    inputs = _fixed_padding(inputs, kernel_size)

  return tf.layers.Conv2D(num_outputs, kernel_size, strides=stride,
                          padding=padding, use_bias=use_bias, name=name)(inputs)

### [Resnet Block V2](https://arxiv.org/pdf/1603.05027.pdf) 

![alt text](https://github.com/eeml2019/PracticalSessions/blob/master/assets/bottleneck.png?raw=true)


In [0]:
# Exercise 1: define resnet block v2
def resnet_block(inputs, output_channels, bottleneck_channels, stride,
                 training=None, name="resnet_block"):
  """Create a resnet block."""
  num_input_channels = inputs.get_shape()[-1]
  batch_norm_args = {
      "training": training
      }
  # ResNet V2 uses pre-activation, where the batch norm and relu are before
  # convolutions, rather than after as in ResNet V1.
  preact = tf.layers.BatchNormalization(name=name+"/bn_preact")(inputs,
                                                                **batch_norm_args)
  preact = tf.nn.relu(preact)

  if output_channels == num_input_channels:
    # Use subsampling to match output size.
    # Note we always use `inputs` in this case, not `preact`.
    if stride == 1:
      shortcut = inputs
    else:
      shortcut = _max_pool2d_same(inputs, 1, stride=stride, padding="SAME")
  else:
    # Use 1x1 convolution shortcut to increase channels to `output_channels`.
    ################
    # YOUR CODE HERE shortcut = tf.layers.Conv2D

  # add the 3 residual subunits: conv + batchnorm + relu
  # subunit 1
  ################
  # YOUR CODE HERE residual = tf.layers.Conv2D...
  # YOUR CODE HERE residual = tf.layers.BatchNormalization...
  # YOUR CODE HERE residual = tf.nn.relu...
  
  # subunit 2
  ################
  # YOUR CODE HERE residual = _conv2d_same
  # YOUR CODE HERE residual = tf.layers.BatchNormalization...
  # YOUR CODE HERE residual = tf.nn.relu...

  # subunit 3
  ################
  # YOUR CODE HERE residual = tf.layers.Conv2D...
  # YOUR CODE HERE residual = tf.layers.BatchNormalization...
  # YOUR CODE HERE residual = tf.nn.relu...
  
  # add residual to shortcut
  output = shortcut + residual

  return output

In [0]:
# stack resnet blocks
def _build_resnet_blocks(inputs, blocks, batch_norm_args):
  """Connects the resnet block into the graph."""
  outputs = []

  for num, subblocks in enumerate(blocks):
    with tf.variable_scope("block_{}".format(num)):
      for i, block in enumerate(subblocks):
        args = {
            "name": "resnet_block_{}".format(i)
        }
        args.update(block._asdict())
        args.update(batch_norm_args)
        inputs = resnet_block(inputs, **args)
        outputs += [inputs]

  return outputs

In [0]:
# define full architecture: input convs, resnet blocks, output classifier
def resnet_v2(inputs, blocks, is_training=True,
              num_classes=10, num_transf=None, use_global_pool=True, 
              name="resnet_v2"):
  """ResNet V2."""
  blocks = tuple(blocks)

  batch_norm_args = {
      "training": is_training
  }

  with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
    # Add initial non-resnet conv layer and max_pool
    inputs = _conv2d_same(inputs, 64, 7, stride=2, name="root")
    inputs = _max_pool2d_same(inputs, 3, stride=2, padding="SAME")

    # Stack resnet blocks
    resnet_outputs = _build_resnet_blocks(inputs, blocks, batch_norm_args)
    # Take the activations of the last resnet block.
    inputs = resnet_outputs[-1]
    inputs = tf.layers.BatchNormalization(name="bn_postnorm")(inputs,
                                                              **batch_norm_args)
    inputs = tf.nn.relu(inputs)
    if use_global_pool:
      inputs = tf.reduce_mean(inputs, [1, 2], name="use_global_pool",
                              keepdims=True)

    # Add output classifier
    logits = tf.layers.Conv2D(num_classes, 1, name="logits")(inputs)
    logits = tf.squeeze(inputs, axis=[1, 2])
    
    # Add second head for transformation prediction
    logits_transf = None
    if num_transf and flag_selfsup:
      pass
      ################
      # YOUR CODE HERE Exercise 4
      # YOUR CODE HERE logits_transf = tf.layers.Conv2D...
      # YOUR CODE HERE logits_transf = tf.squeeze...

  return (logits, logits_transf)

## Define simple MLP baseline

In [0]:
def mlp(inputs, num_classes=10, num_transf=None, is_training=True, name="mlp"):
  batch_norm_args = {
      "training": is_training
  }
  with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
    bs = inputs.get_shape().as_list()[0]
    inputs = tf.reshape(inputs, [bs, -1])
    net = tf.layers.dense(inputs, 1024)
    net = tf.nn.relu(net)
    net = tf.layers.BatchNormalization(name="bn_postnorm1")(inputs, **batch_norm_args)
    net = tf.layers.dense(net, 1024)
    net = tf.nn.relu(net)
    net = tf.layers.BatchNormalization(name="bn_postnorm2")(net, **batch_norm_args)
    logits = tf.layers.dense(net, num_classes, name="logits")
    logits_transf = None
    if num_transf:
      logits_transf = tf.layers.dense(net, num_transf, name="logits_transf")
    return logits, logits_transf

## Set up training pipeline

In [0]:
# First define the preprocessing ops for the train/test data
crop_height = 32 #@param 
crop_width = 32 #@param
# NUM_TRANSF can be None or 4 corresponding to 4 rotations (0, 90, 180, 270)
NUM_TRANSF = 4 #@param  
preprocess_fn_train = train_image_preprocess(crop_height, crop_width, NUM_TRANSF)
preprocess_fn_test = test_image_preprocess()
NUM_CLASSES = 10 #@param

### Get predictions from either MLP baseline or convnet

In [0]:
blocks = BLOCKS_50
inp_train, labels_selfsup = preprocess_fn_train(batch_train_images)
inp_test = preprocess_fn_test(batch_test_images)

if model == 'mlp':
  train_predictions, logits_selfsup = mlp(inp_train, num_classes=NUM_CLASSES,
                                          num_transf=NUM_TRANSF, is_training=True)
  test_predictions, _ = mlp(inp_test, num_classes=NUM_CLASSES,
                            num_transf=NUM_TRANSF, is_training=False)
else:  # model is resnet_v2
  train_predictions, logits_selfsup = resnet_v2(inp_train, blocks,
                                                num_classes=NUM_CLASSES,
                                                num_transf=NUM_TRANSF,
                                                is_training=True)
  test_predictions, _ = resnet_v2(inp_test, blocks, 
                                  num_classes=NUM_CLASSES,
                                  num_transf=NUM_TRANSF, is_training=False)
print(train_predictions)
print(logits_selfsup)
print(test_predictions)

In [0]:
# Get number of parameters in a scope by iterating through the trainable variables
def get_num_params(scope):
  total_parameters = 0
  for variable in tf.trainable_variables(scope):
    # shape is an array of tf.Dimension
    shape = variable.get_shape()
    variable_parameters = 1
    for dim in shape:
      variable_parameters *= dim.value
    total_parameters += variable_parameters
  return total_parameters

In [0]:
# Get number of parameters in the model.
print ("Total number of parameters of models")
print (get_num_params("resnet_v2"))
print (get_num_params("mlp"))

In [0]:
# classification loss using cross entropy
def classification_loss(logits=None, labels=None):
  # We reduce over batch dimension, to ensure the loss is a scalar.   
  return tf.reduce_mean(
      tf.nn.sparse_softmax_cross_entropy_with_logits(
          labels=labels, logits=logits))

In [0]:
# l2 regularization on the weights
def regularization_loss(l2_regularization=1e-4):
  """Provides regularization loss if it is enabled."""

  if tf.trainable_variables() and (l2_regularization > 0):
    l2_reg = tf.contrib.layers.l2_regularizer(l2_regularization)
    reg_losses = map(l2_reg, tf.trainable_variables())
    return tf.add_n(reg_losses, name='regularization_loss')
  else:
    return tf.constant(0.)

In [0]:
# Define train and test loss functions
train_loss = classification_loss(labels=batch_train_labels, logits=train_predictions)
test_loss = classification_loss(labels=batch_test_labels, logits=test_predictions)

In [0]:
# Exercise 3 - Add regularization
if flag_regularize is True:
  pass
  ################
  # YOUR CODE HERE train_loss += ...

In [0]:
# Exercise 4: Add auxiliary loss for self-supervised learning; you can use the same classification_loss fn defined above
if flag_selfsup:
  pass
  ################
  # YOUR CODE HERE train_loss += ...

In [0]:
# For evaluation, we look at top_k_accuracy since it's easier to interpret; normally k=1 or k=5
def top_k_accuracy(k, labels, logits):
  in_top_k = tf.nn.in_top_k(predictions=tf.squeeze(logits), targets=labels, k=k)
  return tf.reduce_mean(tf.cast(in_top_k, tf.float32))

In [0]:
def get_optimizer(step):
  """Get the optimizer used for training."""
  lr_init = 0.01  # initial value for the learning rate
  lr_schedule = (90e3, 100e3, 110e3) # after how many iterations to reduce the learning rate
  lr_schedule = tf.cast(lr_schedule, tf.int64)
  lr_factor = 0.1 # reduce learning rate by this factor
  num_epochs = tf.reduce_sum(tf.cast(step >= lr_schedule, tf.float32))
  lr = lr_init * lr_factor**num_epochs

  return tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9)

In [0]:
# Create a global step that is incremented during training; useful for e.g. learning rate annealing
global_step = tf.train.get_or_create_global_step()

# instantiate the optimizer
optimizer = get_optimizer(global_step)

In [0]:
# Get training ops
training_op = optimizer.minimize(train_loss, global_step)

if flag_batch_norm == 'ON':
  # Retrieve the update ops, which contain the moving average ops
  update_ops = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS))

  # Manually add the update ops to the dependency path executed at each training iteration
  training_op = tf.group(training_op, update_ops)

In [0]:
# Get test ops
test_acc_op = top_k_accuracy(1, batch_test_labels, test_predictions)
train_acc_op = top_k_accuracy(1, batch_train_labels, train_predictions)

In [0]:
# Function that takes a list of losses and plots them.
def plot_losses(loss_list, steps):
  display.clear_output(wait=True)
  display.display(pl.gcf())
  pl.plot(steps, loss_list, c='b')
  time.sleep(1.0)

### Define training parameters

In [0]:
# Define number of training iterations and reporting intervals
TRAIN_ITERS = 100e3 #@param
REPORT_TRAIN_EVERY = 100 #@param
PLOT_EVERY = 500 #@param
REPORT_TEST_EVERY = 1000 #@param
TEST_ITERS = 100 #@param

### Training the model

In [0]:
# Create the session and initialize variables
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Question: What is the accuracy of the model at iteration 0, i.e. before training starts? 
train_iter = 0
losses = []
steps = []
for train_iter in range(int(TRAIN_ITERS)):
  _, train_loss_np, inp_img, tr_lbl = sess.run([training_op, train_loss, inp_train, batch_train_labels])
  
  if (train_iter % REPORT_TRAIN_EVERY) == 0:
    losses.append(train_loss_np)
    steps.append(train_iter)
  if (train_iter % PLOT_EVERY) == 0:
    plot_losses(losses, steps)    
    
  if (train_iter % REPORT_TEST_EVERY) == 0:
    avg_acc = 0.0
    train_avg_acc = 0.0
    for test_iter in range(TEST_ITERS):
      acc, acc_train = sess.run([test_acc_op, train_acc_op])
      avg_acc += acc
      train_avg_acc += acc_train
      
    avg_acc /= (TEST_ITERS)
    train_avg_acc /= (TEST_ITERS)
    print ('Test acc at iter {0:5d} out of {1:5d} is {2:.2f}%'.format(int(train_iter), int(TRAIN_ITERS), avg_acc*100.0))
    # print ('Train acc at iter {0:5d} out of {1:5d} is {2:.2f}%'.format(int(train_iter), int(TRAIN_ITERS), train_avg_acc*100.0))