In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path

import torch
import torchvision
from torchvision.transforms.transforms import Compose
from tqdm import tqdm
try:
    tqdm._instances.clear()
except:
    pass
import seaborn as sns
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
plt.rc('font', family='serif')
matplotlib.rcParams.update({'font.size': 18})

from uncertify.data.np_transforms import NumpyReshapeTransform, Numpy2PILTransform
from uncertify.data.datasets import GaussianNoiseDataset

from sklearn.metrics import roc_curve, precision_recall_curve, precision_recall_fscore_support
from sklearn.metrics import average_precision_score, roc_auc_score, confusion_matrix

from uncertify.models.vae import load_vae_baur_model
from uncertify.models.vae import VariationalAutoEncoder
from uncertify.models.encoder_decoder_baur2020 import BaurDecoder, BaurEncoder

from uncertify.data.dataloaders import dataloader_factory, DatasetType

from uncertify.evaluation.inference import yield_inference_batches
from uncertify.evaluation.ood_metrics import sample_wise_waic_scores
from uncertify.evaluation.ood_metrics import load_ensemble_models
from uncertify.evaluation.evaluation_pipeline import run_ood_detection_performance, EvaluationConfig, EvaluationResult, PixelAnomalyDetectionResult, SliceAnomalyDetectionResults, OODDetectionResult


from uncertify.visualization.grid import imshow_grid
from uncertify.visualization.model_performance import plot_roc_curve, plot_precision_recall_curve, plot_confusion_matrix
from uncertify.visualization.plotting import setup_plt_figure
from uncertify.visualization.histograms import plot_multi_histogram

from uncertify.common import DATA_DIR_PATH

In [None]:
# Define some paths and high level parameters
RUN_DIR_PATH = Path('/media/juniors/2TB_internal_HD/lightning_logs/train_vae/')
RUN_VERSIONS = [0, 1, 2, 3]
CHECKPOINT_PATHS = [RUN_DIR_PATH / f'version_{version}/checkpoints/last.ckpt' for version in RUN_VERSIONS]

In [None]:
ensemble_models = load_ensemble_models(DATA_DIR_PATH / 'ensemble_models', [f'model{idx}.ckpt' for idx in range(4)])

In [None]:
batch_size = 155
USE_N_BATCHES = 5

PROCESSED_DIR_PATH = Path('/media/juniors/2TB_internal_HD/datasets/processed/')

brats_t2_path    = PROCESSED_DIR_PATH / 'brats17_t2_bc_std_bv3.5_l10.hdf5'
brats_t2_hm_path = PROCESSED_DIR_PATH / 'brats17_t2_hm_bc_std_bv-3.5.hdf5'
brats_t1_path    = PROCESSED_DIR_PATH / 'brats17_t1_bc_std_bv3.5_l10.hdf5'
brats_t1_hm_path = PROCESSED_DIR_PATH / 'brats17_t1_hm_bc_std_bv-3.5.hdf5'
camcan_t2_val_path   = DATA_DIR_PATH  / 'processed/camcan_val_t2_hm_std_bv3.5_xe.hdf5'
camcan_t2_train_path = DATA_DIR_PATH  / 'processed/camcan_train_t2_hm_std_bv3.5_xe.hdf5'

_, brats_val_t2_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_path, shuffle_val=False, num_workers=12)
_, brats_val_t1_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t1_path, shuffle_val=False, num_workers=12)
_, brats_val_t2_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_hm_path, shuffle_val=False, num_workers=12)
_, brats_val_t1_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t1_hm_path, shuffle_val=False, num_workers=12)

hflip_transform = torchvision.transforms.Compose([
    NumpyReshapeTransform((200, 200)),
    Numpy2PILTransform(),
    torchvision.transforms.Resize((128, 128)),
    torchvision.transforms.RandomHorizontalFlip(p=1.0),
    torchvision.transforms.ToTensor()
])
vflip_transform = torchvision.transforms.Compose([
    NumpyReshapeTransform((200, 200)),
    Numpy2PILTransform(),
    torchvision.transforms.Resize((128, 128)),
    torchvision.transforms.RandomVerticalFlip(p=1.0),
    torchvision.transforms.ToTensor()
])

_, brats_val_t2_hm_hflip_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_hm_path, shuffle_val=False, num_workers=12, transform=hflip_transform)
_, brats_val_t1_hm_vflip_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t1_hm_path, shuffle_val=False, num_workers=12, transform=vflip_transform)

camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=batch_size, val_set_path=camcan_t2_val_path, train_set_path=camcan_t2_train_path, shuffle_val=False, shuffle_train=True, num_workers=12)

noise_set = GaussianNoiseDataset()
noise_loader = DataLoader(noise_set, batch_size=batch_size)

_, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, batch_size=batch_size, transform=torchvision.transforms.Compose([
                                                                        torchvision.transforms.Resize((128, 128)),
                                                                        torchvision.transforms.ToTensor()
                                                                    ])
                         )

dataloader_dict = {'BraTS T2 val': brats_val_t2_dataloader, 
                   'BraTS T1 val': brats_val_t1_dataloader, 
                   'BraTS T2 HM val': brats_val_t2_hm_dataloader, 
                   'BraTS T1 HM val': brats_val_t1_hm_dataloader,
                   'CamCAN train': camcan_train_dataloader,
                   'Gaussian noise': noise_loader,
                   'MNIST': mnist_val_dataloader
}

for name, dataloader in dataloader_dict.items(): 
    print(f'{name:15} dataloader: {len(dataloader)} batches (batch_size: {dataloader.batch_size}) -> {len(dataloader) * dataloader.batch_size} samples.')

# OOD Detection Evaluation for different OOD datasets

In [None]:
brats_dataloders = {#'BraTS T2 val': brats_val_t2_dataloader, 
                    #'BraTS T1 val': brats_val_t1_dataloader, 
                    #'BraTS T2 HM val': brats_val_t2_hm_dataloader, 
                    #'BraTS T1 HM val': brats_val_t1_hm_dataloader,
                    'BraTS T2 HM Hor. Flip': brats_val_t2_hm_hflip_dataloader, 
                    'BraTS T2 HM Vert. Flip': brats_val_t1_hm_vflip_dataloader, 
}

for name, dataloader in brats_dataloders.items():
    LOG.info(f'OOD evaluation for {name}...')
    eval_cfg = EvaluationConfig()
    eval_cfg.do_plots = True
    eval_cfg.use_n_batches = USE_N_BATCHES
    results = EvaluationResult(DATA_DIR_PATH / 'evaluation', PixelAnomalyDetectionResult(), SliceAnomalyDetectionResults(), OODDetectionResult())
    results.pixel_anomaly_result.best_threshold = 1.35
    results.make_dirs()
    print(results.current_run_number)
    
    run_ood_detection_performance(ensemble_models, camcan_train_dataloader, dataloader, eval_cfg, results)

# WAIC Score Histograms

In [None]:
NUM_BACTHES = 5

waic_dict = {}
for name, data_loader in dataloader_dict.items():
    LOG.info(f'WAIC score calculation for {name} ({NUM_BACTHES * data_loader.batch_size} patients)...')
    waic_scores = sample_wise_waic_scores(models=ensemble_models, data_loader=data_loader, max_n_batches=NUM_BACTHES)
    waic_dict[name] = waic_scores

In [None]:
fig, _ = plot_multi_histogram(waic_dict.values(), list(waic_dict.keys()), plot_density=False, 
                     figsize=(12, 6), xlabel='WAIC', ylabel='Slice-wise frequency',
                     hist_kwargs={'bins': 15});
fig.savefig(DATA_DIR_PATH / 'plots' / 'waic_scores.png')