In [1]:
!pip install tensorflow_addons



In [2]:
import tensorflow as tf
from tensorflow.keras import (layers, models, optimizers, Model,
                              activations, datasets, regularizers)
import tensorflow.keras.backend as K
import tensorflow_addons as tfa
from keras.utils import np_utils
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
import glob
from sklearn.model_selection import train_test_split
from scipy.io import loadmat
%matplotlib inline
from tqdm import tqdm

In [3]:
# params
EPOCHS = 5
BATCH_SIZE = 128
LR_MAX = 1e-3
LR_MIN = 1e-5
WARMUP_EPOCH = 10
START_ANNEALING_EPOCH = 10
END_ANNEALING_EPOCH = 100

In [4]:
workspace_dir = 'drive/MyDrive/Colab Notebooks/d-SNE'
if os.getcwd() != workspace_dir:
  os.chdir(workspace_dir)

## load datasets

In [5]:
# load
(x_train_mnist, y_train_mnist),(x_test_mnist, y_test_mnist) = datasets.mnist.load_data()

x_train_mnist = [cv2.resize(x, (32,32), interpolation=cv2.INTER_AREA) for x in x_train_mnist]
x_test_mnist = [cv2.resize(x, (32,32), interpolation=cv2.INTER_AREA) for x in x_test_mnist]
# to 3ch
x_train_mnist = np.array(x_train_mnist)[...,np.newaxis]
x_test_mnist = np.array(x_test_mnist)[...,np.newaxis]
x_train_mnist = np.array([np.concatenate((x,x,x), axis=2) for x in x_train_mnist])
x_test_mnist = np.array([np.concatenate((x,x,x), axis=2) for x in x_test_mnist])

# to one-hot
y_train_mnist = np_utils.to_categorical(y_train_mnist)
y_test_mnist = np_utils.to_categorical(y_test_mnist)

# split
x_train_mnist, x_valid_mnist, y_train_mnist, y_valid_mnist = train_test_split(x_train_mnist, y_train_mnist, test_size=0.1)
x_train_mnist.shape, x_valid_mnist.shape, y_train_mnist.shape, y_valid_mnist.shape

((54000, 32, 32, 3), (6000, 32, 32, 3), (54000, 10), (6000, 10))

In [6]:
# load
train_raw = loadmat('datasets/train_32x32.mat')
test_raw = loadmat('datasets/test_32x32.mat')
x_train_svhn, y_train_svhn = train_raw['X'], train_raw['y']
x_test_svhn, y_test_svhn = test_raw['X'], test_raw['y']

# samplesize first
x_train_svhn = np.transpose(x_train_svhn, (3,0,1,2))
x_test_svhn = np.transpose(x_test_svhn, (3,0,1,2))
                                         
                                         
# label 10 to 0
y_train_svhn = np.squeeze(y_train_svhn)
y_test_svhn = np.squeeze(y_test_svhn)
y_train_svhn[y_train_svhn == 10] = 0
y_test_svhn[y_test_svhn == 10] = 0

# to onehot
y_train_svhn = np_utils.to_categorical(y_train_svhn, 10)
y_test_svhn = np_utils.to_categorical(y_test_svhn, 10)

# split
x_train_svhn, x_valid_svhn, y_train_svhn, y_valid_svhn = train_test_split(x_train_svhn, y_train_svhn, test_size=0.1)
x_train_svhn.shape, x_valid_svhn.shape, y_train_svhn.shape, y_valid_svhn.shape

((65931, 32, 32, 3), (7326, 32, 32, 3), (65931, 10), (7326, 10))

In [7]:
x_train_svhn_10samples = []
y_train_svhn_10samples = []
for l in range(10):
  cnt_per_class = 0
  for i in range(len(x_train_svhn)):
    label = np.argmax(y_train_svhn[i])
    if label == l:
      x_train_svhn_10samples.append(x_train_svhn[i])
      y_train_svhn_10samples.append(y_train_svhn[i])
      cnt_per_class += 1
      if cnt_per_class >= 10:
        break

x_train_svhn_10samples = np.asarray(x_train_svhn_10samples)
y_train_svhn_10samples = np.asarray(y_train_svhn_10samples)
x_train_svhn_10samples.shape, y_train_svhn_10samples.shape

((100, 32, 32, 3), (100, 10))

In [8]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

def preprocess_train(x, y):
  x = x / 255
  x = tf.image.random_brightness(x, 0.2)
  x = tf.image.random_contrast(x, 0.8, 1.2)
  x = tf.image.random_crop(x, [BATCH_SIZE, 28, 28, 3])
  x = tf.image.resize(x, [32, 32], tf.image.ResizeMethod.BICUBIC)
  return x, y

def preprocess_test(x, y):
  x = x / 255
  return x, y

def create_train_datasets(x, y, batch_size=BATCH_SIZE):
  ds = tf.data.Dataset.from_tensor_slices((x,y))
  ds = ds.shuffle(len(x))
  ds = ds.batch(batch_size, drop_remainder=True)
  ds = ds.map(preprocess_train, num_parallel_calls=AUTOTUNE)
  return ds

def create_test_datasets(x, y, batch_size=BATCH_SIZE):
  ds = tf.data.Dataset.from_tensor_slices((x,y))
  ds = ds.batch(batch_size, drop_remainder=True)
  ds = ds.map(preprocess_test, num_parallel_calls=AUTOTUNE)
  return ds

train_mnist_ds = create_train_datasets(x_train_mnist, y_train_mnist)
train_svhn_ds = create_train_datasets(x_train_svhn_10samples, y_train_svhn_10samples)
valid_mnist_ds = create_test_datasets(x_valid_mnist, y_valid_mnist)
valid_svhn_ds = create_test_datasets(x_valid_svhn, y_valid_svhn)
test_mnist_ds = create_test_datasets(x_test_mnist, y_test_mnist)
test_svhn_ds = create_test_datasets(x_test_svhn, y_test_svhn)



In [9]:
for x_t, y_t in train_svhn_ds:
  print(x_t.shape, y_t.shape)

In [10]:
def _make_conv_block(block_index, num_chan=32, num_layer=2, stride=1, pad=2):
  convs = []
  for _ in range(num_layer):
    convs.append(layers.Conv2D(num_chan, kernel_size=(3,3), strides=stride, padding='same'))
    convs.append(layers.LeakyReLU(0.2))
  convs.append(layers.MaxPool2D((2,2)))
  return convs


In [11]:
class AngularLinear(layers.Layer):
  def __init__(self, classes):
    super(AngularLinear, self).__init__()
    self.classes = classes

  def build(self, input_shape):
    self.kernel = self.add_variable("w_angular",
                                    shape=[int(input_shape[-1]), classes])

  def call(self, x):
    x_normed = K.l2_normalize(x, axis=-1)
    kernel_normed = K.l2_normalize(self.kernel, axis=-1)
    cos_theta = tf.matmul(x_normed, kernel_normed, name='cos_theta')
    cos_theta = K.clip(cos_theta, -1, 1)
    return cos_theta

class LeNetPlus(Model):
  """
  LeNetPlus model
  """
  def __init__(self, classes=10, feature_size=256, use_dropout=True, use_norm=False, use_bn=False, use_inn=False,
               use_angular=False, **kwargs):
    super(LeNetPlus, self).__init__()
    self.num_chans = [32, 64, 128]
    self.use_dropout = use_dropout
    self.use_norm = use_norm
    self.feature_size = feature_size
    self.use_bn = use_bn
    self.use_inn = use_inn
    self.use_angular = use_angular

    #self.features = gluon.nn.HybridSequential(prefix='')
    self.features = []

    if self.use_inn:
        #self.features.add(gluon.nn.InstanceNorm())
        self.features.append(tfa.layers.InstanceNormalization())
    for i, num_chan in enumerate(self.num_chans):
        if use_bn:
            self.features.append(layers.BatchNormalization())

        self.features += _make_conv_block(i, num_chan=num_chan)

        if use_dropout and i > 0:
            self.features.append(layers.Dropout(0.5))

    self.features.append(layers.Flatten())
    self.features.append(layers.Dense(self.feature_size))

    if self.use_norm:
      #self.features.add(L2Normalization(mode='instance'))
      self.features.append(layers.Lambda(lambda x : K.l2_normalize(x, axis=-1)))

    if use_angular:
      self.outputs = AngularLinear(classes)
    else:
      self.outputs = layers.Dense(classes)

  def call(self, x):
    for f in self.features:
      x = f(x)
    features = x
    preds = self.outputs(features)
    preds = activations.softmax(preds)
    #print(outputs)
    return preds, features


In [12]:
class dSNEloss(tf.keras.losses.Loss):
  def __init__(self, margin=1.0, feature_norm=False, name='dSNEloss'):
    super().__init__(name=name)
    self.margin = margin
    self.feature_norm = feature_norm

  def __call__(self, feature_s, y_s, feature_t, y_t):
    if self.feature_norm:
      feature_s = layers.Lambda(lambda x : K.l2_normalize(x, axis=-1))(feature_s)
      feature_t = layers.Lambda(lambda x : K.l2_normalize(x, axis=-1))(feature_t)
    bs_s, h_s, w_s, ch_s = BATCH_SIZE, 32, 32, 512#tf.shape(feature_s)
    bs_t, h_t, w_t, ch_t = BATCH_SIZE, 32, 32, 512#tf.shape(feature_t)
    features_s_repeat = tf.broadcast_to(tf.expand_dims(feature_s, axis=0), shape=(bs_t, bs_s, ch_s))
    features_t_repeat = tf.broadcast_to(tf.expand_dims(feature_t, axis=1), shape=(bs_t, bs_s, ch_s))

    dists = tf.math.reduce_sum(tf.square(features_t_repeat - features_s_repeat), axis=2)

    if len(tf.shape(y_t)) != 1:
      y_t = tf.argmax(y_t, axis=-1)
    if len(tf.shape(y_s)) != 1:
      y_s = tf.argmax(y_s, axis=-1) 
    y_t_repeat = tf.broadcast_to(tf.expand_dims(y_t, axis=1), shape=(bs_t, bs_s))
    y_s_repeat = tf.broadcast_to(tf.expand_dims(y_s, axis=0), shape=(bs_t, bs_s))
    y_t_repeat = tf.cast(y_t_repeat, tf.int32)
    y_s_repeat = tf.cast(y_s_repeat, tf.int32)

    y_same = tf.equal(y_t_repeat, y_s_repeat)
    y_diff = tf.not_equal(y_t_repeat, y_s_repeat)
    y_same = tf.cast(y_same, tf.float32)
    y_diff = tf.cast(y_diff, tf.float32)

    intra_cls_dists = dists * y_same
    inter_cls_dists = dists * y_diff

    max_dists = tf.math.reduce_max(dists, axis=1, keepdims=True)
    max_dists = tf.broadcast_to(max_dists, shape=(bs_t, bs_s))
    revised_inter_cls_dists = tf.where(tf.cast(y_same, tf.bool), max_dists, inter_cls_dists)

    max_intra_cls_dist = tf.math.reduce_max(intra_cls_dists, axis=1)
    min_inter_cls_dist = tf.math.reduce_min(revised_inter_cls_dists, axis=1)

    loss = tf.nn.relu(max_intra_cls_dist - min_inter_cls_dist + self.margin)
    return loss    

In [13]:
model = LeNetPlus(classes=10, feature_size=512, use_dropout=True, 
                  use_norm=False, use_bn=False, use_inn=True, 
                  use_angular=False)
model.build((BATCH_SIZE,32,32,3))
model.summary()

Model: "le_net_plus"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
instance_normalization (Inst multiple                  6         
_________________________________________________________________
conv2d (Conv2D)              multiple                  896       
_________________________________________________________________
leaky_re_lu (LeakyReLU)      multiple                  0         
_________________________________________________________________
conv2d_1 (Conv2D)            multiple                  9248      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    multiple                  0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
conv2d_2 (Conv2D)            multiple                  

# define model

In [14]:
def __warmup(epoch, warmup_epoch, warmup_rate, lr_end, visible=False):
    warmup_grad = warmup_rate / warmup_epoch
    if epoch <= warmup_epoch:
        lr = lr_end * 10**(-warmup_rate + warmup_grad*epoch)
    if visible:
        print('lr : {}'.format(lr))
        return lr
    else:
        return lr

def __cosine_annealing(epoch, lr_max, lr_min, epoch_start, epoch_end, visible=False):
    if epoch <= epoch_start:
        # before annealing
        return lr_max
    elif epoch_start < epoch <= epoch_end:
        # during annealing
        lr = lr_min + 0.5*(lr_max - lr_min)*(1+np.cos(np.pi*(epoch-epoch_start)/(epoch_end - epoch_start)))
        if visible:
            print('lr : {}'.format(lr))
        return lr
    else:
        # after annealing
        return lr

def update_lr(epoch):
    visible = False
    warmup_epoch = 10
    if epoch <= warmup_epoch:
        lr = __warmup(epoch, warmup_epoch=WARMUP_EPOCH, warmup_rate=2.0, lr_end=LR_MAX, visible=visible)
    else:
        lr = __cosine_annealing(epoch,
                                lr_max=LR_MAX,
                                lr_min=LR_MIN,
                                epoch_start=START_ANNEALING_EPOCH,
                                epoch_end=END_ANNEALING_EPOCH,
                                visible=visible)
    return lr

In [15]:
x_entropy_loss = tf.keras.losses.CategoricalCrossentropy()
dsne_loss = dSNEloss()
optimizer = tf.keras.optimizers.Adam(learning_rate=LR_MIN)

log_train_loss_total = tf.keras.metrics.Mean(name='train_loss_total')
log_train_loss_x_ent = tf.keras.metrics.Mean(name='train_x_ent_loss')
log_train_loss_d_sne = tf.keras.metrics.Mean(name='train_d_sne_loss')

log_valid_loss_total = tf.keras.metrics.Mean(name='valid_loss_total')
log_valid_loss_x_ent = tf.keras.metrics.Mean(name='valid_x_ent_loss')
log_valid_loss_d_sne = tf.keras.metrics.Mean(name='valid_d_sne_loss')

log_train_accuracy_mnist = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy_mnist')
log_train_accuracy_svhn = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy_svhn')

log_valid_accuracy_mnist = tf.keras.metrics.CategoricalAccuracy(name='valid_accuracy_mnist')
log_valid_accuracy_svhn = tf.keras.metrics.CategoricalAccuracy(name='valid_accuracy_svhn')


In [16]:
@tf.function
def train_step(x_s, y_s, x_t, y_t):
  with tf.GradientTape() as tape:
    preds_s, feature_s = model(x_s)
    preds_t, feature_t = model(x_t)
    train_x_entropy_s = x_entropy_loss(y_s, preds_s)
    train_x_entropy_t = x_entropy_loss(y_t, preds_t)
    train_x_entropy = train_x_entropy_s + train_x_entropy_t
    train_dsne = dsne_loss(feature_s=feature_s, y_s=y_s, feature_t=feature_t, y_t=y_t)
    #train_loss = 0.9*train_x_entropy + 0.1*train_dsne
    #train_loss = [0.9*train_x_entropy, 0.1*train_dsne]
  gradients = tape.gradient(train_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  # records
  log_train_loss_total(train_loss)
  log_train_loss_x_ent(train_x_entropy)
  log_train_loss_d_sne(train_dsne)

  log_train_accuracy_mnist(y_s, preds_s)
  log_train_accuracy_svhn(y_t, preds_t)

In [17]:
@tf.function
def valid_step(x_s, y_s, x_t, y_t):
  preds_s, feature_s = model(x_s)
  preds_t, feature_t = model(x_t)
  valid_x_entropy_s = x_entropy_loss(y_s, preds_s)
  valid_x_entropy_t = x_entropy_loss(y_t, preds_t)
  valid_x_entropy = valid_x_entropy_s + valid_x_entropy_t
  valid_dsne = dsne_loss(feature_s=feature_s, y_s=y_s, feature_t=feature_t, y_t=y_t)
  valid_loss = 0.9*valid_x_entropy + 0.1*valid_dsne

  # records
  log_valid_loss_total(valid_loss)
  log_valid_loss_x_ent(valid_x_entropy)
  log_valid_loss_d_sne(valid_dsne)

  log_valid_accuracy_mnist(y_s, preds_s)
  log_valid_accuracy_svhn(y_t, preds_t)

In [18]:
for epoch in range(EPOCHS):
  for x_s, y_s in train_mnist_ds:
    for x_t, y_t in train_svhn_ds:
      train_step(x_s, y_s, x_t, y_t)

  for x_s, y_s in valid_mnist_ds:
    for x_t, y_t in valid_svhn_ds:
      valid_step(x_s, y_s, x_t, y_t)
  
  optimizer.lr = update_lr(epoch)
  print("epoch : {}".format(epoch))
  print('train_loss : {}, valid_loss : {}'.format(log_train_loss_total.result(), log_valid_loss_total.result()))
  print('train x ent : {}, valid x ent : {}'.format(log_train_loss_x_ent.result(), log_valid_loss_x_ent.result()))
  print('train dsne : {}, valid dsne : {}'.format(log_train_loss_d_sne.result(), log_valid_loss_d_sne.result()))
  print('train mnist acc : {}, valid mnist acc : {}'.format(log_train_accuracy_mnist.result(), log_valid_accuracy_mnist.result()))
  print('train svhn acc : {}, valid svhn acc : {}'.format(log_train_accuracy_svhn.result(), log_valid_accuracy_svhn.result()))

  # 次のエポック用にメトリクスをリセット
  log_train_loss_total.reset_states()
  log_valid_loss_total.reset_states()
  log_train_loss_x_ent.reset_states()
  log_valid_loss_x_ent.reset_states()
  log_train_loss_d_sne.reset_states()
  log_valid_loss_d_sne.reset_states()
  log_train_accuracy_mnist.reset_states()
  log_valid_accuracy_mnist.reset_states()
  log_train_accuracy_svhn.reset_states()
  log_valid_accuracy_svhn.reset_states()


epoch : 0
train_loss : 0.0, valid_loss : 4.3658342361450195
train x ent : 0.0, valid x ent : 4.635626316070557
train dsne : 0.0, valid dsne : 1.9376541376113892
train mnist acc : 0.0, valid mnist acc : 0.10190217196941376
train svhn acc : 0.0, valid svhn acc : 0.06661184132099152
epoch : 1
train_loss : 0.0, valid_loss : 4.3658342361450195
train x ent : 0.0, valid x ent : 4.635626316070557
train dsne : 0.0, valid dsne : 1.9376541376113892
train mnist acc : 0.0, valid mnist acc : 0.10190217196941376
train svhn acc : 0.0, valid svhn acc : 0.06661184132099152
epoch : 2
train_loss : 0.0, valid_loss : 4.3658342361450195
train x ent : 0.0, valid x ent : 4.635626316070557
train dsne : 0.0, valid dsne : 1.9376541376113892
train mnist acc : 0.0, valid mnist acc : 0.10190217196941376
train svhn acc : 0.0, valid svhn acc : 0.06661184132099152
epoch : 3
train_loss : 0.0, valid_loss : 4.3658342361450195
train x ent : 0.0, valid x ent : 4.635626316070557
train dsne : 0.0, valid dsne : 1.9376541376113