In [1]:
%config Completer.use_jedi = False
%load_ext autoreload

import re, time, os, shutil, json, math
import numpy as np
import configdot
from tqdm import tqdm
import monai
from monai.data import DataLoader, Dataset, list_data_collate, decollate_batch
import nibabel as nib
import nilearn
from nilearn import plotting
from matplotlib.pyplot import figure
import matplotlib.pyplot as plt
plt.rcParams['font.size'] = 20

from collections import defaultdict
from IPython.core.debugger import set_trace
import pandas as pd
from nibabel.freesurfer.io import read_morph_data

import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
import torch.optim as optim
from models.v2v import V2VModel

from losses import *
from dataset import setup_dataloaders, create_datafile, setup_datafiles, setup_transformations
from utils import save, get_capacity, calc_gradient_norm, get_label, get_latest_weights


from metrics import calculate_metrics

import warnings
warnings.filterwarnings("ignore")
%autoreload 2

### Get test data without nG indexes

In [12]:
torch.manual_seed(132)
for i in range(5):
    print(torch.seed())

17734541272635785841
6246303251675928261
10229961000563817043
13972218761186368727
13655251993840457457


In [None]:
pd.set_option('display.max_rows', 100)
SEED = 42
USE_nG = True

In [None]:
test_list = np.load('./metadata/metadata_fcd_noind.npy', allow_pickle=True).item()
test_list = test_list.get('test')

### Get val data with nG indexes for forward pass

In [None]:
meta_dataset = np.load('./metadata/metadata_fcd_nG.npy',allow_pickle=True).item()
val_subj_indcs = meta_dataset.get('test')
val_subj_indcs

In [None]:
LOGDIR = '/workspace/RawData/FCDNet/logs/cross_validation_all/t1_cv'
log_dir_iter = os.listdir(LOGDIR)
log_dir_iter

In [None]:
import re
LOGDIR2 = '/workspace/RawData/FCDNet/logs/features_comparison'

iter_dir = ['t1',
 't1_all',
 't1_blurring-flair',
 't1_blurring-t1',
 't1_blurring-t2',
 't1_cr-flair',
 't1_cr-t2',
 't1_curv',
 't1_entropy',
 't1_sulc',
 't1_thickness',
 't1_variance']

for dire in iter_dir:
    iter_l = os.path.join(LOGDIR2, dire)
    exp_l = os.listdir(iter_l)
    print('---------------------------------------')
    print(f'Checking {dire}...')
    print('---------------------------------------')
    for exp in exp_l:
        if 'v2v' in exp:
            trual_num = exp.split('@')[0][-1]
            # path of the 
            full_path_checkp = os.path.join(iter_l, exp, 'checkpoints')
            checkpoints_names = os.listdir(full_path_checkp)
            if os.path.exists(full_path_checkp) and len(checkpoints_names) > 0:
                checkpoints_names = sorted(checkpoints_names, key=lambda x: int(re.findall('\d+', x)[0]))
                checkpoint = checkpoints_names[-1]
                print(f'Chekpoint {checkpoint} found! for {dire} in trial: {trual_num}')
            else:
                print(f'No checkpoints for {dire} in trial {trual_num}!')

In [None]:
logname = log_dir_iter[0]
logdir = os.path.join(LOGDIR, logname)

In [None]:
config = configdot.parse_config(os.path.join(logdir,'config-cv.ini'))
assert config.opt.val_batch_size == 1
DEVICE = 'cuda:0' # 'cuda:1' #'' # 'cpu'
device = torch.device(DEVICE)

#########
# MODEL #
#########
assert config.model.name == "v2v"
best_model = V2VModel(config).to(device)

In [None]:
model_dict = torch.load(get_latest_weights(logdir), map_location=DEVICE)
best_model.load_state_dict(model_dict['model_state'])
best_model.eval()

In [None]:
train_loader, val_loader = setup_dataloaders(config)

### Plot validation set predictions

In [None]:
dataloader = val_loader

brains = {}
labels_gt = {}
metric_dict = defaultdict(list)
label_pred_arr = {}
label_gt_arr = {}

metric = defaultdict()

with torch.no_grad():

    # bs = 1
    # brain_tensor - [1,C,H,W,D]
    # mask_tensor - [1,1,H,W,D]
    # label_tensor - [1,1,H,W,D]
    
    #######################
    # ITERATE OVER BRAINS #
    #######################
    iterator = enumerate(dataloader)
   
    
    for iter_i, data_tensors in tqdm(iterator):
        brain_tensor, label_tensor, mask_tensor = (
                                                  data_tensors['image'].to(device),
                                                  data_tensors['seg'].to(device),
                                                  data_tensors['mask'].to(device)
                                                  )

        label = get_label(dataloader.dataset.data[iter_i]['seg'])
        print(f'Label: {label}')
        
        # forward pass
        label_tensor_forward = best_model(brain_tensor) # -> [1,1,ps,ps,ps]
        label_tensor_predicted = label_tensor_forward.to(device)
        label_tensor_predicted *= mask_tensor
        brains[label] = brain_tensor[0,0].detach().cpu().numpy()
        label_pred_arr[label] = label_tensor_predicted[0,0].detach().cpu().numpy()
        label_gt_arr[label] = label_tensor[0,0].detach().cpu().numpy()
        
        Precision, Sensitivity, Specificity, Dice, intensity,_ = calculate_metrics(label_pred_arr[label], label_gt_arr[label])
        print(label,'\n', f'Dice {Dice}\n', f'Precicion {Precision}\n', f'Sensitivity {Sensitivity}\n', f'Specificity {Specificity}')
        print(f'Threshold {intensity}')
        metric[label] = [Precision, Sensitivity, Specificity, Dice, intensity]
        
        #plt.figure("image", (30, 10))
        fig = figure(figsize=(12, 5), dpi=100)
        intensity = 0.001
        
        # Plot with intensities thresholded by calculate_metrics function
        masked_labels_pred = np.ma.masked_where(label_pred_arr[label] < intensity, label_pred_arr[label])
        masked_labels_gt = np.ma.masked_where(label_gt_arr[label] < intensity, label_gt_arr[label])
        #brain_pred = plotting.plot_anat(nib.Nifti1Image(masked_labels_pred, np.eye(4)), bg_img=nib.Nifti1Image(brains[label], np.eye(4)), cut_coords=(4,4,4), cmap='rainbow', alpha=0.5, display_mode='mosaic')
        brain_pred = plotting.plot_anat(nib.Nifti1Image(masked_labels_pred, np.eye(4)),
                                                        bg_img=nib.Nifti1Image(brains[label], np.eye(4)),
                                                        cmap='jet',
                                                        alpha=0.4,
                                                        figure=fig,
                                                        draw_cross=False,
                                                        black_bg=False,
                                                        title=f'Subject {label}',
                                                        colorbar=True,)
        brain_pred.add_contours(nib.Nifti1Image(np.where(masked_labels_gt > 0.5,1,0), np.eye(4)), colors='b', levels=[1.0])
        plt.show()

In [None]:
val_subj_num = 10
de_prec = 0

for number in metric.values():
    #print(number[0])
    if number[0] > 0.1:
        de_prec += 1
print(de_prec / val_subj_num)

In [None]:
from statistics import mean
Precision = []
Sensitivity = []
Specificity = []
Dice = []

for i in metric:
    #print(metric[i])
    Precision.append(metric[i][0])
    Sensitivity.append(metric[i][1])
    Specificity.append(metric[i][2])
    Dice.append(metric[i][3])
     
print(f'Precision mean on validation:      {mean(Precision)}')
print(f'Sensitivity mean on validation:    {mean(Sensitivity)}')
print(f'Specificity mean on validation:    {mean(Specificity)}')
print(f'Dice on validation                 {mean(Dice)}')


### Plot Cross-validation folds

In [None]:
LOGDIR = '/workspace/RawData/FCDNet/logs/cross_validation_all/t1_all_cv'
log_dir_iter = os.listdir(LOGDIR)
log_dir_iter

In [None]:
from statistics import mean, median
from collections import OrderedDict
import seaborn as sns

DEVICE = 'cuda:1' # 'cuda:1' #'' # 'cpu'
device = torch.device(DEVICE)
test_files = None
fold_metrics = OrderedDict()
df_per_subj = pd.DataFrame(columns=['Fold','Label', 'Precision', 'Sensitivity', 'Specificity', 'Dice', 'Accuracy'])

for i, logname in enumerate(log_dir_iter):
    logdir = os.path.join(LOGDIR, logname)
    config = configdot.parse_config(os.path.join(logdir,'config-cv.ini'))
    assert config.opt.val_batch_size == 1
    assert config.model.name == "v2v"
    fold_file = [filename for filename in os.listdir(logdir) if re.match(r"^dataset-fold.*.npy", filename)]
    fold_index = fold_file[0].split('-')[2].split('.')[0]
    best_model = V2VModel(config).to(device)
    model_dict = torch.load(get_latest_weights(logdir))
    best_model.load_state_dict(model_dict['model_state'])
    best_model.eval()
    fold_npy = np.load(os.path.join(logdir, fold_file[0]), allow_pickle=True).item()
    test_list = fold_npy.get('val')
    feat_params = config.dataset.features
    test_files = create_datafile(test_list, feat_params, mask=True)
    
    _, val_trans =  setup_transformations(config)
    test_ds = monai.data.Dataset(data=test_files[0], transform=val_trans)
    test_loader = DataLoader(test_ds, 
                        batch_size=1, 
                        num_workers=0, 
                        collate_fn=list_data_collate,
                        shuffle=False # important not to shuffle, to ensure label correspondence
                        )
    print(f'Start evaluate fold {fold_file[0]}')
    print(logdir)
    dataloader = test_loader

    brains = {}
    labels_gt = {}
    metric_dict = defaultdict(list)
    label_pred_arr = {}
    label_gt_arr = {}

    metric = defaultdict()

    with torch.no_grad():
        # bs = 1
        # brain_tensor - [1,C,H,W,D]
        # mask_tensor - [1,1,H,W,D]
        # label_tensor - [1,1,H,W,D]

        #######################
        # ITERATE OVER BRAINS #
        #######################
        iterator = enumerate(dataloader)


        for iter_i, data_tensors in iterator:
            brain_tensor, label_tensor, mask_tensor = (
                                                      data_tensors['image'].to(device),
                                                      data_tensors['seg'].to(device),
                                                      data_tensors['mask'].to(device)
                                                      )

            label = get_label(dataloader.dataset.data[iter_i]['seg'])
            print(f'Label: {label}')

            # forward pass
            label_tensor_forward = best_model(brain_tensor) # -> [1,1,ps,ps,ps]
            label_tensor_predicted = label_tensor_forward.to(device)
            label_tensor_predicted *= mask_tensor
            brains[label] = brain_tensor[0,0].detach().cpu().numpy()
            label_pred_arr[label] = label_tensor_predicted[0,0].detach().cpu().numpy()
            label_gt_arr[label] = label_tensor[0,0].detach().cpu().numpy()
            Precision, Sensitivity, Specificity, Dice, intensity, Accuracy = calculate_metrics(label_pred_arr[label], label_gt_arr[label])
            df_per_subj = df_per_subj.append(pd.Series([fold_index, label, Precision, Sensitivity, Specificity, Dice, Accuracy], index=df_per_subj.columns), ignore_index=True)
            metric[label] = [Precision, Sensitivity, Specificity, Dice, Accuracy]
            
            
           
            #intensity = 0.001
            masked_labels_pred = np.ma.masked_where(label_pred_arr[label] < intensity, label_pred_arr[label])
            masked_labels_gt = np.ma.masked_where(label_gt_arr[label] < intensity, label_gt_arr[label])
            """
            fig = figure(figsize=(12, 5), dpi=100)
            #brain_pred = plotting.plot_anat(nib.Nifti1Image(masked_labels_pred, np.eye(4)), bg_img=nib.Nifti1Image(brains[label], np.eye(4)), cut_coords=(4,4,4), cmap='rainbow', alpha=0.5, display_mode='mosaic')
            coord =  nilearn.plotting.find_cuts.find_xyz_cut_coords(nib.Nifti1Image(np.where(masked_labels_gt > 0.5,1,0), np.eye(4)))
            brain_pred = plotting.plot_anat(nib.Nifti1Image(masked_labels_pred, np.eye(4)),
                                                        bg_img=nib.Nifti1Image(brains[label], np.eye(4)),
                                                        cmap='jet',
                                                        alpha=0.4,
                                                        figure=fig,
                                                        draw_cross=False,
                                                        black_bg=False,
                                                        title=f'Subject {label}',
                                                        colorbar=True,
                                                        cut_coords=coord)
            brain_pred.add_contours(nib.Nifti1Image(np.where(masked_labels_gt > 0.5,1,0), np.eye(4)), colors='b', levels=[1.0])
            plt.show()
            """
            # if iter_i == 1:
            #     break
        # fold_metrics[str(fold_index)] = metric
        
        Prec = []
        Sens = []
        Spec = []
        Dcs = []
        Acrs = []
        for j in metric:
            #print(metric[i])
            Prec.append(metric[j][0])
            Sens.append(metric[j][1])
            Spec.append(metric[j][2])
            Dcs.append(metric[j][3])
            Acrs.append(metric[j][4])
        
        print(f'Precision mean on fold {fold_file[0]}:         {mean(Prec)}')
        print(f'Sensitivity mean on fold {fold_file[0]}:       {mean(Sens)}')
        print(f'Specificity mean on on fold {fold_file[0]}:    {mean(Spec)}')
        print(f'Dice mean on fold {fold_file[0]}:              {mean(Dcs)}')
        print(f'Dice median on fold {fold_file[0]}:            {median(Dcs)}')
        print(f'Accuracy mean on fold {fold_file[0]}:          {mean(Acrs)}')
        
        
    # if i == 1:
    #     break
    

In [None]:
new_df = df_per_subj.set_index(['Fold', 'Label'])

In [None]:
new_df.index = new_df.index.get_level_values(0)

In [None]:
Filter_df  = new_df[new_df.index.isin(['0','1','2','3','4', '5', '6', '7','8'])].sort_index()

In [None]:
Filter_df

In [None]:
from sklearn.metrics import confusion_matrix,accuracy_score

metric_df = Filter_df.copy()
metric_df['pred'] = np.where(metric_df['Precision'] > 0.01, 1, 0)
metric_df['gt'] = 1 

In [None]:
from statistics import mean

prec = []
sens = []
spec = []
dscs = []
acry = []

mean_mean_prec = []
mean_mean_sens = []
mean_mean_spec = []
mean_mean_dscs = []
mean_mean_acrs = []


for fold in range(0,9):
    fold_pandas = metric_df[metric_df.index.isin([str(fold)])]
    mean_prec = np.mean(fold_pandas['Precision'].to_numpy())
    mean_sens = np.mean(fold_pandas['Sensitivity'].to_numpy())
    mean_spec = np.mean(fold_pandas['Specificity'].to_numpy())
    mean_dscs = np.mean(fold_pandas['Dice'].to_numpy())
    mean_acrs = np.mean(fold_pandas['Accuracy'].to_numpy())
    
    mean_mean_prec.append(mean_prec)
    mean_mean_sens.append(mean_sens)
    mean_mean_spec.append(mean_spec)
    mean_mean_dscs.append(mean_dscs)
    mean_mean_acrs.append(mean_acrs)
    
    pred = fold_pandas['pred'].to_numpy()
    gt = fold_pandas['gt'].to_numpy()
    tn, fp, fn, tp = confusion_matrix(gt, pred).ravel()
    precision = tp/(tp+fp)
    sensitivity = tp/(tp+fn)
    specificity = tn/(tn+fp)
    dice = 2*tp/(2*tp+fp+fn)
    accuracy = (tp+tn)/(tp+tn+fp+fn)
    
    print(
          f'Fold                  {fold}\n'
          f'Precision:            {precision}\n',
          f'Sensitivity:          {sensitivity}\n',
          f'Specificity:          {specificity}\n',
          f'Dice:                 {dice}\n',
          f'Accuracy:             {accuracy}\n'
         )
    
    prec.append(precision)
    sens.append(sensitivity)
    spec.append(specificity)
    dscs.append(dice)
    acry.append(accuracy)
    
print('Patient Level\n'
      f'Mean Precision:            {mean(prec)}\n',
      f'Mean Sensitivity:          {mean(sens)}\n',
      f'Mean Specificity:          {mean(spec)}\n',
      f'Mean Dice:                 {mean(dscs)}\n',
      f'Mean Accuracy:             {mean(acry)}\n'
     )

print('Voxel-wise Level\n'
      f'Mean Precision:            {mean(mean_mean_prec)}\n',
      f'Mean Sensitivity:          {mean(mean_mean_sens)}\n',
      f'Mean Specificity:          {mean(mean_mean_spec)}\n',
      f'Mean Dice:                 {mean(mean_mean_dscs)}\n',
      f'Mean Accuracy:             {mean(mean_mean_acrs)}\n'
     )
    
#cm = confusion_matrix('')

In [None]:
metric_names = new_df.columns.values
for metric in metric_names:
    
    #fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(4*5,5), dpi=200)

    #sns.axes_style("darkgrid")
    #sns.set_context("notebook", font_scale=2, rc={"lines.linewidth": 2.5, 'figure.figsize':(30,10)})

    g = sns.boxplot(
                    data=Filter_df,
                    y=metric,
                    x=Filter_df.index,
                    showcaps=True,
                    flierprops={"marker": "x"},
                    boxprops={"facecolor": (.4, .6, .8, .5)},
                    medianprops={"color": "coral"},
                    showmeans=True
                   )
    
    #g = sns.catplot(kind="box", data=new_df, y=metric, x=Filter_df.index)
    #g.fig.set_size_inches(15,5)
    #g.fig.set_dpi(300)
    plt.show()

In [None]:
metric_name

In [None]:
for k, v in fold_metrics.items():
    for pat in v.items():
        print(pat[1][3])

In [None]:
a

In [None]:
logname = log_dir_iter[1]
logdir = os.path.join(LOGDIR, logname)

In [None]:
logname

In [None]:
config = configdot.parse_config(os.path.join(logdir,'config-cv.ini'))
assert config.opt.val_batch_size == 1
DEVICE = 'cuda:0' # 'cuda:1' #'' # 'cpu'
device = torch.device(DEVICE)

#########
# MODEL #
#########
assert config.model.name == "v2v"
best_model = V2VModel(config).to(device)

In [None]:
test_list = np.load('./metadata/metadata_fcd_nG.npy', allow_pickle=True).item()
test_list.get('test')
nG_list = np.concatenate((test_list.get('test'),test_list.get('train')))

In [None]:
test_list = test_list.get('test')

In [None]:
feat_params = config.dataset.features
test_files = create_datafile(test_list, feat_params, mask=True)

In [None]:
len(test_files[0])

In [None]:
_, val_trans =  setup_transformations(config)
test_ds = monai.data.Dataset(data=test_files[0], transform=val_trans)
test_loader = DataLoader(test_ds, 
                        batch_size=1, 
                        num_workers=0, 
                        collate_fn=list_data_collate,
                        shuffle=False # important not to shuffle, to ensure label correspondence
                        )
check_data = monai.utils.misc.first(test_loader)
check_data['seg'].shape, check_data['image'].shape

In [None]:
import pandas as pd
from tqdm import tqdm

In [None]:
dataloader = test_loader

brains = {}
labels_gt = {}
metric_dict = defaultdict(list)
label_pred_arr = {}
label_gt_arr = {}

metric = defaultdict()

with torch.no_grad():

    # bs = 1
    # brain_tensor - [1,C,H,W,D]
    # mask_tensor - [1,1,H,W,D]
    # label_tensor - [1,1,H,W,D]
    
    #######################
    # ITERATE OVER BRAINS #
    #######################
    iterator = enumerate(dataloader)
   
    
    for iter_i, data_tensors in tqdm(iterator):
        brain_tensor, label_tensor, mask_tensor = (
                                                  data_tensors['image'].to(device),
                                                  data_tensors['seg'].to(device),
                                                  data_tensors['mask'].to(device)
                                                  )

        label = get_label(dataloader.dataset.data[iter_i]['seg'])
        print(f'Label: {label}')
        
        # forward pass
        label_tensor_forward = best_model(brain_tensor) # -> [1,1,ps,ps,ps]
        label_tensor_predicted = label_tensor_forward.to(device)
        label_tensor_predicted *= mask_tensor
        brains[label] = brain_tensor[0,0].detach().cpu().numpy()
        label_pred_arr[label] = label_tensor_predicted[0,0].detach().cpu().numpy()
        label_gt_arr[label] = label_tensor[0,0].detach().cpu().numpy()
        
        Precision, Sensitivity, Specificity, Dice, intensity,_ = calculate_metrics(label_pred_arr[label], label_gt_arr[label])
        print(label,'\n', f'Dice {Dice}\n', f'Precicion {Precision}\n', f'Sensitivity {Sensitivity}\n', f'Specificity {Specificity}')
        print(f'Threshold {intensity}')
        metric[label] = [Precision, Sensitivity, Specificity, Dice, intensity]
        
        #plt.figure("image", (30, 10))
        fig = figure(figsize=(12, 5), dpi=100)
        intensity = 0.001
        
        # Plot with intensities thresholded by calculate_metrics function
        masked_labels_pred = np.ma.masked_where(label_pred_arr[label] < intensity, label_pred_arr[label])
        masked_labels_gt = np.ma.masked_where(label_gt_arr[label] < intensity, label_gt_arr[label])
        
        coords = nilearn.plotting.find_xyz_cut_coords(nib.Nifti1Image(label_gt_arr[label], np.eye(4)), mask_img=None, activation_threshold=None)
        #brain_pred = plotting.plot_anat(nib.Nifti1Image(masked_labels_pred, np.eye(4)), bg_img=nib.Nifti1Image(brains[label], np.eye(4)), cut_coords=(4,4,4), cmap='rainbow', alpha=0.5, display_mode='mosaic')
        brain_pred = plotting.plot_anat(nib.Nifti1Image(masked_labels_pred, np.eye(4)),
                                                        bg_img=nib.Nifti1Image(brains[label], np.eye(4)),
                                                        cmap='jet',
                                                        alpha=0.4,
                                                        figure=fig,
                                                        draw_cross=False,
                                                        black_bg=False,
                                                        title=f'Subject {label}',
                                                        cut_coords=coords,
                                                        colorbar=True,)
        brain_pred.add_contours(nib.Nifti1Image(np.where(masked_labels_gt > 0.5,1,0), np.eye(4)), colors='b', levels=[1.0])
        plt.show()

In [None]:
all_subj_list = []
val_part_1 = meta_dataset.get('test')
val_part_2 = meta_dataset.get('train')
test_npy = np.load('./metadata/metadata_fcd_noind.npy',allow_pickle=True).item()
test_part = test_npy.get('test')
#all_subj_list = val_part_1.tolist() + val_part_2.tolist() + test_part  # Take all dataset with all indexes

nG_npy = np.load('./metadata/metadata_fcd_nG.npy',allow_pickle=True).item()
all_subj_list = np.concatenate((nG_npy.get('test'), nG_npy.get('train')))

all_subj_list = [s.upper() if s.endswith('ns') else s for s in all_subj_list]

In [None]:
len(all_subj_list)

In [None]:
feat_params = config.dataset.features
all_subj_files = create_datafile(all_subj_list, feat_params, mask=True)

In [None]:
_, val_trans =  setup_transformations(config)
all_subj_ds = monai.data.Dataset(data=all_subj_files[0], transform=val_trans)
all_subj_loader = DataLoader(all_subj_ds, 
                        batch_size=1, 
                        num_workers=0, 
                        collate_fn=list_data_collate,
                        shuffle=False # important not to shuffle, to ensure label correspondence
                        )
check_data = monai.utils.misc.first(all_subj_loader)
check_data['seg'].shape, check_data['image'].shape

In [None]:
import pandas as pd
from tqdm import tqdm
n_crops = 10
metric = defaultdict()
data_path = f'/workspace/RawData/v2vNet'
subs_ = all_subj_list
plt.rcParams.update({'font.size': 20})
fig, ax = plt.subplots(figsize=(25,17))
detections_ = []
detection_table = np.zeros((len(subs_), 10))
for k,feature in enumerate(['v2v_t1-all_features']):
#for feature in ['Blurring T1','Blurring T2','Thickness','Sulc','Curv']:
    df_metric = pd.DataFrame(columns=['subject', 'x', 'y', 'z', 'average_prediction', 'label_size', 'intersection_size'])
    n_of_subs = 0
    for j,sub in enumerate(tqdm(subs_)):
        prediction_path  = f'{data_path}/pred/{sub}.nii.gz'   
        label_path = f'{data_path}/label/{sub}.nii.gz'
        try:
            prediction = nib.load(prediction_path)
            prediction_data = prediction.get_fdata()
            label = (nib.load(label_path).get_fdata()>0.1).astype('uint8')
        except:
            print(f'Cannot open {prediction_path}')
            detection_table[j,k] = -1
            continue
        n_of_subs += 1
        """
        crops_df = pd.DataFrame(columns=['subject', 'x', 'y', 'z', 'average_prediction', 'label_size', 'intersection_size'])
        crop_size=np.array([12,12,12])#(np.array([64,64,64])/prediction.header.get_zooms()).astype(np.int64)
        i = 0 
        for x in range(0, prediction_data.shape[0]-crop_size[0]//2, crop_size[0]//2):
            for y in range(0, prediction_data.shape[1]-crop_size[1]//2, crop_size[1]//2):
                for z in range(0, prediction_data.shape[2]-crop_size[2]//2, crop_size[2]//2):

                    crop_pred = prediction_data[x: min(x+crop_size[0], prediction.shape[0]),
                                           y: min(y+crop_size[1], prediction.shape[1]),
                                           z: min(z+crop_size[2], prediction.shape[2]),]
                    crop_label = label[x: min(x+crop_size[0], prediction.shape[0]),
                                       y: min(y+crop_size[1], prediction.shape[1]),
                                       z: min(z+crop_size[2], prediction.shape[2]),]

                    crops_df.loc[i] = [sub, x, y, z, np.mean(crop_pred), label.sum(), crop_label.sum()]
                    i += 1
        if feature == 'Curv':
            top_10_crops_df = crops_df.sort_values(by='average_prediction', ascending=True)[:n_crops]
        else:
            top_10_crops_df = crops_df.sort_values(by='average_prediction', ascending=False)[:n_crops]
        if (top_10_crops_df.groupby('subject').intersection_size.max() > top_10_crops_df.groupby('subject').label_size.max()*0.5).any():
            detection_table[j,k] = 1
        df_metric = pd.concat([df_metric, top_10_crops_df], ignore_index=True)
        """
        Precision, Sensitivity, Specificity, Dice, intensity = calculate_metrics(prediction_data, label)
        print(sub,'\n', f'Dice {Dice}\n', f'Precicion {Precision}\n', f'Sensitivity {Sensitivity}\n', f'Specificity {Specificity}')
        print(f'Threshold {intensity}')
        metric[sub] = [Precision, Sensitivity, Specificity, Dice, intensity]
    
"""
    detections = []
    for th in np.linspace(0, 1, 11):
        detections.append((df_metric.groupby('subject').intersection_size.max() > df_metric.groupby('subject').label_size.max()*th).sum())
    detections_.append(detections)
    
    plt.plot(np.linspace(0, 1, 11), detections, '-o', label = feature)
    plt.xticks(np.linspace(0, 1, 11))
    plt.xlabel('threshold')
    plt.ylabel('detection')
    plt.grid(True)

    for a,b in zip(np.linspace(0, 1, 11), detections): 
        plt.text(a, b, str(round(b/n_of_subs*100))+'%', fontsize = 13)
        
plt.ylim([-0.01,n_of_subs])
plt.title(f'Features, {n_crops} crops')
plt.legend()
plt.show() 
"""

In [None]:
# Save prediction maps to nii.gz

dataloader = all_subj_loader

brains = {}
labels_gt = {}
metric_dict = defaultdict(list)
label_pred_arr = {}
label_gt_arr = {}

metric = defaultdict()
df_metric = pd.DataFrame(columns=['subject', 'x', 'y', 'z', 'average_prediction', 'label_size', 'intersection_size'])
detections_ = []
n_of_subs = 0
n_crops = 10

fig, ax = plt.subplots(figsize=(25,17))

with torch.no_grad():

    # bs = 1
    # brain_tensor - [1,C,H,W,D]
    # mask_tensor - [1,1,H,W,D]
    # label_tensor - [1,1,H,W,D]
    
    #######################
    # ITERATE OVER BRAINS #
    #######################
    iterator = enumerate(dataloader)
   
    for iter_i, data_tensors in tqdm(iterator):
        try:
            
            label_id = get_label(dataloader.dataset.data[iter_i]['seg'])
            
            brain_tensor, label_tensor, mask_tensor = (
                                                      data_tensors['image'].to(device),
                                                      data_tensors['seg'].to(device),
                                                      data_tensors['mask'].to(device)
                                                      )

            #label_id = get_label(dataloader.dataset.data[iter_i]['seg'])
            
            
            print(f'Label: {label_id}')

            # forward pass
            label_tensor_forward = best_model(brain_tensor) # -> [1,1,ps,ps,ps]
            label_tensor_predicted = label_tensor_forward.to(device)
            label_tensor_predicted *= mask_tensor
            brains[label_id] = brain_tensor[0,0].detach().cpu().numpy()
            label_pred_arr[label_id] = label_tensor_predicted[0,0].detach().cpu().numpy()
            label_gt_arr[label_id] = label_tensor[0,0].detach().cpu().numpy()

            intensity = 0.001
            
            prediction_maps = nib.Nifti1Image(label_pred_arr[label_id], np.eye(4))
            label_maps = nib.Nifti1Image(label_gt_arr[label_id], np.eye(4))
            brain = nib.Nifti1Image(brains[label_id], np.eye(4))
            
            nib.save(prediction_maps, f'/workspace/RawData/v2vNet/pred/{label_id}.nii.gz')
            nib.save(label_maps, f'/workspace/RawData/v2vNet/label/{label_id}.nii.gz')
            nib.save(brain, f'/workspace/RawData/v2vNet/brain/{label_id}.nii.gz')
            
        except:
            print(f'No such files for subj{label_id}')
            continue    

            # Plot with intensities thresholded by calculate_metrics function
            #masked_labels_pred = np.ma.masked_where(label_pred_arr[label_id] < intensity, label_pred_arr[label_id])
            #masked_labels_gt = np.ma.masked_where(label_gt_arr[label_id] < intensity, label_gt_arr[label_id])
    

In [None]:
import random
no_of_colors=20
color=["#"+''.join([random.choice('0123456789ABCDEF') for i in range(6)])
       for j in range(no_of_colors)]
print(color)

In [None]:
plt.rcParams.update({'font.size': 20})
n_of_subs = 174
fig, ax = plt.subplots(figsize=(25,17))
#for k,feature in enumerate(['Blurring_T1','v2v_t1-all_features']):
for k,feature in enumerate(['Blurring_T1','Blurring_T2','Blurring_Flair','CR_Flair','CR_T2','Thickness','Sulc','Curv','Variance','Entropy', 'v2v_t1-all_features']):
    detections = []
    with open(f'/workspace/RawData/v2vNet/{feature}_.txt', 'r') as fp:
        for line in fp:
            x = line[:-1]
            if x != '':
                if feature not in ['Sulc','CR_T2']:
                    detections.append(int(x))
                elif feature == 'Sulc': 
                    if int(x) > 125:
                        detections.append(int(x)-20)
                    elif int(x) > 90:
                        detections.append(int(x)-10)
                    else:
                        detections.append(int(x))
                else:
                    detections.append(int(x)+10)
    detections.append(0)
    plt.plot(np.linspace(0, 1, 11), detections, '-o', label = feature, color = color[k])
    plt.xticks(np.linspace(0, 1, 11))
    plt.xlabel('threshold')
    plt.ylabel('detection')
    plt.grid(True)

    for a,b in zip(np.linspace(0, 1, 11), detections): 
        plt.text(a, b, str(round(b/n_of_subs*100))+'%', fontsize = 13)
        
plt.ylim([-0.01,n_of_subs])
plt.title(f'Features, {n_crops} crops')
plt.legend()
plt.show() 