In [None]:
%load_ext autoreload

import re, time, os, shutil, json, math
import numpy as np
import configdot
from tqdm import tqdm
import monai
from monai.data import DataLoader, Dataset, list_data_collate, decollate_batch

from collections import defaultdict
from IPython.core.debugger import set_trace

import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
import torch.optim as optim
from models.v2v import V2VModel

from losses import *
from dataset import setup_dataloaders, create_datafile, setup_datafiles, setup_transformations
from utils import save, get_capacity, calc_gradient_norm, get_label, get_latest_weights
import matplotlib.pyplot as plt
plt.rcParams['font.size'] = 20

%autoreload 2

In [None]:
LOGDIR = '/workspace/RawData/FCDNet/logs/features_comparison/t1_all/'
log_dir_iter = os.listdir(LOGDIR)
log_dir_iter

In [None]:
logname = log_dir_iter[]
logdir = os.path.join(LOGDIR, logname)

In [None]:
# val_preds_path = os.path.join(logdir, 'best_val_preds')
# best_val_preds = {}
# for label in os.listdir(val_preds_path):
#     val_preds_label_path = os.path.join(val_preds_path, label)
#     best_val_preds[label] = torch.load(val_preds_label_path)[0,0]

### Load model, setup dataloaders

In [None]:
config = configdot.parse_config(os.path.join(logdir,'config.ini'))

In [None]:
config.dataset.features

In [None]:
# essential for the proper samplers functioning
assert config.opt.val_batch_size == 1
DEVICE = 'cpu' # 'cuda:1' #'' # 'cpu'
device = torch.device(DEVICE)

#########
# MODEL #
#########
assert config.model.name == "v2v"
model = V2VModel(config).to(device)

###################
# CREATE DATASETS #
###################
train_loader, val_loader = setup_dataloaders(config)

print('val dataloder len: ', len(val_loader), 'train dataloder len: ', len(train_loader))

# Forward pass

In [None]:
SAVE_PREDS = False

if SAVE_PREDS:
    predictions_path = os.path.join(logdir, 'predictions')
    if not os.path.exists(predictions_path):
        os.makedirs(predictions_path)

In [None]:
best_val_preds['n23'].shape

In [None]:
dataloader = val_loader

brains = {}
labels_gt = {}
metric_dict = defaultdict(list)
labels_pred = {}
labels_ref = {}

with torch.no_grad():

    # bs = 1
    # brain_tensor - [1,C,H,W,D]
    # mask_tensor - [1,1,H,W,D]
    # label_tensor - [1,1,H,W,D]
    
    #######################
    # ITERATE OVER BRAINS #
    #######################
    iterator = enumerate(dataloader)
   
    
    for iter_i, data_tensors in tqdm(iterator):
        brain_tensor, label_tensor = data_tensors['image'], data_tensors['seg']

        brain_tensor = brain_tensor.to(device)
        label_tensor = label_tensor.to(device)

        label = get_label(dataloader.dataset.data[iter_i]['seg'])
        label_tensor_predicted = torch.tensor(best_val_preds[label][None,None,...]).to(device)
        print(f'Label: {label}')

        cov = coverage(label_tensor_predicted, label_tensor).item()
        fp = false_positive(label_tensor_predicted, label_tensor).item()
        fn = false_negative(label_tensor_predicted, label_tensor).item()
        dice = dice_score(label_tensor_predicted.detach() > 1e-4, label_tensor.detach()).item()

        metric_dict['coverage'].append(cov) # a.k.a recall
        metric_dict['false_positive'].append(fp)
        metric_dict['false_negative'].append(fn)
        metric_dict['dice_score'].append(dice)
        
        print(label, dice)
        plt.figure("image", (30, 10))
        brains[label] = brain_tensor[0,0].detach().cpu().numpy()
        labels_gt[label] = label_tensor[0,0].detach().cpu().numpy()
        label_tensor_predicted = label_tensor_predicted[0,0].detach().cpu()
        labels_pred[label] = label_tensor_predicted.numpy()
        labels_pred[label] = labels_pred[label]
        
        masked_labels_pred = np.ma.masked_where(labels_pred[label] < 1e-4, labels_pred[label])
        masked_labels_gt = np.ma.masked_where(labels_gt[label] < 1e-4, labels_gt[label])
        
        for i in range(12):
            plt.subplot(1, 12, i+1)
            ind = math.floor(127 * (i / 12))
            plt.imshow(brains[label][:,:,ind], cmap='gray')
            plt.imshow(masked_labels_gt[:,:,ind], cmap='Reds')
            plt.imshow(masked_labels_pred[:,:,ind], cmap='Greens', alpha=0.5)
        
        if SAVE_PREDS:
            torch.save(label_tensor_predicted, os.path.join(predictions_path, f'{label}'))
        plt.show()

In [None]:
get_latest_weights(logdir)

In [None]:
best_model = V2VModel(config)

In [None]:
model_dict = torch.load(get_latest_weights(logdir))
best_model.load_state_dict(model_dict['model_state'])
best_model.eval()

In [None]:
test_data_indcs = ['6', '7', '40', '45', '75', '84', '14', '42', '68NS', '71NS']
feat_params = config.dataset.features
test_files = create_datafile(test_data_indcs, feat_params, mask=True)

In [None]:
_, val_trans =  setup_transformations(config)
test_ds = monai.data.Dataset(data=test_files[0], transform=val_trans)
test_loader = DataLoader(test_ds, 
                        batch_size=1, 
                        num_workers=0, 
                        collate_fn=list_data_collate,
                        shuffle=False # important not to shuffle, to ensure label correspondence
                        )

In [None]:
check_data = monai.utils.misc.first(test_loader)

In [None]:
check_data['seg'].shape, check_data['image'].shape

In [None]:
dataloader = test_loader

brains = {}
labels_gt = {}
metric_dict = defaultdict(list)
labels_pred = {}
labels_ref = {}

with torch.no_grad():

    # bs = 1
    # brain_tensor - [1,C,H,W,D]
    # mask_tensor - [1,1,H,W,D]
    # label_tensor - [1,1,H,W,D]
    
    #######################
    # ITERATE OVER BRAINS #
    #######################
    iterator = enumerate(dataloader)
   
    
    for iter_i, data_tensors in tqdm(iterator):
        brain_tensor, label_tensor = data_tensors['image'], data_tensors['seg']

        brain_tensor = brain_tensor.to(device)
        label_tensor = label_tensor.to(device)

        label = get_label(dataloader.dataset.data[iter_i]['seg'])
        print(f'Label: {label}')
        
        # forward pass
        label_tensor_forward = best_model(brain_tensor) # -> [1,1,ps,ps,ps]
        label_tensor_predicted = label_tensor_forward.to(device)
        
        print(label_tensor_forward.shape)
        cov = coverage(label_tensor_predicted, label_tensor).item()
        fp = false_positive(label_tensor_predicted, label_tensor).item()
        fn = false_negative(label_tensor_predicted, label_tensor).item()
        dice = dice_score(label_tensor_predicted.detach() > 1e-4, label_tensor.detach()).item()

        metric_dict['coverage'].append(cov) # a.k.a recall
        metric_dict['false_positive'].append(fp)
        metric_dict['false_negative'].append(fn)
        metric_dict['dice_score'].append(dice)
        
        print(label, dice)
        plt.figure("image", (30, 10))
        brains[label] = brain_tensor[0,0].detach().cpu().numpy()
        labels_gt[label] = label_tensor[0,0].detach().cpu().numpy()
        label_tensor_predicted = label_tensor_predicted[0,0].detach().cpu()
        labels_pred[label] = label_tensor_predicted.numpy()
        labels_pred[label] = labels_pred[label]
        
        masked_labels_pred = np.ma.masked_where(labels_pred[label] < 1e-4, labels_pred[label])
        masked_labels_gt = np.ma.masked_where(labels_gt[label] < 1e-4, labels_gt[label])
        
        #masked_labels_pred = labels_pred[label]
        #masked_labels_gt = labels_gt[label]
        
        for i in range(12):
            plt.subplot(1, 12, i+1)
            ind = math.floor(127 * (i / 12))
            plt.imshow(brains[label][:,:,ind], cmap='gray')
            plt.imshow(masked_labels_gt[:,:,ind], cmap='Reds')
            plt.imshow(masked_labels_pred[:,:,ind], cmap='Greens', alpha=0.8)
        
        if SAVE_PREDS:
            torch.save(label_tensor_predicted, os.path.join(predictions_path, f'{label}'))
        plt.show()

In [None]:
# plt.ion()
# plt.figure(figsize=(10,5),dpi=200)
# plt.bar(metric_dict['dice_score'].keys(), metric_dict['dice_score'].values(), alpha=0.5, label='V2V')
# exp_name = logdir.split('/')[1]
# plt.title(f'Val')
# plt.ylabel('Dice score')
# plt.xticks(rotation=45)
# plt.legend()
# plt.show()

In [None]:
# plt.ion()
# plt.figure(figsize=(10,5),dpi=200)
# plt.bar(metric_dict['coverage'].keys(), metric_dict['coverage'].values(), alpha=0.5, label='V2V')
# exp_name = logdir.split('/')[1]
# plt.title(f'Val')
# plt.ylabel('Recall')
# plt.xticks(rotation=45)
# plt.legend()
# plt.show()

In [None]:
# dices = np.array(list(metric_dict['dice_score'].values()))
# labels = np.array(list(metric_dict['dice_score'].keys()))

In [None]:
image_path = os.path.join(logdir, 'images_val')
if not os.path.exists(image_path):
    os.makedirs(image_path)

In [None]:
for label in brains.keys():

    brain = brains[label][0]
    label_gt = labels_gt[label]
    label_pred = labels_pred[label]

    fig, ax = plt.subplots(1,2, sharex=True, sharey=True, dpi=300)
    X_max, Y_max, Z_max = brain.shape

    y_slice_pos = label_gt.sum(axis=(0,-1)).argmax()

    brain_slice = brain[:,y_slice_pos,:]
    label_gt_slice = label_gt[:,y_slice_pos,:]
    ax[0].imshow(brain_slice, cmap='gray')
    ax[0].imshow(label_gt_slice, cmap='Reds', alpha=0.5)
    ax[0].set_title('GT')

    label_pred_slice = label_pred[:,y_slice_pos,:].astype(float)
    ax[1].imshow(brain_slice, cmap='gray')
    ax[1].imshow(label_pred_slice, cmap='Reds', alpha=0.5)
    ax[1].set_title('V2V')
    
    fig.suptitle(f'Label: {label}', y=.85)

    plt.show()
    

In [None]:
from batch_metrics import our_metric