In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
numba_logger = logging.getLogger('numba')
numba_logger.setLevel(logging.WARNING)

In [None]:
import itertools
from functools import partial
from pathlib import Path

import torch
import torchvision
from torchvision.transforms.transforms import Compose
from tqdm import tqdm
try:
    tqdm._instances.clear()
except:
    pass
import seaborn as sns
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
plt.rc('font', family='serif')
matplotlib.rcParams.update({'font.size': 18})

from uncertify.models.vae import VariationalAutoEncoder
from uncertify.models.encoder_decoder_baur2020 import BaurDecoder, BaurEncoder
from uncertify.data.dataloaders import dataloader_factory, DatasetType
from uncertify.visualization.reconstruction import plot_stacked_scan_reconstruction_batches
from uncertify.deploy import yield_reconstructed_batches
from uncertify.visualization.grid import imshow_grid
from uncertify.visualization.plotting import setup_plt_figure
from uncertify.evaluation.thresholding import threshold_vs_fpr
from uncertify.algorithms.golden_section_search import golden_section_search
from uncertify.evaluation.thresholding import calculate_fpr_minus_accepted
from uncertify.data.datasets import GaussianNoiseDataset

from uncertify.common import DATA_DIR_PATH

# Load Model and Data

In [None]:
model = VariationalAutoEncoder(BaurEncoder(), BaurDecoder())

In [None]:
CHECKPOINT_PATH = DATA_DIR_PATH / 'models/last.ckpt'
assert CHECKPOINT_PATH.exists(), f'Model checkpoint does not exist!'

checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['state_dict'])

In [None]:
batch_size = 155

HDD_PROCESSED_DIR_PATH = Path('/media/juniors/2TB_internal_HD/datasets/processed/')
SSD_PROCESSED_DIR_PATH = DATA_DIR_PATH / 'processed'

_, brats_val_t2_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, 
                                             val_set_path=SSD_PROCESSED_DIR_PATH / 'brats17_t2_hm_bc_std_bv-3.5.hdf5', shuffle_val=True)
_, brats_val_t1_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, 
                                             val_set_path=SSD_PROCESSED_DIR_PATH / 'brats17_t1_hm_bc_std_bv-3.5.hdf5', shuffle_val=True)
camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=batch_size, 
                                                                    val_set_path=DATA_DIR_PATH / 'processed/camcan_val_t2_hm_std_bv3.5_xe.hdf5', 
                                                                    train_set_path=DATA_DIR_PATH / 'processed/camcan_train_t2_hm_std_bv3.5_xe.hdf5', 
                                                                    shuffle_val=False, shuffle_train=True)
noise_set = GaussianNoiseDataset()
noise_loader = DataLoader(noise_set, batch_size=batch_size)

_, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, batch_size=batch_size, transform=torchvision.transforms.Compose([
                                                                        torchvision.transforms.Resize((128, 128)),
                                                                        torchvision.transforms.ToTensor()
                                                                    ])
                         )

for name, dataloader in [('BraTS T2 val', brats_val_t2_dataloader), 
                         ('BraTS T1 val', brats_val_t1_dataloader), 
                         ('CamCAN train', camcan_train_dataloader),
                         ('Gaussian noise', noise_loader),
                         ('MNIST', mnist_val_dataloader)
                        ]: 
    print(f'{name:15} dataloader: {len(dataloader)} batches (batch_size: {dataloader.batch_size}) -> {len(dataloader) * dataloader.batch_size} samples.')

# Latent samples from different locations

In [None]:
from uncertify.visualization.latent_space_analysis import plot_latent_samples_from_ring

In [None]:
radii = [(0, 1), (2, 3), (4, 5), (50, 70), (200, 210), (240, 250)]

for sample in radii:
    inner_radius, outer_radius = sample
    fig = plot_latent_samples_from_ring(model, n_samples=8, inner_radius=inner_radius, outer_radius=outer_radius)

# Plot Results

In [None]:
# Input MNIST digits

plot_n_batches = 1
batch_size = 8
for n in range(0, 10):
    _, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, batch_size=batch_size, transform=torchvision.transforms.Compose([
                                                                        torchvision.transforms.Resize((128, 128)),
                                                                        torchvision.transforms.ToTensor()
                                                                    ]),
                         mnist_label=n)
    batch_generator = yield_reconstructed_batches(mnist_val_dataloader, model, residual_threshold=0.3)
    plot_stacked_scan_reconstruction_batches(batch_generator, plot_n_batches, 
                                             cmap='hot', axis='off', figsize=(15, 15), save_dir_path=DATA_DIR_PATH/'reconstructions')

In [None]:
plot_n_batches = 1

# CamCAN
batch_generator = yield_reconstructed_batches(camcan_val_dataloader, model, residual_threshold=0.3)
plot_stacked_scan_reconstruction_batches(batch_generator, plot_n_batches, 
                                         cmap='hot', axis='off', figsize=(20, 20), save_dir_path=DATA_DIR_PATH/'reconstructions')

# BraTS
for brats_dataloader in [brats_val_t2_dataloader, brats_val_t1_dataloader, noise_loader, mnist_val_dataloader]:
    batch_generator = yield_reconstructed_batches(brats_dataloader, model, residual_threshold=0.3)
    plot_stacked_scan_reconstruction_batches(batch_generator, plot_n_batches, 
                                             cmap='hot', axis='off', figsize=(20, 20), save_dir_path=DATA_DIR_PATH/'reconstructions')

# Model Performance

## Segmentation Scores

In [None]:
from uncertify.evaluation.model_performance import mean_std_dice_scores, mean_std_iou_scores
from uncertify.visualization.model_performance import plot_segmentation_performance_vs_threshold
try:
    tqdm._instances.clear()
except:
    pass

In [None]:
n_thresholds = 1
max_n_batches = 10

pixel_thresholds = np.linspace(2.6, 2.6, n_thresholds)
mean_dice_scores, std_dice_scores = mean_std_dice_scores(brats_val_t2_dataloader, model, residual_thresholds=pixel_thresholds, max_n_batches=max_n_batches)
# mean_iou_scores, std_iou_scores = mean_std_iou_scores(brats_val_dataloader, model, residual_thresholds=pixel_thresholds, max_n_batches=max_n_batches)

In [None]:
print(pixel_thresholds)
print(mean_dice_scores)
print(std_dice_scores)

In [None]:
fig = plot_segmentation_performance_vs_threshold(pixel_thresholds, dice_scores=mean_dice_scores, dice_stds=std_dice_scores, iou_scores=None, 
                                                    train_set_threshold=None, figsize=(12, 6));
fig.savefig(DATA_DIR_PATH / 'plots' / 'dice_iou_vs_threshold.png')

## Pixel-wise anomaly detection / classification scores

In [None]:
from uncertify.deploy import yield_y_true_y_pred
from uncertify.evaluation.model_performance import calculate_confusion_matrix
from uncertify.visualization.model_performance import plot_confusion_matrix
from scikitplot.metrics import plot_precision_recall, plot_roc
from sklearn.metrics import average_precision_score, roc_auc_score

In [None]:
use_n_batches = 10

confusion_matrix = calculate_confusion_matrix(brats_val_t2_dataloader, model, residual_threshold=0.26, max_n_batches=use_n_batches, normalize=None)

fig, _ = plot_confusion_matrix(confusion_matrix, categories=['normal', 'anomaly'], cbar=False, cmap='YlOrRd_r', figsize=(10, 9))
fig.savefig(DATA_DIR_PATH / 'plots' / 'confusion_matrix.png')

In [None]:
y_true, y_pred = yield_y_true_y_pred(brats_val_t2_dataloader, model, max_n_batches=use_n_batches)

In [None]:
auprc = average_precision_score(y_true, y_pred[:, 1])
%time ax = plot_precision_recall(y_true, y_pred, figsize=(6, 4), classes_to_plot=[1], plot_micro=False, title=f'PR Curve Pixel-wise Anomaly Detection')
plt.savefig(DATA_DIR_PATH / 'plots' / 'precision_recall_curve.png')

In [None]:
auroc = roc_auc_score(y_true, y_pred[:, 1])
%time ax = plot_roc(y_true, y_pred, figsize=(6, 4), plot_micro=False, plot_macro=False, classes_to_plot=[1], title=f'ROC Curve Pixel-wise Anomaly Detection')
plt.savefig(DATA_DIR_PATH / 'plots' / 'roc_curve.png')

# Sample-wise Loss Term Histograms

In [None]:
from sklearn.neighbors import KernelDensity

from uncertify.visualization.histograms import plot_loss_histograms
try:
    tqdm._instances.clear()
except:
    pass

In [None]:
max_n_batches = 30
redisual_threshold = 0.25

dataloaders = [camcan_train_dataloader,
               brats_val_t1_dataloader, 
               brats_val_t2_dataloader, 
               mnist_val_dataloader,
               noise_loader]

generator_names = ['CamCAN Train T2', 
                   'BraTS17 T1',
                   'BraTS17 T2',
                   'MNIST Val',
                   'Gaussian Noise']

output_generators = []
for dataloader, name in zip(dataloaders, generator_names):
    output_generators.append(yield_reconstructed_batches(dataloader, model, max_n_batches, redisual_threshold, progress_bar_suffix=f'{name}'))

In [None]:
figs_axes = plot_loss_histograms(output_generators=output_generators, names=generator_names, 
                                 figsize=(12, 6), ylabel='Normalized Frequency', plot_density=True, show_data_ticks=False, kde_bandwidth=[0.009, 0.009*5.5], show_histograms=False)

for idx, (fig, _) in enumerate(figs_axes):
    fig.savefig(DATA_DIR_PATH / 'plots' / f'loss_term_distributions_{idx}.png')

In [None]:
from uncertify.visualization.latent_space_analysis import plot_umap_latent_embedding

In [None]:
max_n_batches = 10
redisual_threshold = 0.25

dataloaders = [brats_val_t1_dataloader, 
               brats_val_t2_dataloader, 
               mnist_val_dataloader,
               noise_loader,
               camcan_train_dataloader,
]

generator_names = ['BraTS17 T1',
                   'BraTS17 T2',
                   'MNIST Val',
                   'Gaussian Noise',
                   'CamCAN Train T2']

output_generators = []
for dataloader, name in zip(dataloaders, generator_names):
    output_generators.append(yield_reconstructed_batches(dataloader, model, max_n_batches, redisual_threshold, progress_bar_suffix=f'{name}'))

umap_fig = plot_umap_latent_embedding(output_generators, generator_names, figsize=(14, 10))
umap_fig.savefig(DATA_DIR_PATH / 'plots' / f'umap_latent_embedding.png')

# Threshold calculation

In [None]:
%%time
from uncertify.visualization.threshold_search import plot_fpr_vs_residual_threshold
try:
    tqdm._instances.clear()
except:
    pass

In [None]:
pixel_thresholds = np.linspace(0, 1, 4)
thresholds, camcan_false_positive_rates = threshold_vs_fpr(camcan_train_dataloader, model, thresholds=pixel_thresholds,
                                                    use_ground_truth=False, n_batches_per_thresh=10)
thresholds, brats_false_positive_rates = threshold_vs_fpr(brats_val_dataloader, model, thresholds=pixel_thresholds,
                                                    use_ground_truth=False, n_batches_per_thresh=10)

In [None]:
# Calculating the threshold value which secures a certain FPR on the training data
ACCEPTED_FPR = 0.05

objective = partial(calculate_fpr_minus_accepted, 
                    accepted_fpr=ACCEPTED_FPR,
                    data_loader=camcan_train_dataloader, 
                    model=model, 
                    use_ground_truth=False, 
                    n_batches_per_thresh=10)
best_thresholds = golden_section_search(objective, low=0.0, up=1.0, tolerance=0.003)
best_threshold = np.mean(best_thresholds)
print(f'Found threshold value: {best_threshold}')

In [None]:
fig = plot_fpr_vs_residual_threshold(accepted_fpr=ACCEPTED_FPR, 
                                     calculated_threshold=best_threshold, 
                                     thresholds=thresholds, 
                                     fpr_train=camcan_false_positive_rates, 
                                     fpr_val=brats_false_positive_rates)