Importing required libraries

*   numpy
*   time
*   math
*   tqdm_notebook (tqdm is a Progress Bar)
*   tensorflow

In [1]:
import numpy as np
import time, math
from tqdm import tqdm_notebook as tqdm

import tensorflow as tf
import tensorflow.contrib.eager as tfe

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



Using Eager Mode of TesorFlow. In Eager mode, functions execute immediately and return the values

In [0]:
tf.enable_eager_execution()

In [0]:
BATCH_SIZE = 512 #@param {type:"integer"}
MOMENTUM = 0.9 #@param {type:"number"}
LEARNING_RATE = 0.4 #@param {type:"number"}
WEIGHT_DECAY = 5e-4 #@param {type:"number"}
EPOCHS = 24 #@param {type:"integer"}

https://mc.ai/tutorial-1-cifar10-with-google-colabs-free-gpu%E2%80%8A-%E2%80%8A92-5/

Function to return a tensor of 'shape' filled with random values from a uniform distribution between minval and maxval (i.e between -bound and +bound)

fan : the product of the input elements  
bound : inverse square root of fan

In [0]:
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
  fan = np.prod(shape[:-1])
  bound = 1 / math.sqrt(fan)
  return tf.random.uniform(shape, minval=-bound, maxval=bound, dtype=dtype)

Convolution Block :   
A Convolution Layer followed by BatchNormalization Layer and ReLu 

*   The Conv2D layer uses previously defined function for initialising the kernel with random values from a uniform distribution (Kernel size = 3x3)
*   momentum 0.9 and epsilon 1e-5 is used in BatchNormalization to match the PyTorch implementation of DavidNet
*   Dropout of 0.05 is used


In [0]:
class ConvBN(tf.keras.Model):
  def __init__(self, c_out):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
    self.drop = tf.keras.layers.Dropout(0.05)

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.drop(self.conv(inputs))))

Residual Block :  
Conv-BN-ReLu Block followed by 2x2 MaxPooling and 2 optional Conv-BN-ReLu blocks with a residual connection

In [0]:
class ResBlk(tf.keras.Model):
  def __init__(self, c_out, pool, res = False):
    super().__init__()
    self.conv_bn = ConvBN(c_out)
    self.pool = pool
    self.res = res
    if self.res:
      self.res1 = ConvBN(c_out)
      self.res2 = ConvBN(c_out)

  def call(self, inputs):
    h = self.pool(self.conv_bn(inputs))
    if self.res:
      h = h + self.res2(self.res1(h))
    return h

DavidNet :  

*   Conv2D-BN-ReLu 
*   Residual Block 1 with a residual connection
*   Residual Block 2 without a residual connection 
*   Residual Block 3 with a residual connection
*   Global Max Pooling 
*   Fully Connected Layer
*   Multiply output of fully connected layer by 0.125






In [0]:
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = ConvBN(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight

  def call(self, x, y):
    h = self.pool(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

Loading CIFAR-10 dataset  



*   normalize : subtract the image by mean and divide by standard deviation 
*   pad4 : Padding 4 pixels to make the input image 40x40 


 


In [0]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
len_train, len_test = len(x_train), len(x_test)
y_train = y_train.astype('int64').reshape(len_train)
y_test = y_test.astype('int64').reshape(len_test)

train_mean = np.mean(x_train, axis=(0,1,2))
train_std = np.std(x_train, axis=(0,1,2))

normalize = lambda x: ((x - train_mean) / train_std).astype('float32') # todo: check here
pad4 = lambda x: np.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)], mode='reflect')

x_train = normalize(pad4(x_train))
x_test = normalize(x_test)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


Instantiating the DavidNet model and defining functions for scheduling the learning rate, choosing the optimiser and data augmentation technique

*   lr_schedule : evaluates the learning rate as the linear interpolant values at t (taking EPOCHS along x-axis and LEARNING_RATE along y-axis)  
increasing data point values along x-axis are in the range [0,(epochs+1)/epochs, epochs]. The corresponding LR values for this is a one cycle from 0 to LEARNING_RATE and goes back to 0. Forming an isosceles triangle.
*   global_step : the number of batches seen by the graph. Every time a batch is provided, the weights are updated in the direction that minimizes the loss
*   lr_func : Dividing the learning rate by batch size which is 512. (This is because loss is summed instead of average in DavidNet which has scaled up the loss and gradient by 512)





In [0]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)

Training and testing the model

*   tf.data.Dataset.from_tensor_slices is a method which returns a Dataset. It takes tensors as input. 
*   Dataset.map() is used to apply per element transformation  
*   Dataset.batch() is used to apply multi-element transformations 
*   tf.keras.backend.set_learning_phase(1) denotes training phase
*   tf.keras.backend.set_learning_phase(0) denotes testing phase
*   the tf.GradientTape API is used for automatic differentiation - computing the gradient of a computation with respect to its input variables. 
*   WEIGHT_DECAY is scaled down by lr_func. Thus we upscale it by multiplying with batch size  
*   The gradients computed using GradientTape are applied  to the optimizer using opt.apply_gradients()

In [0]:

t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.6219283795166015 train acc: 0.4088 val loss: 1.3596600708007813 val acc: 0.5421 time: 37.37372827529907


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8898376220703125 train acc: 0.68372 val loss: 0.8799490173339843 val acc: 0.6921 time: 66.53042817115784


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6744012210083008 train acc: 0.76492 val loss: 0.6637866775512695 val acc: 0.7724 time: 95.55683445930481


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5680239816284179 train acc: 0.80318 val loss: 0.9503156463623047 val acc: 0.7089 time: 124.48915910720825


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.510378998413086 train acc: 0.82436 val loss: 0.8559248016357421 val acc: 0.7625 time: 153.36513471603394


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37894736842105264 train loss: 0.41239007385253906 train acc: 0.85692 val loss: 0.5525562118530274 val acc: 0.8158 time: 182.28042149543762


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.35789473684210527 train loss: 0.34824659439086914 train acc: 0.88084 val loss: 0.4841607467651367 val acc: 0.8403 time: 211.0377643108368


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.33684210526315794 train loss: 0.29988678756713866 train acc: 0.89642 val loss: 0.38255838317871094 val acc: 0.8702 time: 239.8541920185089


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.31578947368421056 train loss: 0.2589371276855469 train acc: 0.90964 val loss: 0.4507476753234863 val acc: 0.8587 time: 268.5690622329712


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.2947368421052632 train loss: 0.22727200714111329 train acc: 0.92154 val loss: 0.33030320434570315 val acc: 0.8931 time: 297.5945861339569


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.2736842105263158 train loss: 0.19945408248901367 train acc: 0.92966 val loss: 0.40582934799194337 val acc: 0.8701 time: 326.2247312068939


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.25263157894736843 train loss: 0.17665268104553222 train acc: 0.9378 val loss: 0.29891041107177735 val acc: 0.9014 time: 354.8569004535675


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.23157894736842108 train loss: 0.15444558387756346 train acc: 0.94592 val loss: 0.4297833343505859 val acc: 0.8671 time: 383.4799575805664


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.2105263157894737 train loss: 0.13676265922546388 train acc: 0.95258 val loss: 0.28208601913452147 val acc: 0.9115 time: 412.3975296020508


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.18947368421052635 train loss: 0.12060557788848877 train acc: 0.95866 val loss: 0.3070586860656738 val acc: 0.9039 time: 440.9925649166107


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.16842105263157897 train loss: 0.10169705291748046 train acc: 0.9659 val loss: 0.3177379936218262 val acc: 0.9078 time: 469.6459541320801


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.1473684210526316 train loss: 0.09139842258453369 train acc: 0.96906 val loss: 0.3025899299621582 val acc: 0.9068 time: 498.2742748260498


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.12631578947368421 train loss: 0.07664332817077636 train acc: 0.97438 val loss: 0.30393065643310546 val acc: 0.91 time: 526.80393242836


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.10526315789473689 train loss: 0.06705754245758057 train acc: 0.97868 val loss: 0.2588904460906982 val acc: 0.9226 time: 555.4115362167358


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.08421052631578951 train loss: 0.05834117286682129 train acc: 0.98134 val loss: 0.2559453193664551 val acc: 0.9229 time: 583.9860785007477


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.06315789473684214 train loss: 0.04809682218551636 train acc: 0.98454 val loss: 0.26116416130065917 val acc: 0.923 time: 612.7558953762054


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.04210526315789476 train loss: 0.04236889501571655 train acc: 0.98692 val loss: 0.25088658485412596 val acc: 0.9264 time: 641.4947671890259


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.02105263157894738 train loss: 0.03534799867630005 train acc: 0.99004 val loss: 0.2448337703704834 val acc: 0.9293 time: 670.0616793632507


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.0 train loss: 0.03465892559051514 train acc: 0.9901 val loss: 0.24664610748291016 val acc: 0.9301 time: 698.6935245990753
