In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Train the ML model

from models import Decoder
# from utils import load_model_data_new, normalize_params
from utils import plot_loss, decoder_files_to_tensors
from utils import sample_files

import time
import glob
import tensorflow as tf
from tensorflow import keras
import yaml
import os
import numpy as np
from datetime import datetime
import argparse
import matplotlib.pyplot as plt
import matplotlib as mpl


In [None]:
# Initialize parameters
# data_dir = './tomo_data/datasets_decoder_02-12-22'
data_dir = './tomo_data/datasets_decoder_TF_16-12-22'

timestamp = datetime.now().strftime("%Y_%m_%d_%H-%M-%S")

# Data specific
IMG_OUTPUT_SIZE = 128
BATCH_SIZE = 32  # 8
latent_dim = 6  # 6 + the new VrfSPS - inten
additional_latent_dim = 1

var_names = ['turn', 'phEr', 'enEr', 'bl', 'inten', 'Vrf', 'mu', 'VrfSPS']

# Train specific
train_cfg = {
    'epochs': 100,
    'dense_layers': [latent_dim + additional_latent_dim, 64, 1024],
    'filters': [32, 16, 8, 1],
    'kernel_size': 7,
    'strides': [2, 2],
    'final_kernel_size': 5,
    'activation': 'relu',
    'final_activation': 'tanh',
    'dropout': 0.,
    'loss': 'mse',
    'normalization': 'minmax',
    'lr': 1e-3,
    'dataset%': 0.01,
    'loss_weights': [0, 1, 2, 3, 5, 6, 7],
}

In [None]:
# Initialize directories
trial_dir = os.path.join('./trials/', timestamp)
weights_dir = os.path.join(trial_dir, 'weights')
plots_dir = os.path.join(trial_dir, 'plots')

# Initialize train/ test / validation paths
ML_dir = os.path.join(data_dir, 'ML_data')
TRAINING_PATH = os.path.join(ML_dir, 'TRAINING')
VALIDATION_PATH = os.path.join(ML_dir, 'VALIDATION')
assert os.path.exists(TRAINING_PATH)
assert os.path.exists(VALIDATION_PATH)

# create the directory to store the results
os.makedirs(trial_dir, exist_ok=True)
os.makedirs(weights_dir, exist_ok=False)
os.makedirs(plots_dir, exist_ok=False)

# Initialize GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
device_to_use = 0

if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        tf.config.experimental.set_memory_growth(gpus[device_to_use], True)
        tf.config.experimental.set_virtual_device_configuration(
            gpus[device_to_use],
            [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=12*1024)])
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(
            logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)
else:
    print('No GPU available, using the CPU')


In [None]:
# First the training data
file_names = sample_files(
    TRAINING_PATH, train_cfg['dataset%'], keep_every=1)
print('Number of Training files: ', len(file_names))
x_train, y_train = decoder_files_to_tensors(
    file_names, normalization=train_cfg['normalization'])

# Repeat for validation data
file_names = sample_files(
    VALIDATION_PATH, train_cfg['dataset%'], keep_every=1)
print('Number of Validation files: ', len(file_names))

x_valid, y_valid = decoder_files_to_tensors(
    file_names, normalization=train_cfg['normalization'])


In [None]:
# drop column from y_train, y_valid
x_train = tf.concat([tf.expand_dims(tf.gather(x_train, i, axis=1), axis=1)
                     for i in train_cfg['loss_weights']], -1)
print('x_train shape: ', x_train.shape)

x_valid = tf.concat([tf.expand_dims(tf.gather(x_valid, i, axis=1), axis=1)
                     for i in train_cfg['loss_weights']], -1)
print('x_valid shape: ', x_valid.shape)


In [None]:
%matplotlib inline
# plot some of the outputs

nrows = 3
# Get nrows * nrows random images
sample = np.random.choice(np.arange(len(x_train)), size=nrows * nrows, replace=False)

samples_X = tf.gather(x_train, sample)
samples_y = tf.gather(y_train, sample)

# Create 3x3 grid of figures
fig, axes = plt.subplots(ncols=nrows, nrows=nrows, figsize=(12, 12))
axes = np.ravel(axes)
for i in range(len(axes)):
    ax = axes[i]
    ax.set_xticks([])
    ax.set_yticks([])
    # show the image
    ax.imshow(samples_y[i], cmap='jet')
    # Set the label
    title = ','.join([f'{num:.1f}' for num in samples_X[i]])
    ax.set_title(f'{title}')


In [None]:
# Model instantiation
input_shape = (IMG_OUTPUT_SIZE, IMG_OUTPUT_SIZE, 1)

decoder = Decoder(input_shape, **train_cfg)

print(decoder.model.summary())


In [None]:
# Train the decoder

# callbacks, save the best model, and early stop if no improvement in val_loss
stop_early = keras.callbacks.EarlyStopping(monitor='val_loss',
                                           patience=10, restore_best_weights=True)
save_best = keras.callbacks.ModelCheckpoint(filepath=os.path.join(weights_dir, 'decoder.h5'),
                                            monitor='val_loss', save_best_only=True)


In [None]:
start_time = time.time()
history = decoder.model.fit(
    x_train, y_train, epochs=train_cfg['epochs'],
    validation_data=(x_valid, y_valid), batch_size=BATCH_SIZE,
    callbacks=[save_best])

total_time = time.time() - start_time
print(
    f'\n---- Training complete, epochs: {len(history.history["loss"])}, total time {total_time} ----\n')


In [None]:
# Plot training and validation loss
print('\n---- Plotting loss ----\n')
train_loss_l = np.array(history.history['loss'])
valid_loss_l = np.array(history.history['val_loss'])

plot_loss({'Training': train_loss_l, 'Validation': valid_loss_l},
          title='Decoder Train/Validation Loss',
          figname=os.path.join(plots_dir, 'decoder_train_valid_loss.png'))


In [None]:
# save file with experiment configuration
print('\n---- Saving a summary ----\n')

config_dict = {}
config_dict['decoder'] = train_cfg.copy()

config_dict['decoder'].update({
    'epochs': len(history.history["loss"]),
    'min_train_loss': float(np.min(train_loss_l)),
    'min_valid_loss': float(np.min(valid_loss_l)),
    'total_train_time': total_time,
    'used_gpus': len(gpus)
})

# save config_dict
with open(os.path.join(trial_dir, 'decoder-summary.yml'), 'w') as configfile:
    yaml.dump(config_dict, configfile, default_flow_style=False)
