In [4]:
import argparse
import logging
import sys
import numpy as np
import torch
from utils import compute_mean_dice
import pandas as pd
import os
from glob import glob
from matplotlib import pyplot as plt
import monai
import torchinfo
import nibabel as nib
from monai.transforms import AsDiscrete
from miseval import evaluate

In [10]:
template_mask = os.path.join("dataset2", "Atlas", "Identity_Feminad_Template_Mask.nii.gz")
template_mask = torch.from_numpy(nib.load(template_mask).get_fdata().reshape(1,128,128,128))
template_mask = AsDiscrete(threshold=0.5)(template_mask)
template_mask = template_mask.cpu().numpy().squeeze()
original_masks = sorted(glob(os.path.join("dataset2", "Feminad", "Mask_Resample_Identity", "*.nii.gz")))
affine_masks = sorted(glob(os.path.join("dataset2", "Feminad", "Mask_Resample_Identity_Affine", "*.nii.gz")))
ants_masks = sorted(glob(os.path.join("dataset2", "Feminad", "Mask_Resample_Identity_Deformable", "*.nii.gz")))


names = ["No registration", "Affine", "SyN"]
masks_lists = [original_masks, affine_masks, ants_masks]
dice = 0
hausdorf = 0
iou = 0
sens = 0
spec = 0
for z, masks_list in enumerate(masks_lists):
    for i in range(len(masks_list)):
        print(i, end='\r')
        mask = torch.from_numpy(nib.load(masks_list[i]).get_fdata().reshape(1,128,128,128))
        mask = AsDiscrete(threshold=0.5)(mask)
        mask = mask.cpu().numpy().squeeze()

        dice += evaluate(template_mask, mask, metric="DSC") 
        hausdorf += evaluate(template_mask, mask, metric="AHD")  
        iou += evaluate(template_mask, mask, metric="IoU")    
        sens += evaluate(template_mask, mask, metric="SENS")
        spec += evaluate(template_mask, mask, metric="SPEC")
    dice /= len(masks_list)
    hausdorf /= len(masks_list)
    iou /= len(masks_list)
    sens /= len(masks_list)
    spec /= len(masks_list)
    print(names[z])
    print(
        "Labels: Dice: {:.4f} / Haussdorf: {:.4f} / IoU: {:.4f} / Sens: {:.4f} / Spec: {:.4f}".format(
            dice, hausdorf, iou, sens, spec
        )
    )
    print('-'*20)


No registration
Labels: Dice: 0.8880 / Haussdorf: 10.6292 / IoU: 0.8002 / Sens: 0.8953 / Spec: 0.9818
--------------------
Affine
Labels: Dice: 0.9772 / Haussdorf: 8.7946 / IoU: 0.9299 / Sens: 0.9876 / Spec: 1.0205
--------------------
SyN
Labels: Dice: 0.9884 / Haussdorf: 8.5734 / IoU: 0.9495 / Sens: 0.9970 / Spec: 1.0233
--------------------


In [9]:
#dice = 0
#hausdorf = 0
#iou = 0
#sens = 0
#spec = 0
#for i in range(len(affine_masks)):
#    print(i, end='\r')
#    mask = torch.from_numpy(nib.load(affine_masks[i]).get_fdata().reshape(1,128,128,128))
#    mask = AsDiscrete(threshold=0.5)(mask)
#    mask = mask.cpu().numpy().squeeze()
#    
#    dice += evaluate(template_mask, mask, metric="DSC")  
#    hausdorf += evaluate(template_mask, mask, metric="AHD")  
#    iou += evaluate(template_mask, mask, metric="IoU")    
#    sens += evaluate(template_mask, mask, metric="SENS")
#    spec += evaluate(template_mask, mask, metric="SPEC")
#dice /= len(affine_masks)
#hausdorf /= len(affine_masks)
#iou /= len(affine_masks)
#sens /= len(affine_masks)
#spec /= len(affine_masks)
#
#print('Affine:')
#print(
#    "Labels: Dice: {:.4f} / Haussdorf: {:.4f} / IoU: {:.4f} / Sens: {:.4f} / Spec: {:.4f}".format(
#        dice, hausdorf, iou, sens, spec
#    )
#)
#print('-'*20)
#
#dice = 0
#hausdorf = 0
#iou = 0
#sens = 0
#spec = 0
#for i in range(len(ants_masks)):
#    print(i, end='\r')
#    mask = torch.from_numpy(nib.load(ants_masks[i]).get_fdata().reshape(1,128,128,128))
#    mask = AsDiscrete(threshold=0.5)(mask)
#    mask = mask.cpu().numpy().squeeze()
#    
#    dice += evaluate(template_mask, mask, metric="DSC")   
#    hausdorf += evaluate(template_mask, mask, metric="AHD")  
#    iou += evaluate(template_mask, mask, metric="IoU")    
#    sens += evaluate(template_mask, mask, metric="SENS")
#    spec += evaluate(template_mask, mask, metric="SPEC")
#    
#dice /= len(ants_masks)
#hausdorf /= len(ants_masks)
#iou /= len(ants_masks)
#sens /= len(ants_masks)
#spec /= len(ants_masks)
#
#print('SyN:')
#print(
#    "Labels: Dice: {:.4f} / Haussdorf: {:.4f} / IoU: {:.4f} / Sens: {:.4f} / Spec: {:.4f}".format(
#        dice, hausdorf, iou, sens, spec
#    )
#)
#print('-'*20)