In [1]:
### Install requirements

# Uncomment to install missing pre-requisites
#!pip install elasticdeform
#!pip install tensorflow_addons
#!pip install nibabel
#!pip install matplotlib
#!pip install sklearn

In [2]:
### Imports

# System
import glob
import time
import os
import concurrent.futures
import importlib

# Visualization
import matplotlib.pyplot as plt

# Tensorflow
import tensorflow as tf

# Numerical calculations
import numpy as np
from sklearn.model_selection import train_test_split

# Own stuff
import config as conf
importlib.reload(conf)

import scan_loader
from data_generator import DataGenerator
import architecture


from print2file import *
if conf.log_file is not None:
    with open(conf.log_file, "w"):
        pass

In [3]:
# GPU Setup

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = conf.gpu

# Tensorflow 2.XX\n",
allow_multi_gpu = True
tf_version = 2

if tf_version == 2 and allow_multi_gpu:
    gpus = tf.config.experimental.list_physical_devices('GPU')
    print2file(f"Num GPUs Available: {len(gpus)}")
    if gpus:
        for gpu in gpus:
          tf.config.experimental.set_memory_growth(gpu, True)

Num GPUs Available: 1


In [4]:
### Loads training/validation data

# Fetches list of image paths for each modality and segmentation ground truth
path_train = conf.aug_export_path + "/*/"

t1_list_train    = sorted(glob.glob(path_train + '*t1.nii.gz'))
t2_list_train    = sorted(glob.glob(path_train + '*t2.nii.gz'))
t1ce_list_train  = sorted(glob.glob(path_train + '*t1ce.nii.gz'))
flair_list_train = sorted(glob.glob(path_train + '*flair.nii.gz'))
seg_list_train   = sorted(glob.glob(path_train + '*seg.nii.gz'))

# Splits data, or loads existing splits.
if conf.make_new_splits:
    ids_list = list(np.arange(len(t1_list_train)))

    # Splits off random 20% as test data
    idxTrain, idxValid =  train_test_split(ids_list, test_size = 0.2, shuffle = True)
    idxTrain = sorted(idxTrain)
    idxValid = sorted(idxValid)

    # Saves indices to numpy files
    np.save(conf.output_path + "/idxValid.npy", idxValid)
    np.save(conf.output_path + "/idxTrain.npy", idxTrain)
else:
    # Loads indices from numpy files
    idxTrain = np.load(conf.output_path + '/idxTrain.npy')
    idxValid = np.load(conf.output_path + '/idxValid.npy')
    
# Compiles the lists from all modalities into one set
sets = {'train': [], 'valid': []}

for i in idxTrain:
    sets['train'].append([t1_list_train[i], t2_list_train[i], t1ce_list_train[i], flair_list_train[i], seg_list_train[i]])
for i in idxValid:
    sets['valid'].append([t1_list_train[i], t2_list_train[i], t1ce_list_train[i], flair_list_train[i], seg_list_train[i]])
    
# Debug info
num_data = len(idxValid) + len(idxTrain)

print2file(f"Num validation data: {len(idxValid)} ({len(idxValid)/num_data*100}%)")
print2file(f"Num training data: {len(idxTrain)} ({len(idxTrain)/num_data*100}%)")

Num validation data: 2 (33.33333333333333%)
Num training data: 4 (66.66666666666666%)


In [5]:
### Data Generator Initialization

# Training data generator
train_gen = DataGenerator(sets['train'], 
                          shuffle    = True,
                          input_dim  = conf.augmented_dim,
                          output_dim = conf.model_dim,
                          batch_size = conf.batch_size,
                          n_channels = conf.num_channels,
                          n_classes  = conf.num_classes)

# Validation data generator
valid_gen = DataGenerator(sets['valid'], 
                          shuffle    = True,
                          input_dim  = conf.augmented_dim,
                          output_dim = conf.model_dim,
                          batch_size = conf.batch_size,
                          n_channels = conf.num_channels,
                          n_classes  = conf.num_classes)

## GAN: Vox2Vox

In [6]:
### Model Initialization

# Input shape for images and ground truth:
im_shape = (*conf.model_dim, conf.num_channels) 
gt_shape = (*conf.model_dim, conf.num_classes)

# Loads initial class weights.
class_weights = np.load('class_weights.npy')

# Initializes the model.
gan = architecture.vox2vox(im_shape, gt_shape, class_weights,
                           output_path = conf.output_path, 
                           save_images = conf.export_images) 

# Loads existing model weights.
if conf.continue_training:
    gan.generator.load_weights(conf.output_path + '/Generator.h5')
    gan.discriminator.load_weights(conf.output_path + '/Discriminator.h5')
    gan.combined.load_weights(conf.output_path + '/Vox2Vox.h5') 

In [7]:
### Model training!
start_time = time.time()

trends_train, trends_valid = gan.train(train_gen, valid_gen, conf.num_epochs)

duration = time.time() - start_time
print2file(f"Training duration: {duration} seconds")

Training process:
Training on 1 and validating on 1 batches.

Epoch 1/100
Training Batch: 1/1 - v2v_loss: 16.6837
Training Batch Average: v2v_loss: 16.6837

Validation Batch: 1/1 - v2v_loss: 10.0833
Validation Batch Average: v2v_loss_val: 10.0833

Elapsed time: 0:33 mm:ss

Epoch 0 was best model so far.
Epoch 2/100
Training Batch: 1/1 - v2v_loss: 5.6400
Training Batch Average: v2v_loss: 5.6400

Validation Batch: 1/1 - v2v_loss: 6.0003
Validation Batch Average: v2v_loss_val: 6.0003

Elapsed time: 0:17 mm:ss

Epoch 1 was best model so far.
Epoch 3/100
Training Batch: 1/1 - v2v_loss: 5.8880
Training Batch Average: v2v_loss: 5.8880

Validation Batch: 1/1 - v2v_loss: 6.4689
Validation Batch Average: v2v_loss_val: 6.4689

Elapsed time: 0:18 mm:ss

Epoch 4/100
Training Batch: 1/1 - v2v_loss: 5.3554
Training Batch Average: v2v_loss: 5.3554

Validation Batch: 1/1 - v2v_loss: 6.7894
Validation Batch Average: v2v_loss_val: 6.7894

Elapsed time: 0:17 mm:ss

Epoch 5/100


KeyboardInterrupt: 

In [None]:
### Simple training evaluation

# Plots loss over epochs.
train_losses = np.load(conf.output_path + '/history_train.npy', allow_pickle=True).tolist()
valid_losses = np.load(conf.output_path + '/history_valid.npy', allow_pickle=True).tolist()

train_losses = train_losses['loss']
valid_losses = valid_losses['loss']

# Gets 10 smallest losses from the history.
num_min = min(10, len(train_losses)-1)
min_loss_indices = np.argpartition(valid_losses, num_min)[:num_min]

print2file(f'Minimum validation loss: Epoch: {np.argmin(valid_losses)} - Loss: {np.min(valid_losses)}')
print2file([f"Epoch: {e} - Loss: {round(valid_losses[e],2)}" for e in min_loss_indices])

# Shows plot for training and validation loss.
plt.plot(train_losses, label="Training Loss")
plt.plot(valid_losses, label="Validation Loss")
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend()

In [8]:
### Saves the model settings for the segmentation module.

import json

#        self.model_dim = config["model_dim"]
#        self.model_path = config["model_path"]
#        self.class_weights = config["class_weights"]
#        self.num_classes = config["num_classes"]
#        self.num_modalities = config["num_modalities"]

seg_config = {
    "model_dim": conf.model_dim,
    "model_path": "/mnt/Data/Vox2vox_output/Generator.h5",
    "class_weights": class_weights.tolist(),
    "num_classes": conf.num_classes,
    "num_modalities": conf.num_channels
}

with open(conf.output_path + "/seg.json", "w") as seg_config_file:
    json.dump(seg_config, seg_config_file, indent=4)
    
print("Done")

Done
