In [None]:
import numpy as np
import time, math
from tqdm.notebook import tqdm

import tensorflow as tf
import tensorflow.contrib.eager as tfe

In [2]:
tf.enable_eager_execution()

In [3]:
BATCH_SIZE = 512
MOMENTUM = 0.9
LEARNING_RATE = 0.4
WEIGHT_DECAY = 5e-4
EPOCHS = 24

In [4]:
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
    fan = np.prod(shape[:-1])
    bound = 1 / math.sqrt(fan)
    return tf.random.uniform(shape, minval=-bound, maxval=bound,
                             dtype=dtype)

In [5]:
class ConvBN(tf.keras.Model):
    def __init__(self, c_out):
        super().__init__()
        self.conv = tf.keras.layers.Conv2D(filters=c_out,
                                           kernel_size=3,
                                           padding="SAME",
                                           kernel_initializer=init_pytorch,
                                           use_bias=False)
        self.bn = tf.keras.layers.BatchNormalization(momentum=0.9,
                                                     epsilon=1e-5)
        self.drop = tf.keras.layers.Dropout(0.05)
        
    def call(self, inputs):
        return tf.nn.relu(self.bn(self.drop(self.conv(inputs))))
    
        

In [6]:
class ResBlk(tf.keras.Model):
    def __init__(self, c_out, pool, res=False):
        super().__init__()
        self.conv_bn = ConvBN(c_out)
        self.pool = pool
        self.res = res
        if self.res:
            self.res1 = ConvBN(c_out)
            self.res2 = ConvBN(c_out)
    
    def call(self, inputs):
        h = self.pool(self.conv_bn(inputs))
        
        if self.res:
            h = h + self.res2(self.res1(h))
        return h

In [7]:
class DavidNet(tf.keras.Model):
    def __init__(self, c=64, weight=0.125):
        super().__init__()
        pool = tf.keras.layers.MaxPool2D()
        self.init_conv_bn = ConvBN(c)
        self.blk1 = ResBlk(c*2, pool, res=True)
        self.blk2 = ResBlk(c*4, pool)
        self.blk3 = ResBlk(c*8, pool, res=True)
        self.pool = tf.keras.layers.GlobalAvgPool2D()
        self.linear = tf.keras.layers.Dense(10,
                                            kernel_initializer=init_pytorch,
                                            use_bias=False)
        self.weight = weight
    
    def call(self, x, y):
        h = self.pool(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))
        h = self.linear(h) * self.weight
        ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
        loss = tf.reduce_sum(ce)
        correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis=1), y), tf.float32))
        
        return loss, correct

In [8]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
len_train, len_test = len(x_train), len(x_test)
y_train = y_train.astype('int64').reshape(len_train)
y_test = y_test.astype('int64').reshape(len_test)

train_mean = np.mean(x_train, axis=(0, 1, 2))
train_std = np.std(x_train, axis=(0, 1, 2))

normalize = lambda x : ((x - train_mean) / train_std).astype('float32')
pad4 = lambda x: np.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)],
                        mode='reflect')
print(x_train.shape)
x_train = normalize(pad4(x_train))
x_test = normalize(x_test)
print(x_train.shape)

(50000, 32, 32, 3)
(50000, 40, 40, 3)


In [9]:
model = DavidNet()

In [10]:
batches_per_epoch = len_train//BATCH_SIZE + 1
lr_schedule = lambda t : np.interp([t], [0, (EPOCHS + 1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda : lr_schedule(global_step/batches_per_epoch)/ BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y : (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)

In [11]:
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
    train_loss = test_loss = train_acc = test_acc = 0.0
    train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)
    
    tf.keras.backend.set_learning_phase(1)
    for (x, y) in tqdm(train_set):
        with tf.GradientTape() as tape:
            loss, correct = model(x, y)
            
        var = model.trainable_variables
        grads = tape.gradient(loss, var)
        for g, v in zip(grads, var):
            g += v * WEIGHT_DECAY * BATCH_SIZE
        opt.apply_gradients(zip(grads, var), global_step=global_step)
        train_loss += loss.numpy()
        train_acc += correct.numpy()
        
    tf.keras.backend.set_learning_phase(0)
    for (x, y) in test_set:
        loss, correct = model(x, y)
        test_loss += loss.numpy()
        test_acc += correct.numpy()
        
    print('epoch: ', epoch + 1, 'lr: ', lr_schedule(epoch + 1),
          'train_loss: ', train_loss/len_train,
          'train_acc: ', train_acc/len_train,
          'test_loss: ', test_loss/len_test,
          'test_acc: ', test_acc/len_test)

W1219 05:29:22.983031 139705006970688 module_wrapper.py:139] From /home/gauravp/anaconda3/envs/eip4/lib/python3.6/site-packages/tensorflow_core/python/autograph/converters/directives.py:119: The name tf.random_crop is deprecated. Please use tf.image.random_crop instead.



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  1 lr:  0.08 train_loss:  1.8289673193359375 train_acc:  0.32926 test_loss:  1.6231308197021483 test_acc:  0.4173


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  2 lr:  0.16 train_loss:  1.1198017279052734 train_acc:  0.59648 test_loss:  1.0625398498535157 test_acc:  0.6186


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  3 lr:  0.24 train_loss:  0.8053392712402344 train_acc:  0.71628 test_loss:  1.5090457336425782 test_acc:  0.5677


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  4 lr:  0.32 train_loss:  0.643346462097168 train_acc:  0.7759 test_loss:  1.6825183166503905 test_acc:  0.5365


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  5 lr:  0.4 train_loss:  0.5548943740844726 train_acc:  0.80814 test_loss:  0.6880031539916992 test_acc:  0.7785


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  6 lr:  0.37894736842105264 train_loss:  0.48420679779052733 train_acc:  0.83132 test_loss:  0.6044529357910157 test_acc:  0.8013


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  7 lr:  0.35789473684210527 train_loss:  0.41908712829589845 train_acc:  0.85548 test_loss:  0.8229438095092774 test_acc:  0.7493


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  8 lr:  0.33684210526315794 train_loss:  0.36626926025390627 train_acc:  0.87342 test_loss:  0.524585270690918 test_acc:  0.8323


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  9 lr:  0.31578947368421056 train_loss:  0.33487300689697264 train_acc:  0.88426 test_loss:  0.4781340881347656 test_acc:  0.8467


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  10 lr:  0.2947368421052632 train_loss:  0.29702436935424803 train_acc:  0.8982 test_loss:  0.4155321350097656 test_acc:  0.8636


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  11 lr:  0.2736842105263158 train_loss:  0.2748430215454102 train_acc:  0.90502 test_loss:  0.5174979858398437 test_acc:  0.8348


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  12 lr:  0.25263157894736843 train_loss:  0.2459860530090332 train_acc:  0.91568 test_loss:  0.3896016410827637 test_acc:  0.8727


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  13 lr:  0.23157894736842108 train_loss:  0.22293497985839844 train_acc:  0.92188 test_loss:  0.4208643295288086 test_acc:  0.8631


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  14 lr:  0.2105263157894737 train_loss:  0.20312972930908202 train_acc:  0.92944 test_loss:  0.4416583953857422 test_acc:  0.8654


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  15 lr:  0.18947368421052635 train_loss:  0.1823797004699707 train_acc:  0.93674 test_loss:  0.5224834365844726 test_acc:  0.8517


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  16 lr:  0.16842105263157897 train_loss:  0.16210940963745119 train_acc:  0.94332 test_loss:  0.33241577377319337 test_acc:  0.8923


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  17 lr:  0.1473684210526316 train_loss:  0.1479202377319336 train_acc:  0.9489 test_loss:  0.35800165252685545 test_acc:  0.8898


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  18 lr:  0.12631578947368421 train_loss:  0.12889591247558593 train_acc:  0.95652 test_loss:  0.4357103317260742 test_acc:  0.868


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  19 lr:  0.10526315789473689 train_loss:  0.11636688232421875 train_acc:  0.95956 test_loss:  0.36166288681030273 test_acc:  0.8909


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  20 lr:  0.08421052631578951 train_loss:  0.10240814102172852 train_acc:  0.96534 test_loss:  0.29917692413330077 test_acc:  0.9085


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  21 lr:  0.06315789473684214 train_loss:  0.09061344188690186 train_acc:  0.96968 test_loss:  0.29494773941040037 test_acc:  0.914


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  22 lr:  0.04210526315789476 train_loss:  0.07740202465057373 train_acc:  0.9751 test_loss:  0.2761667045593262 test_acc:  0.9161


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  23 lr:  0.02105263157894738 train_loss:  0.06746797229766846 train_acc:  0.97942 test_loss:  0.27044158821105957 test_acc:  0.9197


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch:  24 lr:  0.0 train_loss:  0.0627386039352417 train_acc:  0.98054 test_loss:  0.2598004585266113 test_acc:  0.9202
