In [1]:
repo_path = "/home/kjakkala/mmwave"

import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

import sys
sys.path.append(os.path.join(repo_path, 'models'))

from utils import *
from resnet import ResNet50

import tensorflow as tf
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

print(tf.__version__)

2.0.0


In [2]:
dataset_path    = os.path.join(repo_path, 'data')
num_classes     = 9
batch_size      = 64
train_src_days  = 3
train_trg_days  = 0
train_trg_env_days = 0
epochs          = 250
init_lr         = 0.0001
num_features    = 256
activation_fn   = 'selu'
notes           = "vanilla_baseline"
log_data = "classes-{}_bs-{}_train_src_days-{}_train_trg_days-{}_train_trgenv_days-{}_initlr-{}_num_feat-{}_act_fn-{}_{}".format(num_classes,
                                                                                                                                 batch_size,
                                                                                                                                 train_src_days,
                                                                                                                                 train_trg_days,
                                                                                                                                 train_trg_env_days,
                                                                                                                                 init_lr,
                                                                                                                                 num_features,
                                                                                                                                 activation_fn,
                                                                                                                                 notes)
log_dir         = os.path.join(repo_path, 'logs/new_logs/baselines/{}'.format(log_data))
checkpoint_path = os.path.join(repo_path, 'checkpoints/{}'.format(log_data))

In [3]:
X_data, y_data, classes = get_h5dataset(os.path.join(dataset_path, 'source_data.h5'))
X_data = resize_data(X_data)
print(X_data.shape, y_data.shape, "\n", classes)

X_data, y_data = balance_dataset(X_data, y_data, 
                                 num_days=10, 
                                 num_classes=len(classes), 
                                 max_samples_per_class=95)
print(X_data.shape, y_data.shape)

#remove harika's data (incomplete data)
X_data = np.delete(X_data, np.where(y_data[:, 0] == 1)[0], 0)
y_data = np.delete(y_data, np.where(y_data[:, 0] == 1)[0], 0)

#update labes to handle 9 classes instead of 10
y_data[y_data[:, 0] >= 2, 0] -= 1
del classes[1]
print(X_data.shape, y_data.shape, "\n", classes)

#split days of data to train and test
X_src = X_data[y_data[:, 1] < train_src_days]
y_src = y_data[y_data[:, 1] < train_src_days, 0]
y_src = np.eye(len(classes))[y_src]
X_train_src, X_test_src, y_train_src, y_test_src = train_test_split(X_src,
                                                                    y_src,
                                                                    stratify=y_src,
                                                                    test_size=0.10,
                                                                    random_state=42)

X_trg = X_data[y_data[:, 1] >= train_src_days]
y_trg = y_data[y_data[:, 1] >= train_src_days]
X_train_trg = X_trg[y_trg[:, 1] < train_src_days+train_trg_days]
y_train_trg = y_trg[y_trg[:, 1] < train_src_days+train_trg_days, 0]
y_train_trg = np.eye(len(classes))[y_train_trg]

X_test_trg = X_data[y_data[:, 1] >= train_src_days+train_trg_days]
y_test_trg = y_data[y_data[:, 1] >= train_src_days+train_trg_days, 0]
y_test_trg = np.eye(len(classes))[y_test_trg]

del X_src, y_src, X_trg, y_trg, X_data, y_data

#mean center and normalize dataset
X_train_src, src_mean = mean_center(X_train_src)
X_train_src, src_min, src_ptp = normalize(X_train_src)

X_test_src, _    = mean_center(X_test_src, src_mean)
X_test_src, _, _ = normalize(X_test_src, src_min, src_ptp)

if(X_train_trg.shape[0] != 0):
  X_train_trg, trg_mean = mean_center(X_train_trg)
  X_train_trg, trg_min, trg_ptp = normalize(X_train_trg)

  X_test_trg, _    = mean_center(X_test_trg, trg_mean)
  X_test_trg, _, _ = normalize(X_test_trg, trg_min, trg_ptp)  
else:
  X_test_trg, _    = mean_center(X_test_trg, src_mean)
  X_test_trg, _, _ = normalize(X_test_trg, src_min, src_ptp)
  
X_train_src = X_train_src.astype(np.float32)
y_train_src = y_train_src.astype(np.uint8)
X_test_src  = X_test_src.astype(np.float32)
y_test_src  = y_test_src.astype(np.uint8)
X_train_trg = X_train_trg.astype(np.float32)
y_train_trg = y_train_trg.astype(np.uint8)
X_test_trg  = X_test_trg.astype(np.float32)
y_test_trg  = y_test_trg.astype(np.uint8)
print("Final shapes: ")
print(X_train_src.shape, y_train_src.shape,  X_test_src.shape, y_test_src.shape, X_train_trg.shape, y_train_trg.shape, X_test_trg.shape, y_test_trg.shape)

X_train_conf,   y_train_conf,   X_test_conf,   y_test_conf   = get_trg_data(os.path.join(dataset_path, 'target_conf_data.h5'),   classes, train_trg_env_days)
X_train_server, y_train_server, X_test_server, y_test_server = get_trg_data(os.path.join(dataset_path, 'target_server_data.h5'), classes, train_trg_env_days)
_             , _             , X_data_office, y_data_office = get_trg_data(os.path.join(dataset_path, 'target_office_data.h5'), classes, 0)

print(X_train_conf.shape,   y_train_conf.shape,    X_test_conf.shape,   y_test_conf.shape)
print(X_train_server.shape, y_train_server.shape,  X_test_server.shape, y_test_server.shape)
print(X_data_office.shape,  y_data_office.shape)

#get tf.data objects for each set

#Test
conf_test_set = tf.data.Dataset.from_tensor_slices((X_test_conf, y_test_conf))
conf_test_set = conf_test_set.batch(batch_size, drop_remainder=False)
conf_test_set = conf_test_set.prefetch(batch_size)

server_test_set = tf.data.Dataset.from_tensor_slices((X_test_server, y_test_server))
server_test_set = server_test_set.batch(batch_size, drop_remainder=False)
server_test_set = server_test_set.prefetch(batch_size)

office_test_set = tf.data.Dataset.from_tensor_slices((X_data_office, y_data_office))
office_test_set = office_test_set.batch(batch_size, drop_remainder=False)
office_test_set = office_test_set.prefetch(batch_size)

src_test_set = tf.data.Dataset.from_tensor_slices((X_test_src, y_test_src))
src_test_set = src_test_set.batch(batch_size, drop_remainder=False)
src_test_set = src_test_set.prefetch(batch_size)

time_test_set = tf.data.Dataset.from_tensor_slices((X_test_trg, y_test_trg))
time_test_set = time_test_set.batch(batch_size, drop_remainder=False)
time_test_set = time_test_set.prefetch(batch_size)

#Train
src_train_set = tf.data.Dataset.from_tensor_slices((X_train_src, y_train_src))
src_train_set = src_train_set.shuffle(X_train_src.shape[0])
src_train_set = src_train_set.batch(batch_size, drop_remainder=True)
src_train_set = src_train_set.prefetch(batch_size)

(9127, 256, 256, 1) (9127, 2) 
 ['arahman3', 'harika', 'hchen32', 'jlaivins', 'kjakkala', 'pjanakar', 'ppinyoan', 'pwang13', 'upattnai', 'wrang']
(8737, 256, 256, 1) (8737, 2)
(8547, 256, 256, 1) (8547, 2) 
 ['arahman3', 'hchen32', 'jlaivins', 'kjakkala', 'pjanakar', 'ppinyoan', 'pwang13', 'upattnai', 'wrang']
Final shapes: 
(2308, 256, 256, 1) (2308, 9) (257, 256, 256, 1) (257, 9) (0, 256, 256, 1) (0, 9) (5982, 256, 256, 1) (5982, 9)
(0, 256, 256, 1) (0,) (1350, 256, 256, 1) (1350, 9)
(0, 256, 256, 1) (0,) (1346, 256, 256, 1) (1346, 9)
(899, 256, 256, 1) (899, 9)


In [4]:
cross_entropy_loss   = tf.keras.metrics.Mean(name='cross_entropy_loss')
source_train_acc     = tf.keras.metrics.CategoricalAccuracy(name='source_train_acc')
source_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='source_test_acc')
office_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='office_test_acc')
server_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='server_test_acc')
temporal_test_acc    = tf.keras.metrics.CategoricalAccuracy(name='temporal_test_acc')
conference_test_acc  = tf.keras.metrics.CategoricalAccuracy(name='conference_test_acc')

def get_cross_entropy_loss(labels, logits):
  loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
  return tf.reduce_mean(loss)

@tf.function
def test_step(images):
  logits, _ = model(images, training=False)
  return tf.nn.softmax(logits)

@tf.function
def train_step(src_images, src_labels):
  with tf.GradientTape() as tape:
    src_logits, _ = model(src_images, training=True)
    batch_cross_entropy_loss  = get_cross_entropy_loss(labels=src_labels,
                                                       logits=src_logits)

  gradients = tape.gradient(batch_cross_entropy_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  source_train_acc(src_labels, tf.nn.softmax(src_logits))
  cross_entropy_loss(batch_cross_entropy_loss)

learning_rate  = tf.keras.optimizers.schedules.PolynomialDecay(init_lr,
                                                               decay_steps=5000,
                                                               end_learning_rate=init_lr*1e-2)
model      = ResNet50(num_classes, num_features, activation_fn)
optimizer  = tf.keras.optimizers.Adam(learning_rate = learning_rate)

summary_writer = tf.summary.create_file_writer(log_dir)

ckpt = tf.train.Checkpoint(model=model,
                           optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

for epoch in range(epochs):
  for source_data in src_train_set:
    train_step(source_data[0], source_data[1])

  for data in time_test_set:
    temporal_test_acc(test_step(data[0]), data[1])

  for data in src_test_set:
    source_test_acc(test_step(data[0]), data[1])

  for data in office_test_set:
    office_test_acc(test_step(data[0]), data[1])

  for data in server_test_set:
    server_test_acc(test_step(data[0]), data[1])

  for data in conf_test_set:
    conference_test_acc(test_step(data[0]), data[1])

  with summary_writer.as_default():
    tf.summary.scalar("cross_entropy_loss", cross_entropy_loss.result(), step=epoch)
    tf.summary.scalar("temporal_test_acc", temporal_test_acc.result(), step=epoch)
    tf.summary.scalar("source_train_acc", source_train_acc.result(), step=epoch)
    tf.summary.scalar("source_test_acc", source_test_acc.result(), step=epoch)
    tf.summary.scalar("office_test_acc", office_test_acc.result(), step=epoch)
    tf.summary.scalar("server_test_acc", server_test_acc.result(), step=epoch)
    tf.summary.scalar("conference_test_acc", conference_test_acc.result(), step=epoch)
    

  if (epoch + 1) % 25 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  cross_entropy_loss.reset_states()
  temporal_test_acc.reset_states()
  source_train_acc.reset_states()
  source_test_acc.reset_states()
  office_test_acc.reset_states()
  server_test_acc.reset_states()
  conference_test_acc.reset_states()

Saving checkpoint for epoch 25 at /home/kjakkala/mmwave/checkpoints/classes-9_bs-64_train_src_days-3_train_trg_days-0_train_trgenv_days-0_initlr-0.0001_num_feat-256_act_fn-selu_vanilla_baseline_no-lr-decay/ckpt-1
Saving checkpoint for epoch 50 at /home/kjakkala/mmwave/checkpoints/classes-9_bs-64_train_src_days-3_train_trg_days-0_train_trgenv_days-0_initlr-0.0001_num_feat-256_act_fn-selu_vanilla_baseline_no-lr-decay/ckpt-2
Saving checkpoint for epoch 75 at /home/kjakkala/mmwave/checkpoints/classes-9_bs-64_train_src_days-3_train_trg_days-0_train_trgenv_days-0_initlr-0.0001_num_feat-256_act_fn-selu_vanilla_baseline_no-lr-decay/ckpt-3
Saving checkpoint for epoch 100 at /home/kjakkala/mmwave/checkpoints/classes-9_bs-64_train_src_days-3_train_trg_days-0_train_trgenv_days-0_initlr-0.0001_num_feat-256_act_fn-selu_vanilla_baseline_no-lr-decay/ckpt-4
Saving checkpoint for epoch 125 at /home/kjakkala/mmwave/checkpoints/classes-9_bs-64_train_src_days-3_train_trg_days-0_train_trgenv_days-0_initlr-0