In [None]:
# imports and constants
import os
import glob
import nibabel as nib
import nilearn
import numpy as np
import tensorflow as tf
import random
import scipy
import shutil
from tensorflow import keras
from keras import Sequential, Model
from keras.layers import Dense, Flatten, Dropout, Reshape, Conv3D, Conv2D, Activation, UpSampling3D, \
                         MaxPooling3D, SpatialDropout3D, BatchNormalization, Conv3DTranspose
from keras.layers import concatenate, Input
# from keras.engine import Model
from keras.optimizers import Adam
from keras import backend as K
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from nilearn import image
from datetime import datetime
from scipy import ndimage
import keras_unet
import imageio


###### MODIFY HPARAMS ###########

DEBUG = False # for testing purposes

K_FOLD_VAL_NUM = "ALLFOLDS"

SEQUENCES = ["t1ce", "flair", "t1", "t2"]

TYPE_OF_SEGMENTATION = "ET"

##################################


N_EPOCHS = 60 if DEBUG == False else 2


DEBUG_SAMPLES = 20

print(tf.test.is_gpu_available())

# TRAINING_DATA_ROOT_DIR = "data/"
# TRAIN_DIRS = ["train/LGG", "train/HGG"]

TRAINING_DATA_ROOT_DIR = "data_folds/"

TRAIN_DIRS = ["fold_" + str(i + 1) for i in range(5) if i + 1 != K_FOLD_VAL_NUM]

print(TRAIN_DIRS)

VAL_DIRS = ["val"]


GROUND_TRUTH_FILE = "seg"



N_X = 240
N_Y = 240
N_Z = 155
NUM_CHANNELS = len(SEQUENCES)
NUM_CLASSES = 4 if TYPE_OF_SEGMENTATION == "ALL" else 2


PATCH_X = 80
PATCH_Y = 80
PATCH_Z = 80


BATCH_SIZE = 2

PERCENTILE_95 = 10.204082 * 2 # 20 pixels

SEQUENCES_JOINED = ",".join(SEQUENCES)

EXPERIMENT_NAME = f"{TYPE_OF_SEGMENTATION}_{SEQUENCES_JOINED}_{N_EPOCHS}-epochs_fold-{K_FOLD_VAL_NUM}-holdout"

if DEBUG:
    EXPERIMENT_NAME = "DEBUG_EXPERIMENT"


print(EXPERIMENT_NAME)

EXPERIMENT_FOLDER = os.path.join("experiments", EXPERIMENT_NAME)

CKPT_PATH_PREFIX = os.path.join(EXPERIMENT_FOLDER, EXPERIMENT_NAME + "_ckpt")

print(CKPT_PATH_PREFIX)

In [None]:
import sys

import tensorflow.keras
import pandas as pd
import sklearn as sk
import tensorflow as tf

print(f"Tensor Flow Version: {tf.__version__}")
print(f"Keras Version: {tensorflow.keras.__version__}")
print()
print(f"Python {sys.version}")
print(f"Pandas {pd.__version__}")
print(f"Scikit-Learn {sk.__version__}")
gpu = len(tf.config.list_physical_devices('GPU'))>0
print("GPU is", "available" if gpu else "NOT AVAILABLE")

In [None]:
if DEBUG == True:
    if os.path.isdir(EXPERIMENT_FOLDER):
        shutil.rmtree(EXPERIMENT_FOLDER)

if os.path.isdir(EXPERIMENT_FOLDER):
    raise ValueError(f"Experiment Folder {EXPERIMENT_FOLDER} already exists")
        
        
os.makedirs(EXPERIMENT_FOLDER, exist_ok=True)

In [None]:
# Reading data into X and y

def read_training_data(root_path, dirs):
    # list of numpy arrays
    X = []
    y = []
    
    for d in dirs:
        path = root_path + d
        samples = sorted(glob.glob('{}/*/'.format(path)))

        if DEBUG: # don't want to train on entire dataset if debugging
            samples = samples[:DEBUG_SAMPLES]
        
        for t in samples:
            x_sample = []
            y_sample = []
            for s in SEQUENCES:
                x_file = glob.glob('{}/*{}.nii.gz'.format(t, s))[0] # this is unique
                x_sample.append(np.array(nilearn.image.get_data(x_file)))
                
            X.append(np.array(x_sample))
            y_file = glob.glob('{}/*{}.nii.gz'.format(t, GROUND_TRUTH_FILE))[0] # this is unique
            y.append(np.array(nilearn.image.get_data(y_file)))
            
    return np.array(X), np.array(y)


X, y = read_training_data(TRAINING_DATA_ROOT_DIR, TRAIN_DIRS)

print("X shape: ", X.shape)
print("y shape: ", y.shape)

In [None]:
# preprocessing

def get_patches(X, y, mode):
    
    X_reshaped = []
    y_reshaped = []
    
    for b in range(y.shape[0]):

        if mode == "Training":
            which_sampling = np.random.uniform(0, 1)
            threshold = 1 # alpha value
            
            if np.count_nonzero(y[b, :, :, :, 1] > 0):
                try:
                    centre_of_mass = ndimage.measurements.center_of_mass(y[b, :, :, :, 1])
                except:
                    which_sampling = 1
            else:
                which_sampling = 1
            
            
            # select which sampling mode we want
            if which_sampling < threshold:
                for i in range(100):
                    x_mid = int(centre_of_mass[0] + np.random.normal(0, PERCENTILE_95))
                    y_mid = int(centre_of_mass[1] + np.random.normal(0, PERCENTILE_95))
                    z_mid = int(centre_of_mass[2] + np.random.normal(0, PERCENTILE_95))

                    PATCH_X_RADIUS = int(PATCH_X/2)
                    PATCH_Y_RADIUS = int(PATCH_Y/2)
                    PATCH_Z_RADIUS = int(PATCH_Z/2)

                    if i != 99: # patch is out of bounds
                        if x_mid - PATCH_X_RADIUS < 0 or x_mid + PATCH_X_RADIUS >= N_X or y_mid - PATCH_Y_RADIUS < 0 \
                           or y_mid + PATCH_Y_RADIUS >= N_Y or z_mid - PATCH_Z_RADIUS < 0 or z_mid + PATCH_Z_RADIUS >= N_Z:
                            continue

                    if i == 99:
                        print("Generate anyway")
                        # Failed to generate within bounds, generating around middle of image
                        x_mid = int(N_X/2 + np.random.normal(0, PERCENTILE_95/4))
                        y_mid = int(N_Y/2 + np.random.normal(0, PERCENTILE_95/4))
                        z_mid = int(N_Z/2 + np.random.normal(0, PERCENTILE_95/4))
                        

                    X_reshaped.append(np.copy(X[b, x_mid - PATCH_X_RADIUS : x_mid + PATCH_X_RADIUS, \
                                                y_mid - PATCH_Y_RADIUS : y_mid + PATCH_Y_RADIUS, \
                                                z_mid - PATCH_Z_RADIUS : z_mid + PATCH_Z_RADIUS, :]))
                    y_reshaped.append(np.copy(y[b, x_mid - PATCH_X_RADIUS : x_mid + PATCH_X_RADIUS, \
                                                y_mid - PATCH_Y_RADIUS : y_mid + PATCH_Y_RADIUS, \
                                                z_mid - PATCH_Z_RADIUS : z_mid + PATCH_Z_RADIUS, :]))

                    break
            else:
                done = False
                for retries in range(100):
                    x_rand = np.random.randint(0, N_X - PATCH_X + 1)
                    y_rand = np.random.randint(0, N_Y - PATCH_Y + 1)
                    z_rand = np.random.randint(0, N_Z - PATCH_Z + 1)
                    
                    if (np.count_nonzero(y[b, x_rand:x_rand + PATCH_X, y_rand:y_rand + PATCH_Y, z_rand:z_rand + PATCH_Z, 1]) > 100):
                        X_reshaped.append(np.copy(X[b, x_rand:x_rand + PATCH_X, y_rand:y_rand + PATCH_Y, z_rand:z_rand + PATCH_Z, :]))
                        y_reshaped.append(np.copy(y[b, x_rand:x_rand + PATCH_X, y_rand:y_rand + PATCH_Y, z_rand:z_rand + PATCH_Z, :]))
                        done = True
                        break
                        
                if not done: # failed to pick a patch with at least 100 tumorous voxels. Generate anyway.
                    x_rand = np.random.randint(0, N_X - PATCH_X + 1)
                    y_rand = np.random.randint(0, N_Y - PATCH_Y + 1)
                    z_rand = np.random.randint(0, N_Z - PATCH_Z + 1)
                    
                    X_reshaped.append(np.copy(X[b, x_rand:x_rand + PATCH_X, y_rand:y_rand + PATCH_Y, z_rand:z_rand + PATCH_Z, :]))
                    y_reshaped.append(np.copy(y[b, x_rand:x_rand + PATCH_X, y_rand:y_rand + PATCH_Y, z_rand:z_rand + PATCH_Z, :]))


        elif mode == "Validation":
            x_rand = np.random.randint(0, N_X - PATCH_X + 1)
            y_rand = np.random.randint(0, N_Y - PATCH_Y + 1)
            z_rand = np.random.randint(0, N_Z - PATCH_Z + 1)

            X_reshaped.append(np.copy(X[b, x_rand:x_rand + PATCH_X, y_rand:y_rand + PATCH_Y, z_rand:z_rand + PATCH_Z, :]))
            y_reshaped.append(np.copy(y[b, x_rand:x_rand + PATCH_X, y_rand:y_rand + PATCH_Y, z_rand:z_rand + PATCH_Z, :]))
                    
    X_reshaped = np.array(X_reshaped)
    y_reshaped = np.array(y_reshaped)
    
   
    return X_reshaped, y_reshaped



def _standardize_nii_file(x):
    ret = np.copy(x)
    ret = (ret - x.mean())/x.std() if x.std() > 0 else (ret - x.mean())
    return ret

def standardize_training_features(X, y):
    for i in range(X.shape[0]):
        X[i, :, :, :, :] = np.array(_standardize_nii_file(X[i, :, :, :, :]))
    return X, y

def _convert_labels(y):
    # This section of code modifies the labels in the training dataset.
    # The BraTS dataset provides 4 different labels (0, 1, 2, 4) on a pixel level
    # But we don't care about this level of detail. We only want to identify whether
    # a pixel is tumorous (0/1 classification).
    # In particular, we only care if it is an "enhancing tumor structure"
    # This is the light blue section here: https://www.med.upenn.edu/sbia/brats2018.html
    # From here (https://arxiv.org/pdf/1811.02629.pdf), we see that the labels correspond to the following
    # For BraTS 2017 and above (note Label 3 has been combined with Label 1):
    # Label 1 (+ 3): NCR -- necrotic core, and NET -- Non enhancing tumor
    # Label 2: ED -- Edema
    # Label 4: AT -- Enhancing regions within the gross tumor abnormality
    # Thus, we only care about Label 4. We should therefore set labels 1 and 2 to label 0, and then
    # set label 4 to 1 to achieve what we want.
    
    # Note: 3 is not present in this dataset
    
    # Update: to segment TC, set 2 -> 0, 4 -> 1. 
    # to segment ET, set 1->0, 2->0, 4->1
    
    
    if TYPE_OF_SEGMENTATION == "TC":
        y[y == 2] = 0
        y[y == 4] = 1
    elif TYPE_OF_SEGMENTATION == "ET":
        y[y == 1] = 0
        y[y == 2] = 0
        y[y == 4] = 1
    elif TYPE_OF_SEGMENTATION == "ALL":
        y[y == 4] = 3
    else:
        raise ValueError('invalid segmentation type.')
        
        
    return y


def convert_target_labels(X, y):
    y = _convert_labels(y)
    return X, y

def switch_to_last_channel_mode(X, y):
    X = np.moveaxis(X, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3])
    return X, y

def one_hot_labels(X, y):
    y = keras.utils.to_categorical(y, NUM_CLASSES)
    return X, y



def preprocess_data(X, y, validation=False):
    X, y = convert_target_labels(X, y)
    
    #X, y = one_hot_labels(X, y)
    
    X, y = switch_to_last_channel_mode(X, y)
    
    return X, y




def translation(data):
    return scipy.ndimage.interpolation.shift(data, (5, 5, 0, 0), order=0, mode='nearest')

def rotation(data):
    rotation_angles = [0, 90, 180, 270]
    angle = rotation_angles[np.random.randint(4)]
    
    return scipy.ndimage.rotate(data, angle, reshape=False, order=0, mode='nearest')

def intensity(data):
    return data*float(random.uniform(0.8, 1.2))

def flip(data):
    if random.uniform(0, 1) < 0.5:
        data = np.fliplr(data)
    if random.uniform(0, 1) < 0.5:
        data = np.flipud(data)
    return data

def data_augmentation(data):
    # https://mlnotebook.github.io/post/dataaug/
    ret = np.copy(data)
    
    # disabling for now
    """
    if random.uniform(0, 1) < 0.25: # randomly augment the data.
        if random.uniform(0, 1) < 0.25:
            ret = translation(ret)

        #if random.uniform(0, 1) < 0.20:
        #    ret = intensity(ret)

        if random.uniform(0, 1) < 0.25:
            ret = rotation(ret)

        if random.uniform(0, 1) < 0.25:
            ret = flip(ret)
    """
    return ret




def read_training_data_from_memory(data_ids):
    X_train = []
    y_train = []
    for i in data_ids:
        X_train.append(data_augmentation(X[i, :, :, :, :]))
        y_train.append(data_augmentation(y[i, :, :, :]))
    return one_hot_labels(np.array(X_train), np.array(y_train))

def read_validation_data_from_memory(data_ids):
    X_train = []
    y_train = []
    for i in data_ids:
        X_train.append(data_augmentation(X_val[i, :, :, :, :]))
        y_train.append(data_augmentation(y_val[i, :, :, :]))
    return one_hot_labels(np.array(X_train), np.array(y_train))


def get_data(data_ids, mode):
    
    if mode == "Training":
        X, y = read_training_data_from_memory(data_ids)
    elif mode == "Validation":
        X, y = read_validation_data_from_memory(data_ids)
    
    X, y = get_patches(X, y, mode)
    
    #X, y = standardize_training_features(X, y)
    
    return X, y


X, y = standardize_training_features(X, y)

X, X_test, y, y_test = train_test_split(X, y, test_size=0.1)

X, X_val, y, y_val = train_test_split(X, y, test_size=0.1)

X, y = preprocess_data(X, y)

X_val, y_val = preprocess_data(X_val, y_val, validation=True)

X_test, y_test = preprocess_data(X_test, y_test)



X_val = None
X_test = None
y_val = None
y_test = None


In [None]:
# https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

class DataGenerator(keras.utils.Sequence):
    # 'Generates data for Keras'
    def __init__(self, size, batch_size=BATCH_SIZE, dim=(PATCH_X,PATCH_Y,PATCH_Z), n_channels=NUM_CHANNELS,
                 n_classes=NUM_CLASSES, shuffle=True, mode="Training"):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.size = size
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.mode = mode
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        
        N = 3000 if self.mode == "Training" else 300
        
        return int(N/self.batch_size)

    def __getitem__(self, index):
        'Generate one batch of data'
        indexes = []
        for i in range(self.batch_size):
            indexes.append(random.randint(0, self.size - 1))

        # Generate data
        X, y = self.__data_generation(indexes)
        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        pass

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples'

        X, y = get_data(list_IDs_temp, self.mode)
        
        return X, y

In [None]:
# metrics

def recall(y, y_hat):
    true_positives = K.sum(K.round(K.clip(y * y_hat, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y, 0, 1)))
    return true_positives/(possible_positives + K.epsilon())

def precision(y, y_hat):
    true_positives = K.sum(K.round(K.clip(y * y_hat, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_hat, 0, 1)))
    return true_positives/(predicted_positives + K.epsilon())

def f1(y, y_hat):
    r = recall(y, y_hat)
    p = precision(y, y_hat)
    return 2*((p*r)/(p + r + K.epsilon()))


def dice_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=(0, 1, 2, 3))
    dice_numerator = (2. * intersection + smooth)
    dice_denominator = (K.sum(K.square(y_true),axis=(0, 1, 2, 3)) + K.sum(K.square(y_pred),axis=(0, 1, 2, 3)) + smooth)

    return dice_numerator/dice_denominator


def dice_coef_loss(y_true, y_pred):
    return 1- K.mean(dice_coef(y_true, y_pred))




def tversky(y_true, y_pred, smooth=0.01):
    true_pos = K.sum(y_true * y_pred, axis=(0, 1, 2, 3))
    false_neg = K.sum(y_true * (1-y_pred), axis=(0, 1, 2, 3))
    false_pos = K.sum((1-y_true)*y_pred, axis=(0, 1, 2, 3))
    alpha = 0.7
    return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)

def tversky_loss(y_true, y_pred):
    return 1 - K.mean(tversky(y_true,y_pred))


def combined_loss(y_true, y_pred):
    return 0.5*dice_coef_loss(y_true, y_pred) + 0.5*tversky_loss(y_true, y_pred)


In [None]:
# baseline model
def unet_2d_model():
    input_shape = (240, 240, NUM_CLASSES)
    model = keras_unet.models.custom_unet(input_shape, use_batch_norm=True, num_classes=NUM_CLASSES, filters=64, dropout=0.5, output_activation='softmax')
    return model



def basic_cnn_model():
    
    # our 3D U-Net model
    
    inputs = Input(shape=(None, None, None, NUM_CHANNELS))
    
    NUM_FILTERS = 32
    
    DROPOUT_PROB = 0.5
    
    kernel = (3, 3, 3)
    
    stride = (2, 2, 2)
    
    pooling = (2, 2, 2)
    
    #contraction
    
    X = Conv3D(filters=NUM_FILTERS, kernel_size=kernel, padding='same', activation='relu')(inputs)
    X = BatchNormalization()(X)
    X = SpatialDropout3D(DROPOUT_PROB)(X)
    X = Conv3D(filters=NUM_FILTERS, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    concat_1 = X
    X = MaxPooling3D(pool_size=pooling)(X)
    
    X = Conv3D(filters=NUM_FILTERS * 2, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    X = SpatialDropout3D(DROPOUT_PROB)(X)
    X = Conv3D(filters=NUM_FILTERS * 2, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    concat_2 = X
    X = MaxPooling3D(pool_size=pooling)(X)
    
    X = Conv3D(filters=NUM_FILTERS * 4, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    X = SpatialDropout3D(DROPOUT_PROB)(X)
    X = Conv3D(filters=NUM_FILTERS * 4, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    concat_3 = X
    X = MaxPooling3D(pool_size=pooling)(X)
    
    X = Conv3D(filters=NUM_FILTERS * 8, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    X = SpatialDropout3D(DROPOUT_PROB)(X)
    X = Conv3D(filters=NUM_FILTERS * 8, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    concat_4 = X
    X = MaxPooling3D(pool_size=pooling)(X)
    
    # trough
    
    X = Conv3D(filters=NUM_FILTERS * 16, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    X = SpatialDropout3D(DROPOUT_PROB)(X)
    X = Conv3D(filters=NUM_FILTERS * 16, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    
    # expansion
    
    X = Conv3DTranspose(NUM_FILTERS * 8, pooling, strides=stride, padding='same')(X)
    X = concatenate([X, concat_4], axis=4)
    X = Conv3D(filters=NUM_FILTERS * 8, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    X = SpatialDropout3D(DROPOUT_PROB)(X)
    X = Conv3D(filters=NUM_FILTERS * 8, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    
    X = Conv3DTranspose(NUM_FILTERS * 4, pooling, strides=stride, padding='same')(X)
    X = concatenate([X, concat_3], axis=4)
    X = Conv3D(filters=NUM_FILTERS * 4, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    X = SpatialDropout3D(DROPOUT_PROB)(X)
    X = Conv3D(filters=NUM_FILTERS * 4, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    
    X = Conv3DTranspose(NUM_FILTERS * 2, pooling, strides=stride, padding='same')(X)
    X = concatenate([X, concat_2], axis=4)
    X = Conv3D(filters=NUM_FILTERS * 2, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    X = SpatialDropout3D(DROPOUT_PROB)(X)
    X = Conv3D(filters=NUM_FILTERS * 2, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    
    X = Conv3DTranspose(NUM_FILTERS, pooling, strides=stride, padding='same')(X)
    X = concatenate([X, concat_1], axis=4)
    X = Conv3D(filters=NUM_FILTERS, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    X = SpatialDropout3D(DROPOUT_PROB)(X)
    X = Conv3D(filters=NUM_FILTERS, kernel_size=kernel, padding='same', activation='relu')(X)
    X = BatchNormalization()(X)
    
    # segmentation
    
    outputs = Conv3D(filters=NUM_CLASSES, kernel_size=1, padding='same', activation='softmax')(X)
    
    model = Model(inputs=inputs, outputs=outputs)
    
    return model

In [None]:
from keras_unet.metrics import iou, iou_thresholded
from keras.callbacks import ModelCheckpoint,  LearningRateScheduler
from keras.optimizers import SGD, Adam
from keras.metrics import MeanIoU
import math

# training

#model = unet_2d_model()
model = basic_cnn_model()

model.compile(optimizer=Adam(lr=0.0001),
              loss=dice_coef_loss,
              metrics=[f1]
             )

training_generator = DataGenerator(X.shape[0], mode="Training")
#validation_generator = DataGenerator(X_val.shape[0], mode="Validation")

training_history = model.fit(x=training_generator, #validation_data=validation_generator, 
                             epochs=N_EPOCHS, verbose=1,
                             callbacks=[ModelCheckpoint(filepath=CKPT_PATH_PREFIX + ".{epoch:02d}.h5",  save_freq='epoch', period=5)])

model.save(CKPT_PATH_PREFIX + ".h5") 


model.summary()



#model.save("HGG_model_verification.h5") 


#model = keras.models.load_model('HGG_model_original_alpha1.h5', custom_objects={'dice_coef_loss': dice_coef_loss, 'f1': f1})

# Statistics Inference

In [None]:
#validation inference

CHECKPOINTS = {
    "TC_flair_60-epochs_fold-1-holdout": "TC",
    "TC_t1ce_60-epochs_fold-1-holdout": "TC",
}

SEQUENCES_LIST = {
    "TC_t1ce_60-epochs_fold-1-holdout": ["t1ce"],
    "TC_flair_60-epochs_fold-1-holdout": ["flair"],
}




def read_validation_data(root_path, dirs, seq_list):
    # list of numpy arrays
    X = []
    sample_names = []
    
    for d in dirs:
        path = root_path + d
        samples = sorted(glob.glob('{}/*/'.format(path)))

        for t in samples:
            x_sample = []
            y_sample = []
            for s in seq_list:
                x_file = glob.glob('{}/*{}.nii.gz'.format(t, s))[0] # this is unique
                x_sample.append(np.array(nilearn.image.get_data(x_file)))
                
            X.append(np.array(x_sample))
            sample_names.append(t)
    X = np.array(X)
    print(X.shape)
    X = np.moveaxis(X, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3])
    print(X.shape)
    return sample_names, X

In [None]:
root_folder = "experiments"

def do_pred_with_ckpt(exp_name, c, segType, seq_list, train_folders):
    model = keras.models.load_model(c, custom_objects={'dice_coef_loss': dice_coef_loss, 'f1': f1})
    X = None
    y = None
#     X_validation = read_validation_data(TRAINING_DATA_ROOT_DIR, VAL_DIRS, seq_list)
    
    samples, X_validation = read_validation_data(TRAINING_DATA_ROOT_DIR, train_folders, seq_list)

    print("X_validation shape: ", X_validation.shape)
    #print("y shape: ", y.shape)
    
    NUM_CLASSES = 4 if segType == "ALL" else 2
    
    NUM_CHANNELS = len(seq_list)
    
    #print(X_validation.shape)
    preds = np.zeros((X_validation.shape[0], X_validation.shape[1], X_validation.shape[2], X_validation.shape[3] + 5, NUM_CLASSES))

    tmp111 = np.zeros((X_validation.shape[0], X_validation.shape[1], X_validation.shape[2], X_validation.shape[3] + 5, NUM_CHANNELS))

    tmp111[:, :, :, :-5, :] = X_validation

    X_validation = tmp111

    #print(X_validation.shape)

    for b in range(0, X_validation.shape[0]):
        #print(b)
        X_normalized, _ = standardize_training_features(np.copy(X_validation[b:b+1, :, :, :, :]), None)
        logits = model.predict(X_normalized, batch_size=1)
        logits = np.squeeze(logits)
        preds[b, :, :, :, :] += logits


    tmp111 = None
    X_validation = None

    #print(preds.shape)

    preds = preds[:, :, :, :-5, :]

    #print(preds.shape)

    preds = np.argmax(preds, axis=-1)
    
    if segType == "TC":
        pass
    elif segType == "ET":
        preds[preds == 1] = 4
    elif segType == "ALL":
        preds[preds == 3] = 4
    else:
        raise ValueError('invalid segmentation type.')

    for b in range(preds.shape[0]):
        if np.count_nonzero(preds[b, :, :, :]) == 0:
            #print("all 0")
            preds[b, 0, 0, 0] = 1 # to avoid invalid data errors when submitting
    
    chkpt_folder_name = c[:-3].replace(".", "_").split("/")[-1]
    
    #PREDS_DIR = "loop_preds/predictions5/" 
    path = os.path.join(root_folder, exp_name, "segmentations", chkpt_folder_name)
    
    #samples = sorted(glob.glob('{}/*/'.format(TRAINING_DATA_ROOT_DIR + VAL_DIRS[0])))
    #samples1 = sorted(glob.glob('{}/*/'.format(TRAINING_DATA_ROOT_DIR + TRAIN_DIRS[0])))
    #samples2 = sorted(glob.glob('{}/*/'.format(TRAINING_DATA_ROOT_DIR + TRAIN_DIRS[1])))
    #samples = samples1 + samples2
    #samples = sorted(samples)
    #samples = [x.split('/')[-2] for x in samples ]

    #print("Outputting", c, "to folder", path)
    os.makedirs(path, exist_ok=True)

    #print(samples)

    for i in range(preds.shape[0]):
        s = samples[i]
        s = s.split("/")[-2]
        output = preds[i, :, :, :]
        output = output.astype(np.int32)
        filename = os.path.join(path, "{}.nii.gz".format(s))
        ni_img = nib.Nifti1Image(output, affine=np.eye(4))
        nib.save(ni_img, filename)
        #print(s, np.count_nonzero(preds[i, :, :, :]))



def do_inference(f, segType):
    chkpt_folder = os.path.join(root_folder, f)
    # TMP JUST DO 30 EPOCHS
    chkpts = [cp for cp in glob.glob(f"{chkpt_folder}/*.h5") if cp.startswith(chkpt_folder + "/" + f) and cp != chkpt_folder + "/" + f + ".h5" and ".30.h5" in cp]
    #chkpts = [cp for cp in glob.glob(f"{chkpt_folder}/*.h5") if cp.startswith(chkpt_folder + "/" + f) and cp != chkpt_folder + "/" + f + ".h5"]
#     print(chkpts)
    for c in chkpts:
        print(c)
        # TODO: filter out the fold used for holdout
        for i in range(5):
            do_pred_with_ckpt(f, c, segType, SEQUENCES_LIST[f], ["fold_" + str(i + 1)])


for k, v in CHECKPOINTS.items():
    do_inference(k, v)

# LOOP INFERENCE

In [None]:
#validation inference

CHECKPOINTS = {
    #"tc_t1ce_60": "TC",
    #"tc_flair_60": "TC",
    #"at_t1ce_60": "ET",
    #"et_all_100": "ET", #not working....
#     "at_flair_60": "ET"
    #"et_t1ce_flair_60": "ET",
    #"tc_t1ce_flair_60": "TC",
    "tc_all_60": "TC",
    "et_all_60": "ET"
}

SEQUENCES_LIST = {
    #"tc_t1ce_60": ["t1ce"],
    #"tc_flair_60": ["flair"],
    #"at_t1ce_60": ["t1ce"],
    #"et_all_100": ["t1ce", "flair", "t2", "t1"],
#     "at_flair_60": ["flair"]
    #"et_t1ce_flair_60": ["flair", "t1ce"],
    #"tc_t1ce_flair_60": ["flair", "t1ce"],
    "tc_all_60": ["flair", "t1ce", "t1", "t2"],
    "et_all_60": ["flair", "t1ce", "t1", "t2"]
}




def read_validation_data(root_path, dirs, seq_list):
    # list of numpy arrays
    X = []
    
    for d in dirs:
        path = root_path + d
        samples = sorted(glob.glob('{}/*/'.format(path)))

        for t in samples:
            x_sample = []
            y_sample = []
            for s in seq_list:
                x_file = glob.glob('{}/*{}.nii.gz'.format(t, s))[0] # this is unique
                x_sample.append(np.array(nilearn.image.get_data(x_file)))
                
            X.append(np.array(x_sample))
    X = np.array(X)
    print(X.shape)
    X = np.moveaxis(X, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3])
    print(X.shape)
    return X

In [None]:
root_folder = "loop_preds"

def do_pred_with_ckpt(c, segType, seq_list):
    model = keras.models.load_model(c, custom_objects={'dice_coef_loss': dice_coef_loss, 'f1': f1})
    X = None
    y = None
    X_validation = read_validation_data(TRAINING_DATA_ROOT_DIR, VAL_DIRS, seq_list)
    
    NUM_CLASSES = 4 if segType == "ALL" else 2
    
    NUM_CHANNELS = len(seq_list)
    
    #print(X_validation.shape)
    preds = np.zeros((X_validation.shape[0], X_validation.shape[1], X_validation.shape[2], X_validation.shape[3] + 5, NUM_CLASSES))

    tmp111 = np.zeros((X_validation.shape[0], X_validation.shape[1], X_validation.shape[2], X_validation.shape[3] + 5, NUM_CHANNELS))

    tmp111[:, :, :, :-5, :] = X_validation

    X_validation = tmp111

    #print(X_validation.shape)

    for b in range(0, X_validation.shape[0]):
        #print(b)
        X_normalized, _ = standardize_training_features(np.copy(X_validation[b:b+1, :, :, :, :]), None)
        logits = model.predict(X_normalized, batch_size=1)
        logits = np.squeeze(logits)
        preds[b, :, :, :, :] += logits


    tmp111 = None
    X_validation = None

    #print(preds.shape)

    preds = preds[:, :, :, :-5, :]

    #print(preds.shape)

    preds = np.argmax(preds, axis=-1)
    
    if segType == "TC":
        pass
    elif segType == "ET":
        preds[preds == 1] = 4
    elif segType == "ALL":
        preds[preds == 3] = 4
    else:
        raise ValueError('invalid segmentation type.')

    for b in range(preds.shape[0]):
        if np.count_nonzero(preds[b, :, :, :]) == 0:
            #print("all 0")
            preds[b, 0, 0, 0] = 1 # to avoid invalid data errors when submitting
    
    chkpt_folder_name = c[:-3].replace(".", "_")
    
    PREDS_DIR = "loop_preds/predictions5/" 
    path = os.path.join(PREDS_DIR, chkpt_folder_name)
    
    samples = sorted(glob.glob('{}/*/'.format(TRAINING_DATA_ROOT_DIR + VAL_DIRS[0])))
    #samples1 = sorted(glob.glob('{}/*/'.format(TRAINING_DATA_ROOT_DIR + TRAIN_DIRS[0])))
    #samples2 = sorted(glob.glob('{}/*/'.format(TRAINING_DATA_ROOT_DIR + TRAIN_DIRS[1])))
    #samples = samples1 + samples2
    #samples = sorted(samples)
    samples = [x.split('/')[-2] for x in samples ]

    #print("Outputting", c, "to folder", path)
    os.makedirs(path, exist_ok=True)

    #print(samples)

    for i in range(preds.shape[0]):
        s = samples[i]
        output = preds[i, :, :, :]
        filename = os.path.join(path, "{}.nii.gz".format(s))
        ni_img = nib.Nifti1Image(output, affine=np.eye(4))
        nib.save(ni_img, filename)
        #print(s, np.count_nonzero(preds[i, :, :, :]))



def do_inference(f, segType):
    chkpts = [cp for cp in glob.glob("*.h5") if cp.startswith(f) and cp != f + ".h5"]
    for c in chkpts:
        print(c)
        do_pred_with_ckpt(c, segType, SEQUENCES_LIST[f])




for k, v in CHECKPOINTS.items():
    do_inference(k, v)

# EVERYTHING AFTER IS THE SAME

In [None]:
model = keras.models.load_model('et_all_100.60.h5', custom_objects={'dice_coef_loss': dice_coef_loss, 'f1': f1})

In [None]:
X = None
y = None

In [None]:
#validation inference

def read_validation_data(root_path, dirs):
    # list of numpy arrays
    X = []
    
    for d in dirs:
        path = root_path + d
        samples = sorted(glob.glob('{}/*/'.format(path)))

        for t in samples:
            x_sample = []
            y_sample = []
            for s in SEQUENCES:
                x_file = glob.glob('{}/*{}.nii.gz'.format(t, s))[0] # this is unique
                x_sample.append(np.array(nilearn.image.get_data(x_file)))
                
            X.append(np.array(x_sample))
    X = np.array(X)
    print(X.shape)
    X = np.moveaxis(X, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3])
    print(X.shape)
    return X


X_validation = read_validation_data(TRAINING_DATA_ROOT_DIR, VAL_DIRS)

In [None]:
print(X_validation.shape)
preds = np.zeros((X_validation.shape[0], X_validation.shape[1], X_validation.shape[2], X_validation.shape[3] + 5, NUM_CLASSES))

tmp111 = np.zeros((X_validation.shape[0], X_validation.shape[1], X_validation.shape[2], X_validation.shape[3] + 5, NUM_CHANNELS))

tmp111[:, :, :, :-5, :] = X_validation

X_validation = tmp111

print(X_validation.shape)

for b in range(0, X_validation.shape[0]):
    print(b)
    X_normalized, _ = standardize_training_features(np.copy(X_validation[b:b+1, :, :, :, :]), None)
    logits = model.predict(X_normalized, batch_size=1)
    logits = np.squeeze(logits)
    preds[b, :, :, :, :] += logits
     

tmp111 = None
X_validation = None
        
print(preds.shape)

preds = preds[:, :, :, :-5, :]

print(preds.shape)

preds = np.argmax(preds, axis=-1)

In [None]:
if TYPE_OF_SEGMENTATION == "TC":
    pass
elif TYPE_OF_SEGMENTATION == "ET":
    preds[preds == 1] = 4
elif TYPE_OF_SEGMENTATION == "ALL":
    preds[preds == 3] = 4
else:
    raise ValueError('invalid segmentation type.')

for b in range(preds.shape[0]):
    if np.count_nonzero(preds[b, :, :, :]) == 0:
        print("all 0")
        preds[b, 0, 0, 0] = 1 # to avoid invalid data errors when submitting

In [None]:
PREDS_DIR = "data/predictions/" 
path = PREDS_DIR
samples = sorted(glob.glob('{}/*/'.format(TRAINING_DATA_ROOT_DIR + VAL_DIRS[0])))
#samples1 = sorted(glob.glob('{}/*/'.format(TRAINING_DATA_ROOT_DIR + TRAIN_DIRS[0])))
#samples2 = sorted(glob.glob('{}/*/'.format(TRAINING_DATA_ROOT_DIR + TRAIN_DIRS[1])))
#samples = samples1 + samples2
#samples = sorted(samples)
samples = [x.split('/')[-2] for x in samples ]


os.makedirs(path, exist_ok=True)

print(samples)

for i in range(preds.shape[0]):
    s = samples[i]
    output = preds[i, :, :, :]
    filename = os.path.join(path, "{}.nii.gz".format(s))
    ni_img = nib.Nifti1Image(output, affine=np.eye(4))
    nib.save(ni_img, filename)
    print(s, np.count_nonzero(preds[i, :, :, :]))