In [None]:
%load_ext autoreload

import re, time, os, shutil, json, math
import numpy as np
import configdot
from tqdm import tqdm
import monai
from monai.data import DataLoader, Dataset, list_data_collate, decollate_batch
import nibabel as nib
import nilearn
from nilearn import plotting
from matplotlib.pyplot import figure
import matplotlib.pyplot as plt
plt.rcParams['font.size'] = 20

from collections import defaultdict
from IPython.core.debugger import set_trace
import pandas as pd
from nibabel.freesurfer.io import read_morph_data

import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
import torch.optim as optim
from models.v2v import V2VModel

from losses import *
from dataset import setup_dataloaders, create_datafile, setup_datafiles, setup_transformations
from utils import save, get_capacity, calc_gradient_norm, get_label, get_latest_weights


from metrics import calculate_metrics

import warnings
warnings.filterwarnings("ignore")
%autoreload 2

In [None]:
pd.set_option('display.max_rows', 100)
SEED = 42
USE_nG = True

In [None]:
LOGDIR = '/workspace/RawData/FCDNet/logs'
"""
log_dir_iter = ['stash/t1_all/v2v-IN_autocast_DICE_lr-1e-3_nG-bs2-AUG-MASK-to-all-imgch-t1-all_scaler-trial4@29.07.2022-11',
                'stash/t1_cr-flair/v2v-IN_autocast_DICE_lr-1e-3_nG-bs2-AUG-MASK-to-all-imgch-t1-cr-flair_scaler-trial4@19.07.2022-09',
                'stash/t1_blurring-flair/v2v-IN_autocast_DICE_lr-1e-3_nG-bs2-AUG-MASK-to-all-imgch-t1-blurring-flair_scaler-trial3@18.07.2022-10',
                'stash/t1_entropy/v2v-IN_autocast_DICE_lr-1e-3_nG-bs2-AUG-MASK-to-all-imgch-t1-entropy_scaler-trial5@24.07.2022-08',
                'stash/t1_blurring-t2/v2v-IN_autocast_DICE_lr-1e-3_nG-bs2-AUG-MASK-to-all-imgch-t1-blurring-t2_scaler-trial2@25.07.2022-07',
                'stash/t1_blurring-t1/v2v-IN_autocast_DICE_lr-1e-3_nG-bs2-AUG-MASK-to-all-imgch-t1-blurring-t1_scaler-trial3@17.07.2022-15',
                'stash/t1_variance/v2v-IN_autocast_DICE_lr-1e-3_nG-bs2-AUG-MASK-to-all-imgch-t1-variance_scaler-trial3@22.07.2022-10',
                'stash/t1_sulc/v2v-IN_autocast_DICE_lr-1e-3_nG-bs2-AUG-MASK-to-all-imgch-t1-sulc_scaler-rerun_trail4@10.08.2022-20',
                'stash/t1_thickness/v2v-IN_autocast_DICE_lr-1e-3_nG-bs2-AUG-MASK-to-all-imgch-t1-thickness_scaler-trial1@19.07.2022-09',
                'stash/t1_curv/v2v-IN_autocast_DICE_lr-1e-3_nG-bs2-AUG-MASKint-t1-curv_scaler-trial3@22.08.2022-16']
"""
log_dir_iter = ['debug/loss/v2v-IN_autocast_TL-d-0.9_lr-1e-3_nG-bs2-AUG-MASKint-t1-all_scaler_minmax-prep@30.09.2022-14']

In [None]:
from statistics import mean, median
from collections import OrderedDict
import seaborn as sns

DEVICE = 'cuda:1' # 'cuda:1' #'' # 'cpu'
device = torch.device(DEVICE)
test_files = None
fold_metrics = OrderedDict()
df_per_subj = pd.DataFrame(columns=['Feature','Label', 'Precision', 'Sensitivity', 'Specificity', 'Dice', 'Accuracy'])
fold_file = './metadata/metadata_fcd_nG.npy'
for i, logname in enumerate(log_dir_iter):
    logdir = os.path.join(LOGDIR, logname)
    config = configdot.parse_config(os.path.join(logdir,'config.ini'))
    assert config.opt.val_batch_size == 1
    assert config.model.name == "v2v"
    
    fold_index = logname.split('/')[2]  # feature name
    
    best_model = V2VModel(config).to(device)
    model_dict = torch.load(get_latest_weights(logdir), map_location=DEVICE)
    best_model.load_state_dict(model_dict['model_state'])
    best_model.eval()
    fold_npy = np.load(os.path.join(fold_file), allow_pickle=True).item()
    print(fold_index)
    test_list = fold_npy.get('test')
    feat_params = config.dataset.features
    test_files = create_datafile(test_list, feat_params, mask=True)
    
    _, val_trans =  setup_transformations(config)
    test_ds = monai.data.Dataset(data=test_files[0], transform=val_trans)
    test_loader = DataLoader(test_ds, 
                        batch_size=1, 
                        num_workers=0, 
                        collate_fn=list_data_collate,
                        shuffle=False # important not to shuffle, to ensure label correspondence
                        )
    print(f'Start evaluate fold {fold_index}')
    print(logdir)
    dataloader = test_loader

    brains = {}
    labels_gt = {}
    metric_dict = defaultdict(list)
    label_pred_arr = {}
    label_gt_arr = {}

    metric = defaultdict()

    with torch.no_grad():
        # bs = 1
        # brain_tensor - [1,C,H,W,D]
        # mask_tensor - [1,1,H,W,D]
        # label_tensor - [1,1,H,W,D]

        #######################
        # ITERATE OVER BRAINS #
        #######################
        iterator = enumerate(dataloader)


        for iter_i, data_tensors in iterator:
            brain_tensor, label_tensor, mask_tensor = (
                                                      data_tensors['image'].to(device),
                                                      data_tensors['seg'].to(device),
                                                      data_tensors['mask'].to(device)
                                                      )

            label = get_label(dataloader.dataset.data[iter_i]['seg'])
            print(f'Label: {label}')

            # forward pass
            label_tensor_forward = best_model(brain_tensor) # -> [1,1,ps,ps,ps]
            label_tensor_predicted = label_tensor_forward.to(device)
            label_tensor_predicted *= mask_tensor
            brains[label] = brain_tensor[0,0].detach().cpu().numpy()
            label_pred_arr[label] = label_tensor_predicted[0,0].detach().cpu().numpy()
            label_gt_arr[label] = label_tensor[0,0].detach().cpu().numpy()
            Precision, Sensitivity, Specificity, Dice, intensity, Accuracy = calculate_metrics(label_pred_arr[label], label_gt_arr[label])
            df_per_subj = df_per_subj.append(pd.Series([fold_index, label, Precision, Sensitivity, Specificity, Dice, Accuracy], index=df_per_subj.columns), ignore_index=True)
            metric[label] = [Precision, Sensitivity, Specificity, Dice, Accuracy]
            
            
           
            #intensity = 0.001
            masked_labels_pred = np.ma.masked_where(label_pred_arr[label] < intensity, label_pred_arr[label])
            masked_labels_gt = np.ma.masked_where(label_gt_arr[label] < intensity, label_gt_arr[label])
            
            '''
            fig = figure(figsize=(12, 5), dpi=100)
            #brain_pred = plotting.plot_anat(nib.Nifti1Image(masked_labels_pred, np.eye(4)), bg_img=nib.Nifti1Image(brains[label], np.eye(4)), cut_coords=(4,4,4), cmap='rainbow', alpha=0.5, display_mode='mosaic')
            coord =  nilearn.plotting.find_cuts.find_xyz_cut_coords(nib.Nifti1Image(np.where(masked_labels_gt > 0.5,1,0), np.eye(4)))
            brain_pred = plotting.plot_anat(nib.Nifti1Image(masked_labels_pred, np.eye(4)),
                                                        bg_img=nib.Nifti1Image(brains[label], np.eye(4)),
                                                        cmap='jet',
                                                        alpha=0.4,
                                                        figure=fig,
                                                        draw_cross=False,
                                                        black_bg=False,
                                                        title=f'Subject {label}',
                                                        colorbar=True,
                                                        cut_coords=coord)
            brain_pred.add_contours(nib.Nifti1Image(np.where(masked_labels_gt > 0.5,1,0), np.eye(4)), colors='b', levels=[1.0])
            plt.show()
            '''
        
            # if iter_i == 1:
            #     break
        # fold_metrics[str(fold_index)] = metric
        
        Prec = []
        Sens = []
        Spec = []
        Dcs = []
        Acrs = []
        for j in metric:
            #print(metric[i])
            Prec.append(metric[j][0])
            Sens.append(metric[j][1])
            Spec.append(metric[j][2])
            Dcs.append(metric[j][3])
            Acrs.append(metric[j][4])
        
        print(f'Precision mean on fold {fold_index}:         {mean(Prec)}')
        print(f'Sensitivity mean on fold {fold_index}:       {mean(Sens)}')
        print(f'Specificity mean on on fold {fold_index}:    {mean(Spec)}')
        print(f'Dice mean on fold {fold_index}:              {mean(Dcs)}')
        print(f'Dice median on fold {fold_index}:            {median(Dcs)}')
        print(f'Accuracy mean on fold {fold_index}:          {mean(Acrs)}')
    # if i == 1:
    #     break
    

In [None]:
df_per_subj

In [None]:
new_df = df_per_subj.set_index(['Feature', 'Label'])
new_df.index = new_df.index.get_level_values(0)
"""
Filter_df  = new_df[new_df.index.isin(['t1_all',
                                       't1_cr-flair',
                                       't1_blurring-flair',
                                       't1_entropy',
                                       't1_cr-t2',
                                       't1_blurring-t2',
                                       't1_blurring-t1',
                                       't1_variance',
                                       't1_sulc',
                                       't1_thickness',
                                       't1_curv',
                                       't1_cr-t2'])]
"""

In [None]:
Filter_df

In [None]:
from sklearn.metrics import confusion_matrix,accuracy_score

metric_df = Filter_df.copy()
metric_df['pred'] = np.where(metric_df['Precision'] > 0.01, 1, 0)
metric_df['gt'] = 1 

In [None]:
fold_l = ['t1_cr-flair',
          't1_blurring-flair',
          't1_entropy',
          't1_cr-t2',
          't1_blurring-t2',
          't1_blurring-t1',
          't1_variance',
          't1_sulc',
          't1_thickness',
          't1_curv',
          't1_cr-t2']

In [None]:
from statistics import mean

prec = []
sens = []
spec = []
dscs = []
acry = []

mean_mean_prec = []
mean_mean_sens = []
mean_mean_spec = []
mean_mean_dscs = []
mean_mean_acrs = []


for fold in fold_l:
    fold_pandas = metric_df[metric_df.index.isin([fold])]
    mean_prec = np.mean(fold_pandas['Precision'].to_numpy())
    mean_sens = np.mean(fold_pandas['Sensitivity'].to_numpy())
    mean_spec = np.mean(fold_pandas['Specificity'].to_numpy())
    mean_dscs = np.mean(fold_pandas['Dice'].to_numpy())
    mean_acrs = np.mean(fold_pandas['Accuracy'].to_numpy())
    
    mean_mean_prec.append(mean_prec)
    mean_mean_sens.append(mean_sens)
    mean_mean_spec.append(mean_spec)
    mean_mean_dscs.append(mean_dscs)
    mean_mean_acrs.append(mean_acrs)
    
    print(
          f'Fold                  {fold}\n'
          f'Precision:            {mean_prec}\n',
          f'Sensitivity:          {mean_sens}\n',
          f'Specificity:          {mean_spec}\n',
          f'Dice:                 {mean_dscs}\n',
          f'Accuracy:             {mean_acrs}\n'
         )
    
    """
    pred = fold_pandas['pred'].to_numpy()
    gt = fold_pandas['gt'].to_numpy()
    tn, fp, fn, tp = confusion_matrix(gt, pred).ravel()
    precision = tp/(tp+fp)
    sensitivity = tp/(tp+fn)
    specificity = tn/(tn+fp)
    dice = 2*tp/(2*tp+fp+fn)
    accuracy = (tp+tn)/(tp+tn+fp+fn)
    
    print(
          f'Fold                  {fold}\n'
          f'Precision:            {precision}\n',
          f'Sensitivity:          {sensitivity}\n',
          f'Specificity:          {specificity}\n',
          f'Dice:                 {dice}\n',
          f'Accuracy:             {accuracy}\n'
         )
    
    prec.append(precision)
    sens.append(sensitivity)
    spec.append(specificity)
    dscs.append(dice)
    acry.append(accuracy)
    

print('Patient Level\n'
      f'Mean Precision:            {mean(prec)}\n',
      f'Mean Sensitivity:          {mean(sens)}\n',
      f'Mean Specificity:          {mean(spec)}\n',
      f'Mean Dice:                 {mean(dscs)}\n',
      f'Mean Accuracy:             {mean(acry)}\n'
     )
"""
print('Voxel-wise Level\n'
      f'Mean Precision:            {mean(mean_mean_prec)}\n',
      f'Mean Sensitivity:          {mean(mean_mean_sens)}\n',
      f'Mean Specificity:          {mean(mean_mean_spec)}\n',
      f'Mean Dice:                 {mean(mean_mean_dscs)}\n',
      f'Mean Accuracy:             {mean(mean_mean_acrs)}\n'
     )
    
#cm = confusion_matrix('')