In [1]:
repo_path = "/home/kjakkala/mmwave"

import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

import sys
sys.path.append(os.path.join(repo_path, 'models'))

from utils import *
from resnet import ResNet50
from pix2pix import UNetGenerator, PatchGanDiscriminator

import tensorflow as tf
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

print(tf.__version__)

2.0.0


In [2]:
dataset_path    = os.path.join(repo_path, 'data')
num_classes     = 9
batch_size      = 16
train_src_days  = 6
train_trg_days  = 0
train_trg_env_days = 2
epochs          = 500
init_lr         = 0.0001
num_features    = 256
activation_fn   = 'selu'
notes           = "cyclegan_server_clas_lr_0.000001"
log_data = "classes-{}_bs-{}_train_src_days-{}_train_trg_days-{}_train_trgenv_days-{}_initlr-{}_num_feat-{}_act_fn-{}_{}".format(num_classes,
                                                                                                                                 batch_size,
                                                                                                                                 train_src_days,
                                                                                                                                 train_trg_days,
                                                                                                                                 train_trg_env_days,
                                                                                                                                 init_lr,
                                                                                                                                 num_features,
                                                                                                                                 activation_fn,
                                                                                                                                 notes)
log_dir         = os.path.join(repo_path, 'logs/new_logs/CycleGAN/{}'.format(log_data))
checkpoint_path = os.path.join(repo_path, 'checkpoints/{}'.format(log_data))
classifier_checkpoint_path = "/home/kjakkala/mmwave/checkpoints/classes-9_bs-64_train_src_days-6_train_trg_days-0_train_trgenv_days-0_initlr-0.0001_num_feat-256_act_fn-selu_vanilla_baseline"

In [3]:
X_data, y_data, classes = get_h5dataset(os.path.join(dataset_path, 'source_data.h5'))
X_data = resize_data(X_data)
print(X_data.shape, y_data.shape, "\n", classes)

X_data, y_data = balance_dataset(X_data, y_data, 
                                 num_days=10, 
                                 num_classes=len(classes), 
                                 max_samples_per_class=95)
print(X_data.shape, y_data.shape)

#remove harika's data (incomplete data)
X_data = np.delete(X_data, np.where(y_data[:, 0] == 1)[0], 0)
y_data = np.delete(y_data, np.where(y_data[:, 0] == 1)[0], 0)

#update labes to handle 9 classes instead of 10
y_data[y_data[:, 0] >= 2, 0] -= 1
del classes[1]
print(X_data.shape, y_data.shape, "\n", classes)

#split days of data to train and test
X_src = X_data[y_data[:, 1] < train_src_days]
y_src = y_data[y_data[:, 1] < train_src_days, 0]
y_src = np.eye(len(classes))[y_src]
X_train_src, X_test_src, y_train_src, y_test_src = train_test_split(X_src,
                                                                    y_src,
                                                                    stratify=y_src,
                                                                    test_size=0.10,
                                                                    random_state=42)

X_trg = X_data[y_data[:, 1] >= train_src_days]
y_trg = y_data[y_data[:, 1] >= train_src_days]
X_train_trg = X_trg[y_trg[:, 1] < train_src_days+train_trg_days]
y_train_trg = y_trg[y_trg[:, 1] < train_src_days+train_trg_days, 0]
y_train_trg = np.eye(len(classes))[y_train_trg]

X_test_trg = X_data[y_data[:, 1] >= train_src_days+train_trg_days]
y_test_trg = y_data[y_data[:, 1] >= train_src_days+train_trg_days, 0]
y_test_trg = np.eye(len(classes))[y_test_trg]

del X_src, y_src, X_trg, y_trg, X_data, y_data

#mean center and normalize dataset
X_train_src, src_mean = mean_center(X_train_src)
X_train_src, src_min, src_ptp = normalize(X_train_src)

X_test_src, _    = mean_center(X_test_src, src_mean)
X_test_src, _, _ = normalize(X_test_src, src_min, src_ptp)

if(X_train_trg.shape[0] != 0):
  X_train_trg, trg_mean = mean_center(X_train_trg)
  X_train_trg, trg_min, trg_ptp = normalize(X_train_trg)

  X_test_trg, _    = mean_center(X_test_trg, trg_mean)
  X_test_trg, _, _ = normalize(X_test_trg, trg_min, trg_ptp)  
else:
  X_test_trg, _    = mean_center(X_test_trg, src_mean)
  X_test_trg, _, _ = normalize(X_test_trg, src_min, src_ptp)
  
X_train_src = X_train_src.astype(np.float32)
y_train_src = y_train_src.astype(np.uint8)
X_test_src  = X_test_src.astype(np.float32)
y_test_src  = y_test_src.astype(np.uint8)
X_train_trg = X_train_trg.astype(np.float32)
y_train_trg = y_train_trg.astype(np.uint8)
X_test_trg  = X_test_trg.astype(np.float32)
y_test_trg  = y_test_trg.astype(np.uint8)
print("Final shapes: ")
print("Source:", X_train_src.shape, y_train_src.shape,  X_test_src.shape, y_test_src.shape)
print("Time:", X_train_trg.shape, y_train_trg.shape, X_test_trg.shape, y_test_trg.shape)

X_train_conf,   y_train_conf,   X_test_conf,   y_test_conf   = get_trg_data(os.path.join(dataset_path, 'target_conf_data.h5'),   classes, 0)
X_train_server, y_train_server, X_test_server, y_test_server = get_trg_data(os.path.join(dataset_path, 'target_server_data.h5'), classes, train_trg_env_days)
_             , _             , X_data_office, y_data_office = get_trg_data(os.path.join(dataset_path, 'target_office_data.h5'), classes, 0)

print("Conf:",   X_train_conf.shape,   y_train_conf.shape,    X_test_conf.shape,   y_test_conf.shape)
print("Server",  X_train_server.shape, y_train_server.shape,  X_test_server.shape, y_test_server.shape)
print("Office:", X_data_office.shape,  y_data_office.shape)

#get tf.data objects for each set

#Test
conf_test_set = tf.data.Dataset.from_tensor_slices((X_test_conf, y_test_conf))
conf_test_set = conf_test_set.batch(batch_size, drop_remainder=False)
conf_test_set = conf_test_set.prefetch(batch_size)

server_test_set = tf.data.Dataset.from_tensor_slices((X_test_server, y_test_server))
server_test_set = server_test_set.batch(batch_size, drop_remainder=False)
server_test_set = server_test_set.prefetch(batch_size)

office_test_set = tf.data.Dataset.from_tensor_slices((X_data_office, y_data_office))
office_test_set = office_test_set.batch(batch_size, drop_remainder=False)
office_test_set = office_test_set.prefetch(batch_size)

src_test_set = tf.data.Dataset.from_tensor_slices((X_test_src, y_test_src))
src_test_set = src_test_set.batch(batch_size, drop_remainder=False)
src_test_set = src_test_set.prefetch(batch_size)

time_test_set = tf.data.Dataset.from_tensor_slices((X_test_trg, y_test_trg))
time_test_set = time_test_set.batch(batch_size, drop_remainder=False)
time_test_set = time_test_set.prefetch(batch_size)

#Train
src_train_set = tf.data.Dataset.from_tensor_slices((X_train_src, y_train_src))
src_train_set = src_train_set.shuffle(X_train_src.shape[0])
src_train_set = src_train_set.batch(batch_size, drop_remainder=True)
src_train_set = src_train_set.prefetch(batch_size)

server_train_set = tf.data.Dataset.from_tensor_slices((X_train_server, y_train_server))
server_train_set = server_train_set.shuffle(X_train_server.shape[0])
server_train_set = server_train_set.batch(batch_size, drop_remainder=True)
server_train_set = server_train_set.prefetch(batch_size)
server_train_set = server_train_set.repeat(-1)

(9127, 256, 256, 1) (9127, 2) 
 ['arahman3', 'harika', 'hchen32', 'jlaivins', 'kjakkala', 'pjanakar', 'ppinyoan', 'pwang13', 'upattnai', 'wrang']
(8737, 256, 256, 1) (8737, 2)
(8547, 256, 256, 1) (8547, 2) 
 ['arahman3', 'hchen32', 'jlaivins', 'kjakkala', 'pjanakar', 'ppinyoan', 'pwang13', 'upattnai', 'wrang']
Final shapes: 
Source: (4615, 256, 256, 1) (4615, 9) (513, 256, 256, 1) (513, 9)
Time: (0, 256, 256, 1) (0, 9) (3419, 256, 256, 1) (3419, 9)
Conf: (0, 256, 256, 1) (0,) (1350, 256, 256, 1) (1350, 9)
Server (898, 256, 256, 1) (898, 9) (448, 256, 256, 1) (448, 9)
Office: (899, 256, 256, 1) (899, 9)


In [4]:
LAMBDA = 10

loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)
  generated_loss = loss_obj(tf.zeros_like(generated), generated)
  total_disc_loss = real_loss + generated_loss
  return total_disc_loss * 0.5

def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

def get_cross_entropy_loss(labels, logits):
  loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
  return tf.reduce_mean(loss)

In [5]:
tb_gen_src_entropy_loss           = tf.keras.metrics.Mean(name='gen_src_entropy_loss')
tb_gen_trg_entropy_loss_src_guide = tf.keras.metrics.Mean(name='gen_trg_entropy_loss_src_guide')
tb_gen_trg_entropy_loss_src_data  = tf.keras.metrics.Mean(name='gen_trg_entropy_loss_src_data')
tb_disc_src_loss                  = tf.keras.metrics.Mean(name='disc_src_loss')
tb_disc_trg_loss                  = tf.keras.metrics.Mean(name='disc_trg_loss')
tb_gen_src_loss                   = tf.keras.metrics.Mean(name='gen_src_loss')
tb_gen_trg_loss                   = tf.keras.metrics.Mean(name='gen_trg_loss')
tb_src_identity_loss              = tf.keras.metrics.Mean(name='src_identity_loss')
tb_trg_identity_loss              = tf.keras.metrics.Mean(name='trg_identity_loss')
tb_total_gen_s_loss               = tf.keras.metrics.Mean(name='total_gen_s_loss')   
tb_total_gen_t_loss               = tf.keras.metrics.Mean(name='total_gen_t_loss')
tb_total_clas_s_loss              = tf.keras.metrics.Mean(name='total_clas_s_loss')   
tb_total_clas_t_loss              = tf.keras.metrics.Mean(name='total_clas_t_loss')
temporal_test_acc    = tf.keras.metrics.CategoricalAccuracy(name='temporal_test_acc')
source_train_acc     = tf.keras.metrics.CategoricalAccuracy(name='source_train_acc')
source_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='source_test_acc')
office_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='office_test_acc')
server_train_acc     = tf.keras.metrics.CategoricalAccuracy(name='server_train_acc')
server_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='server_test_acc')
conference_test_acc  = tf.keras.metrics.CategoricalAccuracy(name='conference_test_acc')

@tf.function
def train_step(src_x, src_y, trg_x, trg_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    src_x_fake   = generator_s(trg_x,      training=True)
    trg_x_cycled = generator_t(src_x_fake, training=True)

    trg_x_fake   = generator_t(src_x,      training=True)
    src_x_cycled = generator_s(trg_x_fake, training=True)
    
    src_x_same = generator_s(src_x, training=True)
    trg_x_same = generator_t(trg_x, training=True)
    
    src_logits, _ = classifier_s(src_x, training=True)
    trg_logits, _ = classifier_t(trg_x, training=True)
    
    src_logits_fake, _ = classifier_s(src_x_fake, training=True)
    trg_logits_fake, _ = classifier_t(trg_x_fake, training=True)
    
    #Classifier##########################################
    gen_src_entropy_loss = get_cross_entropy_loss(labels=src_y, 
                                                  logits=src_logits)
    
    gen_trg_entropy_loss_src_guide = get_cross_entropy_loss(labels=tf.nn.softmax(src_logits_fake), 
                                                            logits=trg_logits)   
    gen_trg_entropy_loss_src_data  = get_cross_entropy_loss(labels=src_y, 
                                                            logits=trg_logits_fake)
    
    #Discriminator##########################################
    disc_src      = discriminator_s(src_x, training=True)
    disc_src_fake = discriminator_s(src_x_fake, training=True)
    disc_src_loss = discriminator_loss(disc_src, disc_src_fake)

    disc_trg      = discriminator_t(trg_x, training=True)
    disc_trg_fake = discriminator_t(trg_x_fake, training=True)
    disc_trg_loss = discriminator_loss(disc_trg, disc_trg_fake)

    #Generator##########################################
    gen_src_loss = generator_loss(disc_src_fake)
    gen_trg_loss = generator_loss(disc_trg_fake)
    
    src_identity_loss = identity_loss(src_x, src_x_same)
    trg_identity_loss = identity_loss(trg_x, trg_x_same)
    
    #Loss##########################################
    total_gen_s_loss = gen_src_loss + src_identity_loss + gen_src_entropy_loss
    total_gen_t_loss = gen_trg_loss + \
                       trg_identity_loss + \
                       gen_trg_entropy_loss_src_guide + \
                       gen_trg_entropy_loss_src_data 
    
    total_clas_s_loss = gen_src_entropy_loss
    total_clas_t_loss = gen_trg_entropy_loss_src_guide + gen_trg_entropy_loss_src_data 
    
  # Calculate the gradients for generator and discriminator
  classifier_s_gradients = tape.gradient(total_clas_s_loss, 
                                         classifier_s.trainable_variables)
  classifier_t_gradients = tape.gradient(total_clas_t_loss, 
                                         classifier_t.trainable_variables)
  
  generator_s_gradients = tape.gradient(total_gen_s_loss, 
                                        generator_s.trainable_variables)
  generator_t_gradients = tape.gradient(total_gen_t_loss, 
                                        generator_t.trainable_variables)
  
  discriminator_s_gradients = tape.gradient(disc_src_loss, 
                                            discriminator_s.trainable_variables)
  discriminator_t_gradients = tape.gradient(disc_trg_loss, 
                                            discriminator_t.trainable_variables)
    
  # Apply the gradients to the optimizer
  classifier_s_optimizer.apply_gradients(zip(classifier_s_gradients, 
                                             classifier_s.trainable_variables))

  classifier_t_optimizer.apply_gradients(zip(classifier_t_gradients, 
                                             classifier_t.trainable_variables))
  
  generator_s_optimizer.apply_gradients(zip(generator_s_gradients, 
                                            generator_s.trainable_variables))

  generator_t_optimizer.apply_gradients(zip(generator_t_gradients, 
                                            generator_t.trainable_variables))
  
  discriminator_s_optimizer.apply_gradients(zip(discriminator_s_gradients,
                                                discriminator_s.trainable_variables))
  
  discriminator_t_optimizer.apply_gradients(zip(discriminator_t_gradients,
                                                discriminator_t.trainable_variables))
  
  tb_gen_src_entropy_loss(gen_src_entropy_loss)
  tb_gen_trg_entropy_loss_src_guide(gen_trg_entropy_loss_src_guide)
  tb_gen_trg_entropy_loss_src_data(gen_trg_entropy_loss_src_data)
  tb_disc_src_loss(disc_src_loss)
  tb_disc_trg_loss(disc_trg_loss)
  tb_gen_src_loss(gen_src_loss)
  tb_gen_trg_loss(gen_trg_loss)
  tb_src_identity_loss(src_identity_loss)
  tb_trg_identity_loss(trg_identity_loss)
  tb_total_gen_s_loss(total_gen_s_loss)      
  tb_total_gen_t_loss(total_gen_t_loss)    
  tb_total_clas_s_loss(total_clas_s_loss)    
  tb_total_clas_t_loss(total_clas_t_loss)    
  source_train_acc(src_y, tf.nn.softmax(src_logits))
  server_train_acc(trg_y, tf.nn.softmax(trg_logits))

@tf.function
def test_step(images):
  logits, _ =  classifier_s(images, training=False)
  return tf.nn.softmax(logits)

In [6]:
generator_t     = UNetGenerator(input_channels=1, output_channels=1, norm_type='instancenorm')
generator_s     = UNetGenerator(input_channels=1, output_channels=1, norm_type='instancenorm')
discriminator_s = PatchGanDiscriminator(input_channels=1, norm_type='instancenorm', target=False)
discriminator_t = PatchGanDiscriminator(input_channels=1, norm_type='instancenorm', target=False)
classifier_s    = ResNet50(num_classes, num_features, "selu")
classifier_t    = ResNet50(num_classes, num_features, "selu")

learning_rate  = tf.keras.optimizers.schedules.PolynomialDecay(init_lr,
                                                               decay_steps=5000,
                                                               end_learning_rate=init_lr*1e-2)
classifier_learning_rate  = tf.keras.optimizers.schedules.PolynomialDecay(init_lr*1e-2,
                                                                          decay_steps=5000,
                                                                          end_learning_rate=init_lr*1e-4)
classifier_s_optimizer    = tf.keras.optimizers.Adam(classifier_learning_rate, beta_1=0.5)
classifier_t_optimizer    = tf.keras.optimizers.Adam(classifier_learning_rate, beta_1=0.5)
generator_s_optimizer     = tf.keras.optimizers.Adam(learning_rate, beta_1=0.5)
generator_t_optimizer     = tf.keras.optimizers.Adam(learning_rate, beta_1=0.5)
discriminator_s_optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.5)
discriminator_t_optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.5)

summary_writer = tf.summary.create_file_writer(log_dir)

classifier_ckpt = tf.train.Checkpoint(model=classifier_s)
ckpt_manager = tf.train.CheckpointManager(classifier_ckpt, classifier_checkpoint_path, max_to_keep=5)
if ckpt_manager.latest_checkpoint:
  classifier_ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
  print ('Latest checkpoint restored to source classifier!!')
    
classifier_ckpt = tf.train.Checkpoint(model=classifier_t)
ckpt_manager = tf.train.CheckpointManager(classifier_ckpt, classifier_checkpoint_path, max_to_keep=5)
if ckpt_manager.latest_checkpoint:
  classifier_ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
  print ('Latest checkpoint restored to target classifier!!')
    

ckpt = tf.train.Checkpoint(generator_t=generator_t,
                           generator_s=generator_s,
                           discriminator_t=discriminator_t,
                           discriminator_s=discriminator_s,
                           classifier_t=classifier_t,
                           classifier_s=classifier_s)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

Latest checkpoint restored to source classifier!!
Latest checkpoint restored to target classifier!!


In [7]:
for epoch in range(epochs):
  for source_data, server_data in zip(src_train_set, server_train_set):
    train_step(source_data[0], source_data[1], server_data[0], server_data[1])
    
  for data in time_test_set:
    temporal_test_acc(test_step(data[0]), data[1])

  for data in src_test_set:
    source_test_acc(test_step(data[0]), data[1])

  for data in office_test_set:
    office_test_acc(test_step(data[0]), data[1])

  for data in server_test_set:
    server_test_acc(test_step(data[0]), data[1])

  for data in conf_test_set:
    conference_test_acc(test_step(data[0]), data[1])

  with summary_writer.as_default():
    tf.summary.scalar("tb_gen_src_entropy_loss", tb_gen_src_entropy_loss.result(), step=epoch)
    tf.summary.scalar("tb_gen_trg_entropy_loss_src_guide", tb_gen_trg_entropy_loss_src_guide.result(), step=epoch)
    tf.summary.scalar("tb_gen_trg_entropy_loss_src_data", tb_gen_trg_entropy_loss_src_data.result(), step=epoch)
    tf.summary.scalar("tb_disc_src_loss", tb_disc_src_loss.result(), step=epoch)
    tf.summary.scalar("tb_disc_trg_loss", tb_disc_trg_loss.result(), step=epoch)
    tf.summary.scalar("tb_gen_src_loss", tb_gen_src_loss.result(), step=epoch)
    tf.summary.scalar("tb_gen_trg_loss", tb_gen_trg_loss.result(), step=epoch)
    tf.summary.scalar("tb_src_identity_loss", tb_src_identity_loss.result(), step=epoch)
    tf.summary.scalar("tb_trg_identity_loss", tb_trg_identity_loss.result(), step=epoch)
    tf.summary.scalar("tb_total_gen_s_loss", tb_total_gen_s_loss.result(), step=epoch)
    tf.summary.scalar("tb_total_gen_t_loss", tb_total_gen_t_loss.result(), step=epoch)
    tf.summary.scalar("tb_total_clas_s_loss", tb_total_clas_s_loss.result(), step=epoch)
    tf.summary.scalar("tb_total_clas_t_loss", tb_total_clas_t_loss.result(), step=epoch)    
    tf.summary.scalar("source_train_acc", source_train_acc.result(), step=epoch)
    tf.summary.scalar("server_train_acc", server_train_acc.result(), step=epoch)
    tf.summary.scalar("temporal_test_acc", temporal_test_acc.result(), step=epoch)
    tf.summary.scalar("source_test_acc", source_test_acc.result(), step=epoch)
    tf.summary.scalar("office_test_acc", office_test_acc.result(), step=epoch)
    tf.summary.scalar("server_test_acc", server_test_acc.result(), step=epoch)
    tf.summary.scalar("conference_test_acc", conference_test_acc.result(), step=epoch)
    
  if (epoch + 1) % 25 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))
    
  tb_gen_src_entropy_loss.reset_states()
  tb_gen_trg_entropy_loss_src_guide.reset_states()
  tb_gen_trg_entropy_loss_src_data.reset_states()
  tb_disc_src_loss.reset_states()
  tb_disc_trg_loss.reset_states()
  tb_gen_src_loss.reset_states()
  tb_gen_trg_loss.reset_states()
  tb_src_identity_loss.reset_states()
  tb_trg_identity_loss.reset_states()
  tb_total_gen_s_loss.reset_states()    
  tb_total_gen_t_loss.reset_states() 
  tb_total_clas_s_loss.reset_states() 
  tb_total_clas_t_loss.reset_states()     
  source_train_acc.reset_states()
  server_train_acc.reset_states()
  temporal_test_acc.reset_states()
  source_test_acc.reset_states()
  office_test_acc.reset_states()
  server_test_acc.reset_states()
  conference_test_acc.reset_states()

Saving checkpoint for epoch 25 at /home/kjakkala/mmwave/checkpoints/classes-9_bs-16_train_src_days-6_train_trg_days-0_train_trgenv_days-2_initlr-0.0001_num_feat-256_act_fn-selu_cyclegan_server_clas_lr_0.000001/ckpt-1
Saving checkpoint for epoch 50 at /home/kjakkala/mmwave/checkpoints/classes-9_bs-16_train_src_days-6_train_trg_days-0_train_trgenv_days-2_initlr-0.0001_num_feat-256_act_fn-selu_cyclegan_server_clas_lr_0.000001/ckpt-2
Saving checkpoint for epoch 75 at /home/kjakkala/mmwave/checkpoints/classes-9_bs-16_train_src_days-6_train_trg_days-0_train_trgenv_days-2_initlr-0.0001_num_feat-256_act_fn-selu_cyclegan_server_clas_lr_0.000001/ckpt-3
Saving checkpoint for epoch 100 at /home/kjakkala/mmwave/checkpoints/classes-9_bs-16_train_src_days-6_train_trg_days-0_train_trgenv_days-2_initlr-0.0001_num_feat-256_act_fn-selu_cyclegan_server_clas_lr_0.000001/ckpt-4
Saving checkpoint for epoch 125 at /home/kjakkala/mmwave/checkpoints/classes-9_bs-16_train_src_days-6_train_trg_days-0_train_trgenv

KeyboardInterrupt: 