In [1]:
import SimpleITK as sitk
import numpy as np
import os

def normalize_one_volume(volume):
    new_volume = np.zeros(volume.shape)
    location = np.where(volume != 0)
    mean = np.mean(volume[location])
    var = np.std(volume[location])
    new_volume[location] = (volume[location] - mean) / var

    return new_volume

def merge_volumes(*volumes):
    return np.stack(volumes, axis=0)
    
def get_volume(root, patient, desired_depth,
              desired_height, desired_width,
              normalize_flag, 
              flip=0):
    flair_suffix = "_flair.nii.gz"
    t1_suffix = "_t1.nii.gz"
    t1ce_suffix = "_t1ce.nii.gz"
    t2_suffix = "_t2.nii.gz"
    
    path_flair = os.path.join(root, patient, patient + flair_suffix)
    path_t1 = os.path.join(root, patient, patient + t1_suffix)
    path_t2 = os.path.join(root, patient, patient + t2_suffix)
    path_t1ce = os.path.join(root, patient, patient + t1ce_suffix)
    
    flair = sitk.GetArrayFromImage(sitk.ReadImage(path_flair))
    t1 = sitk.GetArrayFromImage(sitk.ReadImage(path_t1))
    t2 = sitk.GetArrayFromImage(sitk.ReadImage(path_t2))
    t1ce = sitk.GetArrayFromImage(sitk.ReadImage(path_t1ce))

    if desired_depth > 155:
        flair = np.concatenate([flair, np.zeros((desired_depth - 155, 240, 240))], axis=0)
        t1 = np.concatenate([t1, np.zeros((desired_depth - 155, 240, 240))], axis=0)
        t2 = np.concatenate([t2, np.zeros((desired_depth - 155, 240, 240))], axis=0)
        t1ce = np.concatenate([t1ce, np.zeros((desired_depth - 155, 240, 240))], axis=0)

    if normalize_flag == True:
        out = merge_volumes(normalize_one_volume(flair), normalize_one_volume(t2), normalize_one_volume(t1ce), 
                            normalize_one_volume(t1))
    else:
        out = merge_volumes(flair, t2, t1ce, t1)
    
    if flip == 1:
        out = out[:, ::-1, :, :]
    elif flip == 2:
        out = out[:, :, ::-1, :]
    elif flip == 3:
        out = out[:, :, :, ::-1]
    elif flip == 4:
        out = out[:, :, ::-1, ::-1]
    elif flip == 5:
        out = out[:, ::-1, ::-1, ::-1]
    
    return np.expand_dims(out, axis=0)

In [2]:
import sys
import numpy as np
sys.path.remove('/home/sentic/.local/lib/python3.6/site-packages')

import torch
torch.backends.cudnn.benchmark=True

device_id = 1
torch.cuda.set_device(device_id)

root = "/home/sentic/MICCAI/data/val/"
use_gpu = True
n_epochs = 300
batch_size = 1
use_amp = False
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.optim.lr_scheduler import LambdaLR

from torch.utils.data import DataLoader

from model import LargeCascadedModel
from dataset import BraTS
from losses import DiceLoss, DiceLossLoss
from tqdm import tqdm_notebook, tqdm
path = "./checkpoints/checkpoint_69.pt"

model = LargeCascadedModel(inplanes_encoder_1=4, channels_encoder_1=16, num_classes_1=3,
                           inplanes_encoder_2=7, channels_encoder_2=32, num_classes_2=3)
model.load_state_dict(torch.load(path, map_location='cuda:1')['state_dict'])
if use_gpu:
    model = model.to("cuda")

model.eval()
diceLoss = DiceLossLoss()

In [3]:
intresting_patients = ['BraTS20_Validation_067', 'BraTS20_Validation_068', 'BraTS20_Validation_069', 'BraTS20_Validation_072',
                      'BraTS20_Validation_083', 'BraTS20_Validation_077', 'BraTS20_Validation_076', 'BraTS20_Validation_074',
                      'BraTS20_Validation_085', 'BraTS20_Validation_087', 'BraTS20_Validation_088', 'BraTS20_Validation_089',
                      'BraTS20_Validation_091', 'BraTS20_Validation_092', 'BraTS20_Validation_099', 'BraTS20_Validation_103']


threshold_wt = 0.7
threshold_tc = 0.55
threshold_et = 0.7
low_threshold_et = 0.6
threshold_num_pixels_et = 150

patients_path = "/home/sentic/MICCAI/data/val/"
for patient_name in tqdm(os.listdir(patients_path)):
    if patient_name.startswith('BraTS'):
        output = np.zeros((3, 155, 240, 240))

        with torch.no_grad():
            for flip in range(0, 6):
                volume = get_volume(patients_path, patient_name, 160, 240, 240, True, flip)
                volume = torch.FloatTensor(volume.copy())
                if use_gpu:
                    volume = volume.to("cuda")
                _, _, decoded_region3, _ = model(volume)
                decoded_region3 = decoded_region3.detach().cpu().numpy()
                decoded_region3 = decoded_region3.squeeze()
                if flip == 1:
                    decoded_region3 = decoded_region3[:, ::-1, :, :]
                elif flip == 2:
                    decoded_region3 = decoded_region3[:, :, ::-1, :]
                elif flip == 3:
                    decoded_region3 = decoded_region3[:, :, :, ::-1]
                elif flip == 4:
                    decoded_region3 = decoded_region3[:, :, ::-1, ::-1]
                elif flip == 5:
                    decoded_region3 = decoded_region3[:, ::-1, ::-1, ::-1]

                output += decoded_region3[:, :155, :, :]

        np_array = output
        np_array = np_array / 6.0

        np.save("./val_masks_np/" + patient_name + ".np", np_array)

100%|██████████| 129/129 [27:30<00:00, 12.80s/it]


In [4]:
for patient_name in os.listdir(patients_path):
    if patient_name.startswith('BraTS'):
        path_big_volume = os.path.join(patients_path, patient_name, patient_name + "_flair.nii.gz")
        np_array = np.load("./val_masks_np/" + patient_name + ".np.npy")
        image = sitk.ReadImage(path_big_volume)
        direction = image.GetDirection()
        spacing = image.GetSpacing()
        origin = image.GetOrigin()

        seg_image = np.zeros((155, 240, 240))
        label_1 = np_array[2, :, :, :] # where the enhanced tumor is
        location_pixels_et = np.where(label_1 > threshold_et)
        num_pixels_et = location_pixels_et[0].shape[0]
        label_2 = np_array[1, :, :, :] # locatia 1-urilor si 4-urilor
        label_3 = np_array[0, :, :, :] # locatia 1 + 2 + 4

        if patient_name in intresting_patients:
            print(patient_name, "--->", num_pixels_et)
        else:
            print(patient_name, "***->", num_pixels_et)

        if num_pixels_et > threshold_num_pixels_et: # if there are at least num of pixels
            label_1[label_1 > threshold_et] = 1 # put them in et category
        else:
            label_1[label_1 > threshold_et] = 0 # don't put them
            label_2[location_pixels_et] = 1 # but put them on tumor core

        label_2[(label_1 < threshold_et) & (label_1 > low_threshold_et)] = 1

        label_1[label_1 < threshold_et] = 0
        location_1 = np.where(label_1 != 0)
        seg_image[location_1] = 4

        label_2[label_2 > threshold_tc] = 1
        label_2[label_2 < threshold_tc] = 0
        location_2 = np.where((label_2 != 0) & (label_1 == 0))
        seg_image[location_2] = 1


        label_3[label_3 > threshold_wt] = 1
        label_3[label_3 < threshold_wt] = 0
        location_3 = np.where((label_3 != 0) & (label_2 == 0))
        seg_image[location_3] = 2

        out_image = sitk.GetImageFromArray(seg_image)
        out_image.SetDirection(direction)
        out_image.SetSpacing(spacing)
        out_image.SetOrigin(origin)

        sitk.WriteImage(out_image, os.path.join("./final_masks", patient_name + ".nii.gz"))

BraTS20_Validation_075 ***-> 554
BraTS20_Validation_046 ***-> 18696
BraTS20_Validation_026 ***-> 2889
BraTS20_Validation_114 ***-> 10160
BraTS20_Validation_102 ***-> 3470
BraTS20_Validation_055 ***-> 27786
BraTS20_Validation_107 ***-> 366
BraTS20_Validation_044 ***-> 34019
BraTS20_Validation_101 ***-> 7543
BraTS20_Validation_125 ***-> 10513
BraTS20_Validation_066 ***-> 9464
BraTS20_Validation_010 ***-> 34874
BraTS20_Validation_116 ***-> 15515
BraTS20_Validation_083 ---> 468
BraTS20_Validation_008 ***-> 41347
BraTS20_Validation_094 ***-> 20040
BraTS20_Validation_105 ***-> 102265
BraTS20_Validation_034 ***-> 68103
BraTS20_Validation_097 ***-> 0
BraTS20_Validation_072 ---> 567
BraTS20_Validation_069 ---> 1038
BraTS20_Validation_007 ***-> 26728
BraTS20_Validation_063 ***-> 41096
BraTS20_Validation_012 ***-> 5329
BraTS20_Validation_001 ***-> 18940
BraTS20_Validation_093 ***-> 12558
BraTS20_Validation_002 ***-> 8283
BraTS20_Validation_115 ***-> 23772
BraTS20_Validation_009 ***-> 3121
BraTS20