In [1]:
import SimpleITK as sitk
import numpy as np
import os

def normalize_one_volume(volume):
    new_volume = np.zeros(volume.shape)
    location = np.where(volume != 0)
    mean = np.mean(volume[location])
    var = np.std(volume[location])
    new_volume[location] = (volume[location] - mean) / var

    return new_volume

def merge_volumes(*volumes):
    return np.stack(volumes, axis=0)
    
def get_volume(root, patient, desired_depth,
              desired_height, desired_width,
              normalize_flag, 
              flip=0):
    flair_suffix = "_flair.nii.gz"
    t1_suffix = "_t1.nii.gz"
    t1ce_suffix = "_t1ce.nii.gz"
    t2_suffix = "_t2.nii.gz"
    
    path_flair = os.path.join(root, patient, patient + flair_suffix)
    path_t1 = os.path.join(root, patient, patient + t1_suffix)
    path_t2 = os.path.join(root, patient, patient + t2_suffix)
    path_t1ce = os.path.join(root, patient, patient + t1ce_suffix)
    
    flair = sitk.GetArrayFromImage(sitk.ReadImage(path_flair))
    t1 = sitk.GetArrayFromImage(sitk.ReadImage(path_t1))
    t2 = sitk.GetArrayFromImage(sitk.ReadImage(path_t2))
    t1ce = sitk.GetArrayFromImage(sitk.ReadImage(path_t1ce))

    if desired_depth > 155:
        flair = np.concatenate([flair, np.zeros((desired_depth - 155, 240, 240))], axis=0)
        t1 = np.concatenate([t1, np.zeros((desired_depth - 155, 240, 240))], axis=0)
        t2 = np.concatenate([t2, np.zeros((desired_depth - 155, 240, 240))], axis=0)
        t1ce = np.concatenate([t1ce, np.zeros((desired_depth - 155, 240, 240))], axis=0)

    if normalize_flag == True:
        out = merge_volumes(normalize_one_volume(flair), normalize_one_volume(t2), normalize_one_volume(t1ce), 
                            normalize_one_volume(t1))
    else:
        out = merge_volumes(flair, t2, t1ce, t1)
    
    if flip == 1:
        out = out[:, ::-1, :, :]
    elif flip == 2:
        out = out[:, :, ::-1, :]
    elif flip == 3:
        out = out[:, :, :, ::-1]
    elif flip == 4:
        out = out[:, :, ::-1, ::-1]
    elif flip == 5:
        out = out[:, ::-1, ::-1, ::-1]
    
    return np.expand_dims(out, axis=0)

In [2]:
import sys
import numpy as np
sys.path.remove('/home/sentic/.local/lib/python3.6/site-packages')

import torch
torch.backends.cudnn.benchmark=True

device_id = 0
torch.cuda.set_device(device_id)

root = "/home/sentic/MICCAI/data/train/"
use_gpu = True
n_epochs = 300
batch_size = 1
use_amp = False
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.optim.lr_scheduler import LambdaLR

from torch.utils.data import DataLoader

from model import LargeCascadedModel
from dataset import BraTS
from losses import DiceLoss, DiceLossLoss
from tqdm import tqdm_notebook, tqdm
paths_model = ["./checkpoints/checkpoint_190.pt", "./checkpoints/checkpoint_191.pt", "./checkpoints/checkpoint_192.pt",
              "./checkpoints/checkpoint_193.pt", "./checkpoints/checkpoint_194.pt", "./checkpoints/checkpoint_195.pt",
              "./checkpoints/checkpoint_196.pt", "./checkpoints/checkpoint_197.pt", "./checkpoints/checkpoint_198.pt",
              "./checkpoints/checkpoint_199.pt"]
diceLoss = DiceLossLoss()

In [3]:
intresting_patients = ['BraTS20_Validation_067', 'BraTS20_Validation_068', 'BraTS20_Validation_069', 'BraTS20_Validation_072',
                      'BraTS20_Validation_083', 'BraTS20_Validation_077', 'BraTS20_Validation_076', 'BraTS20_Validation_074',
                      'BraTS20_Validation_085', 'BraTS20_Validation_087', 'BraTS20_Validation_088', 'BraTS20_Validation_089',
                      'BraTS20_Validation_091', 'BraTS20_Validation_092', 'BraTS20_Validation_099', 'BraTS20_Validation_103']


patients_path = "/home/sentic/MICCAI/data/train/"
for patient_name in tqdm(os.listdir(patients_path)):
    if patient_name.startswith('BraTS'):
        output = np.zeros((3, 155, 240, 240))
        
        for path_model in paths_model:
            model = LargeCascadedModel(inplanes_encoder_1=4, channels_encoder_1=16, num_classes_1=3,
                           inplanes_encoder_2=7, channels_encoder_2=32, num_classes_2=3)
            model.load_state_dict(torch.load(path_model, map_location='cuda:0')['state_dict'])
            if use_gpu:
                model = model.to("cuda")

            model.eval()
            
            with torch.no_grad():
                for flip in range(0, 6):
                    volume = get_volume(patients_path, patient_name, 160, 240, 240, True, flip)
                    volume = torch.FloatTensor(volume.copy())
                    if use_gpu:
                        volume = volume.to("cuda")
                    _, _, decoded_region3, _ = model(volume)
                    decoded_region3 = decoded_region3.detach().cpu().numpy()
                    decoded_region3 = decoded_region3.squeeze()
                    if flip == 1:
                        decoded_region3 = decoded_region3[:, ::-1, :, :]
                    elif flip == 2:
                        decoded_region3 = decoded_region3[:, :, ::-1, :]
                    elif flip == 3:
                        decoded_region3 = decoded_region3[:, :, :, ::-1]
                    elif flip == 4:
                        decoded_region3 = decoded_region3[:, :, ::-1, ::-1]
                    elif flip == 5:
                        decoded_region3 = decoded_region3[:, ::-1, ::-1, ::-1]

                    output += decoded_region3[:, :155, :, :]

        np_array = output
        np_array = np_array / (6.0 * 10.0)

        np.save("./val_masks_np/" + patient_name + ".np", np_array)

100%|██████████| 371/371 [13:16:08<00:00, 128.76s/it]  


In [11]:
import os
import numpy as np
import SimpleITK as sitk

threshold_wt = 0.7
threshold_tc = 0.6
threshold_et = 0.7
low_threshold_et = 0.6
threshold_num_pixels_et = 150

patients_path = "/home/sentic/MICCAI/data/train/"
intresting_patients = ['BraTS20_Validation_067', 'BraTS20_Validation_068', 'BraTS20_Validation_069', 'BraTS20_Validation_072',
                      'BraTS20_Validation_083', 'BraTS20_Validation_077', 'BraTS20_Validation_076', 'BraTS20_Validation_074',
                      'BraTS20_Validation_085', 'BraTS20_Validation_087', 'BraTS20_Validation_088', 'BraTS20_Validation_089',
                      'BraTS20_Validation_091', 'BraTS20_Validation_092', 'BraTS20_Validation_099', 'BraTS20_Validation_103']


for patient_name in os.listdir(patients_path):
    if patient_name.startswith('BraTS'):
        path_big_volume = os.path.join(patients_path, patient_name, patient_name + "_flair.nii.gz")
        np_array = np.load("./val_masks_np/" + patient_name + ".np.npy")
        image = sitk.ReadImage(path_big_volume)
        direction = image.GetDirection()
        spacing = image.GetSpacing()
        origin = image.GetOrigin()

        seg_image = np.zeros((155, 240, 240))
        label_1 = np_array[2, :, :, :] # where the enhanced tumor is
        location_pixels_et = np.where(label_1 >= threshold_et)
        num_pixels_et = location_pixels_et[0].shape[0]
        label_2 = np_array[1, :, :, :] # locatia 1-urilor si 4-urilor
        label_3 = np_array[0, :, :, :] # locatia 1 + 2 + 4

        if patient_name in intresting_patients:
            print(patient_name, "--->", num_pixels_et)
        else:
            print(patient_name, "***->", num_pixels_et)

        if num_pixels_et > threshold_num_pixels_et: # if there are at least num of pixels
            label_1[label_1 >= threshold_et] = 1 # put them in et category
        else:
            label_1[label_1 >= threshold_et] = 0 # don't put them
            label_2[location_pixels_et] = 1 # but put them on tumor core

        label_2[(label_1 < threshold_et) & (label_1 >= low_threshold_et)] = 1

        label_1[label_1 < threshold_et] = 0
        location_1 = np.where(label_1 != 0)
        seg_image[location_1] = 4

        label_2[label_2 >= threshold_tc] = 1
        label_2[label_2 < threshold_tc] = 0
        location_2 = np.where((label_2 != 0) & (label_1 == 0))
        seg_image[location_2] = 1


        label_3[label_3 >= threshold_wt] = 1
        label_3[label_3 < threshold_wt] = 0
        location_3 = np.where((label_3 != 0) & (label_2 == 0))
        seg_image[location_3] = 2

        out_image = sitk.GetImageFromArray(seg_image)
        out_image.SetDirection(direction)
        out_image.SetSpacing(spacing)
        out_image.SetOrigin(origin)

        sitk.WriteImage(out_image, os.path.join("./final_masks", patient_name + ".nii.gz"))

BraTS20_Training_178 ***-> 45141
BraTS20_Training_202 ***-> 23822
BraTS20_Training_201 ***-> 29321
BraTS20_Training_284 ***-> 2501
BraTS20_Training_300 ***-> 964
BraTS20_Training_189 ***-> 20474
BraTS20_Training_124 ***-> 39207
BraTS20_Training_118 ***-> 7987
BraTS20_Training_283 ***-> 10327
BraTS20_Training_196 ***-> 38012
BraTS20_Training_148 ***-> 16934
BraTS20_Training_141 ***-> 526
BraTS20_Training_130 ***-> 33760
BraTS20_Training_014 ***-> 19049
BraTS20_Training_340 ***-> 29587
BraTS20_Training_215 ***-> 24154
BraTS20_Training_062 ***-> 34123
BraTS20_Training_095 ***-> 17780
BraTS20_Training_057 ***-> 7855
BraTS20_Training_204 ***-> 8566
BraTS20_Training_094 ***-> 934
BraTS20_Training_266 ***-> 126
BraTS20_Training_047 ***-> 31243
BraTS20_Training_149 ***-> 60544
BraTS20_Training_209 ***-> 19756
BraTS20_Training_067 ***-> 34527
BraTS20_Training_282 ***-> 6863
BraTS20_Training_354 ***-> 57606
BraTS20_Training_352 ***-> 8302
BraTS20_Training_342 ***-> 50681
BraTS20_Training_127 ***

BraTS20_Training_220 ***-> 16996
BraTS20_Training_097 ***-> 36051
BraTS20_Training_092 ***-> 30258
BraTS20_Training_200 ***-> 4140
BraTS20_Training_305 ***-> 496
BraTS20_Training_297 ***-> 0
BraTS20_Training_262 ***-> 0
BraTS20_Training_093 ***-> 23133
BraTS20_Training_025 ***-> 17020
BraTS20_Training_261 ***-> 7458
BraTS20_Training_110 ***-> 2454
BraTS20_Training_043 ***-> 2416
BraTS20_Training_078 ***-> 4765
BraTS20_Training_335 ***-> 295
BraTS20_Training_323 ***-> 45
BraTS20_Training_231 ***-> 36132
BraTS20_Training_061 ***-> 6909
BraTS20_Training_252 ***-> 37049
BraTS20_Training_055 ***-> 9316
BraTS20_Training_327 ***-> 14739
BraTS20_Training_269 ***-> 41
BraTS20_Training_082 ***-> 4227
BraTS20_Training_248 ***-> 33368
BraTS20_Training_167 ***-> 17414
BraTS20_Training_245 ***-> 62457
BraTS20_Training_310 ***-> 8
BraTS20_Training_060 ***-> 14873
BraTS20_Training_115 ***-> 13534
BraTS20_Training_239 ***-> 5990
BraTS20_Training_273 ***-> 1043
BraTS20_Training_068 ***-> 13654
BraTS20_T