In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path

import torch
import torchvision
from torchvision.transforms.transforms import Compose
from tqdm import tqdm
try:
    tqdm._instances.clear()
except:
    pass
import seaborn as sns
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
plt.rc('font', family='serif')
matplotlib.rcParams.update({'font.size': 18})

from uncertify.data.np_transforms import NumpyReshapeTransform, Numpy2PILTransform
from uncertify.data.datasets import GaussianNoiseDataset

from uncertify.data.dataloaders import dataloader_factory, DatasetType

from uncertify.io.models import load_ensemble_models
from uncertify.evaluation.ood_experiments import run_ood_evaluations, run_ood_to_ood_dict
from uncertify.evaluation.configs import OODDetectionResults

from uncertify.visualization.histograms import plot_multi_histogram
from uncertify.visualization.ood_scores import plot_ood_scores, plot_most_least_ood, plot_samples_close_to_score
from uncertify.visualization.grid import imshow_grid

from uncertify.common import DATA_DIR_PATH

In [None]:
# Define some paths and high level parameters
RUN_DIR_PATH = Path('/media/juniors/2TB_internal_HD/lightning_logs/train_vae/')
RUN_VERSIONS = range(2)
CHECKPOINT_PATHS = [RUN_DIR_PATH / f'version_{version}/checkpoints/last.ckpt' for version in RUN_VERSIONS]

In [None]:
ensemble_models = load_ensemble_models(DATA_DIR_PATH / 'ensemble_models', [f'model{idx}.ckpt' for idx in RUN_VERSIONS])

In [None]:
batch_size = 155
USE_N_BATCHES = 3

PROCESSED_DIR_PATH = Path('/media/juniors/2TB_internal_HD/datasets/processed/')

brats_t2_path    = PROCESSED_DIR_PATH / 'brats17_t2_bc_std_bv3.5_l10.hdf5'
brats_t2_hm_path = PROCESSED_DIR_PATH / 'brats17_t2_hm_bc_std_bv-3.5.hdf5'
brats_t1_path    = PROCESSED_DIR_PATH / 'brats17_t1_bc_std_bv3.5_l10.hdf5'
brats_t1_hm_path = PROCESSED_DIR_PATH / 'brats17_t1_hm_bc_std_bv-3.5.hdf5'
camcan_t2_val_path   = DATA_DIR_PATH  / 'processed/camcan_val_t2_hm_std_bv3.5_xe.hdf5'
camcan_t2_train_path = DATA_DIR_PATH  / 'processed/camcan_train_t2_hm_std_bv3.5_xe.hdf5'

_, brats_val_t2_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_path, shuffle_val=False, num_workers=12)
_, brats_val_t1_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t1_path, shuffle_val=False, num_workers=12)
_, brats_val_t2_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_hm_path, shuffle_val=False, num_workers=12)
_, brats_val_t1_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t1_hm_path, shuffle_val=False, num_workers=12)

hflip_transform = torchvision.transforms.Compose([
    NumpyReshapeTransform((200, 200)),
    Numpy2PILTransform(),
    torchvision.transforms.Resize((128, 128)),
    torchvision.transforms.RandomHorizontalFlip(p=1.0),
    torchvision.transforms.ToTensor()
])
vflip_transform = torchvision.transforms.Compose([
    NumpyReshapeTransform((200, 200)),
    Numpy2PILTransform(),
    torchvision.transforms.Resize((128, 128)),
    torchvision.transforms.RandomVerticalFlip(p=1.0),
    torchvision.transforms.ToTensor()
])

_, brats_val_t2_hflip_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_path, shuffle_val=False, num_workers=12, transform=hflip_transform)
_, brats_val_t2_vflip_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_path, shuffle_val=False, num_workers=12, transform=vflip_transform)

camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=batch_size, val_set_path=camcan_t2_val_path, train_set_path=camcan_t2_train_path, shuffle_val=False, shuffle_train=True, num_workers=12)

noise_set = GaussianNoiseDataset()
noise_loader = DataLoader(noise_set, batch_size=batch_size)

_, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, batch_size=batch_size, transform=torchvision.transforms.Compose([
                                                                        torchvision.transforms.Resize((128, 128)),
                                                                        torchvision.transforms.ToTensor()
                                                                    ])
                         )

dataloader_dict = {'BraTS T2 val': brats_val_t2_dataloader, 
                   #'BraTS T1 val': brats_val_t1_dataloader, 
                   #'BraTS T2 HM val': brats_val_t2_hm_dataloader, 
                   #'BraTS T1 HM val': brats_val_t1_hm_dataloader,
                   #'CamCAN train': camcan_train_dataloader,
                   #'Gaussian noise': noise_loader,
                   #'MNIST': mnist_val_dataloader,
                   #'BraTS T2 HFlip': brats_val_t2_hflip_dataloader,
                   #'BraTS T2 VFlip': brats_val_t2_vflip_dataloader
}
brats_dataloader_dict = {key: val for key, val in dataloader_dict.items() if 'BraTS' in key}

for name, dataloader in dataloader_dict.items(): 
    print(f'{name:15} dataloader: {len(dataloader)} batches (batch_size: {dataloader.batch_size}) -> {len(dataloader) * dataloader.batch_size} samples.')

# OOD Detection Evaluation for different OOD datasets

In [None]:
run_ood_evaluations(camcan_train_dataloader, dataloader_dict, ensemble_models, residual_threshold=1.35, max_n_batches=3)

# WAIC Score Histograms

In [None]:
NUM_BACTHES = 3

dataloader_dict = {'BraTS T2 val': brats_val_t2_dataloader, 
                   #'BraTS T1 val': brats_val_t1_dataloader, 
                   'BraTS T2 HM val': brats_val_t2_hm_dataloader, 
                   #'BraTS T1 HM val': brats_val_t1_hm_dataloader,
                   'CamCAN train': camcan_train_dataloader,
                   #'Gaussian noise': noise_loader,
                   #'MNIST': mnist_val_dataloader
}

waic_dict = run_ood_to_ood_dict(dataloader_dict, ensemble_models, num_batches=NUM_BACTHES)

In [None]:
plot_ood_scores(waic_dict, score_label='WAIC', dataset_name_filters=['BraTS T2 val'], modes_to_include=['healthy'])

In [None]:
plot_most_least_ood(waic_dict, 'BraTS T2 val')
plot_most_least_ood(waic_dict, 'BraTS T2 HM val')
plot_most_least_ood(waic_dict, 'CamCAN train', do_lesional=False)

In [None]:
plot_samples_close_to_score(waic_dict, 'BraTS T2 val', min_score=0.95, max_score=1.2)
plot_samples_close_to_score(waic_dict, 'BraTS T2 HM val', min_score=0.85, max_score=1.2)
plot_samples_close_to_score(waic_dict, 'CamCAN train', do_lesional=False, min_score=0.8, max_score=0.9)