In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.ERROR)
numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path
import operator

import torch
import torchvision
from tqdm import tqdm
try:
    tqdm._instances.clear()
except:
    pass
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from uncertify.data.dataloaders import dataloader_factory, DatasetType
from uncertify.visualization.reconstruction import plot_stacked_scan_reconstruction_batches
from uncertify.evaluation.inference import yield_inference_batches, yield_anomaly_predictions
from uncertify.evaluation.utils import residual_l1, residual_l1_max
from uncertify.visualization.plotting import save_fig
from uncertify.data.datasets import GaussianNoiseDataset
from uncertify.data.default_dataloaders import default_dataloader_dict_factory, filter_dataloader_dict
from uncertify.io.models import load_ensemble_models, load_vae_baur_model
from uncertify.evaluation.evaluation_pipeline import run_anomaly_detection_performance

from uncertify.common import DATA_DIR_PATH, HD_DATA_PATH, HD_MODELS_PATH

# Load Model and Data

In [None]:
# Load (ensemble) models
RUN_VERSIONS = [0, 1, 2, 3, 4]
ensemble_models = load_ensemble_models(HD_MODELS_PATH / 'scheduled_masked_ensembles', [f'model_{idx}.ckpt' for idx in RUN_VERSIONS])

In [None]:
"""
model_dir_path = HD_MODELS_PATH / 'scheduled_masked_ensembles'
file_names = [f'model_{idx}.ckpt' for idx in [0, 1, 2, 3, 4]]

ensembles = load_ensemble_models(dir_path=model_dir_path, file_names=file_names)
"""

In [None]:
model = load_vae_baur_model(Path('/mnt/2TB_internal_HD/lightning_logs/beta_test/version_2/checkpoints/last.ckpt'))
model = load_vae_baur_model(Path('/mnt/1TB_SSD/lightning_logs/beta_test/version_2/checkpoints/last.ckpt'))

In [None]:
dataloader_dict = default_dataloader_dict_factory(batch_size=128, num_workers=12, shuffle_val=False)

# Plot Infernce Reconstruction

In [None]:
plot_n_batches = 1

#plot_dataloader_dict = filter_dataloader_dict(dataloader_dict, contains=['BraTS'], exclude=[])
plot_dataloader_dict = {name: dataloader_dict[name] for name in ['CamCAN train', 'BraTS T2 val']}

for dataloader_name, dataloader in plot_dataloader_dict.items():
    print(f'Loader {dataloader_name}, Dataset: {dataloader.dataset.name}')
    batch_generator = yield_inference_batches(dataloader, model, residual_fn=residual_l1_max, residual_threshold=0.60,
                                              manual_seed_val=None)
    plot_stacked_scan_reconstruction_batches(batch_generator, plot_n_batches, nrow=32,
                                             cmap='gray', axis='off', figsize=(15, 15), mask_background=False,
                                             save_dir_path=None)


# Pixel-Wise Anomaly Detection Performance (ROC & PRC)

In [None]:
from uncertify.evaluation.configs import EvaluationConfig, EvaluationResult
from uncertify.evaluation.evaluation_pipeline import OUT_DIR_PATH, PixelAnomalyDetectionResult, SliceAnomalyDetectionResults, OODDetectionResults, print_results

In [None]:
eval_cfg = EvaluationConfig()
eval_cfg.use_n_batches = 20

eval_dataloader = dataloader_dict['BraTS T2 HM val']

results = EvaluationResult(OUT_DIR_PATH, eval_cfg, PixelAnomalyDetectionResult(), SliceAnomalyDetectionResults(), OODDetectionResults())
results.make_dirs()
results.pixel_anomaly_result.best_threshold = 0.65

results = run_anomaly_detection_performance(eval_cfg, model, eval_dataloader, results)
print_results(results)

# Segmentation Scores

In [None]:
from uncertify.evaluation.model_performance import mean_std_dice_scores, mean_std_iou_scores
from uncertify.visualization.model_performance import plot_segmentation_performance_vs_threshold
try:
    tqdm._instances.clear()
except:
    pass

## Only run with one pre-defined threshold

In [None]:
max_n_batches = 20
residual_threshold = 0.67

eval_dataloader = dataloader_dict['BraTS T2 val']
best_mean_dice_score, best_std_dice_score = mean_std_dice_scores(eval_dataloader, 
                                                                 model,
                                                                 [residual_threshold],
                                                                 max_n_batches)
LOG.info(f'Dice score (t={residual_threshold:.2f}) for {eval_dataloader.dataset.name}: '
         f'{best_mean_dice_score[0]:.2f} +- {best_std_dice_score[0]:.2f}')

## Check over multiple thresholds

In [None]:
n_thresholds = 8
max_n_batches = 10

pixel_thresholds = np.linspace(0.4, 1.4, n_thresholds)
eval_dataloader = dataloader_dict['BraTS T2 HM val']
mean_dice_scores, std_dice_scores = mean_std_dice_scores(eval_dataloader, model, residual_thresholds=pixel_thresholds, max_n_batches=max_n_batches)
best_dice_idx, best_dice_score = max(enumerate(mean_dice_scores), key=operator.itemgetter(1))
print(f'Best dice score: {best_dice_score:.2f}+-{std_dice_scores[best_dice_idx]} with threshold {pixel_thresholds[best_dice_idx]}.')

In [None]:
fig = plot_segmentation_performance_vs_threshold(pixel_thresholds, dice_scores=mean_dice_scores, dice_stds=std_dice_scores, iou_scores=None, 
                                                    train_set_threshold=None, figsize=(12, 6));
fig.savefig(DATA_DIR_PATH / 'plots' / 'dice_iou_vs_threshold.png')

# Sample-wise Loss Term Histograms

In [None]:
from sklearn.neighbors import KernelDensity

from uncertify.visualization.histograms import plot_loss_histograms
try:
    tqdm._instances.clear()
except:
    pass

In [None]:
max_n_batches = 10
redisual_threshold = 0.67

select_dataloaders = ['CamCAN train', 'BraTS T2 val']

output_generators = []
for dataloader_name in select_dataloaders:
    dataloader = dataloader_dict[dataloader_name]
    output_generators.append(yield_inference_batches(dataloader, model, max_n_batches, redisual_threshold, progress_bar_suffix=f'{dataloader_name}',
                                                    manual_seed_val=10))


In [None]:
figs_axes = plot_loss_histograms(output_generators=output_generators, names=select_dataloaders, 
                                 figsize=(17, 3.5), ylabel='Frequency', plot_density=True, show_data_ticks=False, 
                                 kde_bandwidth=[0.009, 0.009*5.5], show_histograms=False)

for idx, (fig, _) in enumerate(figs_axes):
    save_fig(fig, DATA_DIR_PATH / 'plots' / f'loss_term_distributions_{idx}.png')

# Threshold calculation

In [None]:
from uncertify.visualization.threshold_search import plot_fpr_vs_residual_threshold
from uncertify.evaluation.evaluation_pipeline import run_residual_threshold_evaluation, EvaluationResult, PixelAnomalyDetectionResult, SliceAnomalyDetectionResults, OODDetectionResults
from uncertify.evaluation.configs import EvaluationConfig, PixelThresholdSearchConfig
from uncertify.evaluation.evaluation_pipeline import OUT_DIR_PATH
try:
    tqdm._instances.clear()
except:
    pass

In [None]:
eval_cfg = EvaluationConfig()
eval_cfg.use_n_batches = 10
results = EvaluationResult(OUT_DIR_PATH, eval_cfg, PixelAnomalyDetectionResult(), SliceAnomalyDetectionResults(), OODDetectionResults())
results.make_dirs()

results = run_residual_threshold_evaluation(model, dataloader_dict['CamCAN train'], eval_cfg, results)

# Plot MNIST reconstructions
Run various MNIST examples (batches consisting of samples of a certain number) through the model and plot input and reconstructions.

In [None]:
plot_n_batches = 1
batch_size = 8
for n in range(0, 10):
    _, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, 
                                                 batch_size=batch_size, 
                                                 transform=torchvision.transforms.Compose([
                                                                        torchvision.transforms.Resize((128, 128)),
                                                                        torchvision.transforms.ToTensor()]),
                                                 mnist_label=n)
    batch_generator = yield_inference_batches(mnist_val_dataloader, model, residual_threshold=1.8)
    plot_stacked_scan_reconstruction_batches(batch_generator, plot_n_batches, 
                                             cmap='hot', axis='off', figsize=(15, 15), save_dir_path=DATA_DIR_PATH/'reconstructions')