In [None]:
# imports and constants
import os
import glob
import nibabel as nib
import nilearn
import numpy as np
import tensorflow as tf
import random
import scipy
import shutil
from tensorflow import keras
from keras import Sequential, Model
from keras.layers import Dense, Flatten, Dropout, Reshape, Conv3D, Conv2D, Activation, UpSampling3D, \
                         MaxPooling3D, SpatialDropout3D, BatchNormalization, Conv3DTranspose
from keras.layers import concatenate, Input
# from keras.engine import Model
from keras.optimizers import Adam
from keras import backend as K
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from nilearn import image
from datetime import datetime
from scipy import ndimage
import keras_unet
import imageio


###### MODIFY HPARAMS ###########

DEBUG = True # for testing purposes

K_FOLD_VAL_NUM = "ALLFOLDS"

SEQUENCES = ["t1ce", "flair", "t1", "t2"]

TYPE_OF_SEGMENTATION = "ET"

##################################


N_EPOCHS = 60 if DEBUG == False else 2


DEBUG_SAMPLES = 20

print(tf.test.is_gpu_available())

# TRAINING_DATA_ROOT_DIR = "data/"
# TRAIN_DIRS = ["train/LGG", "train/HGG"]

TRAINING_DATA_ROOT_DIR = "data_folds/"

TRAIN_DIRS = ["fold_" + str(i + 1) for i in range(5) if i + 1 != K_FOLD_VAL_NUM]

print(TRAIN_DIRS)

VAL_DIRS = ["val"]


GROUND_TRUTH_FILE = "seg"



N_X = 240
N_Y = 240
N_Z = 155
NUM_CHANNELS = len(SEQUENCES)
NUM_CLASSES = 4 if TYPE_OF_SEGMENTATION == "ALL" else 2


PATCH_X = 80
PATCH_Y = 80
PATCH_Z = 80


BATCH_SIZE = 2

PERCENTILE_95 = 10.204082 * 2 # 20 pixels

SEQUENCES_JOINED = ",".join(SEQUENCES)

EXPERIMENT_NAME = f"{TYPE_OF_SEGMENTATION}_{SEQUENCES_JOINED}_{N_EPOCHS}-epochs_fold-{K_FOLD_VAL_NUM}-holdout"

if DEBUG:
    EXPERIMENT_NAME = "DEBUG_EXPERIMENT"


print(EXPERIMENT_NAME)

EXPERIMENT_FOLDER = os.path.join("experiments", EXPERIMENT_NAME)

CKPT_PATH_PREFIX = os.path.join(EXPERIMENT_FOLDER, EXPERIMENT_NAME + "_ckpt")

print(CKPT_PATH_PREFIX)

In [None]:
import sys

import tensorflow.keras
import pandas as pd
import sklearn as sk
import tensorflow as tf

print(f"Tensor Flow Version: {tf.__version__}")
print(f"Keras Version: {tensorflow.keras.__version__}")
print()
print(f"Python {sys.version}")
print(f"Pandas {pd.__version__}")
print(f"Scikit-Learn {sk.__version__}")
gpu = len(tf.config.list_physical_devices('GPU'))>0
print("GPU is", "available" if gpu else "NOT AVAILABLE")

In [None]:
# preprocessing

def _standardize_nii_file(x):
    ret = np.copy(x)
    ret = (ret - x.mean())/x.std() if x.std() > 0 else (ret - x.mean())
    return ret

def standardize_training_features(X, y):
    for i in range(X.shape[0]):
        X[i, :, :, :, :] = np.array(_standardize_nii_file(X[i, :, :, :, :]))
    return X, y


In [None]:
# metrics

def recall(y, y_hat):
    true_positives = K.sum(K.round(K.clip(y * y_hat, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y, 0, 1)))
    return true_positives/(possible_positives + K.epsilon())

def precision(y, y_hat):
    true_positives = K.sum(K.round(K.clip(y * y_hat, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_hat, 0, 1)))
    return true_positives/(predicted_positives + K.epsilon())

def f1(y, y_hat):
    r = recall(y, y_hat)
    p = precision(y, y_hat)
    return 2*((p*r)/(p + r + K.epsilon()))


def dice_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=(0, 1, 2, 3))
    dice_numerator = (2. * intersection + smooth)
    dice_denominator = (K.sum(K.square(y_true),axis=(0, 1, 2, 3)) + K.sum(K.square(y_pred),axis=(0, 1, 2, 3)) + smooth)

    return dice_numerator/dice_denominator


def dice_coef_loss(y_true, y_pred):
    return 1- K.mean(dice_coef(y_true, y_pred))


# Statistics Inference

In [None]:
#validation inference

CHECKPOINTS = {
#     "TC_flair_60-epochs_fold-1-holdout": "TC",
#     "TC_t1ce_60-epochs_fold-1-holdout": "TC",
#     "TC_flair_60-epochs_fold-2-holdout": "NotNeeded",
#     "TC_flair_60-epochs_fold-3-holdout": "NotNeeded",
#     "TC_flair_60-epochs_fold-4-holdout": "NotNeeded",
#     "TC_flair_60-epochs_fold-5-holdout": "NotNeeded",
#     "TC_t1ce_60-epochs_fold-2-holdout": "NotNeeded",
#     "TC_t1ce_60-epochs_fold-3-holdout": "NotNeeded",
#     "TC_t1ce_60-epochs_fold-4-holdout": "NotNeeded",
#     "TC_t1ce_60-epochs_fold-5-holdout": "NotNeeded",
#     "TC_t1ce,flair_60-epochs_fold-1-holdout": "NotNeeded",
#     "TC_t1ce,flair_60-epochs_fold-2-holdout": "NotNeeded",
#     "TC_t1ce,flair_60-epochs_fold-3-holdout": "NotNeeded",
#     "TC_t1ce,flair_60-epochs_fold-4-holdout": "NotNeeded",
#     "TC_t1ce,flair_60-epochs_fold-5-holdout": "NotNeeded",
#     "TC_t1ce,flair,t1,t2_60-epochs_fold-1-holdout": "NotNeeded",
#     "TC_t1ce,flair,t1,t2_60-epochs_fold-2-holdout": "NotNeeded",
#     "TC_t1ce,flair,t1,t2_60-epochs_fold-3-holdout": "NotNeeded",
#     "TC_t1ce,flair,t1,t2_60-epochs_fold-4-holdout": "NotNeeded",
#     "TC_t1ce,flair,t1,t2_60-epochs_fold-5-holdout": "NotNeeded",
#     "ET_flair_60-epochs_fold-1-holdout": "NotNeeded",
#     "ET_t1ce_60-epochs_fold-1-holdout": "NotNeeded",
#     "ET_flair_60-epochs_fold-2-holdout": "NotNeeded",
#     "ET_flair_60-epochs_fold-3-holdout": "NotNeeded",
#     "ET_flair_60-epochs_fold-4-holdout": "NotNeeded",
#     "ET_flair_60-epochs_fold-5-holdout": "NotNeeded",
#     "ET_t1ce_60-epochs_fold-2-holdout": "NotNeeded",
#     "ET_t1ce_60-epochs_fold-3-holdout": "NotNeeded",
#     "ET_t1ce_60-epochs_fold-4-holdout": "NotNeeded",
#     "ET_t1ce_60-epochs_fold-5-holdout": "NotNeeded",
#     "ET_t1ce,flair,t1,t2_60-epochs_fold-1-holdout": "NotNeeded",
#     "ET_t1ce,flair,t1,t2_60-epochs_fold-2-holdout": "NotNeeded",
#     "ET_t1ce,flair,t1,t2_60-epochs_fold-3-holdout": "NotNeeded",
#     "ET_t1ce,flair,t1,t2_60-epochs_fold-4-holdout": "NotNeeded",
#     "ET_t1ce,flair,t1,t2_60-epochs_fold-5-holdout": "NotNeeded",
#     "ET_t1ce,flair_60-epochs_fold-1-holdout": "NotNeeded",
#     "ET_t1ce,flair_60-epochs_fold-2-holdout": "NotNeeded",
#     "ET_t1ce,flair_60-epochs_fold-3-holdout": "NotNeeded",
#     "ET_t1ce,flair_60-epochs_fold-4-holdout": "NotNeeded",
#     "ET_t1ce,flair_60-epochs_fold-5-holdout": "NotNeeded",
      "ET_t1ce,flair,t1,t2_60-epochs_fold-ALLFOLDS-holdout": "NotNeeded",
}

# SEQUENCES_LIST = {
#     "TC_t1ce_60-epochs_fold-1-holdout": ["t1ce"],
#     "TC_flair_60-epochs_fold-1-holdout": ["flair"],
# }




def read_validation_data(root_path, dirs, seq_list):
    # list of numpy arrays
    X = []
    sample_names = []
    
    for d in dirs:
        path = root_path + d
        samples = sorted(glob.glob('{}/*/'.format(path)))

        for t in samples:
            x_sample = []
            y_sample = []
            for s in seq_list:
                x_file = glob.glob('{}/*{}.nii.gz'.format(t, s))[0] # this is unique
                x_sample.append(np.array(nilearn.image.get_data(x_file)))
                
            X.append(np.array(x_sample))
            sample_names.append(t)
    X = np.array(X)
    print(X.shape)
    X = np.moveaxis(X, [0, 1, 2, 3, 4], [0, 4, 1, 2, 3])
    print(X.shape)
    return sample_names, X

In [None]:
import gc

root_folder = "experiments"

def do_pred_with_ckpt(exp_name, c, segType, seq_list, train_folders):
    model = keras.models.load_model(c, custom_objects={'dice_coef_loss': dice_coef_loss, 'f1': f1})
    X = None
    y = None
#     X_validation = read_validation_data(TRAINING_DATA_ROOT_DIR, VAL_DIRS, seq_list)
    
    if train_folders != ["val"]:
        samples, X_validation = read_validation_data(TRAINING_DATA_ROOT_DIR, train_folders, seq_list)
    else:
        samples, X_validation = read_validation_data("data/", ["val/val"], seq_list)

    print("X_validation shape: ", X_validation.shape)
    #print("y shape: ", y.shape)
    
    NUM_CLASSES = 4 if segType == "ALL" else 2
    
    NUM_CHANNELS = len(seq_list)
    
    #print(X_validation.shape)
    preds = np.zeros((X_validation.shape[0], X_validation.shape[1], X_validation.shape[2], X_validation.shape[3] + 5, NUM_CLASSES))

    tmp111 = np.zeros((X_validation.shape[0], X_validation.shape[1], X_validation.shape[2], X_validation.shape[3] + 5, NUM_CHANNELS))

    tmp111[:, :, :, :-5, :] = X_validation

    X_validation = tmp111

    #print(X_validation.shape)

    for b in range(0, X_validation.shape[0]):
        #print(b)
        K.clear_session()
        _ = gc.collect()
        X_normalized, _ = standardize_training_features(np.copy(X_validation[b:b+1, :, :, :, :]), None)
        logits = model.predict(X_normalized, batch_size=1, verbose=0)
        logits = np.squeeze(logits)
        preds[b, :, :, :, :] += logits


    tmp111 = None
    X_validation = None

    #print(preds.shape)

    preds = preds[:, :, :, :-5, :]

    #print(preds.shape)

    preds = np.argmax(preds, axis=-1)
    
    if segType == "TC":
        pass
    elif segType == "ET":
        preds[preds == 1] = 4
    elif segType == "ALL":
        preds[preds == 3] = 4
    else:
        raise ValueError('invalid segmentation type.')

    for b in range(preds.shape[0]):
        if np.count_nonzero(preds[b, :, :, :]) == 0:
            #print("all 0")
            preds[b, 0, 0, 0] = 1 # to avoid invalid data errors when submitting
    
    chkpt_folder_name = c[:-3].replace(".", "_").split("/")[-1]
    
    if train_folders == ["val"]:
        chkpt_folder_name += "_val"
    
    segmentations_subpath = "segmentations_fixed" if train_folders != ["val"] else "val_segmentations"
    
    #PREDS_DIR = "loop_preds/predictions5/" 
    path = os.path.join(root_folder, exp_name, segmentations_subpath, chkpt_folder_name)
    
    #samples = sorted(glob.glob('{}/*/'.format(TRAINING_DATA_ROOT_DIR + VAL_DIRS[0])))
    #samples1 = sorted(glob.glob('{}/*/'.format(TRAINING_DATA_ROOT_DIR + TRAIN_DIRS[0])))
    #samples2 = sorted(glob.glob('{}/*/'.format(TRAINING_DATA_ROOT_DIR + TRAIN_DIRS[1])))
    #samples = samples1 + samples2
    #samples = sorted(samples)
    #samples = [x.split('/')[-2] for x in samples ]

    #print("Outputting", c, "to folder", path)
    os.makedirs(path, exist_ok=True)

    #print(samples)

    for i in range(preds.shape[0]):
        s = samples[i]
        s = s.split("/")[-2]
        output = preds[i, :, :, :]
        output = output.astype(np.int32)
        filename = os.path.join(path, "{}.nii.gz".format(s))
        ni_img = nib.Nifti1Image(output, affine=np.eye(4))
        nib.save(ni_img, filename)
        #print(s, np.count_nonzero(preds[i, :, :, :]))



def do_inference(f, segType):
    chkpt_folder = os.path.join(root_folder, f)
    print(chkpt_folder)

    chkpts = [cp for cp in glob.glob(f"{chkpt_folder}/*.h5") if cp.startswith(chkpt_folder + "/" + f) and "ckpt.h5" not in cp]
    #chkpts = [cp for cp in glob.glob(f"{chkpt_folder}/*.h5") if cp.startswith(chkpt_folder + "/" + f) and cp != chkpt_folder + "/" + f + ".h5"]
#     print(chkpts)
    for c in chkpts:
        print(c)
#         TODO: filter out the fold used for holdout
#         curr_folds = ["fold_" + str(i + 1) for i in range(5) if "fold-" + str(i+1) in f]
        
#         for fold in curr_folds:
#             K.clear_session()
#             _ = gc.collect()
#             print(f, fold)
#             do_pred_with_ckpt(f, c, segType, f.split("_")[1].split(","), [fold])
        
#         if "ckpt.60.h5" in c:
        curr_folds = ["val"]
        for fold in curr_folds:
            K.clear_session()
            _ = gc.collect()
            print(f, fold)
            do_pred_with_ckpt(f, c, segType, f.split("_")[1].split(","), [fold])


for k, v in CHECKPOINTS.items():
    segType = k.split("_")[0]
    do_inference(k, segType)