In [None]:
repo_path = "/home/kjakkala/mmwave"

import os
os.environ['CUDA_VISIBLE_DEVICES']='0'

import sys
sys.path.append(os.path.join(repo_path, 'models'))

from utils import *
from resnet import ResNet50
from pix2pix import upsample

import tensorflow as tf
print(tf.__version__)

In [None]:
dataset_path    = os.path.join(repo_path, 'data')
num_classes     = 9
batch_size      = 32
train_src_days  = 3
train_trg_days  = 0
train_trg_env_days = 0
epochs          = 1000
init_lr         = 0.0001
num_features    = 256
alpha           = 0.05
activation_fn   = 'selu'
rot_lambda      = 1
num_hidden      = [256,256,128]
notes           = "AT-CMT-Center-Rot-{}-{}-Baseline".format(rot_lambda, num_hidden)
log_data = "classes-{}_bs-{}_train_src_days-{}_train_trg_days-{}_train_trgenv_days-{}_alpha-{}_initlr-{}_num_feat-{}_act_fn-{}_{}".format(num_classes,
                                                                                                                                 batch_size,
                                                                                                                                 train_src_days,
                                                                                                                                 train_trg_days,
                                                                                                                                 train_trg_env_days,
                                                                                                                                 alpha,
                                                                                                                                 init_lr,
                                                                                                                                 num_features,
                                                                                                                                 activation_fn,
                                                                                                                                 notes)
log_dir         = os.path.join(repo_path, 'logs/new_logs/Baselines/{}'.format(log_data))
checkpoint_path = os.path.join(repo_path, 'checkpoints/Baselines/{}'.format(log_data))

In [None]:
X_data, y_data, classes = get_h5dataset(os.path.join(dataset_path, 'source_data.h5'))
X_data = resize_data(X_data)
print(X_data.shape, y_data.shape, "\n", classes)

X_data, y_data = balance_dataset(X_data, y_data, 
                                 num_days=10, 
                                 num_classes=len(classes), 
                                 max_samples_per_class=95)
print(X_data.shape, y_data.shape)

#remove harika's data (incomplete data)
X_data = np.delete(X_data, np.where(y_data[:, 0] == 1)[0], 0)
y_data = np.delete(y_data, np.where(y_data[:, 0] == 1)[0], 0)

#update labes to handle 9 classes instead of 10
y_data[y_data[:, 0] >= 2, 0] -= 1
del classes[1]
print(X_data.shape, y_data.shape, "\n", classes)

#split days of data to train and test
X_src = X_data[y_data[:, 1] < train_src_days]
y_src = y_data[y_data[:, 1] < train_src_days, 0]
y_src = np.eye(len(classes))[y_src]
X_train_src, X_test_src, y_train_src, y_test_src = train_test_split(X_src,
                                                                    y_src,
                                                                    stratify=y_src,
                                                                    test_size=0.10,
                                                                    random_state=42)

X_trg = X_data[y_data[:, 1] >= train_src_days]
y_trg = y_data[y_data[:, 1] >= train_src_days]
X_train_trg = X_trg[y_trg[:, 1] < train_src_days+train_trg_days]
y_train_trg = y_trg[y_trg[:, 1] < train_src_days+train_trg_days, 0]
y_train_trg = np.eye(len(classes))[y_train_trg]

X_test_trg = X_data[y_data[:, 1] >= train_src_days+train_trg_days]
y_test_trg = y_data[y_data[:, 1] >= train_src_days+train_trg_days, 0]
y_test_trg = np.eye(len(classes))[y_test_trg]

del X_src, y_src, X_trg, y_trg, X_data, y_data

#mean center and normalize dataset
X_train_src, src_mean = mean_center(X_train_src)
X_train_src, src_min, src_ptp = normalize(X_train_src)

X_test_src, _    = mean_center(X_test_src, src_mean)
X_test_src, _, _ = normalize(X_test_src, src_min, src_ptp)

if(X_train_trg.shape[0] != 0):
  X_train_trg, trg_mean = mean_center(X_train_trg)
  X_train_trg, trg_min, trg_ptp = normalize(X_train_trg)

  X_test_trg, _    = mean_center(X_test_trg, trg_mean)
  X_test_trg, _, _ = normalize(X_test_trg, trg_min, trg_ptp)  
else:
  X_test_trg, _    = mean_center(X_test_trg, src_mean)
  X_test_trg, _, _ = normalize(X_test_trg, src_min, src_ptp)
  
X_train_src = X_train_src.astype(np.float32)
y_train_src = y_train_src.astype(np.uint8)
X_test_src  = X_test_src.astype(np.float32)
y_test_src  = y_test_src.astype(np.uint8)
X_train_trg = X_train_trg.astype(np.float32)
y_train_trg = y_train_trg.astype(np.uint8)
X_test_trg  = X_test_trg.astype(np.float32)
y_test_trg  = y_test_trg.astype(np.uint8)
print("Final shapes: ")
print(X_train_src.shape, y_train_src.shape,  X_test_src.shape, y_test_src.shape, X_train_trg.shape, y_train_trg.shape, X_test_trg.shape, y_test_trg.shape)

X_train_conf,   y_train_conf,   X_test_conf,   y_test_conf   = get_trg_data(os.path.join(dataset_path, 'target_conf_data.h5'),   classes, train_trg_env_days)
X_train_server, y_train_server, X_test_server, y_test_server = get_trg_data(os.path.join(dataset_path, 'target_server_data.h5'), classes, train_trg_env_days)
_             , _             , X_data_office, y_data_office = get_trg_data(os.path.join(dataset_path, 'target_office_data.h5'), classes, 0)

print(X_train_conf.shape,   y_train_conf.shape,    X_test_conf.shape,   y_test_conf.shape)
print(X_train_server.shape, y_train_server.shape,  X_test_server.shape, y_test_server.shape)
print(X_data_office.shape,  y_data_office.shape)

#get tf.data objects for each set

#Test
conf_test_set = tf.data.Dataset.from_tensor_slices((X_test_conf, y_test_conf))
conf_test_set = conf_test_set.batch(batch_size, drop_remainder=False)
conf_test_set = conf_test_set.prefetch(batch_size)

server_test_set = tf.data.Dataset.from_tensor_slices((X_test_server, y_test_server))
server_test_set = server_test_set.batch(batch_size, drop_remainder=False)
server_test_set = server_test_set.prefetch(batch_size)

office_test_set = tf.data.Dataset.from_tensor_slices((X_data_office, y_data_office))
office_test_set = office_test_set.batch(batch_size, drop_remainder=False)
office_test_set = office_test_set.prefetch(batch_size)

src_test_set = tf.data.Dataset.from_tensor_slices((X_test_src, y_test_src))
src_test_set = src_test_set.batch(batch_size, drop_remainder=False)
src_test_set = src_test_set.prefetch(batch_size)

time_test_set = tf.data.Dataset.from_tensor_slices((X_test_trg, y_test_trg))
time_test_set = time_test_set.batch(batch_size, drop_remainder=False)
time_test_set = time_test_set.prefetch(batch_size)

#Train
src_train_set = tf.data.Dataset.from_tensor_slices((X_train_src, y_train_src))
src_train_set = src_train_set.shuffle(X_train_src.shape[0])
src_train_set = src_train_set.batch(batch_size, drop_remainder=True)
src_train_set = src_train_set.prefetch(batch_size)

In [None]:
def virtual_adversarial_images(images, logits, pert_norm_radius=3.5):  
  with tf.GradientTape() as tape:
    # Get normalised noise matrix
    noise = tf.random.normal(shape=tf.shape(images))
    noise = 1e-6 * tf.nn.l2_normalize(noise, axis=tf.range(1, len(noise.shape)))

    # Add noise to image and get new logits
    noise_logits, _, _ = model(images + noise, 
                            tf.constant(False, dtype=tf.bool))

    # Get loss from noisey logits
    noise_loss = tf.nn.softmax_cross_entropy_with_logits(labels=logits, logits=noise_logits)
    noise_loss = tf.reduce_mean(noise_loss)

  # Based on perturbed image loss, get direction of greatest error
  adversarial_noise = tape.gradient(noise_loss, 
                                    [noise],
                                    unconnected_gradients='zero')[0]

  adversarial_noise = tf.nn.l2_normalize(adversarial_noise, 
                                         axis=tf.range(1, 4))

  # return images with adversarial perturbation
  return tf.stop_gradient(images + pert_norm_radius * adversarial_noise)

  
def get_cross_entropy_loss(labels, logits):
  loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
  return tf.reduce_mean(loss)
  
class Discriminator(tf.keras.Model):
  def __init__(self, num_hidden, num_classes, activation='relu'):
    super().__init__(name='discriminator')
    self.hidden_layers = []
    for dim in num_hidden:
      self.hidden_layers.append(tf.keras.layers.Dense(dim, activation=activation))
    self.logits = tf.keras.layers.Dense(num_classes, activation=None)

  def call(self, x):
    for layer in self.hidden_layers:
      x = layer(x)
    x = self.logits(x)

    return x
  
"""Instantiates the ResNet50 architecture with discriminator and GRL layer.

Args:
  num_classes: `int` number of classes for image classification.

Returns:
    A Keras model instance.
"""
class RotResNet50(ResNet50):
  def __init__(self, num_classes, num_features, num_hidden, activation='relu'):
    super().__init__(num_classes, num_features, activation)
    self.rot_classifier = Discriminator(num_hidden, 4, activation=self.activation)
    
  def call(self, img_input, training=False):
    x = self.conv1(img_input)
    x = self.bn1(x, training=training)
    x = self.act1(x)
    x = self.max_pool1(x)

    for block in self.blocks:
      x = block(x, training=training)

    x = self.avg_pool(x)
    fc1 = self.fc1(x)
    logits = self.logits(fc1)

    rot_logits = self.rot_classifier(fc1)
    return logits, fc1, rot_logits
  
def random_rotate(x):
    batch_size = tf.shape(x)[0]
    y = tf.random.uniform([batch_size],
                          minval=0,
                          maxval=4,
                          dtype=tf.dtypes.int32)
    x = tf.map_fn(lambda x: tf.image.rot90(x[0], k=x[1]), [x, y], dtype=tf.float32)
    y = tf.one_hot(y, 4)
    return tf.stop_gradient(x), tf.stop_gradient(y)

In [None]:
cross_entropy_loss   = tf.keras.metrics.Mean(name='cross_entropy_loss')
source_vat_loss      = tf.keras.metrics.Mean(name='source_vat_loss')
source_cutmix_loss   = tf.keras.metrics.Mean(name='source_cutmix_loss')
center_loss          = tf.keras.metrics.Mean(name='center_loss') 
source_rot_loss      = tf.keras.metrics.Mean(name='source_rot_loss') 
source_train_acc     = tf.keras.metrics.CategoricalAccuracy(name='source_train_acc')
source_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='source_test_acc')
office_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='office_test_acc')
server_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='server_test_acc')
temporal_test_acc    = tf.keras.metrics.CategoricalAccuracy(name='temporal_test_acc')
conference_test_acc  = tf.keras.metrics.CategoricalAccuracy(name='conference_test_acc')

@tf.function
def test_step(images):
  logits, _, _ = model(images, training=False)
  return tf.nn.softmax(logits)

@tf.function
def train_step(src_images, src_labels):
  with tf.GradientTape() as tape:
    src_images_rot, src_rot_y = random_rotate(src_images)
    _, _, src_rot = model(src_images_rot, training=True)
    
    src_logits, src_enc, _ = model(src_images, training=True)
    
    src_adver_images       = virtual_adversarial_images(src_images, src_labels)
    src_adver_logits, _, _ = model(src_adver_images, training=True)
        
    src_cutmix_images, src_cutmix_labels = cutmix(src_images, src_labels)
    src_cutmix_logits, _, _              = model(src_cutmix_images, training=True)
    
    batch_center_loss         = center_loss_obj.get_center_loss(src_enc, src_labels)
    
    batch_cross_entropy_loss  = get_cross_entropy_loss(labels=src_labels,
                                                       logits=src_logits)
    
    src_vat_loss              = get_cross_entropy_loss(labels=src_labels,
                                                       logits=src_adver_logits)
    
    src_cutmix_loss           = get_cross_entropy_loss(labels=src_cutmix_labels, 
                                                       logits=src_cutmix_logits)
    
    src_rot_loss              = get_cross_entropy_loss(labels=src_rot_y, 
                                                       logits=src_rot)
    
    total_loss = batch_cross_entropy_loss + \
                 1    * src_cutmix_loss + \
                 1    * src_vat_loss + \
                 1    * batch_center_loss + \
                 rot_lambda * src_rot_loss
    
  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  source_train_acc(src_labels, tf.nn.softmax(src_logits))
  cross_entropy_loss(batch_cross_entropy_loss)
  source_vat_loss(src_vat_loss)
  source_cutmix_loss(src_cutmix_loss)
  center_loss(batch_center_loss)
  source_rot_loss(src_rot_loss)
  
learning_rate  = tf.keras.optimizers.schedules.PolynomialDecay(init_lr,
                                                               decay_steps=(X_train_src.shape[0]//batch_size)*200,
                                                               end_learning_rate=init_lr*1e-2,
                                                               cycle=True)
model      = RotResNet50(num_classes, num_features, num_hidden, activation_fn)
optimizer  = tf.keras.optimizers.Adam(learning_rate = learning_rate)
center_loss_obj= CenterLoss(batch_size, num_classes, num_features, alpha)

summary_writer = tf.summary.create_file_writer(log_dir)

ckpt = tf.train.Checkpoint(model=model,
                           optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

In [None]:
for epoch in range(epochs):
  for source_data in src_train_set:
    train_step(source_data[0], source_data[1])

  for data in time_test_set:
    temporal_test_acc(test_step(data[0]), data[1])

  for data in src_test_set:
    source_test_acc(test_step(data[0]), data[1])

  for data in office_test_set:
    office_test_acc(test_step(data[0]), data[1])

  for data in server_test_set:
    server_test_acc(test_step(data[0]), data[1])

  for data in conf_test_set:
    conference_test_acc(test_step(data[0]), data[1])

  with summary_writer.as_default():
    tf.summary.scalar("cross_entropy_loss", cross_entropy_loss.result(), step=epoch)
    tf.summary.scalar("temporal_test_acc", temporal_test_acc.result(), step=epoch)
    tf.summary.scalar("source_train_acc", source_train_acc.result(), step=epoch)
    tf.summary.scalar("source_test_acc", source_test_acc.result(), step=epoch)
    tf.summary.scalar("office_test_acc", office_test_acc.result(), step=epoch)
    tf.summary.scalar("server_test_acc", server_test_acc.result(), step=epoch)
    tf.summary.scalar("conference_test_acc", conference_test_acc.result(), step=epoch)
    tf.summary.scalar("center_loss", center_loss.result(), step=epoch)
    tf.summary.scalar("source_cutmix_loss", source_cutmix_loss.result(), step=epoch)
    tf.summary.scalar("source_vat_loss", source_vat_loss.result(), step=epoch)
    tf.summary.scalar("source_rot_loss", source_rot_loss.result(), step=epoch)

  if (epoch + 1) % 25 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  cross_entropy_loss.reset_states()
  temporal_test_acc.reset_states()
  source_train_acc.reset_states()
  source_test_acc.reset_states()
  office_test_acc.reset_states()
  server_test_acc.reset_states()
  conference_test_acc.reset_states()
  source_vat_loss.reset_states()
  source_cutmix_loss.reset_states()
  center_loss.reset_states()
  source_rot_loss.reset_states()