## Model Training
---
**Import Statements**

In [None]:
import os
import sys
import random
import pickle

In [None]:
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
import pandas as pd
import import_ipynb
from losses import *

In [None]:
import tensorflow as tf
from keras.callbacks import CSVLogger
import warnings
warnings.filterwarnings('ignore')
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session =tf.compat.v1.InteractiveSession(config=config)
print("The following GPU devices are available: %s" % tf.test.gpu_device_name())

In [None]:
#Select model
from model_192x192x96 import *
#from model_192x160x80 import *

**Settings and Model Hyperparameters**

In [None]:
CT_dir = "" #path to directory with processed CTs (output of preprocess script)
mask_dir = "" #path to directory with processed masks (output of preprocess script)

In [None]:
model_name = "[model_parameters].{epoch:02d}-{val_loss:.2f}.h5

In [None]:
patch_size = (192,160,80)
augment = True
batch_size = 2
val_batch_size = 2
loss_function = bce_dice
n_filters = 32
metrics = [dice_coef, 'accuracy']
instancenorm = False # will use batch norm if false
leakyrelu = True
epochs = 200
min_max_norm = False
csv_logger = CSVLogger('', append=True, separator=';') #name of csv file to log training info
callbacks = [
    EarlyStopping(monitor='val_loss', patience=30, verbose=1),
    ReduceLROnPlateau(monitor='val_loss',factor=0.1, patience=10, min_lr=0.000001, verbose=1),
    ModelCheckpoint(model_name, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True), 
    csv_logger
]

In [None]:
seed = 2020
np.random.seed(seed)
random.seed(seed)  

**Train/Val Sets**

In [None]:
ids = next(os.walk(mask_dir))[2]
random.shuffle(ids)

pts = np.unique([i.split("_")[0] for i in ids])
val_data_size = round(len(pts)*0.14) # 0.14 gives a val set size of 15% (0.15 not used b/c this gives too many val CTs after accounting for pts with more than one CT)
train_pts = pts[val_data_size:]
val_pts = pts[:val_data_size]

train_ids = [i for i in ids if i.split("_")[0] in train_pts] 
val_ids = [i for i in ids if i.split("_")[0] in val_pts] 

val_CT_dir = CT_dir
val_mask_dir = mask_dir

print("Train", len(train_ids), "CTs")
print("Val:", len(val_ids), "CTs")

**Example CT Volume**

In [None]:
# Example
gen = DataGenerator(train_ids, CT_dir, mask_dir, patch_size=patch_size, batch_size=batch_size, min_max_norm=min_max_norm, validation=False, augment=augment, seed=seed, shuffle=True)
x, y = gen.__getitem__(random.randrange(len(gen)))

for i in range(len(x)):
    fig, ax = plt.subplots(figsize=(40,40))
    ax.imshow(montage(np.moveaxis(x[i,:,:,:,0],-1,0)), cmap="gray")
    ax.imshow(montage(np.moveaxis(-y[i,:,:,:,0],-1,0)), cmap="flag", alpha = 0.2)
print(x.shape)

# If image is red, this is b/c augmentation resulted in ct/patch w/o tumor (e.g. from scaling/rotations) - happens infrequently

# **Compile Model**

In [None]:
# Compile model
input_vol = Input((patch_size+(1,)), name='vol')
model = unet(input_vol, n_filters=n_filters, instancenorm=instancenorm, leakyrelu=leakyrelu)
model.compile(optimizer=Adam(), loss=loss_function, metrics=metrics)
model.summary()

In [None]:
# Train and val generators
train_gen = DataGenerator(train_ids, CT_dir, mask_dir, patch_size=patch_size, batch_size=batch_size, min_max_norm=min_max_norm, seed=seed, validation=False, augment=augment, shuffle=True)
val_gen = DataGenerator(val_ids, val_CT_dir, val_mask_dir, patch_size=patch_size, batch_size=val_batch_size, min_max_norm=min_max_norm, seed=seed, validation=True, augment=False, shuffle=False)

**Train Model**

In [None]:
results = model.fit_generator(train_gen, validation_data=val_gen, callbacks=callbacks, epochs=epochs, verbose=1, use_multiprocessing=True, workers=39, max_queue_size=39)

**Plot learning curve**

In [None]:
plt.figure(figsize=(8, 8))
plt.title("Learning curve")
plt.plot(results.history["loss"], label="loss")
plt.plot(results.history["val_loss"], label="val_loss")
plt.plot( np.argmin(results.history["val_loss"]), np.min(results.history["val_loss"]), marker="x", color="r", label="best model")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()