In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path

import torch
import torchvision
from tqdm import tqdm
try:
    tqdm._instances.clear()
except:
    pass
import seaborn as sns
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
plt.rc('font', family='serif')
matplotlib.rcParams.update({'font.size': 18})

from uncertify.data.np_transforms import NumpyReshapeTransform, Numpy2PILTransform
from uncertify.data.transforms import H_FLIP_TRANSFORM, V_FLIP_TRANSFORM
from uncertify.data.datasets import GaussianNoiseDataset

from uncertify.data.dataloaders import dataloader_factory, DatasetType
from uncertify.models.vae import VariationalAutoEncoder
from uncertify.models.encoder_decoder_baur2020 import BaurDecoder, BaurEncoder

from uncertify.io.models import load_ensemble_models
from uncertify.evaluation.ood_experiments import run_ood_evaluations, run_ood_to_ood_dict

from uncertify.visualization.ood_scores import plot_ood_scores, plot_most_least_ood, plot_samples_close_to_score

from uncertify.common import DATA_DIR_PATH

In [None]:
# Load models
RUN_DIR_PATH = Path('/media/juniors/2TB_internal_HD/lightning_logs/train_vae/')
RUN_VERSIONS = [0, 1, 2, 3]
CHECKPOINT_PATHS = [RUN_DIR_PATH / f'version_{version}/checkpoints/last.ckpt' for version in RUN_VERSIONS]
ensemble_models = load_ensemble_models(DATA_DIR_PATH / 'ensemble_models', [f'model{idx}.ckpt' for idx in RUN_VERSIONS])

MASKED_CHECKPOINT_PATH = DATA_DIR_PATH / 'lightning_logs/vae_kl_test/version_20/checkpoints/last.ckpt'
assert MASKED_CHECKPOINT_PATH.exists(), f'Model checkpoint does not exist!'

checkpoint = torch.load(MASKED_CHECKPOINT_PATH)
model = VariationalAutoEncoder(BaurEncoder(), BaurDecoder())
model.load_state_dict(checkpoint['state_dict'])

In [None]:
batch_size    = 155
USE_N_BATCHES = 10
NUM_WORKERS   = 0

PROCESSED_DIR_PATH = Path('/media/juniors/2TB_internal_HD/datasets/processed/')

brats_t2_path    = PROCESSED_DIR_PATH / 'brats17_t2_bc_std_bv3.5_l10.hdf5'
brats_t2_hm_path = PROCESSED_DIR_PATH / 'brats17_t2_hm_bc_std_bv-3.5.hdf5'
brats_t1_path    = PROCESSED_DIR_PATH / 'brats17_t1_bc_std_bv3.5_l10.hdf5'
brats_t1_hm_path = PROCESSED_DIR_PATH / 'brats17_t1_hm_bc_std_bv-3.5.hdf5'
camcan_t2_val_path   = DATA_DIR_PATH  / 'processed/camcan_val_t2_hm_std_bv3.5_xe.hdf5'
camcan_t2_train_path = DATA_DIR_PATH  / 'processed/camcan_train_t2_hm_std_bv3.5_xe.hdf5'

_, brats_val_t2_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_path, shuffle_val=False, num_workers=NUM_WORKERS)
_, brats_val_t1_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t1_path, shuffle_val=False, num_workers=NUM_WORKERS)
_, brats_val_t2_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_hm_path, shuffle_val=False, num_workers=NUM_WORKERS)
_, brats_val_t1_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t1_hm_path, shuffle_val=False, num_workers=NUM_WORKERS)


_, brats_val_t2_hflip_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_path, shuffle_val=False, num_workers=NUM_WORKERS, transform=H_FLIP_TRANSFORM)
_, brats_val_t2_vflip_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_path, shuffle_val=False, num_workers=NUM_WORKERS, transform=V_FLIP_TRANSFORM)

camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=batch_size, val_set_path=camcan_t2_val_path, train_set_path=camcan_t2_train_path, shuffle_val=False, shuffle_train=True, num_workers=NUM_WORKERS)

noise_set = GaussianNoiseDataset()
noise_loader = DataLoader(noise_set, batch_size=batch_size)

_, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, batch_size=batch_size, transform=torchvision.transforms.Compose([
                                                                        torchvision.transforms.Resize((128, 128)),
                                                                        torchvision.transforms.ToTensor()
                                                                    ])
                         )

dataloader_dict = {'BraTS T2 val': brats_val_t2_dataloader, 
                   'BraTS T1 val': brats_val_t1_dataloader, 
                   #'BraTS T2 HM val': brats_val_t2_hm_dataloader, 
                   #'BraTS T1 HM val': brats_val_t1_hm_dataloader,
                   #'CamCAN train': camcan_train_dataloader,
                   #'CamCAN val': camcan_val_dataloader,
                   #'Gaussian noise': noise_loader,
                   #'MNIST': mnist_val_dataloader,
                   #'BraTS T2 HFlip': brats_val_t2_hflip_dataloader,
                   #'BraTS T2 VFlip': brats_val_t2_vflip_dataloader
}
brats_dataloader_dict = {key: val for key, val in dataloader_dict.items() if 'BraTS' in key}

for name, dataloader in dataloader_dict.items(): 
    print(f'{name:15} dataloader: {len(dataloader)} batches (batch_size: {dataloader.batch_size}) -> {len(dataloader) * dataloader.batch_size} samples.')

# OOD Detection Evaluation for different OOD datasets

In [None]:
run_ood_evaluations(camcan_train_dataloader, dataloader_dict, [model],
                    residual_threshold=0.93, max_n_batches=10)

# OOD Score

In [None]:
NUM_BACTHES = 10
OOD_METRICS = ('dose', )  # waic
DOSE_STATISTICS = ('elbo', 'kl_div', 'rec_err')

dataloader_dict = {'BraTS T2 val': brats_val_t2_dataloader, 
                   #'BraTS T1 val': brats_val_t1_dataloader, 
                   #'BraTS T2 HM val': brats_val_t2_hm_dataloader,
                   #'BraTS T1 HM val': brats_val_t1_hm_dataloader,
                   #'CamCAN train': camcan_train_dataloader,
                   #'Gaussian noise': noise_loader,
                   #'MNIST': mnist_val_dataloader
}

metrics_ood_dict = run_ood_to_ood_dict(dataloader_dict, [model], camcan_train_dataloader,
                                       num_batches=NUM_BACTHES, ood_metrics=OOD_METRICS,
                                       dose_statistics=DOSE_STATISTICS)

In [None]:
for metric in OOD_METRICS:
    ood_dict = metrics_ood_dict[metric]
    #plot_ood_scores(ood_dict, score_label='metric', dataset_name_filters=[], modes_to_include=['healthy', 'lesional'])

    #plot_most_least_ood(ood_dict, 'BraTS T2 val', n_most=64)
    #plot_most_least_ood(ood_dict, 'BraTS T2 HM val')
    #plot_most_least_ood(ood_dict, 'CamCAN train', do_lesional=False, n_most=64)

    plot_samples_close_to_score(ood_dict, 'BraTS T2 val', min_score=0, max_score=24, n=64, show_ground_truth=True)
    #plot_samples_close_to_score(ood_dict, 'BraTS T2 HM val', min_score=0.85, max_score=1.2)
    #plot_samples_close_to_score(ood_dict, 'CamCAN train', do_lesional=False, min_score=0, max_score=25, n=64)


# Slice-Wise Statistics

In [None]:
import pandas as pd
import seaborn as sns
sns.set_context("poster")

from uncertify.evaluation.dose import full_pipeline_slice_wise_dose_scores
from uncertify.evaluation.statistics import fit_statistics, aggregate_slice_wise_statistics

In [None]:
# Plot the fitted training statistics as a lower triangular pair-plot

DOSE_STATISTICS = ['elbo', 'rec_err', 'kl_div']
MAX_N_BATCHES = 10

test_dataloaders = [camcan_train_dataloader]#, brats_val_t2_dataloader]
for dataloader in test_dataloaders:
    statistics_dict = aggregate_slice_wise_statistics(model, dataloader,
                                                      DOSE_STATISTICS, max_n_batches=MAX_N_BATCHES)
    kde_func_dict = fit_statistics(statistics_dict)
    stat_df = pd.DataFrame(statistics_dict)
    plt.figure()
    sns_plot = sns.pairplot(stat_df, vars=DOSE_STATISTICS,
                            corner=True, plot_kws={"s": 10}, palette='viridis',
                            hue='is_lesional' if dataloader != camcan_train_dataloader else None)
    sns_plot.map_lower(sns.kdeplot, shade=True, thresh=0.05, alpha=0.7)
    sns_plot.map_diag(sns.kdeplot)
    sns_plot.savefig(DATA_DIR_PATH / 'plots' / f'dose_training_statistics_{dataloader.dataset.name}.png')
    
    plt.figure()
    for stat in DOSE_STATISTICS:
        sns.kdeplot(stat_df[stat], label=stat)
    

In [None]:
# Plot the DoSE KDE scores on a fitted training distribution
MAX_N_BATCHES = 5
test_dataloaders = [brats_val_t1_dataloader]  #, brats_val_t2_dataloader
for dataloader in test_dataloaders:
    dose_scores, dose_kde_dict = full_pipeline_slice_wise_dose_scores(camcan_train_dataloader, dataloader, model,
                                                                    DOSE_STATISTICS, MAX_N_BATCHES, kde_func_dict)
    dose_df = pd.DataFrame(dose_kde_dict)
    sns_plot = sns.pairplot(dose_df, vars=DOSE_STATISTICS,
                            hue='is_lesional' if dataloader != camcan_train_dataloader else None,
                            corner=True, plot_kws={"s": 10}, palette='viridis')
    sns_plot.map_lower(sns.kdeplot, shade=True, thresh=0.05, alpha=0.7)
    sns_plot.map_diag(sns.kdeplot)
    sns_plot.savefig(DATA_DIR_PATH / 'plots' / f'dose_kde_{dataloader.dataset.name}.png')

In [None]:
statistics_dict = aggregate_slice_wise_statistics(model, camcan_train_dataloader, DOSE_STATISTICS, max_n_batches=50)
kde_func_dict = fit_statistics(statistics_dict)

In [None]:
dose_scores, dose_kde_dict = full_pipeline_slice_wise_dose_scores(camcan_train_dataloader, brats_val_t2_dataloader, model, 
                                                                    DOSE_STATISTICS, MAX_N_BATCHES, kde_func_dict)
final_dose_df = pd.DataFrame({'DoSE': dose_scores, 'is_lesional': dose_kde_dict['is_lesional']})
sns.kdeplot(final_dose_df.DoSE, hue=final_dose_df.is_lesional)

In [None]:
dose_scores, dose_kde_dict = full_pipeline_slice_wise_dose_scores(camcan_train_dataloader, camcan_train_dataloader, model, 
                                                                    DOSE_STATISTICS, MAX_N_BATCHES, kde_func_dict)
final_dose_df = pd.DataFrame({'DoSE': dose_scores, 'is_lesional': dose_kde_dict['is_lesional']})
sns.kdeplot(final_dose_df.DoSE, hue=final_dose_df.is_lesional)