In [2]:
import argparse
import logging
import sys

import monai
import torch
from torch.utils.tensorboard import SummaryWriter
from monai.utils import set_determinism
import numpy as np

from seg_data import getSegmentationDataset
from seg_model import getUNetForSegmentation, getUNETRForSegmentation
from transforms_dict import getSegmentationPostProcessingForLabel, getSegmentationPostProcessingForLabelOutput
from utils import compute_mean_dice, getReducePlateauScheduler, getAdamOptimizer, loadExistingModel
from utils import print_model_output, check_model_name, getDevice, getWorst, getBest

from monai.inferers import sliding_window_inference
from monai.data import decollate_batch

In [3]:
#Parameters
dataset = "IRIS"
phase = "test"
number = 10
verbose = True
augment = False
n4 = False

In [4]:
#Modelname and device
torch.multiprocessing.set_sharing_strategy('file_system')
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = getDevice()

In [5]:
from monai.transforms import Activations
from torch import nn

#Loss
def loss_munet(preds, labels):
    dice = 1-torch.div(
        torch.sum(torch.mul(torch.mul(labels,preds),2)),
        torch.sum(torch.mul(preds,preds)) + torch.sum(torch.mul(labels,labels))
        )    
    return dice

loss_GDice = monai.losses.GeneralizedDiceLoss(other_act=nn.Softmax(dim=1))
loss_DiceCE = monai.losses.DiceCELoss(other_act=nn.Softmax(dim=1))

def loss_CE(input, target):
        n_pred_ch, n_target_ch = input.shape[1], target.shape[1]
        if n_pred_ch == n_target_ch:
            target = torch.argmax(target, dim=1)
        else:
            target = torch.squeeze(target, dim=1)
        target = target.long()
        return nn.CrossEntropyLoss(reduction="mean")(input, target)
    
def loss_GDiceCE(input, target, lambda_gdice=1.0, lambda_ce=0.5):    
    GDice = loss_GDice(input, target)
    CE = loss_CE(input, target)
    GDiceCELoss = lambda_gdice*GDice + lambda_ce*CE
    return GDiceCELoss

In [6]:
#Metric MUNet
def metric_munet(preds, labels):  
    labels = labels.detach().cpu().numpy()
    preds = preds.detach().cpu().numpy()    
    labels[np.where(labels == np.amax(labels, axis=0))] = 1
    labels[labels != 1] = 0
    dice=2*np.sum(labels*preds,(1,2,3))/(np.sum((labels+preds),(1,2,3))+1)    
    return dice

In [7]:
#postprocessing
outputs_processing = getSegmentationPostProcessingForLabelOutput()
labels_processing = getSegmentationPostProcessingForLabel()

#activations
softmax = Activations(other=nn.Softmax(dim=1))

In [8]:
#dataloaders
dataloaders, size = getSegmentationDataset(dataset=dataset, batch=1, augment=augment, training=False, n4=n4, labels=True, eval_augment=False)

=> Using IRIS dataset.


Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 248/248 [00:00<00:00, 397929.38it/s]
Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 82/82 [00:00<00:00, 266408.15it/s]
Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 82/82 [00:00<00:00, 494156.51it/s]


In [13]:
model = getUNetForSegmentation()
modelname = "test_seg_labels.pth"
#modelnames = 


#########################################

modelname = check_model_name(modelname)
loadExistingModel(model, None, ft=modelname)
model.eval()

scores_list = []

with torch.no_grad():
    metrics = [[], [], [], []]
    running_loss = 0.0
    for i, data in enumerate(dataloaders[phase]):
        print("{}/{}".format(
            i, len(dataloaders[phase])), end='\r'
        )
        inputs, labels = data["img"].to(device), data["seg"].to(device)
        filename = data['img_meta_dict']['filename_or_obj'][0].split('/')[-1]
        labels = labels.squeeze(2)
        outputs = sliding_window_inference(inputs, (96, 96, 96), 4, model)                
        loss = loss_GDiceCE(outputs, labels)
        #loss = loss_DiceCE(outputs, labels)
        preds = [outputs_processing(pred) for pred in decollate_batch(outputs)]
        labels = [labels_processing(label) for label in decollate_batch(labels)]
        running_loss += loss.item() * inputs.size(0)                       
        
        for j in range(len(preds)):
            metric = metric_munet(preds[j], labels[j])                     
            for k in range(4):
                metrics[k].append(metric[k])     
        metric_mean = np.mean(metric)
        scores_list.append([filename, loss.item(), metric_mean])
            
    running_loss /= size[phase]
    epoch_metrics = [np.mean(x) for x in metrics]
    epoch_metric = np.mean(epoch_metrics)
    print(
        "{}: loss: {:.4f}, dice: {:.4f}".format(
            phase, running_loss, epoch_metric
        )
    )
    print("dices: {:.4f}, {:.4f}, {:.4f}, {:.4f}".format(epoch_metrics[0], epoch_metrics[1], epoch_metrics[2], epoch_metrics[3]))

getWorst(scores_list, number)
getBest(scores_list, number)

0/82

AssertionError: ground truth has differing shape ((1, 1, 128, 128, 128)) from input ((1, 4, 128, 128, 128))