In [1]:
import argparse
import logging
import sys

import monai
import torch
from torch.utils.tensorboard import SummaryWriter
from monai.utils import set_determinism
import numpy as np
from glob import glob
import os

from monai.transforms.transform import Transform
from monai.transforms import Activations, EnsureType, Compose
from transforms import CropMRId, BinaryMask, GetLargestComponent, GetLabelsAsOneHotd, Shaped, InverseOneHot
from torch import nn
from transforms_dict import getSegmentationEvalTransformsForMRI
from transforms_dict import getSegmentationInverseTransformForLabels

from seg_data import getDataset
from seg_model import getUNetForSegmentation
from utils import compute_mean_dice, getReducePlateauScheduler, getAdamOptimizer, loadExistingModel
from utils import print_model_output, check_model_name, getDevice

In [2]:
#Parameters
N4 = False
resample = False
save_tmp = False
get_largest_component = False

In [3]:
#Modelname and device

def getUNETRForSegmentation():
    device = getDevice()
    model = monai.networks.nets.UNETR(
        in_channels=1,
        out_channels=4,
        img_size=(96, 96, 96),
        feature_size=16,
        hidden_size=768,
        mlp_dim=3072,
        num_heads=12,
        pos_embed="perceptron",
        norm_name="instance",
        res_block=True,
        dropout_rate=0.0,
    ).to(device)
    return model

model = getUNETRForSegmentation()
modelname = "labels/seg_labels_unetr_1.pth"

#model = getUNetForSegmentation()
#modelname = "labels/seg_labels_2.pth"

modelname = check_model_name(modelname)
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = getDevice()


In [4]:
def max_channel_wise(labels):
    labels = labels.detach().cpu().numpy()
    labels[np.where(labels == np.amax(labels, axis=1))] = 1
    labels[labels != 1] = 0
    return labels


class GetMaxChannelWise(Transform):
    def __call__(self, labels):
        labels = max_channel_wise(labels)
        return labels

In [5]:
def getSegmentationPostProcessingForLabel():
    device = getDevice()
    postprocessing_transforms = Compose(
        [
            EnsureType(),
            GetMaxChannelWise(),
            EnsureType(data_type="tensor", device=device),
        ]
    )
    return postprocessing_transforms


def getSegmentationPostProcessingForLabelOutput():
    device = getDevice()
    postprocessing_transforms = Compose(
        [
            EnsureType(),
            Activations(other=nn.Softmax(dim=1)),
            GetMaxChannelWise(),
            EnsureType(data_type="tensor", device=device),
        ]
    )
    return postprocessing_transforms

def getSegmentationPostProcessingForAllLabelsOutput():
    device = getDevice()
    postprocessing_transforms = Compose(
        [
            EnsureType(),
            Activations(other=nn.Softmax(dim=1)),
            InverseOneHot(),
            EnsureType(data_type="tensor", device=device),
        ]
    )
    return postprocessing_transforms

In [6]:
#postprocessing
device = getDevice()
outputs_processing = getSegmentationPostProcessingForLabelOutput()
all_outputs_processing = getSegmentationPostProcessingForAllLabelsOutput()
labels_processing = getSegmentationPostProcessingForLabel()
test_proba = Compose([
    EnsureType(),
    Activations(other=nn.Softmax(dim=1)),
    EnsureType(data_type="tensor", device=device)
])

#activations
softmax = Activations(other=nn.Softmax(dim=1))

In [7]:
#Model
loadExistingModel(model, None, ft=modelname)
model.eval()

mris = sorted(glob(os.path.join("dataset", "Painfact-Segmentation", 'MRI', "*.nii.gz")))
for mri in mris:
    mriname = mri
    outname = "labels_output/" + mriname.split('/')[-1].split('.')[0]

    mri_preprocessing = getSegmentationEvalTransformsForMRI(N4=N4, outname=outname, save=save_tmp)

    with torch.no_grad():
        inputs = mri_preprocessing(mriname).to(device)
        outputs = model(inputs).cpu()  

        preds = outputs_processing(outputs)
        all_preds = all_outputs_processing(outputs)
        probs = test_proba(outputs)

        suffix = ["bg", "c1", "c2", "c3"]
        for i in range(len(suffix)):
            transform = getSegmentationInverseTransformForLabels(mriname, out_name=outname, suffix=suffix[i], save=True)
            transform(preds[0, i, :, :, :])        
        #for i in range(len(suffix)):
        #    transform = getSegmentationInverseTransformForLabels(mriname, out_name=outname, suffix=suffix[i]+'_prob', save=True)
        #    transform(probs[0, i, :, :, :])
        transform = getSegmentationInverseTransformForLabels(mriname, out_name=outname, suffix=None, save=True)
        transform(all_preds[0,:,:,:])
    
                

=> Saved to labels_output/000_M_I23_out_bg.nii.gz
=> Saved to labels_output/000_M_I23_out_c1.nii.gz
=> Saved to labels_output/000_M_I23_out_c2.nii.gz
=> Saved to labels_output/000_M_I23_out_c3.nii.gz
=> Saved to labels_output/000_M_I23_out.nii.gz
=> Saved to labels_output/101_M_A1_out_bg.nii.gz
=> Saved to labels_output/101_M_A1_out_c1.nii.gz
=> Saved to labels_output/101_M_A1_out_c2.nii.gz
=> Saved to labels_output/101_M_A1_out_c3.nii.gz
=> Saved to labels_output/101_M_A1_out.nii.gz
=> Saved to labels_output/102_M_A2_out_bg.nii.gz
=> Saved to labels_output/102_M_A2_out_c1.nii.gz
=> Saved to labels_output/102_M_A2_out_c2.nii.gz
=> Saved to labels_output/102_M_A2_out_c3.nii.gz
=> Saved to labels_output/102_M_A2_out.nii.gz
=> Saved to labels_output/103_M_I21_out_bg.nii.gz
=> Saved to labels_output/103_M_I21_out_c1.nii.gz
=> Saved to labels_output/103_M_I21_out_c2.nii.gz
=> Saved to labels_output/103_M_I21_out_c3.nii.gz
=> Saved to labels_output/103_M_I21_out.nii.gz
=> Saved to labels_out

=> Saved to labels_output/177_F_N19_out_c3.nii.gz
=> Saved to labels_output/177_F_N19_out.nii.gz
=> Saved to labels_output/178_F_N17_out_bg.nii.gz
=> Saved to labels_output/178_F_N17_out_c1.nii.gz
=> Saved to labels_output/178_F_N17_out_c2.nii.gz
=> Saved to labels_output/178_F_N17_out_c3.nii.gz
=> Saved to labels_output/178_F_N17_out.nii.gz
=> Saved to labels_output/185_F_N18_out_bg.nii.gz
=> Saved to labels_output/185_F_N18_out_c1.nii.gz
=> Saved to labels_output/185_F_N18_out_c2.nii.gz
=> Saved to labels_output/185_F_N18_out_c3.nii.gz
=> Saved to labels_output/185_F_N18_out.nii.gz
=> Saved to labels_output/186_F_N16_out_bg.nii.gz
=> Saved to labels_output/186_F_N16_out_c1.nii.gz
=> Saved to labels_output/186_F_N16_out_c2.nii.gz
=> Saved to labels_output/186_F_N16_out_c3.nii.gz
=> Saved to labels_output/186_F_N16_out.nii.gz
=> Saved to labels_output/187_M_H20_out_bg.nii.gz
=> Saved to labels_output/187_M_H20_out_c1.nii.gz
=> Saved to labels_output/187_M_H20_out_c2.nii.gz
=> Saved to 

=> Saved to labels_output/241_F_AC6_out.nii.gz
=> Saved to labels_output/242_F_AC7_out_bg.nii.gz
=> Saved to labels_output/242_F_AC7_out_c1.nii.gz
=> Saved to labels_output/242_F_AC7_out_c2.nii.gz
=> Saved to labels_output/242_F_AC7_out_c3.nii.gz
=> Saved to labels_output/242_F_AC7_out.nii.gz
=> Saved to labels_output/243_F_AC8_out_bg.nii.gz
=> Saved to labels_output/243_F_AC8_out_c1.nii.gz
=> Saved to labels_output/243_F_AC8_out_c2.nii.gz
=> Saved to labels_output/243_F_AC8_out_c3.nii.gz
=> Saved to labels_output/243_F_AC8_out.nii.gz
=> Saved to labels_output/244_M_R8_out_bg.nii.gz
=> Saved to labels_output/244_M_R8_out_c1.nii.gz
=> Saved to labels_output/244_M_R8_out_c2.nii.gz
=> Saved to labels_output/244_M_R8_out_c3.nii.gz
=> Saved to labels_output/244_M_R8_out.nii.gz
=> Saved to labels_output/246_M_S10_out_bg.nii.gz
=> Saved to labels_output/246_M_S10_out_c1.nii.gz
=> Saved to labels_output/246_M_S10_out_c2.nii.gz
=> Saved to labels_output/246_M_S10_out_c3.nii.gz
=> Saved to label

=> Saved to labels_output/297_M_AK13_out.nii.gz
=> Saved to labels_output/301_F_AR7_out_bg.nii.gz
=> Saved to labels_output/301_F_AR7_out_c1.nii.gz
=> Saved to labels_output/301_F_AR7_out_c2.nii.gz
=> Saved to labels_output/301_F_AR7_out_c3.nii.gz
=> Saved to labels_output/301_F_AR7_out.nii.gz
=> Saved to labels_output/304_F_AQ2_out_bg.nii.gz
=> Saved to labels_output/304_F_AQ2_out_c1.nii.gz
=> Saved to labels_output/304_F_AQ2_out_c2.nii.gz
=> Saved to labels_output/304_F_AQ2_out_c3.nii.gz
=> Saved to labels_output/304_F_AQ2_out.nii.gz
=> Saved to labels_output/306_M_AJ9_out_bg.nii.gz
=> Saved to labels_output/306_M_AJ9_out_c1.nii.gz
=> Saved to labels_output/306_M_AJ9_out_c2.nii.gz
=> Saved to labels_output/306_M_AJ9_out_c3.nii.gz
=> Saved to labels_output/306_M_AJ9_out.nii.gz
=> Saved to labels_output/307_M_AL14_out_bg.nii.gz
=> Saved to labels_output/307_M_AL14_out_c1.nii.gz
=> Saved to labels_output/307_M_AL14_out_c2.nii.gz
=> Saved to labels_output/307_M_AL14_out_c3.nii.gz
=> Save

=> Saved to labels_output/380_F_AU24_out_c3.nii.gz
=> Saved to labels_output/380_F_AU24_out.nii.gz
=> Saved to labels_output/387_M_AG3_out_bg.nii.gz
=> Saved to labels_output/387_M_AG3_out_c1.nii.gz
=> Saved to labels_output/387_M_AG3_out_c2.nii.gz
=> Saved to labels_output/387_M_AG3_out_c3.nii.gz
=> Saved to labels_output/387_M_AG3_out.nii.gz
=> Saved to labels_output/392_F_AU25_out_bg.nii.gz
=> Saved to labels_output/392_F_AU25_out_c1.nii.gz
=> Saved to labels_output/392_F_AU25_out_c2.nii.gz
=> Saved to labels_output/392_F_AU25_out_c3.nii.gz
=> Saved to labels_output/392_F_AU25_out.nii.gz
=> Saved to labels_output/400_M_AX6_out_bg.nii.gz
=> Saved to labels_output/400_M_AX6_out_c1.nii.gz
=> Saved to labels_output/400_M_AX6_out_c2.nii.gz
=> Saved to labels_output/400_M_AX6_out_c3.nii.gz
=> Saved to labels_output/400_M_AX6_out.nii.gz
=> Saved to labels_output/402_F_BI16_out_bg.nii.gz
=> Saved to labels_output/402_F_BI16_out_c1.nii.gz
=> Saved to labels_output/402_F_BI16_out_c2.nii.gz
=>

=> Saved to labels_output/461_M_BB18_out_c2.nii.gz
=> Saved to labels_output/461_M_BB18_out_c3.nii.gz
=> Saved to labels_output/461_M_BB18_out.nii.gz
=> Saved to labels_output/463_f_BG8_out_bg.nii.gz
=> Saved to labels_output/463_f_BG8_out_c1.nii.gz
=> Saved to labels_output/463_f_BG8_out_c2.nii.gz
=> Saved to labels_output/463_f_BG8_out_c3.nii.gz
=> Saved to labels_output/463_f_BG8_out.nii.gz
=> Saved to labels_output/466_F_BH16_out_bg.nii.gz
=> Saved to labels_output/466_F_BH16_out_c1.nii.gz
=> Saved to labels_output/466_F_BH16_out_c2.nii.gz
=> Saved to labels_output/466_F_BH16_out_c3.nii.gz
=> Saved to labels_output/466_F_BH16_out.nii.gz
=> Saved to labels_output/468_M_BE25_out_bg.nii.gz
=> Saved to labels_output/468_M_BE25_out_c1.nii.gz
=> Saved to labels_output/468_M_BE25_out_c2.nii.gz
=> Saved to labels_output/468_M_BE25_out_c3.nii.gz
=> Saved to labels_output/468_M_BE25_out.nii.gz
=> Saved to labels_output/470_F_BG9_out_bg.nii.gz
=> Saved to labels_output/470_F_BG9_out_c1.nii.gz

=> Saved to labels_output/539_F_BX17_out_c1.nii.gz
=> Saved to labels_output/539_F_BX17_out_c2.nii.gz
=> Saved to labels_output/539_F_BX17_out_c3.nii.gz
=> Saved to labels_output/539_F_BX17_out.nii.gz
=> Saved to labels_output/540_M_BL5_out_bg.nii.gz
=> Saved to labels_output/540_M_BL5_out_c1.nii.gz
=> Saved to labels_output/540_M_BL5_out_c2.nii.gz
=> Saved to labels_output/540_M_BL5_out_c3.nii.gz
=> Saved to labels_output/540_M_BL5_out.nii.gz
=> Saved to labels_output/542_M_BO13_out_bg.nii.gz
=> Saved to labels_output/542_M_BO13_out_c1.nii.gz
=> Saved to labels_output/542_M_BO13_out_c2.nii.gz
=> Saved to labels_output/542_M_BO13_out_c3.nii.gz
=> Saved to labels_output/542_M_BO13_out.nii.gz
=> Saved to labels_output/548_F_BW14_out_bg.nii.gz
=> Saved to labels_output/548_F_BW14_out_c1.nii.gz
=> Saved to labels_output/548_F_BW14_out_c2.nii.gz
=> Saved to labels_output/548_F_BW14_out_c3.nii.gz
=> Saved to labels_output/548_F_BW14_out.nii.gz
=> Saved to labels_output/549_F_BV7_out_bg.nii.g

=> Saved to labels_output/595_M_CF16_out_bg.nii.gz
=> Saved to labels_output/595_M_CF16_out_c1.nii.gz
=> Saved to labels_output/595_M_CF16_out_c2.nii.gz
=> Saved to labels_output/595_M_CF16_out_c3.nii.gz
=> Saved to labels_output/595_M_CF16_out.nii.gz
=> Saved to labels_output/597_F_CK11_out_bg.nii.gz
=> Saved to labels_output/597_F_CK11_out_c1.nii.gz
=> Saved to labels_output/597_F_CK11_out_c2.nii.gz
=> Saved to labels_output/597_F_CK11_out_c3.nii.gz
=> Saved to labels_output/597_F_CK11_out.nii.gz
=> Saved to labels_output/600_F_CM6_out_bg.nii.gz
=> Saved to labels_output/600_F_CM6_out_c1.nii.gz
=> Saved to labels_output/600_F_CM6_out_c2.nii.gz
=> Saved to labels_output/600_F_CM6_out_c3.nii.gz
=> Saved to labels_output/600_F_CM6_out.nii.gz
=> Saved to labels_output/602_M_CB7_out_bg.nii.gz
=> Saved to labels_output/602_M_CB7_out_c1.nii.gz
=> Saved to labels_output/602_M_CB7_out_c2.nii.gz
=> Saved to labels_output/602_M_CB7_out_c3.nii.gz
=> Saved to labels_output/602_M_CB7_out.nii.gz
=>

=> Saved to labels_output/670_M_CI25_out.nii.gz
=> Saved to labels_output/673_F_CK15_out_bg.nii.gz
=> Saved to labels_output/673_F_CK15_out_c1.nii.gz
=> Saved to labels_output/673_F_CK15_out_c2.nii.gz
=> Saved to labels_output/673_F_CK15_out_c3.nii.gz
=> Saved to labels_output/673_F_CK15_out.nii.gz
=> Saved to labels_output/675_F_CM10_out_bg.nii.gz
=> Saved to labels_output/675_F_CM10_out_c1.nii.gz
=> Saved to labels_output/675_F_CM10_out_c2.nii.gz
=> Saved to labels_output/675_F_CM10_out_c3.nii.gz
=> Saved to labels_output/675_F_CM10_out.nii.gz
=> Saved to labels_output/678_M_BZ1_out_bg.nii.gz
=> Saved to labels_output/678_M_BZ1_out_c1.nii.gz
=> Saved to labels_output/678_M_BZ1_out_c2.nii.gz
=> Saved to labels_output/678_M_BZ1_out_c3.nii.gz
=> Saved to labels_output/678_M_BZ1_out.nii.gz
=> Saved to labels_output/679_M_BZ2_out_bg.nii.gz
=> Saved to labels_output/679_M_BZ2_out_c1.nii.gz
=> Saved to labels_output/679_M_BZ2_out_c2.nii.gz
=> Saved to labels_output/679_M_BZ2_out_c3.nii.gz
=