<a href="https://colab.research.google.com/github/neethipoonacha/EIP3/blob/master/Copy_of_DN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import time, math
from tqdm import tqdm_notebook as tqdm

import tensorflow as tf
import tensorflow.contrib.eager as tfe

In [0]:
tf.enable_eager_execution()


In [0]:
BATCH_SIZE = 512 #@param {type:"integer"}
MOMENTUM = 0.9 #@param {type:"number"}
LEARNING_RATE = 0.45 #@param {type:"number"}
WEIGHT_DECAY = 5e-4 #@param {type:"number"}
EPOCHS = 24 #@param {type:"integer"}

In [0]:
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
  fan = np.prod(shape[:-1])
  bound = 1 / math.sqrt(fan)
  return tf.random.uniform(shape, minval=-bound, maxval=bound, dtype=dtype)

In [0]:
class ConvBN(tf.keras.Model):
  def __init__(self, c_out):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.conv(inputs)))

In [0]:
class ResBlk(tf.keras.Model):
  def __init__(self, c_out, pool, res = False):
    super().__init__()
    self.conv_bn = ConvBN(c_out)
    self.pool = pool
    self.res = res
    if self.res:
      self.res1 = ConvBN(c_out)
      self.res2 = ConvBN(c_out)

  def call(self, inputs):
    h = self.pool(self.conv_bn(inputs))
    if self.res:
      h = h + self.res2(self.res1(h))
    return h

In [0]:
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = ConvBN(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight

  def call(self, x, y):
    h = self.pool(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

In [8]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
len_train, len_test = len(x_train), len(x_test)
y_train = y_train.astype('int64').reshape(len_train)
y_test = y_test.astype('int64').reshape(len_test)

train_mean = np.mean(x_train, axis=(0,1,2))
train_std = np.std(x_train, axis=(0,1,2))

normalize = lambda x: ((x - train_mean) / train_std).astype('float32') # todo: check here
pad4 = lambda x: np.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)], mode='reflect')

x_train = normalize(pad4(x_train))
x_test = normalize(x_test)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [0]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)

In [10]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.09 train loss: 1.5584104125976563 train acc: 0.43534 val loss: 1.2922035705566406 val acc: 0.5436 time: 48.112001180648804


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.18 train loss: 0.8319435290527344 train acc: 0.70546 val loss: 0.9736420166015625 val acc: 0.6834 time: 83.4289767742157


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.27 train loss: 0.6381332528686523 train acc: 0.78028 val loss: 0.7805060745239257 val acc: 0.7391 time: 118.36550378799438


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.36 train loss: 0.5452634252929688 train acc: 0.8119 val loss: 0.9472632293701172 val acc: 0.7194 time: 153.11157155036926


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.45 train loss: 0.4797203634643555 train acc: 0.83454 val loss: 0.6021984848022461 val acc: 0.8003 time: 188.17966485023499


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.4263157894736842 train loss: 0.39974306732177733 train acc: 0.86064 val loss: 0.6010054397583008 val acc: 0.8114 time: 223.02980971336365


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.40263157894736845 train loss: 0.31525009185791014 train acc: 0.89148 val loss: 0.40362328186035157 val acc: 0.864 time: 257.961181640625


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.37894736842105264 train loss: 0.2667337287902832 train acc: 0.9069 val loss: 0.49140782318115234 val acc: 0.8459 time: 293.0248303413391


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.35526315789473684 train loss: 0.23334291320800782 train acc: 0.91922 val loss: 0.35816749725341795 val acc: 0.8832 time: 327.93345308303833


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.3315789473684211 train loss: 0.19865652252197266 train acc: 0.93074 val loss: 0.47629331436157224 val acc: 0.8625 time: 362.8189058303833


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.3078947368421053 train loss: 0.17565225715637206 train acc: 0.9375 val loss: 0.4199491394042969 val acc: 0.8709 time: 397.68713116645813


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.28421052631578947 train loss: 0.15158911376953124 train acc: 0.94668 val loss: 0.3228204746246338 val acc: 0.8981 time: 432.6622145175934


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.2605263157894737 train loss: 0.1280238045501709 train acc: 0.95584 val loss: 0.335256893157959 val acc: 0.8952 time: 467.5586693286896


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.2368421052631579 train loss: 0.1090817115020752 train acc: 0.96276 val loss: 0.3158982414245605 val acc: 0.9033 time: 502.42597460746765


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.2131578947368421 train loss: 0.09376337188720703 train acc: 0.96796 val loss: 0.31522478256225583 val acc: 0.9044 time: 537.2902143001556


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.1894736842105263 train loss: 0.07698083969116211 train acc: 0.97448 val loss: 0.30342591972351074 val acc: 0.9072 time: 572.2188241481781


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.16578947368421054 train loss: 0.06634500980377198 train acc: 0.97832 val loss: 0.29709734268188476 val acc: 0.9129 time: 607.19904088974


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.14210526315789473 train loss: 0.05405950391769409 train acc: 0.98266 val loss: 0.27794020385742185 val acc: 0.9207 time: 642.0326807498932


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.11842105263157893 train loss: 0.044616647510528566 train acc: 0.98652 val loss: 0.27083468246459963 val acc: 0.9211 time: 676.8724899291992


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.09473684210526317 train loss: 0.03517359901428223 train acc: 0.98962 val loss: 0.2763101615905762 val acc: 0.9227 time: 711.7292401790619


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.07105263157894737 train loss: 0.02913374098777771 train acc: 0.9919 val loss: 0.26464235343933107 val acc: 0.9269 time: 746.589349269867


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.04736842105263156 train loss: 0.02581085828781128 train acc: 0.99304 val loss: 0.2656492668151855 val acc: 0.9269 time: 781.4544138908386


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.023684210526315808 train loss: 0.022002239084243775 train acc: 0.99426 val loss: 0.2577177505493164 val acc: 0.9274 time: 816.3312058448792


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.0 train loss: 0.01923098879814148 train acc: 0.99554 val loss: 0.2583326156616211 val acc: 0.9284 time: 851.1878259181976


In [11]:
score = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

RuntimeError: ignored