In [None]:
import argparse
import logging
import sys

import monai
import torch
from torch.utils.tensorboard import SummaryWriter
from monai.utils import set_determinism
import numpy as np

from seg_data import getDataset
from seg_model import getUNetForSegmentation, getUNETRForSegmentation
from transforms_dict import getSegmentationPostProcessingForLabel, getSegmentationPostProcessingForLabelOutput, getSegmentationPostProcessingForAllLabelsOutput
from utils import compute_mean_dice, getReducePlateauScheduler, getAdamOptimizer, loadExistingModel
from utils import print_model_output, check_model_name, getDevice

In [None]:
#Parameters
mriname = "000_M_I23.nii.gz"
outname = mriname.split('.')[0]
N4 = False
resample = False
save_tmp = False
get_largest_component = False

In [None]:
#Modelname and device

model = getUNETRForSegmentation()
modelname = "labels/seg_labels_unetr_1.pth"

#model = getUNetForSegmentation()
#modelname = "labels/seg_labels_2.pth"

modelname = check_model_name(modelname)
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = getDevice()


In [None]:
from monai.transforms import Activations, EnsureType, Compose
from torch import nn
from transforms_dict import getSegmentationEvalTransformsForMRI, getSegmentationPostProcessingForMaskOutput
from transforms_dict import getSegmentationInverseTransformForLabels, SaveTransformForMRI

In [None]:
#postprocessing
device = getDevice()
outputs_processing = getSegmentationPostProcessingForLabelOutput()
all_outputs_processing = getSegmentationPostProcessingForAllLabelsOutput()
labels_processing = getSegmentationPostProcessingForLabel()
test_proba = Compose([
    EnsureType(),
    Activations(other=nn.Softmax(dim=1)),
    #GetMaxChannelWise(),
    EnsureType(data_type="tensor", device=device)
])

#activations
softmax = Activations(other=nn.Softmax(dim=1))

In [None]:
#Model
loadExistingModel(model, None, ft=modelname)
model.eval()

mri_preprocessing = getSegmentationEvalTransformsForMRI(N4=N4, outname=outname, save=save_tmp)

with torch.no_grad():
    inputs = mri_preprocessing(mriname).to(device)
    outputs = model(inputs).cpu()  
    
    preds = outputs_processing(outputs)
    all_preds = all_outputs_processing(outputs)
    probs = test_proba(outputs)
    
    suffix = ["background", "grey", "white", "csf"]
    for i in range(len(suffix)):
        transform = getSegmentationInverseTransformForLabels(mriname, out_name=outname, suffix=suffix[i], save=True)
        transform(preds[0, i, :, :, :])        
    for i in range(len(suffix)):
        transform = getSegmentationInverseTransformForLabels(mriname, out_name=outname, suffix=suffix[i]+'_prob', save=True)
        transform(probs[0, i, :, :, :])
    transform = getSegmentationInverseTransformForLabels(mriname, out_name=outname, suffix=None, save=True)
    transform(all_preds[0,:,:,:])
    
                