In [None]:
import os
import h5py
import numpy as np
import tensorflow as tf
from skimage.transform import resize
from sklearn.model_selection import train_test_split
os.environ['CUDA_VISIBLE_DEVICES']='0'

tf.enable_eager_execution()
print(tf.__version__)

In [None]:
num_classes     = 9
batch_size      = 8
train_src_days  = 3
train_trg_days  = 3
epochs          = 500
learning_rate   = 0.0001
num_hidden      = [1024, 512]
num_features    = 256
alpha           = 0.05
disc_activation = 'selu'
gen_activation  = 'selu'
train_trg_env_days = 2
notes           = "VMT_gauss_noise_cv_data_center_env_adapt_fdisc_serverconftime_{}trg_day_resize224".format(train_trg_env_days)

log_data = "classes-{}_bs-{}_train_src_days-{}_train_trg_days-{}_lr-{}_num_hidden-{}_num_feat-{}_disc_act-{}_gen_act-{}_center_alpha-{}_{}".format(num_classes, 
                                                                                                                                                        batch_size, 
                                                                                                                                                        train_src_days, 
                                                                                                                                                        train_trg_days, 
                                                                                                                                                        learning_rate, 
                                                                                                                                                        num_hidden, 
                                                                                                                                                        num_features,
                                                                                                                                                        disc_activation,
                                                                                                                                                        gen_activation,
                                                                                                                                                        alpha,
                                                                                                                                                        notes)

In [None]:
def resize_data(data, output_shape=(224, 224)):
  _, height, width, channels = data.shape
  data = data.transpose((1, 2, 3, 0)) 
  data = resize(data.reshape(height, width, -1), output_shape)
  data = data.reshape(*output_shape, channels, -1)
  data = data.transpose((3, 0, 1, 2))
  return data

In [None]:
#Read data
hf = h5py.File('/home/kjakkala/mmwave/data/source_data.h5', 'r')
X_data = resize_data(np.expand_dims(hf.get('X_data'), axis=-1))
y_data = np.array(hf.get('y_data'))
classes = list(hf.get('classes'))
classes = [n.decode("ascii", "ignore") for n in classes]
hf.close()
print(X_data.shape, y_data.shape, "\n", classes)

#balence dataset to 95 samples per day for each person
X_data_tmp = []
y_data_tmp = []
for day in range(10):
  for idx in range(len(classes)):
    X_data_tmp.extend(X_data[(y_data[:, 0] == idx) & (y_data[:, 1] == day)][:95])
    y_data_tmp.extend(y_data[(y_data[:, 0] == idx) & (y_data[:, 1] == day)][:95])
X_data = np.array(X_data_tmp)
y_data = np.array(y_data_tmp)
del X_data_tmp, y_data_tmp
print(X_data.shape, y_data.shape)

#remove harika's data
X_data = np.delete(X_data, np.where(y_data[:, 0] == 1)[0], 0)
y_data = np.delete(y_data, np.where(y_data[:, 0] == 1)[0], 0)

#update labes to handle 9 classes instead of 10
y_data[y_data[:, 0] >= 2, 0] -= 1
del classes[1]
print(X_data.shape, y_data.shape, "\n", classes)

#split days of data to train and test
X_src = X_data[y_data[:, 1] < train_src_days]
y_src = y_data[y_data[:, 1] < train_src_days, 0]
y_src = np.eye(len(classes))[y_src]
X_train_src, X_test_src, y_train_src, y_test_src = train_test_split(X_src,
                                                                    y_src,
                                                                    stratify=y_src,
                                                                    test_size=0.10,
                                                                    random_state=42)

X_trg = X_data[y_data[:, 1] >= train_src_days]
y_trg = y_data[y_data[:, 1] >= train_src_days]
X_train_trg = X_trg[y_trg[:, 1] < train_src_days+train_trg_days]
y_train_trg = y_trg[y_trg[:, 1] < train_src_days+train_trg_days, 0]
y_train_trg = np.eye(len(classes))[y_train_trg]

X_test_trg = X_data[y_data[:, 1] >= train_src_days+train_trg_days]
y_test_trg = y_data[y_data[:, 1] >= train_src_days+train_trg_days, 0]
y_test_trg = np.eye(len(classes))[y_test_trg]

del X_src, y_src, X_trg, y_trg, X_data, y_data

#standardise dataset
src_mean = np.mean(X_train_src)
X_train_src -= src_mean
src_std  = np.std(X_train_src)
X_train_src /= src_std

X_test_src -= src_mean
X_test_src /= src_std

trg_mean = np.mean(X_train_trg)
X_train_trg -= trg_mean
trg_std  = np.std(X_train_trg)
X_train_trg /= trg_std

X_test_trg -= src_mean
X_test_trg /= src_std

X_train_src = X_train_src.astype(np.float32)
y_train_src = y_train_src.astype(np.uint8)
X_test_src  = X_test_src.astype(np.float32)
y_test_src  = y_test_src.astype(np.uint8)
X_train_trg = X_train_trg.astype(np.float32)
y_train_trg = y_train_trg.astype(np.uint8)
X_test_trg  = X_test_trg.astype(np.float32)
y_test_trg  = y_test_trg.astype(np.uint8)

print(X_train_src.shape, y_train_src.shape,  X_test_src.shape, y_test_src.shape, X_train_trg.shape, y_train_trg.shape, X_test_trg.shape, y_test_trg.shape)

In [None]:
def get_trg_data(fname, src_classes, train_trg_days):
  #Read data
  hf = h5py.File(fname, 'r')
  X_data_trg = resize_data(np.expand_dims(hf.get('X_data'), axis=-1))
  y_data_trg = np.array(hf.get('y_data'))
  trg_classes = list(hf.get('classes'))
  trg_classes = [n.decode("ascii", "ignore") for n in trg_classes]
  hf.close()

  #split days of data to train and test
  X_train_trg = X_data_trg[y_data_trg[:, 1] < train_trg_days]
  y_train_trg = y_data_trg[y_data_trg[:, 1] < train_trg_days, 0]
  y_train_trg = np.array([src_classes.index(trg_classes[y_train_trg[i]]) for i in range(y_train_trg.shape[0])])
  y_train_trg = np.eye(len(src_classes))[y_train_trg]
  y_train_trg = y_train_trg.astype(np.int64)

  X_test_trg = X_data_trg[y_data_trg[:, 1] >= train_trg_days]
  y_test_trg = y_data_trg[y_data_trg[:, 1] >= train_trg_days, 0]
  y_test_trg = np.eye(len(src_classes))[y_test_trg]
  y_test_trg = y_test_trg.astype(np.int64)

  #standardise dataset  
  trg_mean     = np.mean(X_train_trg)
  X_train_trg -= trg_mean
  trg_std      = np.std(X_train_trg)
  X_train_trg /= trg_std

  X_test_trg  -= trg_mean
  X_test_trg  /= trg_std
  
  return X_train_trg.astype(np.float32), y_train_trg.astype(np.uint8), X_test_trg.astype(np.float32), y_test_trg.astype(np.uint8)

X_train_conf,   y_train_conf,   X_test_conf,   y_test_conf   = get_trg_data('/home/kjakkala/mmwave/data/target_conf_data.h5',   classes, train_trg_env_days)
X_train_server, y_train_server, X_test_server, y_test_server = get_trg_data('/home/kjakkala/mmwave/data/target_server_data.h5', classes, train_trg_env_days)
X_data_office,  y_data_office,  _,             _             = get_trg_data('/home/kjakkala/mmwave/data/target_office_data.h5', classes, 3)

print(X_train_conf.shape,   y_train_conf.shape,    X_test_conf.shape,   y_test_conf.shape)
print(X_train_server.shape, y_train_server.shape,  X_test_server.shape, y_test_server.shape)
print(X_data_office.shape, y_data_office.shape)

In [None]:
#get tf.data objects for each set

#Test
conf_test_set = tf.data.Dataset.from_tensor_slices((X_test_conf, y_test_conf))
conf_test_set = conf_test_set.batch(batch_size, drop_remainder=False)
conf_test_set = conf_test_set.prefetch(batch_size)

server_test_set = tf.data.Dataset.from_tensor_slices((X_test_server, y_test_server))
server_test_set = server_test_set.batch(batch_size, drop_remainder=False)
server_test_set = server_test_set.prefetch(batch_size)

office_test_set = tf.data.Dataset.from_tensor_slices((X_data_office, y_data_office))
office_test_set = office_test_set.batch(batch_size, drop_remainder=False)
office_test_set = office_test_set.prefetch(batch_size)

src_test_set = tf.data.Dataset.from_tensor_slices((X_test_src, y_test_src))
src_test_set = src_test_set.batch(batch_size, drop_remainder=False)
src_test_set = src_test_set.prefetch(batch_size)

trg_test_set = tf.data.Dataset.from_tensor_slices((X_test_trg, y_test_trg))
trg_test_set = trg_test_set.batch(batch_size, drop_remainder=False)
trg_test_set = trg_test_set.prefetch(batch_size)

#Train
src_train_set = tf.data.Dataset.from_tensor_slices((X_train_src, y_train_src))
src_train_set = src_train_set.shuffle(X_train_src.shape[0])
src_train_set = src_train_set.batch(batch_size, drop_remainder=True)
src_train_set = src_train_set.prefetch(batch_size)

server_train_set = tf.data.Dataset.from_tensor_slices((X_train_server, y_train_server))
server_train_set = server_train_set.shuffle(X_train_server.shape[0])
server_train_set = server_train_set.batch(batch_size, drop_remainder=True)
server_train_set = server_train_set.prefetch(batch_size)
server_train_set = server_train_set.repeat(-1)

conf_train_set = tf.data.Dataset.from_tensor_slices((X_train_conf, y_train_conf))
conf_train_set = conf_train_set.shuffle(X_train_conf.shape[0])
conf_train_set = conf_train_set.batch(batch_size, drop_remainder=True)
conf_train_set = conf_train_set.prefetch(batch_size)
conf_train_set = conf_train_set.repeat(-1)

if (X_train_trg.shape[0] > 0):
  time_train_set = tf.data.Dataset.from_tensor_slices((X_train_trg, y_train_trg))
  time_train_set = time_train_set.shuffle(X_train_trg.shape[0])
  time_train_set = time_train_set.batch(batch_size, drop_remainder=True)
  time_train_set = time_train_set.prefetch(batch_size)
  time_train_set = time_train_set.repeat(-1)

In [None]:
L2_WEIGHT_DECAY = 1e-4
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5

class GaussianNoise(tf.keras.layers.Layer):
  def __init__(self, std):
    super(GaussianNoise, self).__init__()
    self.std = std

  def build(self, input_shapes):
    pass

  def call(self, inputs, training=False):
    eps = tf.random.normal(shape=tf.shape(inputs), mean=0.0, stddev=self.std)
    return tf.where(training, inputs + eps, inputs)

class IdentityBlock(tf.keras.Model):
  def __init__(self, kernel_size, filters, stage, block, activation='relu'):
    self.activation = activation
    
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    super().__init__(name='stage-' + str(stage) + '_block-' + block)

    filters1, filters2, filters3 = filters
    bn_axis = -1

    self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1),
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '2a')
    self.bn2a = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '2a')

    self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size,
                                         padding='same',
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '2b')
    self.bn2b = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '2b')

    self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1),
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '2c')
    self.bn2c = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '2c')

  def call(self, input_tensor, training=False):
    x = self.conv2a(input_tensor)
    x = self.bn2a(x, training=training)
    x = tf.keras.layers.Activation(self.activation)(x)

    x = self.conv2b(x)
    x = self.bn2b(x, training=training)
    x = tf.keras.layers.Activation(self.activation)(x)

    x = self.conv2c(x)
    x = self.bn2c(x, training=training)

    x = tf.keras.layers.add([x, input_tensor])
    x = tf.keras.layers.Activation(self.activation)(x)
    return x


"""A block that has a conv layer at shortcut.

Note that from stage 3,
the second conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well

Args:
  kernel_size: the kernel size of middle conv layer at main path
  filters: list of integers, the filters of 3 conv layer at main path
  stage: integer, current stage label, used for generating layer names
  block: 'a','b'..., current block label, used for generating layer names
  strides: Strides for the second conv layer in the block.

Returns:
  A Keras model instance for the block.
"""
class ConvBlock(tf.keras.Model):
  def __init__(self, kernel_size, filters, stage, block, strides=(2, 2), activation='relu'):
    self.activation = activation
    
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    super().__init__(name='stage-' + str(stage) + '_block-' + block)

    filters1, filters2, filters3 = filters
    bn_axis = -1

    self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1),
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '2a')
    self.bn2a = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '2a')

    self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size,
                                         strides=strides,
                                         padding='same',
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '2b')
    self.bn2b = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '2b')

    self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1),
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '2c')
    self.bn2c = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '2c')

    self.conv2s = tf.keras.layers.Conv2D(filters3, (1, 1),
                                         strides=strides,
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '1')
    self.bn2s = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '1')
    self.gauss1   = GaussianNoise(1)

  def call(self, input_tensor, training=False):
    x = self.conv2a(input_tensor)
    x = self.bn2a(x, training=training)
    x = tf.keras.layers.Activation(self.activation)(x)

    x = self.conv2b(x)
    x = self.bn2b(x, training=training)
    x = tf.keras.layers.Activation(self.activation)(x)

    x = self.conv2c(x)
    x = self.bn2c(x, training=training)

    shortcut = self.conv2s(input_tensor)
    shortcut = self.bn2s(shortcut, training=training)

    x = tf.keras.layers.add([x, shortcut])
    x = tf.keras.layers.Activation(self.activation)(x)
    x = self.gauss1(x)
    return x
  
class Discriminator(tf.keras.Model):
  def __init__(self, num_hidden, num_classes=4, activation='relu'):
    super().__init__(name='discriminator')  
    self.hidden_layers = []
    for dim in num_hidden:
      self.hidden_layers.append(tf.keras.layers.Dense(dim, activation=activation))
    self.logits = tf.keras.layers.Dense(num_classes, activation=None)

  def call(self, x):
    for layer in self.hidden_layers:
      x = layer(x)
    x = self.logits(x)

    return x
  
"""Instantiates the ResNet50 architecture.

Args:
  num_classes: `int` number of classes for image classification.

Returns:
    A Keras model instance.
"""
class ResNet50(tf.keras.Model):
  def __init__(self, num_classes, num_features, activation='relu'):
    super().__init__(name='generator')
    bn_axis = -1
    self.activation = activation

    self.conv1 = tf.keras.layers.Conv2D(32, (7, 7),
                                        strides=(2, 2),
                                        padding='valid',
                                        use_bias=False,
                                        kernel_initializer='he_normal',
                                        kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                        name='conv1')
    self.bn1 = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                  momentum=BATCH_NORM_DECAY,
                                                  epsilon=BATCH_NORM_EPSILON,
                                                  name='bn_conv1')
    self.act1 = tf.keras.layers.Activation(self.activation, name=self.activation+'1')
    self.max_pool1 = tf.keras.layers.MaxPooling2D((3, 3),
                                                  strides=(2, 2),
                                                  padding='same',
                                                  name='max_pool1')

    self.blocks = []
    self.blocks.append(ConvBlock(3, [32, 32, 128], strides=(1, 1), stage=2, block='a', activation=self.activation))
    self.blocks.append(IdentityBlock(3, [32, 32, 128], stage=2, block='b', activation=self.activation))

    self.blocks.append(ConvBlock(3, [64, 64, 256], stage=3, block='a', activation=self.activation))
    self.blocks.append(IdentityBlock(3, [64, 64, 256], stage=3, block='b', activation=self.activation))

    self.blocks.append(ConvBlock(3, [64, 64, 256], stage=4, block='a', activation=self.activation))
    self.blocks.append(IdentityBlock(3, [64, 64, 256], stage=4, block='b', activation=self.activation))

    self.avg_pool = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')
    self.fc1 = tf.keras.layers.Dense(num_features,
                                     activation=self.activation,
                                     kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
                                     kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                     bias_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                     name='fc1')
    self.logits = tf.keras.layers.Dense(num_classes,
                                        activation=None,
                                        kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
                                        kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                        bias_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                        name='logits')

  def call(self, img_input, training=False):    
    x = self.conv1(img_input)
    x = self.bn1(x, training=training)
    x = self.act1(x)
    x = self.max_pool1(x)

    for block in self.blocks:
      x = block(x)

    x = self.avg_pool(x)
    fc1 = self.fc1(x)
    logits = self.logits(fc1)
    return logits, fc1

In [None]:
def get_cross_entropy_loss(labels, logits):
  loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits)
  return tf.reduce_mean(loss)

def get_domain_confusion_loss(src_logits, trg_logits):
  discriminator_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(src_logits),
                                                               logits=src_logits) + \
                       tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(trg_logits),
                                                               logits=trg_logits)
  return 0.5 * tf.reduce_mean(discriminator_loss)

class CenterLoss():
    def __init__(self, batch_size, num_classes, len_features, alpha):
      self.centers = tf.Variable(tf.zeros([num_classes, len_features]),
                                 dtype=tf.float32,
                                 trainable=False)
      self.alpha = alpha
      self.num_classes = num_classes
      self.batch_size = batch_size    
      self.margin = tf.constant(100, dtype="float32")
      self.norm = lambda x: tf.reduce_sum(tf.square(x), 1)
      self.EdgeWeights = tf.ones((self.num_classes,self.num_classes)) - \
                                  tf.eye(self.num_classes)

    def get_center_loss(self, features, labels):
      labels = tf.reshape(tf.argmax(labels, axis=-1), [-1])
      centers0 = tf.math.unsorted_segment_mean(features, 
                                               labels, 
                                               self.num_classes)
      center_pairwise_dist = tf.transpose(self.norm(tf.expand_dims(centers0, 2) - \
                                                    tf.transpose(centers0)))
      self.inter_loss = tf.math.reduce_sum(tf.multiply(tf.maximum(0.0, self.margin - center_pairwise_dist), 
                                                       self.EdgeWeights))

      unique_label, unique_idx, unique_count = tf.unique_with_counts(labels)
      appear_times = tf.gather(unique_count, unique_idx)
      appear_times = tf.reshape(appear_times, [-1, 1])
      centers_batch = tf.gather(self.centers, labels)
      diff = centers_batch - features
      diff /= tf.cast((1 + appear_times), tf.float32)
      diff *= self.alpha
      self.centers_update_op = tf.compat.v1.scatter_sub(self.centers, 
                                                        labels, 
                                                        diff)

      self.intra_loss   = tf.nn.l2_loss(features - centers_batch)
      self.center_loss  = self.intra_loss + self.inter_loss
      self.center_loss /= (self.num_classes*self.batch_size+self.num_classes*self.num_classes)
      return self.center_loss
      
def virtual_adversarial_images(images, logits, pert_norm_radius=3.5):  
  with tf.GradientTape() as tape:
    # Get normalised noise matrix
    noise = tf.random.normal(shape=tf.shape(images))
    noise = 1e-6 * tf.nn.l2_normalize(noise, axis=tf.range(1, len(noise.shape)))

    # Add noise to image and get new logits
    noise_logits, _ = generator(images + noise, 
                                tf.constant(False, dtype=tf.bool))

    # Get loss from noisey logits
    noise_loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=logits, logits=noise_logits)
    noise_loss = tf.reduce_mean(noise_loss)

  # Based on perturbed image loss, get direction of greatest error
  adversarial_noise = tape.gradient(noise_loss, 
                                    [noise],
                                    unconnected_gradients='zero')[0]

  adversarial_noise = tf.nn.l2_normalize(adversarial_noise, 
                                         axis=tf.range(1, 4))

  # return images with adversarial perturbation
  return images + pert_norm_radius * adversarial_noise

def mixup_preprocess(x, y, batch_size, alpha=1):
    # random sample the lambda value from beta distribution.
    weight     = np.random.beta(alpha, alpha, batch_size)
    x_weight   = weight.reshape(batch_size, 1, 1, 1)
    y_weight   = weight.reshape(batch_size, 1)
    
    # Perform the mixup.
    indices = tf.random.shuffle(tf.range(batch_size))
    mixup_images = (x * x_weight) + (tf.gather(x, indices) * (1 - x_weight))
    mixup_labels = (y * y_weight) + (tf.gather(y, indices) * (1 - y_weight))    
    
    return mixup_images, tf.nn.softmax(mixup_labels)

In [None]:
train_total_loss         = tf.keras.metrics.Mean(name='train_total_loss')
train_domain_loss        = tf.keras.metrics.Mean(name='train_domain_loss')
train_src_vat_loss       = tf.keras.metrics.Mean(name='train_src_vat_loss')
train_trg_vat_loss       = tf.keras.metrics.Mean(name='train_trg_vat_loss')
train_src_mixup_loss     = tf.keras.metrics.Mean(name='train_src_mixup_loss')
train_trg_mixup_loss     = tf.keras.metrics.Mean(name='train_trg_mixup_loss')
train_cond_entropy_loss  = tf.keras.metrics.Mean(name='train_cond_entropy_loss')
train_cross_entropy_loss = tf.keras.metrics.Mean(name='train_cross_entropy_loss')
train_discriminator_loss = tf.keras.metrics.Mean(name='train_discriminator_loss')
src_test_accuracy        = tf.keras.metrics.CategoricalAccuracy(name='src_test_accuracy')
trg_test_accuracy        = tf.keras.metrics.CategoricalAccuracy(name='trg_test_accuracy')
office_test_accuracy     = tf.keras.metrics.CategoricalAccuracy(name='office_test_accuracy')
server_test_accuracy     = tf.keras.metrics.CategoricalAccuracy(name='server_test_accuracy')
conf_test_accuracy       = tf.keras.metrics.CategoricalAccuracy(name='conf_test_accuracy')
src_train_accuracy       = tf.keras.metrics.CategoricalAccuracy(name='src_train_accuracy')
trg_train_accuracy       = tf.keras.metrics.CategoricalAccuracy(name='trg_train_accuracy')

@tf.function
def train_gen_step(src_images, src_labels, trg_images_ser, trg_labels_ser, trg_images_conf, trg_labels_conf, trg_images, trg_labels):  
  with tf.GradientTape() as gen_tape:
    #Logits
    src_logits, src_enc           = generator(src_images,      training=True)
    
    trg_logits_ser, trg_enc_ser   = generator(trg_images_ser,  training=True) 
    trg_logits_conf, trg_enc_conf = generator(trg_images_conf, training=True) 
    trg_logits, trg_enc           = generator(trg_images,      training=True)
    
    #VAT
    src_adver_images    = virtual_adversarial_images(src_images, tf.nn.softmax(src_logits))
    src_adver_logits, _ = generator(tf.stop_gradient(src_adver_images), training=True)
    
    trg_adver_images_ser     = virtual_adversarial_images(trg_images_ser, tf.nn.softmax(trg_logits_ser))
    trg_adver_logits_ser, _  = generator(tf.stop_gradient(trg_adver_images_ser), training=True)
    trg_adver_images_conf    = virtual_adversarial_images(trg_images_conf, tf.nn.softmax(trg_logits_conf))
    trg_adver_logits_conf, _ = generator(tf.stop_gradient(trg_adver_images_conf), training=True)
    trg_adver_images         = virtual_adversarial_images(trg_images, tf.nn.softmax(trg_logits))
    trg_adver_logits, _      = generator(tf.stop_gradient(trg_adver_images), training=True)
    
    #MixUp
    src_mixup_images, src_mixup_labels = mixup_preprocess(src_images, src_logits, batch_size)
    src_mixup_logits, _                = generator(tf.stop_gradient(src_mixup_images),
                                                   training=True)
    trg_mixup_images_ser, trg_mixup_labels_ser = mixup_preprocess(trg_images_ser, trg_logits_ser, batch_size)
    trg_mixup_logits_ser, _                    = generator(tf.stop_gradient(trg_mixup_images_ser),
                                                           training=True)
    trg_mixup_images_conf, trg_mixup_labels_conf = mixup_preprocess(trg_images_conf, trg_logits_conf, batch_size)
    trg_mixup_logits_conf, _                     = generator(tf.stop_gradient(trg_mixup_images_conf),
                                                             training=True)
    trg_mixup_images, trg_mixup_labels      = mixup_preprocess(trg_images, trg_logits, batch_size)
    trg_mixup_logits, _                     = generator(tf.stop_gradient(trg_mixup_images),
                                                        training=True)
    
    #Disc
    src_disc_logits      = discriminator(src_enc)
    trg_disc_logits_ser  = discriminator(trg_enc_ser)
    trg_disc_logits_conf = discriminator(trg_enc_conf)
    trg_disc_logits      = discriminator(trg_enc)

    cross_entropy_loss  = get_cross_entropy_loss(labels=src_labels, 
                                                 logits=src_logits)
    cross_cond_loss     = get_cross_entropy_loss(labels=tf.nn.softmax(tf.concat([trg_logits_ser, 
                                                                                trg_logits_conf,
                                                                                trg_logits], 0)), 
                                                 logits=tf.concat([trg_logits_ser, 
                                                                  trg_logits_conf,
                                                                  trg_logits], 0))
    src_vat_loss        = get_cross_entropy_loss(labels=tf.nn.softmax(tf.stop_gradient(src_logits)),
                                                 logits=src_adver_logits)
    trg_vat_loss        = get_cross_entropy_loss(labels=tf.nn.softmax(tf.stop_gradient(tf.concat([trg_logits_ser, 
                                                                                                 trg_logits_conf,
                                                                                                 trg_logits], 0))),
                                                 logits=tf.concat([trg_adver_logits_ser, 
                                                                  trg_adver_logits_conf,
                                                                  trg_adver_logits], 0))
    src_mixup_loss      = get_cross_entropy_loss(labels=tf.stop_gradient(src_mixup_labels), 
                                                 logits=src_mixup_logits)
    trg_mixup_loss      = get_cross_entropy_loss(labels=tf.stop_gradient(tf.concat([trg_mixup_labels_ser, 
                                                                                   trg_mixup_labels_conf,
                                                                                   trg_mixup_labels], 0)), 
                                                 logits=tf.concat([trg_mixup_logits_ser, 
                                                                  trg_mixup_logits_conf,
                                                                  trg_mixup_logits], 0))
    domain_loss         = get_cross_entropy_loss(labels=tf.one_hot(tf.cast(tf.concat([tf.zeros(tf.shape(src_disc_logits)[0]),
                                                                                      tf.ones(tf.shape(trg_disc_logits_ser)[0]),
                                                                                      tf.ones(tf.shape(trg_disc_logits_conf)[0])*2,
                                                                                      tf.ones(tf.shape(trg_disc_logits)[0])*3], 0), tf.int32), 4),
                                                 logits=tf.concat([src_disc_logits,
                                                                  trg_disc_logits_ser, 
                                                                  trg_disc_logits_conf,
                                                                  trg_disc_logits], 0))
          
    batch_center_loss   = center_loss.get_center_loss(src_enc, src_labels)

    total_loss = cross_entropy_loss + \
                 8e-2 * domain_loss + \
                 8e-2 * cross_cond_loss + \
                 1    * src_mixup_loss +\
                 8e-2 * trg_mixup_loss +\
                 8e-2 * trg_vat_loss + \
                 1    * src_vat_loss + \
                 1    * batch_center_loss
    
  gen_gradients = gen_tape.gradient(total_loss, generator.trainable_variables)
  with tf.control_dependencies([center_loss.centers_update_op]):
    gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))

  src_train_accuracy(src_labels, src_logits)
  trg_train_accuracy(tf.stack([trg_labels_ser, trg_labels_conf]), tf.stack([trg_logits_ser, trg_logits_conf]))
  train_cross_entropy_loss(cross_entropy_loss)
  train_cond_entropy_loss(cross_cond_loss)
  train_src_mixup_loss(src_mixup_loss)
  train_trg_mixup_loss(trg_mixup_loss)
  train_src_vat_loss(src_vat_loss)
  train_trg_vat_loss(trg_vat_loss)
  train_domain_loss(domain_loss)
  train_total_loss(total_loss)
  
@tf.function
def train_disc_step(src_images, trg_images_ser, trg_images_conf, trg_images):  
  with tf.GradientTape() as disc_tape:    
    _, src_enc           = generator(src_images, training=True)
    _, trg_enc_ser       = generator(trg_images_ser, training=True)  
    _, trg_enc_conf      = generator(trg_images_conf, training=True)  
    _, trg_enc           = generator(trg_images, training=True)  
    
    src_disc_logits      = discriminator(src_enc)
    trg_disc_logits_ser  = discriminator(trg_enc_ser)
    trg_disc_logits_conf = discriminator(trg_enc_conf)
    trg_disc_logits      = discriminator(trg_enc)
    
    domain_conf_loss    = get_cross_entropy_loss(labels=tf.one_hot(tf.cast(tf.concat([tf.ones(3),
                                                                                      tf.ones(3)*2,
                                                                                      tf.ones(2)*3,
                                                                                      tf.zeros(tf.shape(trg_disc_logits_ser)[0]),
                                                                                      tf.zeros(tf.shape(trg_disc_logits_conf)[0]),
                                                                                      tf.zeros(tf.shape(trg_disc_logits)[0])], 0), tf.int32), 4),
                                             logits=tf.concat([src_disc_logits,
                                                              trg_disc_logits_ser, 
                                                              trg_disc_logits_conf,
                                                              trg_disc_logits], 0))
  
  disc_gradients = disc_tape.gradient(domain_conf_loss, 
                                      discriminator.trainable_variables)
  disc_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))
  train_discriminator_loss(domain_conf_loss)
  
@tf.function
def test_step(images):
  return generator(images, training=False)[0]

In [None]:
train_template = 'Epoch: {:03d}, TotalL: {:.4f}, CrossE: {:.4f}, CondE: {:.4f}, disc: {:.4f}, domain: {:.4f}, Src VAT: {:.4f}, Trg VAT: {:.4f}, Src MixUp: {:.4f}, Trg MixUp: {:.4f}, Src Train Acc: {:.2f}, Trg Train Acc: {:.2f}, '
test_template  = 'Src Test Acc: {:.2f}, Trg Test Acc: {:.2f}, Server Test Acc: {:.2f}, Office Test Acc: {:.2f}, Conf Test Acc: {:.2f}'

generator      = ResNet50(num_classes, num_features, gen_activation)
discriminator  = Discriminator(num_hidden, disc_activation)
disc_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5)
gen_optimizer  = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5)
center_loss    = CenterLoss(batch_size, num_classes, num_features, alpha)

summary_writer = tf.contrib.summary.create_file_writer('../logs/{}'.format(log_data), flush_millis=10000)
summary_writer.set_as_default()
global_step = tf.train.get_or_create_global_step()

def log_loss():
  with tf.contrib.summary.always_record_summaries():
    tf.contrib.summary.scalar("train_total_loss", train_total_loss.result())
    tf.contrib.summary.scalar("train_cross_entropy_loss", train_cross_entropy_loss.result())
    tf.contrib.summary.scalar("train_cond_entropy_loss", train_cond_entropy_loss.result())
    tf.contrib.summary.scalar("src_train_accuracy", src_train_accuracy.result())
    tf.contrib.summary.scalar("trg_train_accuracy", trg_train_accuracy.result())
    tf.contrib.summary.scalar("train_discriminator_loss", train_discriminator_loss.result())
    tf.contrib.summary.scalar("train_domain_loss", train_domain_loss.result())
    tf.contrib.summary.scalar("train_src_vat_loss", train_src_vat_loss.result())
    tf.contrib.summary.scalar("train_trg_vat_loss", train_trg_vat_loss.result())
    tf.contrib.summary.scalar("train_src_mixup_loss", train_src_mixup_loss.result())
    tf.contrib.summary.scalar("train_trg_mixup_loss", train_trg_mixup_loss.result())
    tf.contrib.summary.scalar("src_test_accuracy", src_test_accuracy.result())
    tf.contrib.summary.scalar("trg_test_accuracy", trg_test_accuracy.result())
    tf.contrib.summary.scalar("office_test_accuracy", office_test_accuracy.result())
    tf.contrib.summary.scalar("server_test_accuracy", server_test_accuracy.result())
    tf.contrib.summary.scalar("conf_test_accuracy", conf_test_accuracy.result())
    
for epoch in range(epochs):
  global_step.assign_add(1)  
    
  for source_data, target_data, server_data, conf_data in zip(src_train_set, time_train_set, server_train_set, conf_train_set):
    train_gen_step(source_data[0], source_data[1], server_data[0], server_data[1], conf_data[0], conf_data[1], target_data[0], target_data[1])
    train_disc_step(source_data[0], server_data[0], conf_data[0], target_data[0])

  print(train_template.format(epoch+1,
                              train_total_loss.result(),
                              train_cross_entropy_loss.result(),
                              train_cond_entropy_loss.result(),
                              train_discriminator_loss.result(),
                              train_domain_loss.result(),
                              train_src_vat_loss.result(),
                              train_trg_vat_loss.result(),
                              train_src_mixup_loss.result(),
                              train_trg_mixup_loss.result(),
                              src_train_accuracy.result()*100,
                              trg_train_accuracy.result()*100), end="")

  for data in trg_test_set:
    trg_test_accuracy(test_step(data[0]), data[1])
    
  for data in src_test_set:
    src_test_accuracy(test_step(data[0]), data[1])
    
  for data in office_test_set:
    office_test_accuracy(test_step(data[0]), data[1])
    
  for data in server_test_set:
    server_test_accuracy(test_step(data[0]), data[1])
    
  for data in conf_test_set:
    conf_test_accuracy(test_step(data[0]), data[1])
    
  print(test_template.format(src_test_accuracy.result()*100,
                             trg_test_accuracy.result()*100,
                             server_test_accuracy.result()*100,
                             office_test_accuracy.result()*100,
                             conf_test_accuracy.result()*100))

  log_loss()
  
  train_total_loss.reset_states()
  train_cross_entropy_loss.reset_states()
  train_cond_entropy_loss.reset_states()
  src_train_accuracy.reset_states()
  trg_train_accuracy.reset_states()
  train_discriminator_loss.reset_states()
  train_domain_loss.reset_states()
  train_src_vat_loss.reset_states()
  train_trg_vat_loss.reset_states()
  train_src_mixup_loss.reset_states()
  train_trg_mixup_loss.reset_states()
  src_test_accuracy.reset_states()
  trg_test_accuracy.reset_states()
  server_test_accuracy.reset_states()
  office_test_accuracy.reset_states()
  conf_test_accuracy.reset_states()