In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path

import torch
import torchvision
from torchvision.transforms.transforms import Compose
from tqdm import tqdm
try:
    tqdm._instances.clear()
except:
    pass
import seaborn as sns
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
plt.rc('font', family='serif')
matplotlib.rcParams.update({'font.size': 18})

from uncertify.data.np_transforms import NumpyReshapeTransform, Numpy2PILTransform
from uncertify.data.datasets import GaussianNoiseDataset

from uncertify.data.dataloaders import dataloader_factory, DatasetType

from uncertify.evaluation.ood_metrics import sample_wise_waic_scores
from uncertify.evaluation.ood_metrics import load_ensemble_models
from uncertify.evaluation.evaluation_pipeline import run_ood_detection_performance, EvaluationConfig, EvaluationResult, PixelAnomalyDetectionResult, SliceAnomalyDetectionResults
from uncertify.evaluation.evaluation_pipeline import print_results_from_evaluation_dirs
from uncertify.evaluation.configs import OODDetectionResults
from uncertify.utils.python_helpers import get_indices_of_n_largest_items, get_indices_of_n_smallest_items, get_idx_of_closest_value

from uncertify.visualization.histograms import plot_multi_histogram
from uncertify.visualization.ood_scores import plot_ood_scores
from uncertify.visualization.grid import imshow_grid

from uncertify.common import DATA_DIR_PATH

In [None]:
# Define some paths and high level parameters
RUN_DIR_PATH = Path('/media/juniors/2TB_internal_HD/lightning_logs/train_vae/')
RUN_VERSIONS = [0, 1, 2, 3]
CHECKPOINT_PATHS = [RUN_DIR_PATH / f'version_{version}/checkpoints/last.ckpt' for version in RUN_VERSIONS]

In [None]:
ensemble_models = load_ensemble_models(DATA_DIR_PATH / 'ensemble_models', [f'model{idx}.ckpt' for idx in range(4)])

In [None]:
batch_size = 155
USE_N_BATCHES = 10

PROCESSED_DIR_PATH = Path('/media/juniors/2TB_internal_HD/datasets/processed/')

brats_t2_path    = PROCESSED_DIR_PATH / 'brats17_t2_bc_std_bv3.5_l10.hdf5'
brats_t2_hm_path = PROCESSED_DIR_PATH / 'brats17_t2_hm_bc_std_bv-3.5.hdf5'
brats_t1_path    = PROCESSED_DIR_PATH / 'brats17_t1_bc_std_bv3.5_l10.hdf5'
brats_t1_hm_path = PROCESSED_DIR_PATH / 'brats17_t1_hm_bc_std_bv-3.5.hdf5'
camcan_t2_val_path   = DATA_DIR_PATH  / 'processed/camcan_val_t2_hm_std_bv3.5_xe.hdf5'
camcan_t2_train_path = DATA_DIR_PATH  / 'processed/camcan_train_t2_hm_std_bv3.5_xe.hdf5'

_, brats_val_t2_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_path, shuffle_val=False, num_workers=12)
_, brats_val_t1_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t1_path, shuffle_val=False, num_workers=12)
_, brats_val_t2_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_hm_path, shuffle_val=False, num_workers=12)
_, brats_val_t1_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t1_hm_path, shuffle_val=False, num_workers=12)

hflip_transform = torchvision.transforms.Compose([
    NumpyReshapeTransform((200, 200)),
    Numpy2PILTransform(),
    torchvision.transforms.Resize((128, 128)),
    torchvision.transforms.RandomHorizontalFlip(p=1.0),
    torchvision.transforms.ToTensor()
])
vflip_transform = torchvision.transforms.Compose([
    NumpyReshapeTransform((200, 200)),
    Numpy2PILTransform(),
    torchvision.transforms.Resize((128, 128)),
    torchvision.transforms.RandomVerticalFlip(p=1.0),
    torchvision.transforms.ToTensor()
])

_, brats_val_t2_hflip_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_path, shuffle_val=False, num_workers=12, transform=hflip_transform)
_, brats_val_t2_vflip_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_path, shuffle_val=False, num_workers=12, transform=vflip_transform)

camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=batch_size, val_set_path=camcan_t2_val_path, train_set_path=camcan_t2_train_path, shuffle_val=False, shuffle_train=True, num_workers=12)

noise_set = GaussianNoiseDataset()
noise_loader = DataLoader(noise_set, batch_size=batch_size)

_, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, batch_size=batch_size, transform=torchvision.transforms.Compose([
                                                                        torchvision.transforms.Resize((128, 128)),
                                                                        torchvision.transforms.ToTensor()
                                                                    ])
                         )

dataloader_dict = {#'BraTS T2 val': brats_val_t2_dataloader, 
                   #'BraTS T1 val': brats_val_t1_dataloader, 
                   #'BraTS T2 HM val': brats_val_t2_hm_dataloader, 
                   #'BraTS T1 HM val': brats_val_t1_hm_dataloader,
                   #'CamCAN train': camcan_train_dataloader,
                   #'Gaussian noise': noise_loader,
                   #'MNIST': mnist_val_dataloader,
                   'BraTS T2 HFlip': brats_val_t2_hflip_dataloader,
                   'BraTS T2 VFlip': brats_val_t2_vflip_dataloader
}
brats_dataloader_dict = {key: val for key, val in dataloader_dict.items() if 'BraTS' in key}

for name, dataloader in dataloader_dict.items(): 
    print(f'{name:15} dataloader: {len(dataloader)} batches (batch_size: {dataloader.batch_size}) -> {len(dataloader) * dataloader.batch_size} samples.')

# OOD Detection Evaluation for different OOD datasets

In [None]:
run_numbers = []
for name, dataloader in dataloader_dict.items():
    LOG.info(f'OOD evaluation for {name}...')
    eval_cfg = EvaluationConfig()
    eval_cfg.do_plots = True
    eval_cfg.use_n_batches = USE_N_BATCHES
    results = EvaluationResult(DATA_DIR_PATH / 'evaluation', eval_cfg, PixelAnomalyDetectionResult(), SliceAnomalyDetectionResults(), OODDetectionResults())
    results.pixel_anomaly_result.best_threshold = 1.35
    results.make_dirs()
    run_numbers.append(results.run_number)
    run_ood_detection_performance(ensemble_models, camcan_train_dataloader, dataloader, eval_cfg, results)
    results.test_set_name = name
    results.to_json()
print_results_from_evaluation_dirs(DATA_DIR_PATH / 'evaluation', run_numbers, print_results_only=True)

# WAIC Score Histograms

In [None]:
NUM_BACTHES = 10

dataloader_dict = {'BraTS T2 val': brats_val_t2_dataloader, 
                   #'BraTS T1 val': brats_val_t1_dataloader, 
                   'BraTS T2 HM val': brats_val_t2_hm_dataloader, 
                   #'BraTS T1 HM val': brats_val_t1_hm_dataloader,
                   'CamCAN train': camcan_train_dataloader,
                   #'Gaussian noise': noise_loader,
                   #'MNIST': mnist_val_dataloader
}

waic_dict = {}
for name, data_loader in dataloader_dict.items():
    LOG.info(f'WAIC score calculation for {name} ({NUM_BACTHES * data_loader.batch_size} slices)...')
    slice_wise_waic_scores, slice_wise_is_lesional, scans = sample_wise_waic_scores(models=ensemble_models, data_loader=data_loader, max_n_batches=NUM_BACTHES, return_slices=True)
    
    # Organise as healthy / unhealthy
    healthy_waics = []
    lesional_waics = []
    healthy_scans = []
    lesional_scans = []
    for idx in range(len(slice_wise_waic_scores)):
        is_lesional_slice = slice_wise_is_lesional[idx]
        if is_lesional_slice:
            lesional_waics.append(slice_wise_waic_scores[idx])
            if scans is not None:
                lesional_scans.append(scans[idx])
        else:
            healthy_waics.append(slice_wise_waic_scores[idx])
            if scans is not None:
                healthy_scans.append(scans[idx])

    dataset_waic_dict = {'all': slice_wise_waic_scores, 'healthy': healthy_waics, 'lesional': lesional_waics,
                         'healthy_scans': healthy_scans, 'lesional_scans': lesional_scans}
    waic_dict[name] = dataset_waic_dict

In [None]:
plot_ood_scores(waic_dict, score_label='WAIC', dataset_name_filters=['BraTS T2 val'], modes_to_include=['healthy'])

In [None]:
waic_dict.keys()

In [None]:
DATASET = 'BraTS T2 val'


def plot_most_least_ood(waic_dict: dict, dataset_name: str, n_most: int = 16, do_lesional: bool = True) -> None:
    """For healthy and lesional samples, plot the ones which are most and least OOD."""
    ood_dict = waic_dict[dataset_name]
    
    def create_ood_grids(healthy_leasional: str):
        scores = ood_dict[healthy_leasional]
        slices = ood_dict[f'{healthy_leasional}_scans']
        largest_score_indices = get_indices_of_n_largest_items(scores, n_most)
        smallest_score_indices = get_indices_of_n_smallest_items(scores, n_most)
        
        largest_slices = [slices[idx] for idx in largest_score_indices]
        smallest_slices = [slices[idx] for idx in smallest_score_indices]
        
        largest_grid = torchvision.utils.make_grid(largest_slices, padding=0, normalize=False)
        smallest_grid = torchvision.utils.make_grid(smallest_slices, padding=0, normalize=False)
        
        return largest_grid, smallest_grid
    
    print('Creating healthy grids...')
    most_ood_healthy_grid, least_ood_healthy_grid = create_ood_grids('healthy')
    if do_lesional:
        print('Creating lesional grids...')
        most_ood_lesional_grid, least_ood_lesional_grid = create_ood_grids('lesional')
    
    imshow_grid(most_ood_healthy_grid, one_channel=True, figsize=(12, 8), title=f'Most OOD Healthy {dataset_name}', axis='off')
    imshow_grid(least_ood_healthy_grid, one_channel=True, figsize=(12, 8), title=f'Least OOD Healthy {dataset_name}', axis='off')
    if do_lesional:
        imshow_grid(most_ood_lesional_grid, one_channel=True, figsize=(12, 8), title=f'Most OOD Lesional {dataset_name}', axis='off')
        imshow_grid(least_ood_lesional_grid, one_channel=True, figsize=(12, 8), title=f'Least OOD Lesional {dataset_name}', axis='off')

plot_most_least_ood(waic_dict, 'BraTS T2 val')
plot_most_least_ood(waic_dict, 'BraTS T2 HM val')
plot_most_least_ood(waic_dict, 'CamCAN train', do_lesional=False)

In [None]:
def plot_samples_close_to_score(ood_dict: dict, dataset_name: str, min_score: float, max_score: float, n: int = 32, do_lesional: bool = True) -> None:
    ood_dict = ood_dict[dataset_name]
    ref_scores = np.linspace(min_score, max_score, n)
    def create_ood_grids(healthy_leasional: str):
        scores = ood_dict[healthy_leasional]
        slices = ood_dict[f'{healthy_leasional}_scans']
        
        final_scores = []
        final_slices = []
        
        for ref_score in ref_scores:
            scores_idx = get_idx_of_closest_value(scores, ref_score)
            final_scores.append(scores[scores_idx])
            final_slices.append(slices[scores_idx])
            
        return torchvision.utils.make_grid(final_slices, padding=0, normalize=False)
    
    healthy_grid = create_ood_grids('healthy')
    if do_lesional:
        lesional_grid = create_ood_grids('lesional')
    
    imshow_grid(healthy_grid, one_channel=True, figsize=(12, 8), title=f'Healthy {dataset_name} {min_score}-{max_score}', axis='off')
    if do_lesional:
        imshow_grid(lesional_grid, one_channel=True, figsize=(12, 8), title=f'Lesional {dataset_name} {min_score}-{max_score}', axis='off')

plot_samples_close_to_score(waic_dict, 'BraTS T2 val', min_score=0.95, max_score=1.2)
plot_samples_close_to_score(waic_dict, 'BraTS T2 HM val', min_score=0.85, max_score=1.2)
plot_samples_close_to_score(waic_dict, 'CamCAN train', do_lesional=False, min_score=0.8, max_score=0.9)