In [1]:
import numpy as np
import time, math
from tqdm import tqdm_notebook as tqdm

import tensorflow as tf
import tensorflow.contrib.eager as tfe

import matplotlib.pyplot as plt

In [2]:
tf.enable_eager_execution()

In [3]:
BATCH_SIZE = 512 #@param {type:"integer"}
MOMENTUM = 0.9 #@param {type:"number"}
LEARNING_RATE = 0.4 #@param {type:"number"}
WEIGHT_DECAY = 5e-4 #@param {type:"number"}
EPOCHS =  20#@param {type:"integer"}

In [4]:
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
  fan = np.prod(shape[:-1])
  bound = 1 / math.sqrt(fan)
  return tf.random.uniform(shape, minval=-bound, maxval=bound, dtype=dtype)

In [5]:
class ConvBN(tf.keras.Model):
  def __init__(self, c_out):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.conv(inputs)))

In [6]:
class ResBlk(tf.keras.Model):
  def __init__(self, c_out, pool, res = False):
    super().__init__()
    self.conv_bn = ConvBN(c_out)
    self.pool = pool
    self.res = res
    if self.res:
      self.res1 = ConvBN(c_out)
      self.res2 = ConvBN(c_out)

  def call(self, inputs):
    h = self.pool(self.conv_bn(inputs))
    if self.res:
      h = h + self.res2(self.res1(h))
    return h

In [7]:
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = ConvBN(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight

  def call(self, x, y):
    h = self.pool(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

In [8]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import numpy as np
import matplotlib.pyplot as plt

from tensorflow.python.keras import backend as K
from tensorflow.python.keras.datasets.cifar import load_batch
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.util.tf_export import keras_export


# @keras_export('keras.datasets.cifar10.load_data')
def load_data():
  """Loads CIFAR10 dataset.
  Returns:
      Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
  """
  dirname = 'cifar-10-batches-py'
  origin = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
  path = get_file(dirname, origin=origin, untar=True)
#   path = dirname #fix
  num_train_samples = 50000

  x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')
  y_train = np.empty((num_train_samples,), dtype='uint8')

  for i in range(1, 6):
    fpath = os.path.join(path, 'data_batch_' + str(i))
    (x_train[(i - 1) * 10000:i * 10000, :, :, :],
     y_train[(i - 1) * 10000:i * 10000]) = load_batch(fpath)

  fpath = os.path.join(path, 'test_batch')
  x_test, y_test = load_batch(fpath)

  y_train = np.reshape(y_train, (len(y_train), 1))
  y_test = np.reshape(y_test, (len(y_test), 1))

  if K.image_data_format() == 'channels_last':
    x_train = x_train.transpose(0, 2, 3, 1)
    x_test = x_test.transpose(0, 2, 3, 1)

  return (x_train, y_train), (x_test, y_test)

In [9]:
(x_train, y_train), (x_test, y_test) = load_data()
len_train, len_test = len(x_train), len(x_test)
y_train = y_train.astype('int64').reshape(len_train)
y_test = y_test.astype('int64').reshape(len_test)

# train_mean = np.mean(x_train, axis=(0,1,2))
# train_std = np.std(x_train, axis=(0,1,2))

normalize = lambda x: ((x - [125.31, 122.95, 113.87]) / [62.99, 62.09, 66.70]).astype('float32') ## custom normalization values
pad4 = lambda x: np.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)], mode='reflect')

x_train = normalize(pad4(x_train))
x_test = normalize(x_test)

In [10]:
(x_train.shape, y_train.shape), (x_test.shape, y_test.shape)

(((50000, 40, 40, 3), (50000,)), ((10000, 32, 32, 3), (10000,)))

In [11]:
### cutout
p=0.5
s_l=0.02
s_h=0.4
r_1=0.3
r_2=1/0.3
v_l=0
v_h=1
pixel_level=True
def eraser(input_img):
  img_h, img_w, img_c = input_img.shape
  p_1 = np.random.rand()

  if p_1 > p:
      return input_img

  while True:
      s = np.random.uniform(s_l, s_h) * img_h * img_w
      r = np.random.uniform(r_1, r_2)
      w = int(np.sqrt(s / r))
      h = int(np.sqrt(s * r))
      left = np.random.randint(0, img_w)
      top = np.random.randint(0, img_h)

      if left + w <= img_w and top + h <= img_h:
          break

  if pixel_level:
      c = np.random.uniform(v_l, v_h, (h, w, img_c))
  else:
      c = np.random.uniform(v_l, v_h)

  input_img[top:top + h, left:left + w, :] = c

  return input_img

In [12]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//4, EPOCHS], [0, LEARNING_RATE, 0])[0]   ### adjusted learning rate for 20 epochs
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(eraser(x), [32, 32, 3])), y) ## with cutout

test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

In [13]:
### MIXUP

class MixupGenerator():
    def __init__(self, X_train, y_train, batch_size=512, alpha=0.35, shuffle=True, datagen=None):
        self.X_train = X_train
        self.y_train = y_train
        self.batch_size = batch_size
        self.alpha = alpha
        self.shuffle = shuffle
        self.sample_num = len(X_train)
        self.datagen = datagen

    def __call__(self):
        while True:
            indexes = self.__get_exploration_order()
            itr_num = int(len(indexes) // (self.batch_size * 2))

            for i in range(itr_num):
                batch_ids = indexes[i * self.batch_size * 2:(i + 1) * self.batch_size * 2]
                X, y = self.__data_generation(batch_ids)

                yield X, y

    def __get_exploration_order(self):
        indexes = np.arange(self.sample_num)

        if self.shuffle:
            np.random.shuffle(indexes)

        return indexes

    def __data_generation(self, batch_ids):
        _, h, w, c = self.X_train.shape
        l = np.random.beta(self.alpha, self.alpha, self.batch_size)
        X_l = l.reshape(self.batch_size, 1, 1, 1)
        y_l = l.reshape(self.batch_size, 1)

        X1 = self.X_train[batch_ids[:self.batch_size]]
        X2 = self.X_train[batch_ids[self.batch_size:]]
        X = X1 * X_l + X2 * (1 - X_l)

        if self.datagen:
            for i in range(self.batch_size):
                X[i] = self.datagen.random_transform(X[i])
                X[i] = self.datagen.standardize(X[i])

        if isinstance(self.y_train, list):
            y = []

            for y_train_ in self.y_train:
                y1 = y_train_[batch_ids[:self.batch_size]]
                y2 = y_train_[batch_ids[self.batch_size:]]
                y.append(y1 * y_l + y2 * (1 - y_l))
        else:
            y1 = self.y_train[batch_ids[:self.batch_size]]
            y2 = self.y_train[batch_ids[self.batch_size:]]
            y = y1 * y_l + y2 * (1 - y_l)

        return X, y

In [14]:
datagen = tf.keras.preprocessing.image.ImageDataGenerator(preprocessing_function=MixupGenerator(x_train,y_train))
datagen.fit(x_train) ## fitting mixup generator

In [16]:
train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

t = time.time()

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  
  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.5673791424560546 train acc: 0.43296 val loss: 1.1037547515869142 val acc: 0.5988 time: 25.25158452987671


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8314814947509765 train acc: 0.703 val loss: 0.9692330718994141 val acc: 0.6746 time: 45.889484167099


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6351113354492187 train acc: 0.77852 val loss: 0.769729443359375 val acc: 0.7456 time: 66.39383506774902


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5458097592163086 train acc: 0.81184 val loss: 0.8992500427246094 val acc: 0.7248 time: 86.9902753829956


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.47441636596679687 train acc: 0.83776 val loss: 0.8303089157104492 val acc: 0.7446 time: 108.11946868896484


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37333333333333335 train loss: 0.39200757537841796 train acc: 0.86554 val loss: 0.4611866455078125 val acc: 0.8464 time: 128.452486038208


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3466666666666667 train loss: 0.3127730975341797 train acc: 0.89156 val loss: 0.36985001754760743 val acc: 0.8766 time: 148.90296411514282


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.32 train loss: 0.2635659066772461 train acc: 0.90778 val loss: 0.37159037322998045 val acc: 0.875 time: 169.6071376800537


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29333333333333333 train loss: 0.2252009294128418 train acc: 0.92164 val loss: 0.3725972038269043 val acc: 0.8775 time: 190.02497100830078


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.2666666666666667 train loss: 0.19296568939208986 train acc: 0.93436 val loss: 0.33903170318603515 val acc: 0.8866 time: 210.64727449417114


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24000000000000002 train loss: 0.1690389906311035 train acc: 0.9424 val loss: 0.2967694290161133 val acc: 0.9014 time: 231.1677188873291


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.21333333333333335 train loss: 0.1428063077545166 train acc: 0.95098 val loss: 0.28631438865661624 val acc: 0.912 time: 251.88637137413025


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.18666666666666668 train loss: 0.1242778050994873 train acc: 0.95786 val loss: 0.2549659557342529 val acc: 0.9204 time: 272.8673470020294


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.16 train loss: 0.09915962455749512 train acc: 0.96634 val loss: 0.2740080863952637 val acc: 0.9181 time: 293.57465648651123


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.13333333333333336 train loss: 0.08518942386627197 train acc: 0.97222 val loss: 0.26593612747192386 val acc: 0.92 time: 314.8794147968292


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.10666666666666669 train loss: 0.07047460510253906 train acc: 0.97744 val loss: 0.25288733749389647 val acc: 0.9269 time: 335.0034348964691


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.08000000000000002 train loss: 0.057721946964263915 train acc: 0.9823 val loss: 0.24616375274658203 val acc: 0.932 time: 354.95211958885193


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.053333333333333344 train loss: 0.04621327480316162 train acc: 0.98608 val loss: 0.23305194969177245 val acc: 0.9369 time: 375.40617775917053


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.026666666666666672 train loss: 0.03916914981842041 train acc: 0.9886 val loss: 0.22364716024398804 val acc: 0.9396 time: 395.8001956939697


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.0 train loss: 0.03410941864013672 train acc: 0.9906 val loss: 0.2155842408180237 val acc: 0.941 time: 416.04638719558716
