In [1]:
import argparse
import logging
import sys

import numpy as np
import torch
from monai.transforms import AsDiscrete, MaskIntensity
from monai.utils import set_determinism
from torch.utils.tensorboard import SummaryWriter
import random

import utils_parser
from reg_data import getRegistrationDataset
from reg_model import getRegistrationModel
from utils import compute_mean_dice, getAdamOptimizer, getReducePlateauScheduler, loadExistingModel, getDevice, compute_mean_dice
from utils import print_model_output, print_weights, add_weights_to_name, compute_landmarks_distance_local, compute_csv_distance
from loss import get_deformable_registration_loss_from_weights, get_affine_registration_loss_from_weights
import pandas as pd
import os
from glob import glob
from matplotlib import pyplot as plt
import monai
import torchinfo
from miseval import evaluate
import nibabel as nib

In [3]:
atlas_name = "dataset2/Atlas/Identity_Neatin_MRI_A9.nii.gz"
atlas_affine = nib.load(atlas_name).affine
atlas_header = nib.load(atlas_name).header 
z = 0
names = [
        #"neatin-1.2",
        "neatin-1.1-continue",
        #"neatin-1.1-masked",
        #"neatin-finetunepainfact",
        ]
for modelname in [
    #"neatin_scenario1_1.0-0.0-2.0.pth",
    "neatin_scenario1_continue_1.0-0.0-1.0.pth",
    #"neatin_scenario1_maskedloss_1.0-0.0-1.0.pth",
    #"neatin_scenario3_1.0-0.0-1.0.pth",
           
                 ]:
    dataset = "neatinaffine"
    batchsize = 1
    registration_type = "local"
    atlas=True
    mask=False
    torch.multiprocessing.set_sharing_strategy('file_system')
    torch.backends.cudnn.benchmark = True
    
    set_determinism(seed=0)
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    device = getDevice()
    
    
    if "ddf" in modelname:
        use_ddf = True
    else:
        use_ddf = False
    model = getRegistrationModel(registration_type, img_size=128, channels=32, extract=[0,1,2,3,4], use_ddf=use_ddf)
    weights = loadExistingModel(model, None, ft=modelname, registration=True)
    model.eval()

    dataloaders, size = getRegistrationDataset(dataset=dataset,
                                               batch=batchsize,
                                               training=False,
                                               augment=False,
                                               eval_augment=False,
                                               atlas=atlas,
                                               mask=mask,
                                               )

    running_metric = np.zeros(41)
    running_loss = 0.0
    
    running_dice = 0.0
    running_hausdorf = 0.0
    running_iou = 0.0
    running_sens = 0.0
    running_spec = 0.0
    
    running_meanl = 0.0
    running_maxl = 0.0
    running_minl = 0.0
    phase="test"
    
    mouses_df = pd.DataFrame()
    with torch.no_grad():
        for i, data in enumerate(dataloaders[phase]):
            print(i, end='\r')
            with torch.set_grad_enabled(phase == 'train'):
                rtype = registration_type.lower()
                if rtype == 'affine' or rtype == 'local':
                    ddf, pred_image, pred_label, _ = model(data)
                elif rtype == 'deformable':
                    affine_ddf, ddf, pred_image, pred_label, affine_image, affine_label = model(data)

                pred_image = pred_image.to(device, non_blocking=True)
                pred_label = pred_label.to(device, non_blocking=True)
                pred_mask = AsDiscrete(threshold=0.5)(pred_label)
                pred_mask_np = pred_mask.cpu().numpy().squeeze()

                fixed_image = data['fixed_image'].to(device, non_blocking=True)
                fixed_label = data['fixed_label'].to(device, non_blocking=True)
                fixed_mask = AsDiscrete(threshold=0.5)(fixed_label)
                fixed_mask_np = fixed_mask.cpu().numpy().squeeze()
                
                fixed_regions = data['fixed_regions'].to(device, non_blocking=True)
                fixed_regions_np = fixed_regions.cpu().detach().numpy().squeeze()
                moving_regions = data['moving_regions'].to(device, dtype=torch.float, non_blocking=True)
                pred_regions = model.warp_nearest(moving_regions, ddf)
                pred_regions_np = pred_regions.cpu().detach().numpy().squeeze()
                
                labelwarped = nib.Nifti1Image(pred_regions_np, atlas_affine, atlas_header)
                outname_label = "output/Neatin/neatin-1.1-continue/Label_A" + str(i) + ".nii.gz"
                nib.save(labelwarped, outname_label)
        
                weights = [1, 0, 1] # to compare loss between methods
                img_loss, lbl_loss, ddf_loss = get_deformable_registration_loss_from_weights(pred_image, pred_label,
                                                                                                 fixed_image, fixed_label,
                                                                                                 ddf, weights)
                loss = img_loss + lbl_loss + ddf_loss
                
                ##################################
                metric = evaluate(fixed_regions_np, pred_regions_np, metric="DSC", multi_class=True, n_classes=41)               
                
                dice = evaluate(fixed_mask_np, pred_mask_np, metric="DSC")   
                hausdorf = evaluate(fixed_mask_np, pred_mask_np, metric="AHD")  
                iou = evaluate(fixed_mask_np, pred_mask_np, metric="IoU")    
                sens = evaluate(fixed_mask_np, pred_mask_np, metric="SENS")
                spec = evaluate(fixed_mask_np, pred_mask_np, metric="SPEC")
                
                ##################################
            
            
            running_metric += metric
            
            running_loss += loss.item() * fixed_image.size(0)
            
            running_dice += dice.item() * fixed_image.size(0)
            running_hausdorf += hausdorf.item() * fixed_image.size(0)
            running_iou  += iou.item() * fixed_image.size(0)
            running_sens += sens.item() * fixed_image.size(0)
            running_spec += spec.item() * fixed_image.size(0)

            row = {
                    'mouse': data['moving_image_meta_dict']['filename_or_obj'][0].split('/')[-1],
                    'dsc_'+str(names[z]): dice,
                    'ahd_'+str(names[z]): hausdorf,
                    'iou_'+str(names[z]): iou,
                    'sens_'+str(names[z]): sens,
                    'spec_'+str(names[z]): spec,                
            }
            mouse_df = pd.DataFrame(data=row, index=[0])
            mouses_df = pd.concat([mouses_df, mouse_df], ignore_index=True)


    sizelol = 9  
    for i in range(len(running_metric)):
        running_metric[i] = running_metric[i] / sizelol
    print(running_metric)
    print(np.mean(running_metric))
    running_loss /= sizelol
    running_dice /= sizelol
    running_hausdorf /= sizelol
    running_iou /= sizelol
    running_sens /= sizelol
    running_spec /= sizelol
    print(names[z])
    print(
        "Loss: loss: {:.4f}".format(
            running_loss, 
        )
    )
    print(
        "Labels: Dice: {:.4f} / Haussdorf: {:.4f} / IoU: {:.4f} / Sens: {:.4f} / Spec: {:.4f}".format(
            running_dice, running_hausdorf, running_iou, running_sens, running_spec
        )
    )
    #print(mouses_df)
    #mouses_df.to_csv("models/" + modelname.split('/')[-1].split('.pth')[0] + '.csv', index=False)
    print('-'*20)    
    z += 1

=> Using Neatin affine registered dataset.


Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 43996.20it/s]
Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 74455.10it/s]
Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 84260.57it/s]


[0.99727056 0.86486067 0.58845492 0.84394648 0.4715059  0.74089494
 0.63118029 0.8997552  0.91591427 0.81402929 0.61947485 0.85406698
 0.8056509  0.70263117 0.88513144 0.80211086 0.84504462 0.91612041
 0.79309199 0.83149811 0.58433839 0.84687756 1.         0.77410761
 0.40787836 0.59477173 0.54910177 0.87722295 0.90913353 0.79511465
 1.         0.83446216 0.8135662  0.66164304 0.86709526 0.75799324
 0.82213567 1.         0.79567642 0.80789283 0.4927063 ]
0.7808378420205065
neatin-1.1-continue
Loss: loss: 0.2422
Labels: Dice: 0.9670 / Haussdorf: 23.4120 / IoU: 0.9362 / Sens: 0.9789 / Spec: 0.9961
--------------------


In [27]:
template_mask = os.path.join("dataset2", "Atlas", "Identity_Neatin_Mask_A9.nii.gz")
template_mask = torch.from_numpy(nib.load(template_mask).get_fdata().reshape(1,128,128,128))
template_mask = AsDiscrete(threshold=0.5)(template_mask)
template_mask = template_mask.cpu().numpy().squeeze()

template_labels = os.path.join("dataset2", "Atlas", "Identity_Neatin_Label_A9.nii.gz")
template_labels = nib.load(template_labels).get_fdata().squeeze()

original_masks = sorted(glob(os.path.join("dataset2", "Neatin", "Mask_Resample_Identity", "*.nii.gz")))
original_labels = sorted(glob(os.path.join("dataset2", "Neatin", "Label_Resample_Identity", "*.nii.gz")))
affine_masks   = sorted(glob(os.path.join("dataset2", "Neatin", "Mask_Resample_Identity_Affine", "*.nii.gz")))
affine_labels = sorted(glob(os.path.join("dataset2", "Neatin", "Label_Resample_Identity_Affine", "*.nii.gz")))
ants_masks   = sorted(glob(os.path.join("dataset2", "Neatin", "Mask_Resample_Identity_Affine_Deformable", "*.nii.gz")))
ants_labels = sorted(glob(os.path.join("dataset2", "Neatin", "Label_Resample_Identity_Affine_Deformable", "*.nii.gz")))


In [24]:
dices = np.zeros(41)
dice = 0
hausdorf = 0
iou = 0
sens = 0
spec = 0
for i in range(len(original_masks)):
    mask = torch.from_numpy(nib.load(original_masks[i]).get_fdata().reshape(1,128,128,128))
    mask = AsDiscrete(threshold=0.5)(mask)
    mask = mask.cpu().numpy().squeeze()
    
    labels = nib.load(original_labels[i]).get_fdata().squeeze()  
    
    dices += evaluate(template_labels, labels, metric="DSC", multi_class=True, n_classes=41)  
    
    dice += evaluate(template_mask, mask, metric="DSC") 
    hausdorf += evaluate(template_mask, mask, metric="AHD")  
    iou += evaluate(template_mask, mask, metric="IoU")    
    sens += evaluate(template_mask, mask, metric="SENS")
    spec += evaluate(template_mask, mask, metric="SPEC")

for i in range(len(dices)):
    dices[i] = dices[i] / len(original_masks)
dice /= len(original_masks)
hausdorf /= len(original_masks)
iou /= len(original_masks)
sens /= len(original_masks)
spec /= len(original_masks)

print('No registration:')
print(dices)
print(np.mean(dices))
print(dice)
print(hausdorf)
print(iou)
print(sens)
print(spec)
print('-'*20)

No registration:
[0.9962349  0.87878197 0.72694881 0.92326183 0.53373324 0.77756274
 0.61172233 0.91490087 0.83510436 0.84315548 0.54773777 0.86936324
 0.77427977 0.73019434 0.89348308 0.8617775  0.84263668 0.91405621
 0.90206021 0.90736144 0.50032562 0.86562254 1.         0.91422581
 0.5147797  0.75474318 0.50137138 0.90842942 0.83656034 0.81963978
 1.         0.86953462 0.76898843 0.66215499 0.87319535 0.89036901
 0.83584773 1.         0.89232392 0.898246   0.4410209 ]
0.8056520847810733
0.9519918553580354
25.312776468601115
0.9085492909334789
0.9336248444884715
0.9976215063662559
--------------------


In [25]:
dices = np.zeros(41)
dice = 0
hausdorf = 0
iou = 0
sens = 0
spec = 0
for i in range(len(affine_masks)):
    mask = torch.from_numpy(nib.load(affine_masks[i]).get_fdata().reshape(1,128,128,128))
    mask = AsDiscrete(threshold=0.5)(mask)
    mask = mask.cpu().numpy().squeeze()
    
    labels = nib.load(affine_labels[i]).get_fdata().squeeze()  
    
    dices += evaluate(template_labels, labels, metric="DSC", multi_class=True, n_classes=41)  
    
    dice += evaluate(template_mask, mask, metric="DSC")  
    hausdorf += evaluate(template_mask, mask, metric="AHD")  
    iou += evaluate(template_mask, mask, metric="IoU")    
    sens += evaluate(template_mask, mask, metric="SENS")
    spec += evaluate(template_mask, mask, metric="SPEC")
for i in range(len(dices)):
    dices[i] = dices[i] / len(original_masks)
dice /= len(affine_masks)
hausdorf /= len(affine_masks)
iou /= len(affine_masks)
sens /= len(affine_masks)
spec /= len(affine_masks)

print('Affine:')
print(dices)
print(np.mean(dices))
print(dice)
print(hausdorf)
print(iou)
print(sens)
print(spec)
print('-'*20)

Affine:
[0.99725976 0.85468886 0.61163671 0.87524648 0.50411536 0.75864602
 0.59698906 0.90196798 0.90831266 0.82572163 0.60204873 0.86189295
 0.80367353 0.70932333 0.88813969 0.80123591 0.82939611 0.91673686
 0.79205266 0.84668662 0.58398994 0.86505483 1.         0.80736001
 0.40238604 0.64291276 0.59338206 0.89309892 0.90813504 0.81644566
 1.         0.84251292 0.82479673 0.66216738 0.87832631 0.80904969
 0.80931045 1.         0.80428086 0.82054416 0.53242658]
0.7898036886361889
0.965758134584607
23.44314440785517
0.933808066289877
0.9617893581751535
0.9974108623151634
--------------------


In [28]:
dices = np.zeros(41)
dice = 0
hausdorf = 0
iou = 0
sens = 0
spec = 0
for i in range(len(ants_masks)):
    mask = torch.from_numpy(nib.load(ants_masks[i]).get_fdata().reshape(1,128,128,128))
    mask = AsDiscrete(threshold=0.5)(mask)
    mask = mask.cpu().numpy().squeeze()
    
    labels = nib.load(ants_labels[i]).get_fdata().squeeze()  
    
    dices += evaluate(template_labels, labels, metric="DSC", multi_class=True, n_classes=41)  
    
    dice += evaluate(template_mask, mask, metric="DSC")   
    hausdorf += evaluate(template_mask, mask, metric="AHD")  
    iou += evaluate(template_mask, mask, metric="IoU")    
    sens += evaluate(template_mask, mask, metric="SENS")
    spec += evaluate(template_mask, mask, metric="SPEC")
    
for i in range(len(dices)):
    dices[i] = dices[i] / len(original_masks)
dice /= len(ants_masks)
hausdorf /= len(ants_masks)
iou /= len(ants_masks)
sens /= len(ants_masks)
spec /= len(ants_masks)

print('SyN:')
print(dices)
print(np.mean(dices))
print(dice)
print(hausdorf)
print(iou)
print(sens)
print(spec)
print('-'*20)

SyN:
[0.99693309 0.84464335 0.56905685 0.83846545 0.46248764 0.71252112
 0.58013421 0.89486696 0.90598678 0.81734694 0.58328378 0.84370097
 0.78335958 0.65412979 0.8766218  0.76943282 0.84501271 0.89960425
 0.76613274 0.81764587 0.56158182 0.83061914 1.         0.76834595
 0.41988586 0.55867927 0.49024419 0.87658156 0.89362757 0.80950242
 1.         0.81219093 0.81165326 0.66109051 0.86227324 0.7150862
 0.83129874 1.         0.76992943 0.81096145 0.51078193]
0.7672121994997342
0.9614102092945284
23.762033440650765
0.9257737868092324
0.9742582975615126
0.9954546351849349
--------------------
