In [1]:
import argparse
import logging
import sys

import monai
import torch
from torch.utils.tensorboard import SummaryWriter
from monai.utils import set_determinism
import numpy as np
from glob import glob
import os

from monai.transforms.transform import Transform
from monai.transforms import Activations, EnsureType, Compose
from transforms import CropMRId, BinaryMask, GetLargestComponent, GetLabelsAsOneHotd, Shaped, InverseOneHot
from torch import nn
from transforms_dict import getSegmentationEvalTransformsForMRI
from transforms_dict import getSegmentationInverseTransformForLabels

from seg_data import getDataset
from seg_model import getUNetForSegmentation
from utils import compute_mean_dice, getReducePlateauScheduler, getAdamOptimizer, loadExistingModel
from utils import print_model_output, check_model_name, getDevice

In [2]:
#Parameters
N4 = False
resample = False
save_tmp = False
get_largest_component = False

In [3]:
#Modelname and device

def getUNETRForSegmentation():
    device = getDevice()
    model = monai.networks.nets.UNETR(
        in_channels=1,
        out_channels=4,
        img_size=(128, 128, 128),
        feature_size=16,
        hidden_size=768,
        mlp_dim=3072,
        num_heads=12,
        pos_embed="perceptron",
        norm_name="instance",
        res_block=True,
        dropout_rate=0.0,
    ).to(device)
    return model

#model = getUNETRForSegmentation()
#modelname = "labels/seg_labels_unetr_1.pth"

model = getUNetForSegmentation()
modelname = "labels/seg_labels_2.pth"

modelname = check_model_name(modelname)
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = getDevice()


In [4]:
def max_channel_wise(labels):
    labels = labels.detach().cpu().numpy()
    labels[np.where(labels == np.amax(labels, axis=1))] = 1
    labels[labels != 1] = 0
    return labels


class GetMaxChannelWise(Transform):
    def __call__(self, labels):
        labels = max_channel_wise(labels)
        return labels

In [5]:
def getSegmentationPostProcessingForLabel():
    device = getDevice()
    postprocessing_transforms = Compose(
        [
            EnsureType(),
            GetMaxChannelWise(),
            EnsureType(data_type="tensor", device=device),
        ]
    )
    return postprocessing_transforms


def getSegmentationPostProcessingForLabelOutput():
    device = getDevice()
    postprocessing_transforms = Compose(
        [
            EnsureType(),
            Activations(other=nn.Softmax(dim=1)),
            GetMaxChannelWise(),
            EnsureType(data_type="tensor", device=device),
        ]
    )
    return postprocessing_transforms

def getSegmentationPostProcessingForAllLabelsOutput():
    device = getDevice()
    postprocessing_transforms = Compose(
        [
            EnsureType(),
            Activations(other=nn.Softmax(dim=1)),
            InverseOneHot(),
            EnsureType(data_type="tensor", device=device),
        ]
    )
    return postprocessing_transforms

In [6]:
#postprocessing
device = getDevice()
outputs_processing = getSegmentationPostProcessingForLabelOutput()
all_outputs_processing = getSegmentationPostProcessingForAllLabelsOutput()
labels_processing = getSegmentationPostProcessingForLabel()
test_proba = Compose([
    EnsureType(),
    Activations(other=nn.Softmax(dim=1)),
    EnsureType(data_type="tensor", device=device)
])

#activations
softmax = Activations(other=nn.Softmax(dim=1))

In [7]:
#Metric MUNet
def metric_munet(preds, labels):  
    #labels[np.where(labels == np.amax(labels, axis=0))] = 1
    #labels[labels != 1] = 0
    dice=2*np.sum(labels*preds,(0,1,2))/(np.sum((labels+preds),(0,1,2))+1)    
    return dice

def dice(a,b):
 intersection = np.logical_and(a, b)
 union = np.logical_or(a, b)
 dice = (2*np.sum(intersection))/(np.sum(union)+np.sum(intersection))
 return dice

In [14]:
#Model
import nibabel as nib
from transforms import GetLabelsAsOneHot, LoadNibabel, NibabelToNumpy

loadExistingModel(model, None, ft=modelname)
model.eval()

mris = sorted(glob(os.path.join("dataset", "Painfact-Segmentation", 'MRI', "*.nii.gz")))
labels = sorted(glob(os.path.join("dataset", "Painfact-Segmentation", 'Label', "*.nii.gz")))

metrics = [[], [], [], []]
for i, mri in enumerate(mris):
    print(str(i) + '/' + str(len(mris)), end='\r')
    if i>=0:
        mriname = mri
        outname = "labels_output_unet/" + mriname.split('/')[-1].split('.')[0]
        

        mri_preprocessing = getSegmentationEvalTransformsForMRI(N4=N4, outname=outname, save=save_tmp)

        with torch.no_grad():
            inputs = mri_preprocessing(mriname).to(device)
            outputs = model(inputs).cpu()  

            preds = outputs_processing(outputs)
            #all_preds = all_outputs_processing(outputs)
            #probs = test_proba(outputs)

            suffix = ["bg", "c1", "c2", "c3"]
            preds_resized = []
            for i in range(len(suffix)):
                transform = getSegmentationInverseTransformForLabels(mriname, out_name=outname, suffix=suffix[i], save=True)
                x = transform(preds[0, i, :, :, :]) 
                preds_resized.append(x)



            x = labels[i]
            x = LoadNibabel()(x)
            x = NibabelToNumpy()(x)
            x = GetLabelsAsOneHot(get=True, skip=False)(x)

            y = torch.stack([torch.from_numpy(z) for z in preds_resized]).numpy()

            for j in range(len(preds_resized)):
                metric = dice(y[j,:,:,:], x[j,:,:,:])  
                metrics[j].append(metric)     

            #for i in range(len(suffix)):
            #    transform = getSegmentationInverseTransformForLabels(mriname, out_name=outname, suffix=suffix[i]+'_prob', save=True)
            #    transform(probs[0, i, :, :, :])
            #transform = getSegmentationInverseTransformForLabels(mriname, out_name=outname, suffix=None, save=False)
            #transform(all_preds[0,:,:,:])

epoch_metrics = [np.mean(x) for x in metrics]
epoch_metric = np.mean(epoch_metrics)


print(
        "dice: {:.4f}".format(
            epoch_metric
        )
    )
print("dices: {:.4f}, {:.4f}, {:.4f}, {:.4f}".format(epoch_metrics[0], epoch_metrics[1], epoch_metrics[2], epoch_metrics[3]))


                

=> Saved to labels_output_unet/000_M_I23_out_bg.nii.gz
=> Saved to labels_output_unet/000_M_I23_out_c1.nii.gz
=> Saved to labels_output_unet/000_M_I23_out_c2.nii.gz
=> Saved to labels_output_unet/000_M_I23_out_c3.nii.gz
=> Saved to labels_output_unet/101_M_A1_out_bg.nii.gz
=> Saved to labels_output_unet/101_M_A1_out_c1.nii.gz
=> Saved to labels_output_unet/101_M_A1_out_c2.nii.gz
=> Saved to labels_output_unet/101_M_A1_out_c3.nii.gz
=> Saved to labels_output_unet/102_M_A2_out_bg.nii.gz
=> Saved to labels_output_unet/102_M_A2_out_c1.nii.gz
=> Saved to labels_output_unet/102_M_A2_out_c2.nii.gz
=> Saved to labels_output_unet/102_M_A2_out_c3.nii.gz
=> Saved to labels_output_unet/103_M_I21_out_bg.nii.gz
=> Saved to labels_output_unet/103_M_I21_out_c1.nii.gz
=> Saved to labels_output_unet/103_M_I21_out_c2.nii.gz
=> Saved to labels_output_unet/103_M_I21_out_c3.nii.gz
=> Saved to labels_output_unet/109_F_K2_out_bg.nii.gz
=> Saved to labels_output_unet/109_F_K2_out_c1.nii.gz
=> Saved to labels_o

=> Saved to labels_output_unet/187_M_H20_out_c3.nii.gz
=> Saved to labels_output_unet/189_M_J25_out_bg.nii.gz
=> Saved to labels_output_unet/189_M_J25_out_c1.nii.gz
=> Saved to labels_output_unet/189_M_J25_out_c2.nii.gz
=> Saved to labels_output_unet/189_M_J25_out_c3.nii.gz
=> Saved to labels_output_unet/195_F_O21_out_bg.nii.gz
=> Saved to labels_output_unet/195_F_O21_out_c1.nii.gz
=> Saved to labels_output_unet/195_F_O21_out_c2.nii.gz
=> Saved to labels_output_unet/195_F_O21_out_c3.nii.gz
=> Saved to labels_output_unet/196_F_O22_out_bg.nii.gz
=> Saved to labels_output_unet/196_F_O22_out_c1.nii.gz
=> Saved to labels_output_unet/196_F_O22_out_c2.nii.gz
=> Saved to labels_output_unet/196_F_O22_out_c3.nii.gz
=> Saved to labels_output_unet/197_M_E11_out_bg.nii.gz
=> Saved to labels_output_unet/197_M_E11_out_c1.nii.gz
=> Saved to labels_output_unet/197_M_E11_out_c2.nii.gz
=> Saved to labels_output_unet/197_M_E11_out_c3.nii.gz
=> Saved to labels_output_unet/198_M_E14_out_bg.nii.gz
=> Saved t

=> Saved to labels_output_unet/253_M_T12_out_bg.nii.gz
=> Saved to labels_output_unet/253_M_T12_out_c1.nii.gz
=> Saved to labels_output_unet/253_M_T12_out_c2.nii.gz
=> Saved to labels_output_unet/253_M_T12_out_c3.nii.gz
=> Saved to labels_output_unet/254_M_U13_out_bg.nii.gz
=> Saved to labels_output_unet/254_M_U13_out_c1.nii.gz
=> Saved to labels_output_unet/254_M_U13_out_c2.nii.gz
=> Saved to labels_output_unet/254_M_U13_out_c3.nii.gz
=> Saved to labels_output_unet/257_F_AD11_out_bg.nii.gz
=> Saved to labels_output_unet/257_F_AD11_out_c1.nii.gz
=> Saved to labels_output_unet/257_F_AD11_out_c2.nii.gz
=> Saved to labels_output_unet/257_F_AD11_out_c3.nii.gz
=> Saved to labels_output_unet/258_F_AD12_out_bg.nii.gz
=> Saved to labels_output_unet/258_F_AD12_out_c1.nii.gz
=> Saved to labels_output_unet/258_F_AD12_out_c2.nii.gz
=> Saved to labels_output_unet/258_F_AD12_out_c3.nii.gz
=> Saved to labels_output_unet/259_F_AD1_out_bg.nii.gz
=> Saved to labels_output_unet/259_F_AD1_out_c1.nii.gz
=>

=> Saved to labels_output_unet/328_F_AT16_out_bg.nii.gz
=> Saved to labels_output_unet/328_F_AT16_out_c1.nii.gz
=> Saved to labels_output_unet/328_F_AT16_out_c2.nii.gz
=> Saved to labels_output_unet/328_F_AT16_out_c3.nii.gz
=> Saved to labels_output_unet/329_F_AS11_out_bg.nii.gz
=> Saved to labels_output_unet/329_F_AS11_out_c1.nii.gz
=> Saved to labels_output_unet/329_F_AS11_out_c2.nii.gz
=> Saved to labels_output_unet/329_F_AS11_out_c3.nii.gz
=> Saved to labels_output_unet/333_M_AN18_out_bg.nii.gz
=> Saved to labels_output_unet/333_M_AN18_out_c1.nii.gz
=> Saved to labels_output_unet/333_M_AN18_out_c2.nii.gz
=> Saved to labels_output_unet/333_M_AN18_out_c3.nii.gz
=> Saved to labels_output_unet/335_F_AT17_out_bg.nii.gz
=> Saved to labels_output_unet/335_F_AT17_out_c1.nii.gz
=> Saved to labels_output_unet/335_F_AT17_out_c2.nii.gz
=> Saved to labels_output_unet/335_F_AT17_out_c3.nii.gz
=> Saved to labels_output_unet/336_F_AS12_out_bg.nii.gz
=> Saved to labels_output_unet/336_F_AS12_out_c1

=> Saved to labels_output_unet/421_F_BI18_out_c3.nii.gz
=> Saved to labels_output_unet/424_F_BF1_out_bg.nii.gz
=> Saved to labels_output_unet/424_F_BF1_out_c1.nii.gz
=> Saved to labels_output_unet/424_F_BF1_out_c2.nii.gz
=> Saved to labels_output_unet/424_F_BF1_out_c3.nii.gz
=> Saved to labels_output_unet/425_M_AV3_out_bg.nii.gz
=> Saved to labels_output_unet/425_M_AV3_out_c1.nii.gz
=> Saved to labels_output_unet/425_M_AV3_out_c2.nii.gz
=> Saved to labels_output_unet/425_M_AV3_out_c3.nii.gz
=> Saved to labels_output_unet/426_M_BA13_out_bg.nii.gz
=> Saved to labels_output_unet/426_M_BA13_out_c1.nii.gz
=> Saved to labels_output_unet/426_M_BA13_out_c2.nii.gz
=> Saved to labels_output_unet/426_M_BA13_out_c3.nii.gz
=> Saved to labels_output_unet/429_M_BB16_out_bg.nii.gz
=> Saved to labels_output_unet/429_M_BB16_out_c1.nii.gz
=> Saved to labels_output_unet/429_M_BB16_out_c2.nii.gz
=> Saved to labels_output_unet/429_M_BB16_out_c3.nii.gz
=> Saved to labels_output_unet/430_F_BI19_out_bg.nii.gz


=> Saved to labels_output_unet/513_M_BK2_out_c3.nii.gz
=> Saved to labels_output_unet/514_F_BW11_out_bg.nii.gz
=> Saved to labels_output_unet/514_F_BW11_out_c1.nii.gz
=> Saved to labels_output_unet/514_F_BW11_out_c2.nii.gz
=> Saved to labels_output_unet/514_F_BW11_out_c3.nii.gz
=> Saved to labels_output_unet/516_M_BS22_out_bg.nii.gz
=> Saved to labels_output_unet/516_M_BS22_out_c1.nii.gz
=> Saved to labels_output_unet/516_M_BS22_out_c2.nii.gz
=> Saved to labels_output_unet/516_M_BS22_out_c3.nii.gz
=> Saved to labels_output_unet/518_F_BU2_out_bg.nii.gz
=> Saved to labels_output_unet/518_F_BU2_out_c1.nii.gz
=> Saved to labels_output_unet/518_F_BU2_out_c2.nii.gz
=> Saved to labels_output_unet/518_F_BU2_out_c3.nii.gz
=> Saved to labels_output_unet/521_M_BK3_out_bg.nii.gz
=> Saved to labels_output_unet/521_M_BK3_out_c1.nii.gz
=> Saved to labels_output_unet/521_M_BK3_out_c2.nii.gz
=> Saved to labels_output_unet/521_M_BK3_out_c3.nii.gz
=> Saved to labels_output_unet/523_M_BO11_out_bg.nii.gz
=

=> Saved to labels_output_unet/584_M_BN10_out_c3.nii.gz
=> Saved to labels_output_unet/587_F_BY21_out_bg.nii.gz
=> Saved to labels_output_unet/587_F_BY21_out_c1.nii.gz
=> Saved to labels_output_unet/587_F_BY21_out_c2.nii.gz
=> Saved to labels_output_unet/587_F_BY21_out_c3.nii.gz
=> Saved to labels_output_unet/588_F_BY22_out_bg.nii.gz
=> Saved to labels_output_unet/588_F_BY22_out_c1.nii.gz
=> Saved to labels_output_unet/588_F_BY22_out_c2.nii.gz
=> Saved to labels_output_unet/588_F_BY22_out_c3.nii.gz
=> Saved to labels_output_unet/589_F_BY23_out_bg.nii.gz
=> Saved to labels_output_unet/589_F_BY23_out_c1.nii.gz
=> Saved to labels_output_unet/589_F_BY23_out_c2.nii.gz
=> Saved to labels_output_unet/589_F_BY23_out_c3.nii.gz
=> Saved to labels_output_unet/590_F_BY24_out_bg.nii.gz
=> Saved to labels_output_unet/590_F_BY24_out_c1.nii.gz
=> Saved to labels_output_unet/590_F_BY24_out_c2.nii.gz
=> Saved to labels_output_unet/590_F_BY24_out_c3.nii.gz
=> Saved to labels_output_unet/591_M_BM7_out_bg.

=> Saved to labels_output_unet/661_F_CN25_out_c2.nii.gz
=> Saved to labels_output_unet/661_F_CN25_out_c3.nii.gz
=> Saved to labels_output_unet/663_F_CM9_out_bg.nii.gz
=> Saved to labels_output_unet/663_F_CM9_out_c1.nii.gz
=> Saved to labels_output_unet/663_F_CM9_out_c2.nii.gz
=> Saved to labels_output_unet/663_F_CM9_out_c3.nii.gz
=> Saved to labels_output_unet/668_M_CE15_out_bg.nii.gz
=> Saved to labels_output_unet/668_M_CE15_out_c1.nii.gz
=> Saved to labels_output_unet/668_M_CE15_out_c2.nii.gz
=> Saved to labels_output_unet/668_M_CE15_out_c3.nii.gz
=> Saved to labels_output_unet/669_M_CG20_out_bg.nii.gz
=> Saved to labels_output_unet/669_M_CG20_out_c1.nii.gz
=> Saved to labels_output_unet/669_M_CG20_out_c2.nii.gz
=> Saved to labels_output_unet/669_M_CG20_out_c3.nii.gz
=> Saved to labels_output_unet/670_M_CI25_out_bg.nii.gz
=> Saved to labels_output_unet/670_M_CI25_out_c1.nii.gz
=> Saved to labels_output_unet/670_M_CI25_out_c2.nii.gz
=> Saved to labels_output_unet/670_M_CI25_out_c3.nii

In [16]:
#labels_output
labels = sorted(glob(os.path.join("dataset", "Painfact-Segmentation-Raw", 'Label-0.2', "*.nii.gz")))
preds_bgs = sorted(glob(os.path.join("labels_output_unet", "*_out_bg.nii.gz")))
preds_c1s = sorted(glob(os.path.join("labels_output_unet", "*_out_c1.nii.gz")))
preds_c2s = sorted(glob(os.path.join("labels_output_unet", "*_out_c2.nii.gz")))
preds_c3s = sorted(glob(os.path.join("labels_output_unet", "*_out_c3.nii.gz")))

preds = [preds_bgs, preds_c1s, preds_c2s, preds_c3s]

metrics = [[], [], [], []]

for i in range(len(preds_bgs)):
    if i>(165+55):# and i<(165+55):
        y = labels[i]
        y = nib.load(y).get_fdata()  
        y = GetLabelsAsOneHot(get=True, skip=False)(y)
        for j in range(4):
            truth = y[j,:,:,:]
            x = preds[j][i]
            x = nib.load(x).get_fdata()
            metric = dice(truth,x)
            metrics[j].append(metric)
    
epoch_metrics = [np.mean(x) for x in metrics]
epoch_metric = np.mean(epoch_metrics)


print(
        "dice: {:.4f}".format(
            epoch_metric
        )
    )
print("dices: {:.4f}, {:.4f}, {:.4f}, {:.4f}".format(epoch_metrics[0], epoch_metrics[1], epoch_metrics[2], epoch_metrics[3]))


    

dice: 0.7377
dices: 0.9924, 0.8901, 0.8134, 0.2550


In [17]:
#labels_output
labels = sorted(glob(os.path.join("dataset", "Painfact-Segmentation-Raw", 'Label-0.2', "*.nii.gz")))
preds_bgs = sorted(glob(os.path.join("labels_output", "*_out_bg.nii.gz")))
preds_c1s = sorted(glob(os.path.join("labels_output", "*_out_c1.nii.gz")))
preds_c2s = sorted(glob(os.path.join("labels_output", "*_out_c2.nii.gz")))
preds_c3s = sorted(glob(os.path.join("labels_output", "*_out_c3.nii.gz")))

preds = [preds_bgs, preds_c1s, preds_c2s, preds_c3s]

metrics = [[], [], [], []]

for i in range(len(preds_bgs)):
    if i>(165+55):# and i<(165+55):
        y = labels[i]
        y = nib.load(y).get_fdata()  
        y = GetLabelsAsOneHot(get=True, skip=False)(y)
        for j in range(4):
            truth = y[j,:,:,:]
            x = preds[j][i]
            x = nib.load(x).get_fdata()
            metric = dice(truth,x)
            metrics[j].append(metric)
    
epoch_metrics = [np.mean(x) for x in metrics]
epoch_metric = np.mean(epoch_metrics)


print(
        "dice: {:.4f}".format(
            epoch_metric
        )
    )
print("dices: {:.4f}, {:.4f}, {:.4f}, {:.4f}".format(epoch_metrics[0], epoch_metrics[1], epoch_metrics[2], epoch_metrics[3]))


    

dice: 0.7498
dices: 0.9927, 0.9034, 0.8180, 0.2851


In [20]:
#labels_output
labels = sorted(glob(os.path.join("dataset", "Painfact-Segmentation-Raw", 'Label-0.2', "*.nii.gz")))
preds_bgs = sorted(glob(os.path.join("labels_output_new", "*_out_bg.nii.gz")))
preds_c1s = sorted(glob(os.path.join("labels_output_new", "*_out_c1.nii.gz")))
preds_c2s = sorted(glob(os.path.join("labels_output_new", "*_out_c2.nii.gz")))
preds_c3s = sorted(glob(os.path.join("labels_output_new", "*_out_c3.nii.gz")))

preds = [preds_bgs, preds_c1s, preds_c2s, preds_c3s]

metrics = [[], [], [], []]

for i in range(len(preds_bgs)):
    if i>(165+55):# and i<(165+55):
        y = labels[i]
        y = nib.load(y).get_fdata()  
        y = GetLabelsAsOneHot(get=True, skip=False)(y)
        for j in range(4):
            truth = y[j,:,:,:]
            x = preds[j][i]
            x = nib.load(x).get_fdata()
            metric = dice(truth,x)
            metrics[j].append(metric)
    
epoch_metrics = [np.mean(x) for x in metrics]
epoch_metrics.append(0.9913)
epoch_metric = np.mean(epoch_metrics[1:4])

print(
        "dice: {:.4f}".format(
            epoch_metric
        )
    )
print("dices: {:.4f}, {:.4f}, {:.4f}, {:.4f}".format(epoch_metrics[4], epoch_metrics[1], epoch_metrics[2], epoch_metrics[3]))


    

dice: 0.6171
dices: 0.9913, 0.8649, 0.6857, 0.3008
