In [1]:
import argparse
import logging
import sys

import monai
import torch
from torch.utils.tensorboard import SummaryWriter
from monai.utils import set_determinism
import numpy as np
from glob import glob
import os
import nibabel as nib

from monai.transforms.transform import Transform
from monai.transforms import Activations, EnsureType, Compose, MaskIntensity, AddChannel, ToTensor, MaskIntensity
from transforms import CropMRId, BinaryMask, GetLargestComponent, GetLabelsAsOneHotd, Shaped, InverseOneHot
from torch import nn
from transforms_dict import getSegmentationEvalTransformsForMRI, getSegmentationPostProcessingForLabelOutput, getSegmentationPostProcessingForAllLabelsOutput
from transforms_dict import getSegmentationInverseTransformForLabels, getSegmentationPostProcessingForLabel

from seg_data import getDataset
from seg_model import getUNetForSegmentation
from utils import compute_mean_dice, getReducePlateauScheduler, getAdamOptimizer, loadExistingModel
from utils import print_model_output, check_model_name, getDevice


from monai.inferers import sliding_window_inference
from monai.data import decollate_batch

In [2]:
#Parameters
N4 = False
resample = False
save_tmp = False
get_largest_component = False

In [3]:
#Modelname and device

def getUNETRForSegmentation():
    device = getDevice()
    model = monai.networks.nets.UNETR(
        in_channels=1,
        out_channels=4,
        img_size=(96, 96, 96),
        feature_size=16,
        hidden_size=768,
        mlp_dim=3072,
        num_heads=12,
        pos_embed="perceptron",
        norm_name="instance",
        res_block=True,
        dropout_rate=0.0,
    ).to(device)
    return model

model = getUNETRForSegmentation()
modelname = "labels/seg_labels_unetr_noresize.pth"

#model = getUNetForSegmentation()
#modelname = "labels/seg_labels_2.pth"

modelname = check_model_name(modelname)
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = getDevice()


In [4]:
#postprocessing
device = getDevice()
outputs_processing = getSegmentationPostProcessingForLabelOutput(axis=1)
all_outputs_processing = getSegmentationPostProcessingForAllLabelsOutput(axis=1)
labels_processing = getSegmentationPostProcessingForLabel(axis=1)
test_proba = Compose([
    EnsureType(),
    Activations(other=nn.Softmax(dim=1)),
    EnsureType(data_type="tensor", device=device)
])

#activations
softmax = Activations(other=nn.Softmax(dim=1))

In [6]:
#Model
loadExistingModel(model, None, ft=modelname)
model.eval()

mris = sorted(glob(os.path.join("dataset", "Painfact-Segmentation", 'MRI', "*.nii.gz")))
masks = sorted(glob(os.path.join("dataset", "Painfact-Segmentation", 'Label', "*.nii.gz")))
for i, mri in enumerate(mris):
    mriname = mri
    outname = "labels_output_new_xd/" + mriname.split('/')[-1].split('.')[0]
    
    mask_bg = AddChannel()(ToTensor()(nib.load(masks[i]).get_fdata()))
    #maskint = MaskIntensity(mask_bg)

    mri_preprocessing = getSegmentationEvalTransformsForMRI(N4=N4, outname=outname, save=save_tmp, no_resize=True)
    

    with torch.no_grad():
        inputs = mri_preprocessing(mriname).to(device)
        print(inputs.shape)
        outputs = sliding_window_inference(inputs, (96, 96, 96), 1, model).cpu()
        print(outputs.shape)
        
        preds = outputs_processing(outputs)
        #preds = maskint(preds)
        all_preds = all_outputs_processing(outputs)
        #all_preds = maskint(all_preds)
        probs = test_proba(outputs)
        
        suffix = ["bg", "c1", "c2", "c3"]
        for i in range(len(suffix)):
            transform = getSegmentationInverseTransformForLabels(mriname, out_name=outname, suffix=suffix[i], save=True, no_resize=True)
            transform(preds[0, i, :, :, :])        
        #for i in range(len(suffix)):
        #    transform = getSegmentationInverseTransformForLabels(mriname, out_name=outname, suffix=suffix[i]+'_prob', save=True)
        #    transform(probs[0, i, :, :, :])
        transform = getSegmentationInverseTransformForLabels(mriname, out_name=outname, suffix=None, save=True, no_resize=True)
        transform(all_preds[0,:,:,:]) 
                

torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/000_M_I23_out_bg.nii.gz
=> Saved to labels_output_new_xd/000_M_I23_out_c1.nii.gz
=> Saved to labels_output_new_xd/000_M_I23_out_c2.nii.gz
=> Saved to labels_output_new_xd/000_M_I23_out_c3.nii.gz
=> Saved to labels_output_new_xd/000_M_I23_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/101_M_A1_out_bg.nii.gz
=> Saved to labels_output_new_xd/101_M_A1_out_c1.nii.gz
=> Saved to labels_output_new_xd/101_M_A1_out_c2.nii.gz
=> Saved to labels_output_new_xd/101_M_A1_out_c3.nii.gz
=> Saved to labels_output_new_xd/101_M_A1_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/102_M_A2_out_bg.nii.gz
=> Saved to labels_output_new_xd/102_M_A2_out_c1.nii.gz
=> Saved to labels_output_new_xd/102_M_A2_out_c2.nii.gz
=> Saved to labels_output_new_xd/102_M_A2_out_c3.nii.gz
=> Saved to l

=> Saved to labels_output_new_xd/150_M_F15_out_c3.nii.gz
=> Saved to labels_output_new_xd/150_M_F15_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/151_F_M11_out_bg.nii.gz
=> Saved to labels_output_new_xd/151_F_M11_out_c1.nii.gz
=> Saved to labels_output_new_xd/151_F_M11_out_c2.nii.gz
=> Saved to labels_output_new_xd/151_F_M11_out_c3.nii.gz
=> Saved to labels_output_new_xd/151_F_M11_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/152_F_M15_out_bg.nii.gz
=> Saved to labels_output_new_xd/152_F_M15_out_c1.nii.gz
=> Saved to labels_output_new_xd/152_F_M15_out_c2.nii.gz
=> Saved to labels_output_new_xd/152_F_M15_out_c3.nii.gz
=> Saved to labels_output_new_xd/152_F_M15_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/159_M_G17_out_bg.nii.gz
=> Saved to labels_output_new_xd/159_M_G17_out_c1.nii.gz
=> Save

torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/207_M_P1_out_bg.nii.gz
=> Saved to labels_output_new_xd/207_M_P1_out_c1.nii.gz
=> Saved to labels_output_new_xd/207_M_P1_out_c2.nii.gz
=> Saved to labels_output_new_xd/207_M_P1_out_c3.nii.gz
=> Saved to labels_output_new_xd/207_M_P1_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/208_M_Q4_out_bg.nii.gz
=> Saved to labels_output_new_xd/208_M_Q4_out_c1.nii.gz
=> Saved to labels_output_new_xd/208_M_Q4_out_c2.nii.gz
=> Saved to labels_output_new_xd/208_M_Q4_out_c3.nii.gz
=> Saved to labels_output_new_xd/208_M_Q4_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/209_M_P2_out_bg.nii.gz
=> Saved to labels_output_new_xd/209_M_P2_out_c1.nii.gz
=> Saved to labels_output_new_xd/209_M_P2_out_c2.nii.gz
=> Saved to labels_output_new_xd/209_M_P2_out_c3.nii.gz
=> Saved to labels_output_new_xd/209_M_P2_out.nii.gz

=> Saved to labels_output_new_xd/246_M_S10_out_c2.nii.gz
=> Saved to labels_output_new_xd/246_M_S10_out_c3.nii.gz
=> Saved to labels_output_new_xd/246_M_S10_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/247_M_Z22_out_bg.nii.gz
=> Saved to labels_output_new_xd/247_M_Z22_out_c1.nii.gz
=> Saved to labels_output_new_xd/247_M_Z22_out_c2.nii.gz
=> Saved to labels_output_new_xd/247_M_Z22_out_c3.nii.gz
=> Saved to labels_output_new_xd/247_M_Z22_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/250_F_AC9_out_bg.nii.gz
=> Saved to labels_output_new_xd/250_F_AC9_out_c1.nii.gz
=> Saved to labels_output_new_xd/250_F_AC9_out_c2.nii.gz
=> Saved to labels_output_new_xd/250_F_AC9_out_c3.nii.gz
=> Saved to labels_output_new_xd/250_F_AC9_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/251_F_AC10_out_bg.nii.gz
=> Sav

=> Saved to labels_output_new_xd/284_M_AK11_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/285_M_AK12_out_bg.nii.gz
=> Saved to labels_output_new_xd/285_M_AK12_out_c1.nii.gz
=> Saved to labels_output_new_xd/285_M_AK12_out_c2.nii.gz
=> Saved to labels_output_new_xd/285_M_AK12_out_c3.nii.gz
=> Saved to labels_output_new_xd/285_M_AK12_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/287_F_AR6_out_bg.nii.gz
=> Saved to labels_output_new_xd/287_F_AR6_out_c1.nii.gz
=> Saved to labels_output_new_xd/287_F_AR6_out_c2.nii.gz
=> Saved to labels_output_new_xd/287_F_AR6_out_c3.nii.gz
=> Saved to labels_output_new_xd/287_F_AR6_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/289_F_AQ_out_bg.nii.gz
=> Saved to labels_output_new_xd/289_F_AQ_out_c1.nii.gz
=> Saved to labels_output_new_xd/289_F_AQ_out_c2.nii.gz
=> S

torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/340_M_AN19_out_bg.nii.gz
=> Saved to labels_output_new_xd/340_M_AN19_out_c1.nii.gz
=> Saved to labels_output_new_xd/340_M_AN19_out_c2.nii.gz
=> Saved to labels_output_new_xd/340_M_AN19_out_c3.nii.gz
=> Saved to labels_output_new_xd/340_M_AN19_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/342_M_AO22_out_bg.nii.gz
=> Saved to labels_output_new_xd/342_M_AO22_out_c1.nii.gz
=> Saved to labels_output_new_xd/342_M_AO22_out_c2.nii.gz
=> Saved to labels_output_new_xd/342_M_AO22_out_c3.nii.gz
=> Saved to labels_output_new_xd/342_M_AO22_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/344_M_AG2_out_bg.nii.gz
=> Saved to labels_output_new_xd/344_M_AG2_out_c1.nii.gz
=> Saved to labels_output_new_xd/344_M_AG2_out_c2.nii.gz
=> Saved to labels_output_new_xd/344_M_AG2_out_c3.nii.gz
=> Saved to labels_output_ne

=> Saved to labels_output_new_xd/409_M_AZ12_out_c1.nii.gz
=> Saved to labels_output_new_xd/409_M_AZ12_out_c2.nii.gz
=> Saved to labels_output_new_xd/409_M_AZ12_out_c3.nii.gz
=> Saved to labels_output_new_xd/409_M_AZ12_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/410_M_AX7_out_bg.nii.gz
=> Saved to labels_output_new_xd/410_M_AX7_out_c1.nii.gz
=> Saved to labels_output_new_xd/410_M_AX7_out_c2.nii.gz
=> Saved to labels_output_new_xd/410_M_AX7_out_c3.nii.gz
=> Saved to labels_output_new_xd/410_M_AX7_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/412_F_BI17_out_bg.nii.gz
=> Saved to labels_output_new_xd/412_F_BI17_out_c1.nii.gz
=> Saved to labels_output_new_xd/412_F_BI17_out_c2.nii.gz
=> Saved to labels_output_new_xd/412_F_BI17_out_c3.nii.gz
=> Saved to labels_output_new_xd/412_F_BI17_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227]

=> Saved to labels_output_new_xd/456_F_BF3_out_c3.nii.gz
=> Saved to labels_output_new_xd/456_F_BF3_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/458_M_BE24_out_bg.nii.gz
=> Saved to labels_output_new_xd/458_M_BE24_out_c1.nii.gz
=> Saved to labels_output_new_xd/458_M_BE24_out_c2.nii.gz
=> Saved to labels_output_new_xd/458_M_BE24_out_c3.nii.gz
=> Saved to labels_output_new_xd/458_M_BE24_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/461_M_BB18_out_bg.nii.gz
=> Saved to labels_output_new_xd/461_M_BB18_out_c1.nii.gz
=> Saved to labels_output_new_xd/461_M_BB18_out_c2.nii.gz
=> Saved to labels_output_new_xd/461_M_BB18_out_c3.nii.gz
=> Saved to labels_output_new_xd/461_M_BB18_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/463_f_BG8_out_bg.nii.gz
=> Saved to labels_output_new_xd/463_f_BG8_out_c1.nii.

torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/516_M_BS22_out_bg.nii.gz
=> Saved to labels_output_new_xd/516_M_BS22_out_c1.nii.gz
=> Saved to labels_output_new_xd/516_M_BS22_out_c2.nii.gz
=> Saved to labels_output_new_xd/516_M_BS22_out_c3.nii.gz
=> Saved to labels_output_new_xd/516_M_BS22_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/518_F_BU2_out_bg.nii.gz
=> Saved to labels_output_new_xd/518_F_BU2_out_c1.nii.gz
=> Saved to labels_output_new_xd/518_F_BU2_out_c2.nii.gz
=> Saved to labels_output_new_xd/518_F_BU2_out_c3.nii.gz
=> Saved to labels_output_new_xd/518_F_BU2_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/521_M_BK3_out_bg.nii.gz
=> Saved to labels_output_new_xd/521_M_BK3_out_c1.nii.gz
=> Saved to labels_output_new_xd/521_M_BK3_out_c2.nii.gz
=> Saved to labels_output_new_xd/521_M_BK3_out_c3.nii.gz

=> Saved to labels_output_new_xd/559_F_BX18_out_c1.nii.gz
=> Saved to labels_output_new_xd/559_F_BX18_out_c2.nii.gz
=> Saved to labels_output_new_xd/559_F_BX18_out_c3.nii.gz
=> Saved to labels_output_new_xd/559_F_BX18_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/563_M_BP15_out_bg.nii.gz
=> Saved to labels_output_new_xd/563_M_BP15_out_c1.nii.gz
=> Saved to labels_output_new_xd/563_M_BP15_out_c2.nii.gz
=> Saved to labels_output_new_xd/563_M_BP15_out_c3.nii.gz
=> Saved to labels_output_new_xd/563_M_BP15_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216, 227])
=> Saved to labels_output_new_xd/564_M_BR19_out_bg.nii.gz
=> Saved to labels_output_new_xd/564_M_BR19_out_c1.nii.gz
=> Saved to labels_output_new_xd/564_M_BR19_out_c2.nii.gz
=> Saved to labels_output_new_xd/564_M_BR19_out_c3.nii.gz
=> Saved to labels_output_new_xd/564_M_BR19_out.nii.gz
torch.Size([1, 1, 305, 216, 227])
torch.Size([1, 4, 305, 216,


KeyboardInterrupt

