In [None]:
%load_ext autoreload

import re, time, os, shutil, json, math
import numpy as np
import configdot
from tqdm import tqdm
import monai
from monai.data import DataLoader, Dataset, list_data_collate, decollate_batch
import nibabel as nib
import nilearn
from nilearn import plotting
from matplotlib.pyplot import figure
import matplotlib.pyplot as plt
plt.rcParams['font.size'] = 20

from collections import defaultdict
from IPython.core.debugger import set_trace
import pandas as pd
from nibabel.freesurfer.io import read_morph_data

import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
import torch.optim as optim
from models.v2v import V2VModel

from losses import *
from dataset import setup_dataloaders, create_datafile, setup_datafiles, setup_transformations
from utils import save, get_capacity, calc_gradient_norm, get_label, get_latest_weights


from metrics import calculate_metrics

import warnings
warnings.filterwarnings("ignore")
%autoreload 2

In [None]:
sheet_id = "1MDleLmQ0Nlcg62x95e3xnkc5_j_i4IK_KQEHccDosG8"
sheet_name = "clean"
url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/gviz/tq?tqx=out:csv&sheet={sheet_name}"

In [None]:
pd.set_option('display.max_rows', 100)
SEED = 42
USE_nG = True

In [None]:
df = 0
subj_list = []
df = pd.read_csv(url, 
                 header=0,
                 usecols=['patient', 'is_good', 'localization', 'comments'],
                 index_col=None, dtype={'patient':str,
                                        'is_good':str,
                                        'localization':str,
                                        'comments':str})

if USE_nG:
    #df = df[df['patient'].apply(lambda x: x[0]!='n' or x[0]!='G')]
    df = df[df['patient'].str.contains(r'^[^G|n].*')] # Use only data NOT contain nG
    #df = df[df['patient'].str.contains(r'^[G|n].*')] # Use only contain nG
df_good = df.query('is_good == "1"')

In [None]:
subj_list = df_good['patient'].values.tolist()

In [None]:
subj_list

In [None]:
meta_dataset = np.load('./metadata/metadata_fcd_nG.npy',allow_pickle=True).item()
val_subj_indcs = meta_dataset.get('test')
val_subj_indcs

In [None]:
LOGDIR = '/workspace/RawData/FCDNet/logs/features_comparison/t1_all/'
log_dir_iter = os.listdir(LOGDIR)
log_dir_iter

In [None]:
import re
LOGDIR2 = '/workspace/RawData/FCDNet/logs/features_comparison'

iter_dir = ['t1',
 't1_all',
 't1_blurring-flair',
 't1_blurring-t1',
 't1_blurring-t2',
 't1_cr-flair',
 't1_cr-t2',
 't1_curv',
 't1_entropy',
 't1_sulc',
 't1_thickness',
 't1_variance']

for dire in iter_dir:
    iter_l = os.path.join(LOGDIR2, dire)
    exp_l = os.listdir(iter_l)
    print('---------------------------------------')
    print(f'Checking {dire}...')
    print('---------------------------------------')
    for exp in exp_l:
        if 'v2v' in exp:
            trual_num = exp.split('@')[0][-1]
            # path of the 
            full_path_checkp = os.path.join(iter_l, exp, 'checkpoints')
            checkpoints_names = os.listdir(full_path_checkp)
            if os.path.exists(full_path_checkp) and len(checkpoints_names) > 0:
                checkpoints_names = sorted(checkpoints_names, key=lambda x: int(re.findall('\d+', x)[0]))
                checkpoint = checkpoints_names[-1]
                print(f'Chekpoint {checkpoint} found! for {dire} in trial: {trual_num}')
            else:
                print(f'No checkpoints for {dire} in trial {trual_num}!')

In [None]:
logname = log_dir_iter[3]
logdir = os.path.join(LOGDIR, logname)

In [None]:
config = configdot.parse_config(os.path.join(logdir,'config.ini'))
assert config.opt.val_batch_size == 1
DEVICE = 'cuda:1' # 'cuda:1' #'' # 'cpu'
device = torch.device(DEVICE)

#########
# MODEL #
#########
assert config.model.name == "v2v"
best_model = V2VModel(config).to(device)

In [None]:
model_dict = torch.load(get_latest_weights(logdir))
best_model.load_state_dict(model_dict['model_state'])
best_model.eval()

In [None]:
train_loader, val_loader = setup_dataloaders(config)

### Plot validation set predictions

In [None]:
dataloader = val_loader

brains = {}
labels_gt = {}
metric_dict = defaultdict(list)
label_pred_arr = {}
label_gt_arr = {}

metric = []

with torch.no_grad():

    # bs = 1
    # brain_tensor - [1,C,H,W,D]
    # mask_tensor - [1,1,H,W,D]
    # label_tensor - [1,1,H,W,D]
    
    #######################
    # ITERATE OVER BRAINS #
    #######################
    iterator = enumerate(dataloader)
   
    
    for iter_i, data_tensors in tqdm(iterator):
        brain_tensor, label_tensor, mask_tensor = (
                                                  data_tensors['image'].to(device),
                                                  data_tensors['seg'].to(device),
                                                  data_tensors['mask'].to(device)
                                                  )

        label = get_label(dataloader.dataset.data[iter_i]['seg'])
        print(f'Label: {label}')
        
        # forward pass
        label_tensor_forward = best_model(brain_tensor) # -> [1,1,ps,ps,ps]
        label_tensor_predicted = label_tensor_forward.to(device)
        label_tensor_predicted *= mask_tensor
        brains[label] = brain_tensor[0,0].detach().cpu().numpy()
        label_pred_arr[label] = label_tensor_predicted[0,0].detach().cpu().numpy()
        label_gt_arr[label] = label_tensor[0,0].detach().cpu().numpy()
        
        Precision, Sensitivity, Specificity, Dice, intensity = calculate_metrics(label_pred_arr[label], label_gt_arr[label])
        print(label,'\n', f'Dice {Dice}\n', f'Precicion {Precision}\n', f'Sensitivity {Sensitivity}\n', f'Specificity {Specificity}')
        print(f'Threshold {intensity}')
        
        #plt.figure("image", (30, 10))
        fig = figure(figsize=(12, 5), dpi=100)
        
        # Plot with intensities thresholded by calculate_metrics function
        masked_labels_pred = np.ma.masked_where(label_pred_arr[label] < intensity, label_pred_arr[label])
        masked_labels_gt = np.ma.masked_where(label_gt_arr[label] < intensity, label_gt_arr[label])
        
        #brain_pred = plotting.plot_anat(nib.Nifti1Image(masked_labels_pred, np.eye(4)), bg_img=nib.Nifti1Image(brains[label], np.eye(4)), cut_coords=(4,4,4), cmap='rainbow', alpha=0.5, display_mode='mosaic')
        brain_pred = plotting.plot_anat(nib.Nifti1Image(masked_labels_pred, np.eye(4)),
                                                        bg_img=nib.Nifti1Image(brains[label], np.eye(4)),
                                                        cmap='jet',
                                                        alpha=0.8,
                                                        figure=fig,
                                                        draw_cross=False,
                                                        title=f'Subject {label}',
                                                        colorbar=True)
        brain_pred.add_contours(nib.Nifti1Image(np.where(masked_labels_gt > 0.5,1,0), np.eye(4)), colors='b', levels=[1.0])
        plt.show()

### Plot test set predictions

In [None]:
feat_params = config.dataset.features
test_files = create_datafile(subj_list, feat_params, mask=True)

In [None]:
_, val_trans =  setup_transformations(config)
test_ds = monai.data.Dataset(data=test_files[0], transform=val_trans)
test_loader = DataLoader(test_ds, 
                        batch_size=1, 
                        num_workers=0, 
                        collate_fn=list_data_collate,
                        shuffle=False # important not to shuffle, to ensure label correspondence
                        )
check_data = monai.utils.misc.first(test_loader)
check_data['seg'].shape, check_data['image'].shape

In [None]:
dataloader = test_loader

brains = {}
labels_gt = {}
metric_dict = defaultdict(list)
label_pred_arr = {}
label_gt_arr = {}

metric = []

with torch.no_grad():

    # bs = 1
    # brain_tensor - [1,C,H,W,D]
    # mask_tensor - [1,1,H,W,D]
    # label_tensor - [1,1,H,W,D]
    
    #######################
    # ITERATE OVER BRAINS #
    #######################
    iterator = enumerate(dataloader)
   
    
    for iter_i, data_tensors in tqdm(iterator):
        brain_tensor, label_tensor, mask_tensor = (
                                                  data_tensors['image'].to(device),
                                                  data_tensors['seg'].to(device),
                                                  data_tensors['mask'].to(device)
                                                  )

        label = get_label(dataloader.dataset.data[iter_i]['seg'])
        print(f'Label: {label}')
        
        # forward pass
        label_tensor_forward = best_model(brain_tensor) # -> [1,1,ps,ps,ps]
        label_tensor_predicted = label_tensor_forward.to(device)
        label_tensor_predicted *= mask_tensor
        brains[label] = brain_tensor[0,0].detach().cpu().numpy()
        label_pred_arr[label] = label_tensor_predicted[0,0].detach().cpu().numpy()
        label_gt_arr[label] = label_tensor[0,0].detach().cpu().numpy()
        
        Precision, Sensitivity, Specificity, Dice, intensity = calculate_metrics(label_pred_arr[label], label_gt_arr[label])
        print(label, f'Dice {Dice}', f'Precicion {Precision}', f'Sensitivity {Sensitivity}', f'Specificity {Specificity}')
        print(f'Threshold {intensity}')
        
        #plt.figure("image", (30, 10))
        fig = figure(figsize=(12, 5), dpi=100)
        
        # Plot with intensities thresholded by calculate_metrics function
        masked_labels_pred = np.ma.masked_where(label_pred_arr[label] < intensity, label_pred_arr[label])
        masked_labels_gt = np.ma.masked_where(label_gt_arr[label] < intensity, label_gt_arr[label])
        
        #brain_pred = plotting.plot_anat(nib.Nifti1Image(masked_labels_pred, np.eye(4)), bg_img=nib.Nifti1Image(brains[label], np.eye(4)), cut_coords=(4,4,4), cmap='rainbow', alpha=0.5, display_mode='mosaic')
        brain_pred = plotting.plot_anat(nib.Nifti1Image(masked_labels_pred, np.eye(4)),
                                                        bg_img=nib.Nifti1Image(brains[label], np.eye(4)),
                                                        cmap='jet',
                                                        alpha=0.8,
                                                        figure=fig,
                                                        draw_cross=False,
                                                        title=f'Subject {label}',
                                                        colorbar=True)
        brain_pred.add_contours(nib.Nifti1Image(np.where(masked_labels_gt > 0.5,1,0), np.eye(4)), colors='b', levels=[1.0])
        plt.show()