In [None]:
repo_path = "/home/kjakkala/mmwave"

import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

import sys
sys.path.append(os.path.join(repo_path, 'models'))

from utils import *
from resnet import ResNet50

import tensorflow as tf
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

print(tf.__version__)

In [None]:
dataset_path    = os.path.join(repo_path, 'data')
num_classes     = 9
batch_size      = 8
train_src_days  = 6
train_trg_days  = 0
train_trg_env_days = 2
epochs          = 500
init_lr         = 0.0001
num_features    = 256
activation_fn   = 'selu'
alpha           = 0.05
disc_hidden     = [1024, 512]
notes           = "resnet_server_adapt_center_grl".format(train_trg_env_days)
log_data = "classes-{}_bs-{}_train_src_days-{}_train_trg_days-{}_train_trgenv_days-{}_initlr-{}_num_feat-{}_act_fn-{}_alpha-{}_disc_hidden-{}_{}".format(num_classes,
                                                                                                                                                         batch_size,
                                                                                                                                                         train_src_days,
                                                                                                                                                         train_trg_days,
                                                                                                                                                         train_trg_env_days,
                                                                                                                                                         init_lr,
                                                                                                                                                         num_features,
                                                                                                                                                         activation_fn,
                                                                                                                                                         alpha,
                                                                                                                                                         disc_hidden,
                                                                                                                                                         notes)
log_dir         = os.path.join(repo_path, 'logs/new_logs/VMT/{}'.format(log_data))
checkpoint_path = os.path.join(repo_path, 'checkpoints/{}'.format(log_data))

In [None]:
X_data, y_data, classes = get_h5dataset(os.path.join(dataset_path, 'source_data.h5'))
X_data = resize_data(X_data)
print(X_data.shape, y_data.shape, "\n", classes)

X_data, y_data = balance_dataset(X_data, y_data, 
                                 num_days=10, 
                                 num_classes=len(classes), 
                                 max_samples_per_class=95)
print(X_data.shape, y_data.shape)

#remove harika's data (incomplete data)
X_data = np.delete(X_data, np.where(y_data[:, 0] == 1)[0], 0)
y_data = np.delete(y_data, np.where(y_data[:, 0] == 1)[0], 0)

#update labes to handle 9 classes instead of 10
y_data[y_data[:, 0] >= 2, 0] -= 1
del classes[1]
print(X_data.shape, y_data.shape, "\n", classes)

#split days of data to train and test
X_src = X_data[y_data[:, 1] < train_src_days]
y_src = y_data[y_data[:, 1] < train_src_days, 0]
y_src = np.eye(len(classes))[y_src]
X_train_src, X_test_src, y_train_src, y_test_src = train_test_split(X_src,
                                                                    y_src,
                                                                    stratify=y_src,
                                                                    test_size=0.10,
                                                                    random_state=42)

X_trg = X_data[y_data[:, 1] >= train_src_days]
y_trg = y_data[y_data[:, 1] >= train_src_days]
X_train_trg = X_trg[y_trg[:, 1] < train_src_days+train_trg_days]
y_train_trg = y_trg[y_trg[:, 1] < train_src_days+train_trg_days, 0]
y_train_trg = np.eye(len(classes))[y_train_trg]

X_test_trg = X_data[y_data[:, 1] >= train_src_days+train_trg_days]
y_test_trg = y_data[y_data[:, 1] >= train_src_days+train_trg_days, 0]
y_test_trg = np.eye(len(classes))[y_test_trg]

del X_src, y_src, X_trg, y_trg, X_data, y_data

#mean center and normalize dataset
X_train_src, src_mean = mean_center(X_train_src)
X_train_src, src_min, src_ptp = normalize(X_train_src)

X_test_src, _    = mean_center(X_test_src, src_mean)
X_test_src, _, _ = normalize(X_test_src, src_min, src_ptp)

if(X_train_trg.shape[0] != 0):
  X_train_trg, trg_mean = mean_center(X_train_trg)
  X_train_trg, trg_min, trg_ptp = normalize(X_train_trg)

  X_test_trg, _    = mean_center(X_test_trg, trg_mean)
  X_test_trg, _, _ = normalize(X_test_trg, trg_min, trg_ptp)  
else:
  X_test_trg, _    = mean_center(X_test_trg, src_mean)
  X_test_trg, _, _ = normalize(X_test_trg, src_min, src_ptp)
  
X_train_src = X_train_src.astype(np.float32)
y_train_src = y_train_src.astype(np.uint8)
X_test_src  = X_test_src.astype(np.float32)
y_test_src  = y_test_src.astype(np.uint8)
X_train_trg = X_train_trg.astype(np.float32)
y_train_trg = y_train_trg.astype(np.uint8)
X_test_trg  = X_test_trg.astype(np.float32)
y_test_trg  = y_test_trg.astype(np.uint8)
print("Final shapes: ")
print("Source:", X_train_src.shape, y_train_src.shape,  X_test_src.shape, y_test_src.shape)
print("Time:", X_train_trg.shape, y_train_trg.shape, X_test_trg.shape, y_test_trg.shape)

X_train_conf,   y_train_conf,   X_test_conf,   y_test_conf   = get_trg_data(os.path.join(dataset_path, 'target_conf_data.h5'),   classes, 0)
X_train_server, y_train_server, X_test_server, y_test_server = get_trg_data(os.path.join(dataset_path, 'target_server_data.h5'), classes, train_trg_env_days)
_             , _             , X_data_office, y_data_office = get_trg_data(os.path.join(dataset_path, 'target_office_data.h5'), classes, 0)

print("Conf:",   X_train_conf.shape,   y_train_conf.shape,    X_test_conf.shape,   y_test_conf.shape)
print("Server",  X_train_server.shape, y_train_server.shape,  X_test_server.shape, y_test_server.shape)
print("Office:", X_data_office.shape,  y_data_office.shape)

#get tf.data objects for each set

#Test
conf_test_set = tf.data.Dataset.from_tensor_slices((X_test_conf, y_test_conf))
conf_test_set = conf_test_set.batch(batch_size, drop_remainder=False)
conf_test_set = conf_test_set.prefetch(batch_size)

server_test_set = tf.data.Dataset.from_tensor_slices((X_test_server, y_test_server))
server_test_set = server_test_set.batch(batch_size, drop_remainder=False)
server_test_set = server_test_set.prefetch(batch_size)

office_test_set = tf.data.Dataset.from_tensor_slices((X_data_office, y_data_office))
office_test_set = office_test_set.batch(batch_size, drop_remainder=False)
office_test_set = office_test_set.prefetch(batch_size)

src_test_set = tf.data.Dataset.from_tensor_slices((X_test_src, y_test_src))
src_test_set = src_test_set.batch(batch_size, drop_remainder=False)
src_test_set = src_test_set.prefetch(batch_size)

time_test_set = tf.data.Dataset.from_tensor_slices((X_test_trg, y_test_trg))
time_test_set = time_test_set.batch(batch_size, drop_remainder=False)
time_test_set = time_test_set.prefetch(batch_size)

#Train
src_train_set = tf.data.Dataset.from_tensor_slices((X_train_src, y_train_src))
src_train_set = src_train_set.shuffle(X_train_src.shape[0])
src_train_set = src_train_set.batch(batch_size, drop_remainder=True)
src_train_set = src_train_set.prefetch(batch_size)

server_train_set = tf.data.Dataset.from_tensor_slices((X_train_server, y_train_server))
server_train_set = server_train_set.shuffle(X_train_server.shape[0])
server_train_set = server_train_set.batch(batch_size, drop_remainder=True)
server_train_set = server_train_set.prefetch(batch_size)
server_train_set = server_train_set.repeat(-1)

In [None]:
"""GradientReversal: helper function
"""
@tf.custom_gradient
def reverse_gradient(x, hp_lambda):
    def custom_grad(dy):
        return tf.math.multiply(tf.negative(dy), hp_lambda)
    return x, custom_grad
  
"""GradientReversal: reverses tha scales the gradient
call Args:
    inputs: tensor, output of preceding layer
    lambda_hp: double, specifying scaling factor of gradient

Returns:
    A Keras layer instance.
"""
class GradientReversal(tf.keras.layers.Layer):
  def __init__(self):
    super(GradientReversal, self).__init__()

  def call(self, inputs, lambda_hp):
    return reverse_gradient(inputs, lambda_hp)
  
  

class Discriminator(tf.keras.Model):
  def __init__(self, num_hidden, num_classes, activation='relu'):
    super().__init__(name='discriminator')
    self.hidden_layers = []
    for dim in num_hidden:
      self.hidden_layers.append(tf.keras.layers.Dense(dim, activation=activation))
    self.logits = tf.keras.layers.Dense(num_classes, activation=None)

  def call(self, x):
    for layer in self.hidden_layers:
      x = layer(x)
    x = self.logits(x)

    return x
  

"""Instantiates the ResNet50 architecture with discriminator and GRL layer.

Args:
  num_classes: `int` number of classes for image classification.

Returns:
    A Keras model instance.
"""
class GRLResNet50(ResNet50):
  def __init__(self, num_classes, num_features, num_hidden, num_disc, activation='relu'):
    super().__init__(num_classes, num_features, activation)
    
    self.disc = Discriminator(num_hidden, num_disc, activation=self.activation)
    self.rev_grad = GradientReversal()
    
  def call(self, img_input, training=False, lambda_hp=1):
    logits, fc1 = super().call(img_input, training)
    
    rev_grad_fc1 = self.rev_grad(fc1, lambda_hp=lambda_hp)
    disc_logits = self.disc(rev_grad_fc1)
    
    return logits, fc1, disc_logits

In [None]:
def get_cross_entropy_loss(labels, logits):
  loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
  return tf.reduce_mean(loss)

class CenterLoss():
    def __init__(self, batch_size, num_classes, len_features, alpha):
      self.centers = tf.Variable(tf.zeros([num_classes, len_features]),
                                 dtype=tf.float32,
                                 trainable=False)
      self.alpha = alpha
      self.num_classes = num_classes
      self.batch_size = batch_size    
      self.margin = tf.constant(100, dtype="float32")
      self.norm = lambda x: tf.reduce_sum(tf.square(x), 1)
      self.EdgeWeights = tf.ones((self.num_classes,self.num_classes)) - \
                                  tf.eye(self.num_classes)

    def get_center_loss(self, features, labels):
      labels = tf.reshape(tf.argmax(labels, axis=-1), [-1])
      centers0 = tf.math.unsorted_segment_mean(features, 
                                               labels, 
                                               self.num_classes)
      center_pairwise_dist = tf.transpose(self.norm(tf.expand_dims(centers0, 2) - \
                                                    tf.transpose(centers0)))
      self.inter_loss = tf.math.reduce_sum(tf.multiply(tf.maximum(0.0, self.margin - center_pairwise_dist), 
                                                       self.EdgeWeights))

      unique_label, unique_idx, unique_count = tf.unique_with_counts(labels)
      appear_times = tf.gather(unique_count, unique_idx)
      appear_times = tf.reshape(appear_times, [-1, 1])
      centers_batch = tf.gather(self.centers, labels)
      diff = centers_batch - features
      diff /= tf.cast((1 + appear_times), tf.float32)
      diff *= self.alpha
      self.centers_update_op = tf.compat.v1.scatter_sub(self.centers, 
                                                        labels, 
                                                        diff)

      self.intra_loss   = tf.nn.l2_loss(features - centers_batch)
      self.center_loss  = self.intra_loss + self.inter_loss
      self.center_loss /= (self.num_classes*self.batch_size+self.num_classes*self.num_classes)
      return self.center_loss
      
def virtual_adversarial_images(images, logits, pert_norm_radius=3.5):  
  with tf.GradientTape() as tape:
    # Get normalised noise matrix
    noise = tf.random.normal(shape=tf.shape(images))
    noise = 1e-6 * tf.nn.l2_normalize(noise, axis=tf.range(1, len(noise.shape)))

    # Add noise to image and get new logits
    noise_logits, _ = generator(images + noise, 
                                tf.constant(False, dtype=tf.bool))

    # Get loss from noisey logits
    noise_loss = tf.nn.softmax_cross_entropy_with_logits(labels=logits, logits=noise_logits)
    noise_loss = tf.reduce_mean(noise_loss)

  # Based on perturbed image loss, get direction of greatest error
  adversarial_noise = tape.gradient(noise_loss, 
                                    [noise],
                                    unconnected_gradients='zero')[0]

  adversarial_noise = tf.nn.l2_normalize(adversarial_noise, 
                                         axis=tf.range(1, 4))

  # return images with adversarial perturbation
  return images + pert_norm_radius * adversarial_noise

def mixup_preprocess(x, y, batch_size, alpha=1):
    # random sample the lambda value from beta distribution.
    weight     = np.random.beta(alpha, alpha, batch_size)
    x_weight   = weight.reshape(batch_size, 1, 1, 1)
    y_weight   = weight.reshape(batch_size, 1)
    
    # Perform the mixup.
    indices = tf.random.shuffle(tf.range(batch_size))
    mixup_images = (x * x_weight) + (tf.gather(x, indices) * (1 - x_weight))
    mixup_labels = (y * y_weight) + (tf.gather(y, indices) * (1 - y_weight))    
    
    return mixup_images, tf.nn.softmax(mixup_labels)

In [None]:
cond_entropy_loss    = tf.keras.metrics.Mean(name='cond_entropy_loss')
source_vat_loss      = tf.keras.metrics.Mean(name='source_vat_loss')
target_vat_loss      = tf.keras.metrics.Mean(name='target_vat_loss')
source_mixup_loss    = tf.keras.metrics.Mean(name='source_mixup_loss')
target_mixup_loss    = tf.keras.metrics.Mean(name='target_mixup_loss')
domain_loss          = tf.keras.metrics.Mean(name='domain_loss')
center_loss          = tf.keras.metrics.Mean(name='center_loss') 
cross_entropy_loss   = tf.keras.metrics.Mean(name='cross_entropy_loss')
temporal_test_acc    = tf.keras.metrics.CategoricalAccuracy(name='temporal_test_acc')
source_train_acc     = tf.keras.metrics.CategoricalAccuracy(name='source_train_acc')
source_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='source_test_acc')
office_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='office_test_acc')
server_train_acc     = tf.keras.metrics.CategoricalAccuracy(name='server_train_acc')
server_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='server_test_acc')
conference_test_acc  = tf.keras.metrics.CategoricalAccuracy(name='conference_test_acc')

@tf.function
def test_step(images):
  logits, _, _ =  generator(images, training=False)
  return tf.nn.softmax(logits)

@tf.function
def train_gen_step(src_images, src_labels, ser_images, ser_labels, lambda_hp):
  with tf.GradientTape() as gen_tape:
    #Logits
    src_logits, src_enc, src_disc_logits = generator(src_images, training=True, lambda_hp=lambda_hp)
    ser_logits, ser_enc, ser_disc_logits = generator(ser_images, training=True, lambda_hp=lambda_hp)
    
    #VAT
    src_adver_images       = virtual_adversarial_images(src_images, tf.nn.softmax(src_logits))
    src_adver_logits, _, _ = generator(tf.stop_gradient(src_adver_images), training=True)
    ser_adver_images       = virtual_adversarial_images(ser_images, tf.nn.softmax(ser_logits))
    ser_adver_logits, _, _ = generator(tf.stop_gradient(ser_adver_images), training=True)
    
    #MixUp
    src_mixup_images, src_mixup_labels = mixup_preprocess(src_images, src_logits, batch_size)
    src_mixup_logits, _, _             = generator(tf.stop_gradient(src_mixup_images), training=True)
    ser_mixup_images, ser_mixup_labels = mixup_preprocess(ser_images, ser_logits, batch_size)
    ser_mixup_logits, _, _             = generator(tf.stop_gradient(ser_mixup_images), training=True)
    
    #Loss
    batch_cross_entropy_loss  = get_cross_entropy_loss(labels=src_labels,
                                                       logits=src_logits)
    batch_cond_entropy_loss   = get_cross_entropy_loss(labels=tf.nn.softmax(ser_logits), 
                                                       logits=ser_logits)
    
    src_vat_loss              = get_cross_entropy_loss(labels=tf.nn.softmax(tf.stop_gradient(src_logits)),
                                                       logits=src_adver_logits)
    trg_vat_loss              = get_cross_entropy_loss(labels=tf.nn.softmax(tf.stop_gradient(ser_logits)),
                                                       logits=ser_adver_logits)

    src_mixup_loss            = get_cross_entropy_loss(labels=tf.stop_gradient(src_mixup_labels), 
                                                       logits=src_mixup_logits)
    trg_mixup_loss            = get_cross_entropy_loss(labels=tf.stop_gradient(ser_mixup_labels), 
                                                       logits=ser_mixup_logits)

    batch_domain_loss         = get_cross_entropy_loss(labels=tf.one_hot(tf.cast(tf.concat([tf.zeros(tf.shape(src_disc_logits)[0]),
                                                                                            tf.ones(tf.shape(ser_disc_logits)[0])], 0), tf.uint8), 2),
                                                       logits=tf.concat([src_disc_logits,
                                                                         ser_disc_logits], 0))

    batch_center_loss         = center_loss_obj.get_center_loss(src_enc, src_labels)


    total_loss = batch_cross_entropy_loss + \
                 8e-2 * batch_domain_loss + \
                 8e-2 * batch_cond_entropy_loss + \
                 1    * src_mixup_loss +\
                 8e-2 * trg_mixup_loss +\
                 8e-2 * trg_vat_loss + \
                 1    * src_vat_loss + \
                 1    * batch_center_loss
    
  gen_gradients = gen_tape.gradient(total_loss, generator.trainable_variables)
  with tf.control_dependencies([center_loss_obj.centers_update_op]):
    gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))

  source_train_acc(src_labels, tf.nn.softmax(src_logits))
  server_train_acc(ser_labels, tf.nn.softmax(ser_logits))
  cross_entropy_loss(batch_cross_entropy_loss)
  cond_entropy_loss(batch_cond_entropy_loss)
  source_vat_loss(src_vat_loss)
  target_vat_loss(trg_vat_loss)
  source_mixup_loss(src_mixup_loss)
  target_mixup_loss(trg_mixup_loss)
  domain_loss(batch_domain_loss)
  center_loss(batch_center_loss)

In [None]:
class grl_lambda(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, delta=10):
    super(grl_lambda, self).__init__()
    self.delta = delta

  def __call__(self, progress):
    return ((2/(1 + tf.exp(-self.delta*progress))) - 1)

lambda_schedule = grl_lambda()
learning_rate   = init_lr
generator       = GRLResNet50(num_classes, num_features, num_hidden, 2, activation_fn)
gen_optimizer   = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5)
center_loss_obj = CenterLoss(batch_size, num_classes, num_features, alpha)

summary_writer = tf.summary.create_file_writer(log_dir)

ckpt = tf.train.Checkpoint(generator=generator,
                           gen_optimizer=gen_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

In [None]:
for epoch in range(epochs):
  for source_data, server_data in zip(src_train_set, server_train_set):
    train_gen_step(source_data[0], source_data[1], server_data[0], server_data[1], lambda_schedule(epoch/epochs))

  for data in time_test_set:
    temporal_test_acc(test_step(data[0]), data[1])

  for data in src_test_set:
    source_test_acc(test_step(data[0]), data[1])

  for data in office_test_set:
    office_test_acc(test_step(data[0]), data[1])

  for data in server_test_set:
    server_test_acc(test_step(data[0]), data[1])

  for data in conf_test_set:
    conference_test_acc(test_step(data[0]), data[1])
    
  with summary_writer.as_default():
    tf.summary.scalar("cross_entropy_loss", cross_entropy_loss.result(), step=epoch)
    tf.summary.scalar("temporal_test_acc", temporal_test_acc.result(), step=epoch)
    tf.summary.scalar("source_train_acc", source_train_acc.result(), step=epoch)
    tf.summary.scalar("source_test_acc", source_test_acc.result(), step=epoch)
    tf.summary.scalar("office_test_acc", office_test_acc.result(), step=epoch)
    tf.summary.scalar("server_train_acc", server_train_acc.result(), step=epoch)
    tf.summary.scalar("server_test_acc", server_test_acc.result(), step=epoch)
    tf.summary.scalar("conference_test_acc", conference_test_acc.result(), step=epoch)
    tf.summary.scalar("cond_entropy_loss", cond_entropy_loss.result(), step=epoch)
    tf.summary.scalar("source_vat_loss", source_vat_loss.result(), step=epoch)
    tf.summary.scalar("target_vat_loss", target_vat_loss.result(), step=epoch)
    tf.summary.scalar("source_mixup_loss", source_mixup_loss.result(), step=epoch)
    tf.summary.scalar("target_mixup_loss", target_mixup_loss.result(), step=epoch)
    tf.summary.scalar("domain_loss", domain_loss.result(), step=epoch)
    tf.summary.scalar("center_loss", center_loss.result(), step=epoch)

  if (epoch + 1) % 25 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))
    
  cross_entropy_loss.reset_states()
  temporal_test_acc.reset_states()
  source_train_acc.reset_states()
  source_test_acc.reset_states()
  office_test_acc.reset_states()
  server_train_acc.reset_states()
  server_test_acc.reset_states()
  conference_test_acc.reset_states()
  cond_entropy_loss.reset_states()
  source_vat_loss.reset_states()
  target_vat_loss.reset_states()
  source_mixup_loss.reset_states()
  target_mixup_loss.reset_states()
  domain_loss.reset_states()
  center_loss.reset_states()