In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path
import operator

import torch
import torchvision
from tqdm import tqdm
try:
    tqdm._instances.clear()
except:
    pass
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from uncertify.data.dataloaders import dataloader_factory, DatasetType
from uncertify.visualization.reconstruction import plot_stacked_scan_reconstruction_batches
from uncertify.evaluation.inference import yield_inference_batches, yield_anomaly_predictions
from uncertify.evaluation.utils import residual_l1, residual_l1_max
from uncertify.visualization.plotting import save_fig
from uncertify.data.datasets import GaussianNoiseDataset
from uncertify.io.models import load_ensemble_models
from uncertify.models.vae import load_vae_baur_model
from uncertify.evaluation.evaluation_pipeline import run_anomaly_detection_performance

from uncertify.common import DATA_DIR_PATH, HD_DATA_PATH

# Load Model and Data

In [None]:
# Load models
MASKED_TRAINING_MODELS = True
model_dir = 'masked_ensemble_models' if MASKED_TRAINING_MODELS else 'ensemble_models'
RUN_VERSIONS = [1, 2, 3, 4, 5] if MASKED_TRAINING_MODELS else [1, 2, 3, 4, 5]
ensemble_models = load_ensemble_models(DATA_DIR_PATH / 'masked_ensemble_models', [f'model{idx}.ckpt' for idx in RUN_VERSIONS])
model = ensemble_models[0]

In [None]:
model = load_vae_baur_model(HD_DATA_PATH.parent / 'lightning_logs/schedule_test/version_4/checkpoints/last.ckpt')

In [None]:
BATCH_SIZE = 8
num_workers = 0
SHUFFLE_VAL = True

EVAL_DIR_PATH = DATA_DIR_PATH / 'evaluation'
PROCESSED_DIR_PATH = HD_DATA_PATH / 'processed' 

brats_t2_path    = DATA_DIR_PATH  / 'processed/brats17_t2_bc_std_bv3.5.hdf5'
brats_t2_hm_path = PROCESSED_DIR_PATH / 'brats17_t2_hm_bc_std_bv3.5.hdf5'
brats_t1_path    = DATA_DIR_PATH  / 'processed/brats17_t1_bc_std_bv3.5.hdf5'
brats_t1_hm_path = PROCESSED_DIR_PATH / 'brats17_t1_hm_bc_std_bv-3.5.hdf5'
camcan_t2_val_path   = DATA_DIR_PATH  / 'processed/camcan_val_t2_hm_std_bv3.5_xe.hdf5'
camcan_t2_train_path = DATA_DIR_PATH  / 'processed/camcan_train_t2_hm_std_bv3.5_xe.hdf5'

_, brats_val_t2_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=BATCH_SIZE, val_set_path=brats_t2_path, shuffle_val=SHUFFLE_VAL, num_workers=num_workers)
_, brats_val_t1_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=BATCH_SIZE, val_set_path=brats_t1_path, shuffle_val=SHUFFLE_VAL, num_workers=num_workers)
_, brats_val_t2_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=BATCH_SIZE, val_set_path=brats_t2_hm_path, shuffle_val=SHUFFLE_VAL, num_workers=num_workers)
_, brats_val_t1_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=BATCH_SIZE, val_set_path=brats_t1_hm_path, shuffle_val=SHUFFLE_VAL, num_workers=num_workers)

camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=BATCH_SIZE, 
                                                                    val_set_path=camcan_t2_val_path, train_set_path=camcan_t2_train_path, 
                                                                    shuffle_val=SHUFFLE_VAL, shuffle_train=True, num_workers=num_workers)
camcan_lesional_train_dataloader, camcan_lesional_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=BATCH_SIZE, val_set_path=camcan_t2_val_path, 
                                                                                      train_set_path=camcan_t2_train_path, shuffle_val=SHUFFLE_VAL, shuffle_train=SHUFFLE_VAL, add_gauss_blobs=True)

noise_set = GaussianNoiseDataset()
noise_loader = DataLoader(noise_set, batch_size=BATCH_SIZE)

_, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, batch_size=BATCH_SIZE, transform=torchvision.transforms.Compose([
                                                                        torchvision.transforms.Resize((128, 128)),
                                                                        torchvision.transforms.ToTensor()
                                                                    ])
                         )

for name, dataloader in [('BraTS T2 val', brats_val_t2_dataloader), 
                         ('BraTS T1 val', brats_val_t1_dataloader), 
                         ('BraTS T2 HM val', brats_val_t2_hm_dataloader), 
                         ('BraTS T1 HM val', brats_val_t1_hm_dataloader),
                         ('CamCAN train', camcan_train_dataloader),
                         ('Gaussian noise', noise_loader),
                         ('MNIST', mnist_val_dataloader)
                        ]: 
    print(f'{name:15} dataloader: {len(dataloader)} batches (batch_size: {dataloader.batch_size}) -> {len(dataloader) * dataloader.batch_size} samples.')

# Plot Infernce Reconstruction

In [None]:
plot_n_batches = 5

dataloaders = [#brats_val_t2_dataloader, 
                brats_val_t2_hm_dataloader,
                #brats_val_t1_hm_dataloader, 
                #brats_val_t1_dataloader, 
                #noise_loader, 
                #mnist_val_dataloader
                camcan_train_dataloader
               ]

for dataloader in dataloaders:
    print(f'Dataset: {dataloader.dataset.name}')
    batch_generator = yield_inference_batches(dataloader, model, residual_fn=residual_l1, residual_threshold=0.90)
    plot_stacked_scan_reconstruction_batches(batch_generator, plot_n_batches, nrow=8,
                                             cmap='gray', axis='off', figsize=(15, 15),
                                             save_dir_path=DATA_DIR_PATH/'reconstructions')


# Pixel-Wise Anomaly Detection Performance (ROC & PRC)

In [None]:
from uncertify.evaluation.configs import EvaluationConfig, EvaluationResult
from uncertify.evaluation.evaluation_pipeline import OUT_DIR_PATH, PixelAnomalyDetectionResult, SliceAnomalyDetectionResults, OODDetectionResults, print_results

In [None]:
eval_cfg = EvaluationConfig()
eval_cfg.use_n_batches = 20

results = EvaluationResult(OUT_DIR_PATH, eval_cfg, PixelAnomalyDetectionResult(), SliceAnomalyDetectionResults(), OODDetectionResults())
results.make_dirs()
results.pixel_anomaly_result.best_threshold = 0.95

results = run_anomaly_detection_performance(eval_cfg, model, brats_val_t2_hm_dataloader, results)
print_results(results)

## Segmentation Scores

In [None]:
from uncertify.evaluation.model_performance import mean_std_dice_scores, mean_std_iou_scores
from uncertify.visualization.model_performance import plot_segmentation_performance_vs_threshold
try:
    tqdm._instances.clear()
except:
    pass

In [None]:
print(f'Best dice score: {best_dice_score:.2f}+-{std_dice_scores[best_dice_idx]} with threshold {pixel_thresholds[best_dice_idx]}.')

In [None]:
n_thresholds = 10
max_n_batches = 10

pixel_thresholds = np.linspace(0.0, 3.0, n_thresholds)
mean_dice_scores, std_dice_scores = mean_std_dice_scores(brats_val_t2_hm_dataloader, model, residual_thresholds=pixel_thresholds, max_n_batches=max_n_batches)
best_dice_idx, best_dice_score = max(enumerate(mean_dice_scores), key=operator.itemgetter(1))
print(f'Best dice score: {best_dice_score:.2f}+-{std_dice_scores[best_dice_idx]} with threshold {pixel_thresholds[best_dice_idx]}.')

In [None]:
fig = plot_segmentation_performance_vs_threshold(pixel_thresholds, dice_scores=mean_dice_scores, dice_stds=std_dice_scores, iou_scores=None, 
                                                    train_set_threshold=None, figsize=(12, 6));
fig.savefig(DATA_DIR_PATH / 'plots' / 'dice_iou_vs_threshold.png')

# Sample-wise Loss Term Histograms

In [None]:
from sklearn.neighbors import KernelDensity

from uncertify.visualization.histograms import plot_loss_histograms
try:
    tqdm._instances.clear()
except:
    pass

In [None]:
max_n_batches = 1

dataloaders = [camcan_train_dataloader,
               brats_val_t2_dataloader, 
               brats_val_t1_dataloader, 
               mnist_val_dataloader,
               noise_loader
              ]

generator_names = ['CamCAN T2', 
                   'BraTS17 T2',
                   'BraTS17 T1',
                   'MNIST',
                   'Gaussian Noise'
                  ]

output_generators = []
for dataloader, name in zip(dataloaders, generator_names):
    output_generators.append(yield_inference_batches(dataloader, model, max_n_batches, progress_bar_suffix=f'{name}'))

In [None]:
figs_axes = plot_loss_histograms(output_generators=output_generators, names=generator_names, 
                                 figsize=(10, 3), ylabel='Frequency', plot_density=True, show_data_ticks=False, kde_bandwidth=[0.009, 0.009*5.5], show_histograms=False)

for idx, (fig, _) in enumerate(figs_axes):
    save_fig(fig, DATA_DIR_PATH / 'plots' / f'loss_term_distributions_{idx}.png')

In [None]:
from uncertify.visualization.latent_space_analysis import plot_umap_latent_embedding

In [None]:
max_n_batches = 10
redisual_threshold = 1.8

dataloaders = [brats_val_t1_dataloader, 
               brats_val_t2_dataloader, 
               #mnist_val_dataloader,
               #noise_loader,
               camcan_train_dataloader,
]

generator_names = ['BraTS17 T1',
                   'BraTS17 T2',
                   #'MNIST Val',
                   #'Gaussian Noise',
                   'CamCAN Train T2']

output_generators = []
for dataloader, name in zip(dataloaders, generator_names):
    output_generators.append(yield_inference_batches(dataloader, model, max_n_batches, redisual_threshold, progress_bar_suffix=f'{name}'))

umap_fig = plot_umap_latent_embedding(output_generators, generator_names, figsize=(14, 10))
umap_fig.savefig(DATA_DIR_PATH / 'plots' / f'umap_latent_embedding.png')

# Threshold calculation

In [None]:
from uncertify.visualization.threshold_search import plot_fpr_vs_residual_threshold
from uncertify.evaluation.evaluation_pipeline import run_residual_threshold_evaluation, EvaluationResult, PixelAnomalyDetectionResult, SliceAnomalyDetectionResults, OODDetectionResults
from uncertify.evaluation.configs import EvaluationConfig, PixelThresholdSearchConfig
try:
    tqdm._instances.clear()
except:
    pass

In [None]:
eval_cfg = EvaluationConfig()
eval_cfg.use_n_batches = 10
results = EvaluationResult(EVAL_DIR_PATH, eval_cfg, PixelAnomalyDetectionResult(), SliceAnomalyDetectionResults(), OODDetectionResults())
results.make_dirs()

results = run_residual_threshold_evaluation(model, camcan_train_dataloader, eval_cfg, results)

# Plot MNIST reconstructions
Run various MNIST examples (batches consisting of samples of a certain number) through the model and plot input and reconstructions.

In [None]:
plot_n_batches = 1
batch_size = 8
for n in range(0, 10):
    _, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, 
                                                 batch_size=batch_size, 
                                                 transform=torchvision.transforms.Compose([
                                                                        torchvision.transforms.Resize((128, 128)),
                                                                        torchvision.transforms.ToTensor()]),
                                                 mnist_label=n)
    batch_generator = yield_inference_batches(mnist_val_dataloader, model, residual_threshold=1.8)
    plot_stacked_scan_reconstruction_batches(batch_generator, plot_n_batches, 
                                             cmap='hot', axis='off', figsize=(15, 15), save_dir_path=DATA_DIR_PATH/'reconstructions')

# Plot latent space sample reconstructions from different locations in latent space

In [None]:
from uncertify.visualization.latent_space_analysis import plot_latent_samples_from_ring

radii = [(0, 1), (2, 3), (4, 5), (7, 9), (10, 12), (15, 17), (20, 30), (50, 60), (200, 210)]

for sample in radii:
    inner_radius, outer_radius = sample
    fig = plot_latent_samples_from_ring(model, n_samples=16, inner_radius=inner_radius, outer_radius=outer_radius)