In [1]:
repo_path = "/home/kjakkala/mmwave"

import os
os.environ['CUDA_VISIBLE_DEVICES']='0'

import sys
sys.path.append(os.path.join(repo_path, 'models'))

from utils import *
import h5py
from tqdm import tqdm

import tensorflow as tf
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

print(tf.__version__)

2.0.0


In [2]:
dataset_path    = os.path.join(repo_path, 'data')
num_classes     = 9
batch_size      = 64
train_src_days  = 3
train_trg_days  = 0
train_trg_env_days = 2
epochs          = 1000
init_lr         = 0.0001
num_features    = 256
alpha           = 0.05
activation_fn   = 'selu'
recon_lambda    = 20
disc_hidden     = [1024] 
clas_hidden     = [2048, 2048, 1024] 
notes           = "AT-CMT-Center-Recon-{}-[256,256,256]-TAT-disc-{}-clas-{}-conf_adapt-cond-1e-2".format(recon_lambda, disc_hidden, clas_hidden)
log_data = "classes-{}_bs-{}_train_src_days-{}_train_trg_days-{}_train_trgenv_days-{}_alpha-{}_initlr-{}_num_feat-{}_act_fn-{}_{}".format(num_classes,
                                                                                                                                 batch_size,
                                                                                                                                 train_src_days,
                                                                                                                                 train_trg_days,
                                                                                                                                 train_trg_env_days,
                                                                                                                                 alpha,
                                                                                                                                 init_lr,
                                                                                                                                 num_features,
                                                                                                                                 activation_fn,
                                                                                                                                 notes)
log_dir         = os.path.join(repo_path, 'logs/new_logs/TAT/{}'.format(log_data))
encodings_file  = os.path.join(repo_path, 'data/encodings_conf.h5')

In [3]:
hf = h5py.File(encodings_file, 'r')

X_train_src = np.array(hf.get('X_train_src'))
y_train_src = np.array(hf.get('y_train_src'))
X_test_src  = np.array(hf.get('X_test_src'))
y_test_src  = np.array(hf.get('y_test_src'))
X_train_trg = np.array(hf.get('X_train_trg'))
y_train_trg = np.array(hf.get('y_train_trg'))
X_test_trg  = np.array(hf.get('X_test_trg'))
y_test_trg  = np.array(hf.get('y_test_trg'))

X_train_conf = np.array(hf.get('X_train_conf'))
y_train_conf = np.array(hf.get('y_train_conf'))
X_test_conf  = np.array(hf.get('X_test_conf'))
y_test_conf  = np.array(hf.get('y_test_conf'))

X_train_server = np.array(hf.get('X_train_server'))
y_train_server = np.array(hf.get('y_train_server'))
X_test_server  = np.array(hf.get('X_test_server'))
y_test_server  = np.array(hf.get('y_test_server'))

X_data_office  = np.array(hf.get('X_data_office'))
y_data_office  = np.array(hf.get('y_data_office'))

hf.close()

print("Final shapes: ")
print(X_train_src.shape, y_train_src.shape,  X_test_src.shape, y_test_src.shape, X_train_trg.shape, y_train_trg.shape, X_test_trg.shape, y_test_trg.shape)
print(X_train_conf.shape,   y_train_conf.shape,    X_test_conf.shape,   y_test_conf.shape)
print(X_train_server.shape, y_train_server.shape,  X_test_server.shape, y_test_server.shape)
print(X_data_office.shape,  y_data_office.shape)

Final shapes: 
(2308, 256) (2308, 9) (257, 256) (257, 9) (0,) (0, 9) (5982, 256) (5982, 9)
(900, 256) (900, 9) (450, 256) (450, 9)
(0,) (0,) (1346, 256) (1346, 9)
(899, 256) (899, 9)


In [4]:
#get tf.data objects for each set

#Test
conf_test_set = tf.data.Dataset.from_tensor_slices((X_test_conf, y_test_conf))
conf_test_set = conf_test_set.batch(batch_size, drop_remainder=False)
conf_test_set = conf_test_set.prefetch(batch_size)

server_test_set = tf.data.Dataset.from_tensor_slices((X_test_server, y_test_server))
server_test_set = server_test_set.batch(batch_size, drop_remainder=False)
server_test_set = server_test_set.prefetch(batch_size)

office_test_set = tf.data.Dataset.from_tensor_slices((X_data_office, y_data_office))
office_test_set = office_test_set.batch(batch_size, drop_remainder=False)
office_test_set = office_test_set.prefetch(batch_size)

src_test_set = tf.data.Dataset.from_tensor_slices((X_test_src, y_test_src))
src_test_set = src_test_set.batch(batch_size, drop_remainder=False)
src_test_set = src_test_set.prefetch(batch_size)

time_test_set = tf.data.Dataset.from_tensor_slices((X_test_trg, y_test_trg))
time_test_set = time_test_set.batch(batch_size, drop_remainder=False)
time_test_set = time_test_set.prefetch(batch_size)

#Train
src_train_set = tf.data.Dataset.from_tensor_slices((X_train_src, y_train_src))
src_train_set = src_train_set.shuffle(X_train_src.shape[0])
src_train_set = src_train_set.batch(batch_size, drop_remainder=True)
src_train_set = src_train_set.prefetch(batch_size)

conf_train_set = tf.data.Dataset.from_tensor_slices((X_train_conf, y_train_conf))
conf_train_set = conf_train_set.shuffle(X_train_conf.shape[0])
conf_train_set = conf_train_set.batch(batch_size, drop_remainder=True)
conf_train_set = conf_train_set.prefetch(batch_size)
conf_train_set = conf_train_set.repeat(-1)

In [5]:
class Discriminator(tf.keras.Model):
  def __init__(self, num_hidden, num_classes, activation='relu'):
    super().__init__(name='discriminator')
    self.hidden_layers = []
    for dim in num_hidden:
      self.hidden_layers.append(tf.keras.layers.Dense(dim, activation=activation))
    self.logits = tf.keras.layers.Dense(num_classes, activation=None)

  def call(self, x):
    for layer in self.hidden_layers:
      x = layer(x)
    x = self.logits(x)

    return x
  
class Classifier(tf.keras.Model):
  def __init__(self, num_hidden, num_classes, activation='relu'):
    super().__init__(name='classifier')
    self.hidden_layers = []
    for dim in num_hidden:
      self.hidden_layers.append(tf.keras.layers.Dense(dim, activation=activation))
    self.logits = tf.keras.layers.Dense(num_classes, activation=None)

  def call(self, x):
    for layer in self.hidden_layers:
      x = layer(x)
    x = self.logits(x)

    return x
  
def get_cross_entropy_loss(labels, logits):
  loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
  return tf.reduce_mean(loss)

In [6]:
domain_loss          = tf.keras.metrics.Mean(name='domain_loss')
confusion_loss       = tf.keras.metrics.Mean(name='confusion_loss')
cross_entropy_loss   = tf.keras.metrics.Mean(name='cross_entropy_loss')
temporal_test_acc    = tf.keras.metrics.CategoricalAccuracy(name='temporal_test_acc')
source_train_acc     = tf.keras.metrics.CategoricalAccuracy(name='source_train_acc')
source_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='source_test_acc')
office_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='office_test_acc')
server_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='server_test_acc')
conference_train_acc = tf.keras.metrics.CategoricalAccuracy(name='conference_train_acc')
conference_test_acc  = tf.keras.metrics.CategoricalAccuracy(name='conference_test_acc')

@tf.function
def test_step(images):
  logits = classifier(images, training=False)
  return tf.nn.softmax(logits)

@tf.function
def train_clas_step(src_enc, src_labels, con_enc, con_labels):
  with tf.GradientTape() as tape:
    #Logits
    src_logits = classifier(src_enc, training=True)
    con_logits = classifier(con_enc, training=True)

    #Disc
    src_disc_logits = discriminator(src_enc, training=False)
    con_disc_logits = discriminator(con_enc, training=False)
    
    #Loss
    batch_cross_entropy_loss  = get_cross_entropy_loss(labels=src_labels,
                                                       logits=src_logits)
    batch_domain_loss         = get_cross_entropy_loss(labels=tf.one_hot(tf.cast(tf.concat([tf.zeros(tf.shape(src_disc_logits)[0]),
                                                                                            tf.ones(tf.shape(con_disc_logits)[0])], 0), tf.uint8), 2),
                                                       logits=tf.concat([src_disc_logits,
                                                                         con_disc_logits], 0))
    batch_cond_entropy_loss   = get_cross_entropy_loss(labels=tf.nn.softmax(con_logits), 
                                                       logits=con_logits)
    
    total_loss = batch_cross_entropy_loss + batch_domain_loss + 1e-2*batch_cond_entropy_loss
    
  clas_gradients = tape.gradient(total_loss, classifier.trainable_variables)
  clas_optimizer.apply_gradients(zip(clas_gradients, classifier.trainable_variables))

  source_train_acc(src_labels, tf.nn.softmax(src_logits))
  conference_train_acc(con_labels, tf.nn.softmax(con_logits))
  cross_entropy_loss(batch_cross_entropy_loss)
  domain_loss(batch_domain_loss)  

@tf.function
def train_disc_step(src_enc, con_enc):  
  with tf.GradientTape() as tape:    
    #Logits
    src_logits = classifier(src_enc, training=False)
    con_logits = classifier(con_enc, training=False)

    #Disc
    src_disc_logits = discriminator(src_enc, training=True)
    con_disc_logits = discriminator(con_enc, training=True)
    
    batch_confusion_loss = get_cross_entropy_loss(labels=tf.one_hot(tf.cast(tf.concat([tf.ones(tf.shape(src_disc_logits)[0]),
                                                                                       tf.zeros(tf.shape(con_disc_logits)[0])], 0), tf.uint8), 2),
                                                  logits=tf.concat([src_disc_logits,
                                                                    con_disc_logits], 0))
  
  disc_gradients = tape.gradient(batch_confusion_loss, discriminator.trainable_variables)
  disc_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))
  
  confusion_loss(batch_confusion_loss)

In [7]:
learning_rate  = tf.keras.optimizers.schedules.PolynomialDecay(init_lr,
                                                               decay_steps=5000,
                                                               end_learning_rate=init_lr*1e-2,
                                                               cycle=True)
discriminator  = Discriminator(disc_hidden, 2, activation_fn)
classifier     = Classifier(clas_hidden, num_classes, activation_fn)

disc_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5)
clas_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5)

summary_writer = tf.summary.create_file_writer(log_dir)

In [8]:
for epoch in tqdm(range(epochs)):  
  for source_data, conf_data in zip(src_train_set, conf_train_set):
    train_clas_step(source_data[0], source_data[1], conf_data[0], conf_data[1])
    train_disc_step(source_data[0], conf_data[0])

  for data in time_test_set:
    temporal_test_acc(test_step(data[0]), data[1])

  for data in src_test_set:
    source_test_acc(test_step(data[0]), data[1])

  for data in office_test_set:
    office_test_acc(test_step(data[0]), data[1])

  for data in server_test_set:
    server_test_acc(test_step(data[0]), data[1])

  for data in conf_test_set:
    conference_test_acc(test_step(data[0]), data[1])

  with summary_writer.as_default():
    tf.summary.scalar("cross_entropy_loss", cross_entropy_loss.result(), step=epoch)
    tf.summary.scalar("temporal_test_acc", temporal_test_acc.result(), step=epoch)
    tf.summary.scalar("source_train_acc", source_train_acc.result(), step=epoch)
    tf.summary.scalar("source_test_acc", source_test_acc.result(), step=epoch)
    tf.summary.scalar("office_test_acc", office_test_acc.result(), step=epoch)
    tf.summary.scalar("server_test_acc", server_test_acc.result(), step=epoch)
    tf.summary.scalar("conference_train_acc", conference_train_acc.result(), step=epoch)
    tf.summary.scalar("conference_test_acc", conference_test_acc.result(), step=epoch)
    tf.summary.scalar("domain_loss", domain_loss.result(), step=epoch)
    tf.summary.scalar("confusion_loss", confusion_loss.result(), step=epoch)
    
  cross_entropy_loss.reset_states()
  temporal_test_acc.reset_states()
  source_train_acc.reset_states()
  source_test_acc.reset_states()
  office_test_acc.reset_states()
  server_test_acc.reset_states()
  conference_train_acc.reset_states()
  conference_test_acc.reset_states()
  domain_loss.reset_states()
  confusion_loss.reset_states()

100%|██████████| 1000/1000 [06:00<00:00,  2.78it/s]
