In [1]:
# This notebook will be used for evaluation a trained model. TODO!

In [2]:
# TODO: exchange config dictionary with config file and change corresponding lines in code to conf.<param>
import config as conf
import importlib
importlib.reload(conf)

<module 'config' from '/home/msc_student/vox2vox/config.py'>

In [3]:
# Uncomment to install missing pre-requisites
#!pip install tensorflow_addons
#!pip install matplotlib

In [4]:
# System imports
import glob
import time
import os
from sys import stdout
import concurrent.futures

# Tensorflow
import tensorflow as tf

# Numerical calculations
import numpy as np

# Own stuff
from helper_functions import scan_loader
from model.data_generator import DataGenerator
from model import architecture

In [5]:
### Sets up GPU(s).

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = conf.gpu

# Tensorflow 2.XX\n",
allow_multi_gpu = True
tf_version = 2

import tensorflow as tf
import os
if tf_version == 2 and allow_multi_gpu:
    gpus = tf.config.experimental.list_physical_devices('GPU')
    print("Num GPUs Available: ", len(gpus))
    if gpus:
        for gpu in gpus:
          tf.config.experimental.set_memory_growth(gpu, True)

Num GPUs Available:  1


In [6]:
### Finds all test data.

# Fetches list of image paths for each modality and segmentation ground truth
path_test = conf.dataset_mask_test
t1_list    = sorted(glob.glob(path_test + '*t1.nii.gz'))
t2_list    = sorted(glob.glob(path_test + '*t2.nii.gz'))
t1ce_list  = sorted(glob.glob(path_test + '*t1ce.nii.gz'))
flair_list = sorted(glob.glob(path_test + '*flair.nii.gz'))
print(f"Num test data: {len(t1_list)}")

test_data = []
for i in range(len(t1_list)):
    test_data.append([t1_list[i], t2_list[i], t1ce_list[i], flair_list[i], None])

Num test data: 125


In [7]:
### Creates test data generator.

# Test data generator
test_gen = DataGenerator(test_data, 
                         shuffle      = False,
                         batch_size   = conf.batch_size,
                         input_dim    = conf.dataset_dim,
                         output_dim   = conf.model_dim,
                         n_channels   = conf.num_channels,
                         n_classes    = conf.num_classes,
                         ground_truth = False,
                         preprocessed = False)

## GAN: Vox2Vox

In [8]:
### Creates the model we want to evaluate on the test data.

im_shape = (*conf.model_dim, conf.num_channels) 
gt_shape = (*conf.model_dim, conf.num_classes)
class_weights = np.load('resources/class_weights.npy')

# Initializes the model.
gan = architecture.vox2vox(im_shape, gt_shape, class_weights,
                           output_path = conf.output_path, 
                           save_images = conf.export_images)

# Loads the model weights.
gan.generator.load_weights(conf.weight_folder + '/Generator.h5')
#gan.discriminator.load_weights(conf.output_path + '/backups/training_7_new_thresholds/Discriminator_62.h5')
#gan.combined.load_weights(conf.output_path + '/backups/training_7_new_thresholds/Vox2Vox_62.h5') 

In [9]:
# Export test segmentation results

# Use this section to upload your segmentation labels in .nii.gz format. 
# Note that each file should be named using the patient ID, given by the folder name 
# containing the 4 modalities for each patient. 
# In other words, for subjects that you were given files named ID_t1.nii.gz, ID_t2.nii.gz, etc., 
# the uploaded segmenations should be named ID.nii.gz

batch_counter = 0
for Xbatch, Ybatch, IDbatch in test_gen:
    
    # Predicts tumor segmentation for the batch
    # gen_pred: num_batches x X x Y x Z x num_classes
    gen_pred = gan.generator.predict(Xbatch)
    
    # Transforms segmentation back to original dimensions.
    # The segmentation must have the same shape as the test data, or it can't be evaluated by BraTS!
    _, gen_pred = scan_loader.make_size_batch(None, gen_pred, conf.dataset_dim)
    
    # Saves everry segmentation in this batch separately.
    for i in range(len(IDbatch)):
        data = gen_pred[i,:,:,:,:]
        patient_id = IDbatch[i]
        
        # Finds the most likely class for every voxel.
        seg = np.argmax(data, axis=-1).astype('float32')
        
        # Switches label 3 back to label 4 after applying a threshold to ET.
        seg_enhancing = (seg == 3)
        
        if np.sum(seg_enhancing) < 1500:
            seg[seg_enhancing] = 1
        else:
            seg[seg_enhancing] = 4
        
        # Applies a threshold to NT
        seg_core = (seg == 1)
        if np.sum(seg_core) < 500:
            seg[seg_core] = 2
        
        
        # Saves the segmentation to an output folder.
        path = f"{conf.output_path}/{conf.eval_export_path}/{patient_id}.nii.gz"
        scan_loader.save_img(seg, path)
        print(f"Saved {patient_id}")
    

Saved BraTS20_Validation_001
Saved BraTS20_Validation_002
Saved BraTS20_Validation_003
Saved BraTS20_Validation_004
Saved BraTS20_Validation_005
Saved BraTS20_Validation_006
Saved BraTS20_Validation_007
Saved BraTS20_Validation_008
Saved BraTS20_Validation_009
Saved BraTS20_Validation_010
Saved BraTS20_Validation_011
Saved BraTS20_Validation_012
Saved BraTS20_Validation_013
Saved BraTS20_Validation_014
Saved BraTS20_Validation_015
Saved BraTS20_Validation_016
Saved BraTS20_Validation_017
Saved BraTS20_Validation_018
Saved BraTS20_Validation_019
Saved BraTS20_Validation_020
Saved BraTS20_Validation_021
Saved BraTS20_Validation_022
Saved BraTS20_Validation_023
Saved BraTS20_Validation_024
Saved BraTS20_Validation_025
Saved BraTS20_Validation_026
Saved BraTS20_Validation_027
Saved BraTS20_Validation_028
Saved BraTS20_Validation_029
Saved BraTS20_Validation_030
Saved BraTS20_Validation_031
Saved BraTS20_Validation_032
Saved BraTS20_Validation_033
Saved BraTS20_Validation_034
Saved BraTS20_