In [None]:
import glob
import os
import numpy as np
import nilearn
from nilearn import image
import nibabel as nib
from tensorflow import keras
from keras import backend as K
from collections import defaultdict
import json
from matplotlib import pyplot as plt

SEGMENTATION_PATH = "experiments/{}/segmentations_fixed"
GROUND_TRUTH_PATH = "data/combined/{}/{}_seg.nii.gz"


def get_segmentation_folder(exp_name):
    return SEGMENTATION_PATH.format(exp_name)


TYPE_OF_SEGMENTATION = "TC"
OUT_DIR = "results/{}".format(TYPE_OF_SEGMENTATION.lower())

OUT_FILE = os.path.join(OUT_DIR, f"{TYPE_OF_SEGMENTATION}_results.json")



In [None]:
def dice_coefficient(truth, prediction, eps=0.01):
    return (2 * np.sum(truth * prediction) + eps)/(np.sum(truth) + np.sum(prediction) + eps)



In [None]:
def _convert_labels(y):
    # This section of code modifies the labels in the training dataset.
    # The BraTS dataset provides 4 different labels (0, 1, 2, 4) on a pixel level
    # But we don't care about this level of detail. We only want to identify whether
    # a pixel is tumorous (0/1 classification).
    # In particular, we only care if it is an "enhancing tumor structure"
    # This is the light blue section here: https://www.med.upenn.edu/sbia/brats2018.html
    # From here (https://arxiv.org/pdf/1811.02629.pdf), we see that the labels correspond to the following
    # For BraTS 2017 and above (note Label 3 has been combined with Label 1):
    # Label 1 (+ 3): NCR -- necrotic core, and NET -- Non enhancing tumor
    # Label 2: ED -- Edema
    # Label 4: AT -- Enhancing regions within the gross tumor abnormality
    # Thus, we only care about Label 4. We should therefore set labels 1 and 2 to label 0, and then
    # set label 4 to 1 to achieve what we want.
    
    # Note: 3 is not present in this dataset
    
    # Update: to segment TC, set 2 -> 0, 4 -> 1. 
    # to segment ET, set 1->0, 2->0, 4->1
    
    
    if TYPE_OF_SEGMENTATION == "TC":
        y[y == 2] = 0
        y[y == 4] = 1
    elif TYPE_OF_SEGMENTATION == "ET":
        y[y == 1] = 0
        y[y == 2] = 0
        y[y == 4] = 1
    elif TYPE_OF_SEGMENTATION == "ALL":
        y[y == 4] = 3
    else:
        raise ValueError('invalid segmentation type.')
        
        
    return y

In [None]:
all_experiments = glob.glob("experiments/*/")

results ={}


for exp in all_experiments:
    if "DEBUG_EXPERIMENT" in exp or not exp.split("/")[1].startswith(TYPE_OF_SEGMENTATION):
        continue
#     if not exp.split("/")[1].split("_")[1] == "flair":
#         continue

    print(exp)
    segmentation_path = os.path.join(exp, "segmentations_fixed")
    assert os.path.isdir(segmentation_path)

    ckpt_segs = sorted(glob.glob('{}/*/'.format(segmentation_path)))
    for ckpt_seg in ckpt_segs:

        print(ckpt_seg)
        if "val" in ckpt_seg:
            continue
        seg_files = sorted(glob.glob('{}/*'.format(ckpt_seg)))
        scores = []
        files = []

        for i in range(len(seg_files)):
            file_index = i

            train_sample_name = os.path.basename(seg_files[file_index]).split('.')[0]
            train_sample_ground_truth_path = GROUND_TRUTH_PATH.format(train_sample_name,train_sample_name)
            segment_path = seg_files[file_index]

            seg_mask = _convert_labels(np.array(nilearn.image.get_data(segment_path))) 
            gt_mask = _convert_labels(np.array(nilearn.image.get_data(train_sample_ground_truth_path)))
            
            scores.append(dice_coefficient(gt_mask, seg_mask))
            files.append(train_sample_name)

        mean = sum(scores)/len(scores)
        
        sorted_scores = sorted(scores)
        print(len(sorted_scores))
        if len(sorted_scores) % 2 == 0:
            median = (sorted_scores[len(sorted_scores)//2] + sorted_scores[(len(sorted_scores)//2) + 1])/2
        else:
            median = sorted_scores[len(sorted_scores)//2]
  
        info_string = ckpt_seg.split("/")[-2]
        info_list = info_string.split("_")
        seg_type = info_list[0]
        segs = info_list[1]
        fold = info_list[3]
        epoch = "ckpt_" + info_list[5]
        
        assert seg_type == TYPE_OF_SEGMENTATION
        
        if segs not in results:
            results[segs] = {}
        
        if epoch not in results[segs]:
            results[segs][epoch] = {}
        
        results[segs][epoch][fold] = {"mean": mean, "median": median}
        
        print(segs, epoch, fold, results[segs][epoch][fold])
    
    print()

with open(OUT_FILE, 'w') as f:
    json.dump(results, f)