<a href="https://colab.research.google.com/github/nrajmalwar/Project/blob/master/Session%2013/Assignment_13.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import libraries

In [0]:
import numpy as np
import time, math
from tqdm import tqdm_notebook as tqdm

import tensorflow as tf
import tensorflow.contrib.eager as tfe

In [13]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
# Enable eager execution in tensorflow
tf.enable_eager_execution()

# Hyperparameters

In [0]:
# Changing the batch size to 128
BATCH_SIZE = 128 #@param {type:"integer"}
MOMENTUM = 0.9 #@param {type:"number"}
LEARNING_RATE = 0.01 #@param {type:"number"}
WEIGHT_DECAY = 5e-4 #@param {type:"number"}
EPOCHS = 50 #@param {type:"integer"}

# Model Building

In [0]:
# Function to use values for kernel initializer
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
  fan = np.prod(shape[:-1])
  bound = 1 / math.sqrt(fan)
  return tf.random.uniform(shape, minval=-bound, maxval=bound, dtype=dtype)

In [0]:
# Class to perform Convolution2d, BatchNormalization and Relu activation
class ConvBN(tf.keras.Model):
  def __init__(self, c_out, k_size=3, stride=(1,1)):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=k_size, strides=stride, padding="same", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
    self.drop = tf.keras.layers.Dropout(0.05)

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.drop(self.conv(inputs))))

In [0]:
# Class to define one Resnet Block
class ResBlk(tf.keras.Model):
  def __init__(self, c_out, res = False):
    super().__init__()
    self.conv_bn = ConvBN(c_out, k_size=1)
    self.res = res
    if self.res:
      self.res1 = ConvBN(c_out)
      self.res2 = ConvBN(c_out)
      self.res3 = ConvBN(c_out)
      self.res4 = ConvBN(c_out)

  # Adding the shortcut connection
  def call(self, inputs):
    h = self.conv_bn(inputs)
    if self.res:
      # Repeating one block twice
      h = h + self.res2(self.res1(h))
      h = h + self.res4(self.res3(h))
    return h

In [0]:
# ResNet18 Model architecture
class ResNet18(tf.keras.Model):
  def __init__(self, c=32, weight=0.125):
    super().__init__()
    # First layer
    # In the defined architecture, it used 7x7 kernel with stride 2
    # We're using 3x3 kernel with stride 1, because the input resolution is just 32x32
    self.init_conv_bn = ConvBN(c, 3, (1,1))

    # Originally used MaxPooling with kernel 3x3 and stride 2
    # We're using MaxPooling with kernel 2x2 and stride 2, gives an output resolution of 16x16
    self.pool1 = tf.keras.layers.MaxPooling2D()

    # First resnet block with 32 kernels, which is used twice. The block is repeated in the class ResBlk
    self.blk1 = ResBlk(c, res = True)
    # Second resnet block with 64 kernels
    self.blk2 = ResBlk(c*2, res = True)
    # Third resnet block with 128 kernels
    self.blk3 = ResBlk(c*4, res = True)
    # Fourth resnet block with 256 kernels
    self.blk4 = ResBlk(c*8, res = True)
    # Average pooling of all the pixels along the channels
    self.pool = tf.keras.layers.GlobalAveragePooling2D()
    # Connect to a dense layer of 10 units
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2((self.blk1(self.pool1(self.init_conv_bn(x))))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

In [0]:
# Load the dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
len_train, len_test = len(x_train), len(x_test)
y_train = y_train.astype('int64').reshape(len_train)
y_test = y_test.astype('int64').reshape(len_test)

# We're not using the calculated normalization values
# train_mean = np.mean(x_train, axis=(0,1,2))
# train_std = np.std(x_train, axis=(0,1,2))

# Normalization values provided in the assignment
train_mean = (0.4914, 0.4822, 0.4465)
train_std = (0.2023, 0.1994, 0.2010)

normalize = lambda x: ((x - train_mean) / train_std).astype('float32') # todo: check here
#Padding with 4px
pad4 = lambda x: np.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)], mode='reflect')

x_train = normalize(pad4(x_train))
x_test = normalize(x_test)

In [0]:
model = ResNet18()
# Use SGD optimizer with momentum
opt = tf.keras.optimizers.SGD(learning_rate=LEARNING_RATE, momentum=MOMENTUM, nesterov=True)
# Do random crop of 32, and horizontal flip
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)

# Model Training

In [22]:
test_acc_prev = 0
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

# Model Training
for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var))

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()

  # Save the best model
  if test_acc > test_acc_prev:
    model.save_weights('/content/drive/My Drive/Colab Notebooks/EVA/Session 13/my_model') 
    test_acc_prev = test_acc

  print('epoch:', epoch+1, 'lr:', 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: train loss: 1.6955046745300293 train acc: 0.34928 val loss: 2.1069267730712893 val acc: 0.3394 time: 74.14128971099854


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: train loss: 1.279337512664795 train acc: 0.53222 val loss: 1.2382920316696167 val acc: 0.5519 time: 148.75813174247742


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: train loss: 1.1015299655151367 train acc: 0.60204 val loss: 1.699995288848877 val acc: 0.454 time: 222.81707072257996


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: train loss: 0.9929155990600586 train acc: 0.64402 val loss: 1.4562041244506836 val acc: 0.5547 time: 296.4142792224884


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: train loss: 0.916309303741455 train acc: 0.67156 val loss: 1.1219068881988525 val acc: 0.6323 time: 370.6456456184387


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: train loss: 0.8628193586730957 train acc: 0.69394 val loss: 0.999749201965332 val acc: 0.6495 time: 445.82898902893066


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: train loss: 0.8027659367370605 train acc: 0.7184 val loss: 1.1066302982330323 val acc: 0.6582 time: 519.8961760997772


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: train loss: 0.757985344543457 train acc: 0.73378 val loss: 0.997033524608612 val acc: 0.675 time: 593.6638534069061


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: train loss: 0.7127339073181153 train acc: 0.74996 val loss: 0.9594331497192383 val acc: 0.7051 time: 667.8821392059326


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: train loss: 0.6778335984802246 train acc: 0.76452 val loss: 0.7845813182830811 val acc: 0.7418 time: 742.806478023529


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: train loss: 0.6458773269653321 train acc: 0.77414 val loss: 0.7975868702888489 val acc: 0.7308 time: 817.516340970993


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: train loss: 0.6129593707656861 train acc: 0.78978 val loss: 0.9095335754394531 val acc: 0.7064 time: 891.7544531822205


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: train loss: 0.5865533914947509 train acc: 0.79638 val loss: 0.836979676246643 val acc: 0.7364 time: 965.979731798172


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: train loss: 0.5695428606414795 train acc: 0.8034 val loss: 0.7299232934951783 val acc: 0.7716 time: 1040.6315279006958


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: train loss: 0.5484435848999023 train acc: 0.81014 val loss: 0.7300183179855346 val acc: 0.7718 time: 1115.4183056354523


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: train loss: 0.5307934096527099 train acc: 0.81548 val loss: 0.5490532110691071 val acc: 0.8123 time: 1188.537669658661


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: train loss: 0.5100941340637207 train acc: 0.82248 val loss: 0.5823674599647521 val acc: 0.8113 time: 1261.318788766861


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: train loss: 0.491742587890625 train acc: 0.82908 val loss: 0.6032387234687805 val acc: 0.8156 time: 1334.28524684906


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: train loss: 0.4800326528930664 train acc: 0.83468 val loss: 0.5429653169631958 val acc: 0.8267 time: 1407.6623814105988


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: train loss: 0.46977256271362305 train acc: 0.83802 val loss: 0.6803572854995728 val acc: 0.777 time: 1480.3959517478943


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: train loss: 0.45780005905151366 train acc: 0.84118 val loss: 0.6422099924087524 val acc: 0.799 time: 1553.170037984848


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: train loss: 0.44255664779663084 train acc: 0.84814 val loss: 0.6511699142932892 val acc: 0.7921 time: 1625.857398033142


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: train loss: 0.4283175271987915 train acc: 0.85166 val loss: 0.6213892225265503 val acc: 0.808 time: 1699.3991377353668


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: train loss: 0.41811118980407713 train acc: 0.85594 val loss: 0.5498752805709839 val acc: 0.8218 time: 1771.9357256889343


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 25 lr: train loss: 0.41119349281311035 train acc: 0.85914 val loss: 0.518316900396347 val acc: 0.8315 time: 1846.0149552822113


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 26 lr: train loss: 0.39574644947052 train acc: 0.86302 val loss: 0.6326039533615112 val acc: 0.8096 time: 1919.1798431873322


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 27 lr: train loss: 0.3892643804550171 train acc: 0.86494 val loss: 0.5171526236534119 val acc: 0.8379 time: 1993.0145082473755


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 28 lr: train loss: 0.3794901997756958 train acc: 0.86784 val loss: 0.46764779987335203 val acc: 0.8474 time: 2066.5663862228394


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 29 lr: train loss: 0.3679442000198364 train acc: 0.8701 val loss: 0.5237010236740113 val acc: 0.8412 time: 2140.2664930820465


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 30 lr: train loss: 0.35845392658233644 train acc: 0.87566 val loss: 0.5233635840415954 val acc: 0.8471 time: 2214.509661436081


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 31 lr: train loss: 0.35217904399871824 train acc: 0.87714 val loss: 0.5160766987800598 val acc: 0.8429 time: 2288.4467856884003


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 32 lr: train loss: 0.3438318761062622 train acc: 0.88264 val loss: 0.6317338792800903 val acc: 0.8272 time: 2362.875613927841


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 33 lr: train loss: 0.3409166767501831 train acc: 0.88064 val loss: 0.44321422863006593 val acc: 0.8588 time: 2436.731882572174


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 34 lr: train loss: 0.3272658473587036 train acc: 0.8857 val loss: 0.6823587480545044 val acc: 0.8073 time: 2510.665113925934


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 35 lr: train loss: 0.32516290866851805 train acc: 0.88766 val loss: 0.5606253541946411 val acc: 0.8288 time: 2584.4183151721954


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 36 lr: train loss: 0.3163463254547119 train acc: 0.88834 val loss: 0.4854100999832153 val acc: 0.8502 time: 2659.211688518524


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 37 lr: train loss: 0.310328168296814 train acc: 0.89136 val loss: 0.5215570050239563 val acc: 0.8457 time: 2732.931015253067


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 38 lr: train loss: 0.3045850645446777 train acc: 0.89314 val loss: 0.42795965700149535 val acc: 0.8669 time: 2807.2147369384766


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 39 lr: train loss: 0.29710972080230713 train acc: 0.89558 val loss: 0.6011090332984924 val acc: 0.8245 time: 2880.7953588962555


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 40 lr: train loss: 0.2883019812774658 train acc: 0.89854 val loss: 0.5144401002883912 val acc: 0.8473 time: 2954.657609939575


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 41 lr: train loss: 0.2854562558364868 train acc: 0.90046 val loss: 0.5028400290489197 val acc: 0.8484 time: 3028.2953128814697


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 42 lr: train loss: 0.27995041009902955 train acc: 0.90174 val loss: 0.47690606899261473 val acc: 0.8585 time: 3103.190388917923


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 43 lr: train loss: 0.2705476710510254 train acc: 0.9043 val loss: 0.4096337494134903 val acc: 0.8771 time: 3177.580826282501


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 44 lr: train loss: 0.2681604339981079 train acc: 0.90772 val loss: 0.48553410506248473 val acc: 0.8571 time: 3252.037034034729


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 45 lr: train loss: 0.2625629962158203 train acc: 0.90754 val loss: 0.4654635350227356 val acc: 0.8644 time: 3324.9861335754395


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 46 lr: train loss: 0.2581636982727051 train acc: 0.90756 val loss: 0.5289528499126435 val acc: 0.8463 time: 3398.3123137950897


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 47 lr: train loss: 0.25018366138458253 train acc: 0.91166 val loss: 0.40377490119934084 val acc: 0.8753 time: 3471.0546083450317


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 48 lr: train loss: 0.25042452993392944 train acc: 0.91202 val loss: 0.4431751935005188 val acc: 0.8678 time: 3544.7276723384857


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 49 lr: train loss: 0.23935777126312255 train acc: 0.91558 val loss: 0.45910451550483705 val acc: 0.8646 time: 3617.98685503006


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 50 lr: train loss: 0.23767643747329711 train acc: 0.91664 val loss: 0.5386530284881592 val acc: 0.8539 time: 3692.19700050354
