In [None]:
import numpy as np
from pathlib import Path
import pickle
import nibabel as nib
import matplotlib.pyplot as plt

In [None]:
TEST_DATASET = "../models/clean/data/test.pickle"

RESULT_PATH_2D = "../models/clean/results/pred/results_2d.pickle"
RESULT_PATH_2D_PREPROCESS = "../models/clean/results/pred/results_2d_preprocess.pickle"
RESULT_PATH_3D = "../models/clean/results/pred/result_3d.pickle"
RESULT_PATH_3D_PREPROCESS = "../models/clean/results/pred/result_3d_preprocess.pickle"
DATASET_PATH = Path('.')

In [None]:
with open(TEST_DATASET, 'rb') as f:
    test_set = pickle.load(f)
    
all_preds = [RESULT_PATH_2D, RESULT_PATH_2D_PREPROCESS, RESULT_PATH_3D, RESULT_PATH_3D_PREPROCESS]

In [None]:
def load_datas(all_preds):
    preds = []
    for path in all_preds:
        with open(path, 'rb') as f:
            preds.append(pickle.load(f)['transformed'])
    return preds

def reshape_datas(all_preds):
    preds = []
    for Y_pred in all_preds:
        preds.append([
            np.transpose(item, (1, 2, 0)).round() for item in Y_pred
        ])
    return preds

In [None]:
all_preds = reshape_datas(load_datas(all_preds))

In [None]:
# With this notebook, we will only show the results in the original
# space
def update_gt(gt):
    gt = gt.round()
    gt[gt > 1] = 0
    gt[gt < 0] = 0
    gt = gt.astype(np.bool)
    return gt

def get_data(test_set):
    patients = test_set['patients']
    flairs = [
        nib.load(DATASET_PATH / patient / 'pre' / 'FLAIR.nii.gz').get_fdata(dtype=np.float32)
        for patient in patients
    ]
    
    t1s = [
        nib.load(DATASET_PATH / patient / 'pre' / 'T1.nii.gz').get_fdata(dtype=np.float32)
        for patient in patients
    ]
    
    Y_true = [
        update_gt(nib.load(DATASET_PATH / patient / 'wmh.nii.gz').get_fdata(dtype=np.float32)) 
        for patient in patients
    ]
    
    return flairs, t1s, Y_true

In [None]:
flairs, t1s, Y_true = get_data(test_set)

In [None]:
datas = [flairs, t1s, Y_true, *all_preds] #, Y_pred2....
# Comparing Y_pred and more

In [None]:
def from_patient_to_slices(datas):
    def patients_to_slices(patients):
        slices = []
        for id_patient in range(len(patients)):
            patient = patients[id_patient]
            _, _, nb_slices = patient.shape
            for s in range(nb_slices):
                slices.append(patient[:, :, s])
        return np.array(slices)
    return [
        patients_to_slices(patients) for patients in datas
    ]

def sort_datas(datas):
    Y_true = datas[2]
    nb_wmh = np.array([
            np.count_nonzero(Y_true[i]) for i in range(len(Y_true))
        ])
    indices = np.argsort(-nb_wmh) # Decreasing
    return [
        data[indices] for data in datas
    ]

In [None]:
new_data = sort_datas(from_patient_to_slices(datas))

In [None]:
def correct_incorrect(pred, y):
    # Among all the true labels of the pred, which ones are correct and incorrect
    h, w = pred.shape
    new_image = np.zeros((h, w, 3))
    incorrect = pred - y
    correct = (pred == 1) & (y == 1)
    for i in range(h):
        for j in range(w):
            if correct[i][j] == 1:
                new_image[i][j] = (0, 1, 0) # R, G, B
            if incorrect[i][j] == 1:
                new_image[i][j] = (1, 0, 0)
    return new_image

def plot_helper(data, title, nb_rows, nb_cols, pos):
    plt.subplot(nb_rows, nb_cols, pos)
    plt.imshow(data)
    plt.title(title)

In [None]:
dic = {
    0: '2d',
    1: '2d_preprocess',
    2: '3d',
    3: '3d_preprocess',
    4: '2d'
}

In [None]:
def display_results(datas, idx):
    nb_pred = len(datas) - 3
    # Data should be flair, t1, y_true, varg
    nb_rows = 1 + nb_pred
    nb_cols = 3
    col_size = 8 * nb_cols
    cpt = 0
    plt.figure(figsize=(col_size, 6 * nb_rows))
    
    flair, t1, y_true = datas[0][idx], datas[1][idx], datas[2][idx]
    plot_helper(flair, 'flair', nb_rows, nb_cols, 1)
    plot_helper(t1, 't1', nb_rows, nb_cols, 2)
    
    cpt = 1
    
    for j in range(nb_pred):
        y_pred = datas[3 + j][idx]
        new_img = correct_incorrect(y_pred, y_true)
        plot_helper(y_true, 'gt', nb_rows, nb_cols, nb_cols * cpt + 1)
        plot_helper(y_pred, dic[j], nb_rows, nb_cols, nb_cols * cpt + 2)
        plot_helper(new_img, 'correct/incorrect', nb_rows, nb_cols, nb_cols * cpt + 3)
        cpt += 1
        
    plt.show()

In [None]:
display_results(new_data, 10)

In [None]:
indices = np.arange(len(new_data[0]))

In [None]:
np.random.shuffle(indices)

In [None]:
for i in range(30):
    display_results(new_data, indices[i])