In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
import itertools
from functools import partial

import torch
import torchvision
from torchvision.transforms.transforms import Compose
import seaborn as sns
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
plt.rc('font', family='serif')
matplotlib.rcParams.update({'font.size': 18})

from uncertify.models.vae import VariationalAutoEncoder
from uncertify.models.encoder_decoder_baur2020 import BaurDecoder, BaurEncoder
from uncertify.data.dataloaders import dataloader_factory, DatasetType
from uncertify.visualization.reconstruction import plot_stacked_scan_reconstruction_batches
from uncertify.deploy import yield_reconstructed_batches
from uncertify.visualization.grid import imshow_grid
from uncertify.visualization.plotting import setup_plt_figure
from uncertify.evaluation.thresholding import threshold_vs_fpr
from uncertify.algorithms.golden_section_search import golden_section_search
from uncertify.evaluation.thresholding import calculate_fpr_minus_accepted
from uncertify.evaluation.model_performance import calculate_mean_dice_score, calculate_mean_dice_scores
from uncertify.common import DATA_DIR_PATH

# Load Model and Data

In [None]:
model = VariationalAutoEncoder(BaurEncoder(), BaurDecoder(), get_batch_fn=lambda batch: batch['scan'])
model_mnist = VariationalAutoEncoder(BaurEncoder(), BaurDecoder(), get_batch_fn=lambda batch: batch[0])

In [None]:
CHECKPOINT_PATH = DATA_DIR_PATH / 'lightning_logs/train_vae/version_1/checkpoints/epoch=261.ckpt'
assert CHECKPOINT_PATH.exists(), f'Model checkpoint does not exist!'

checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['state_dict'])

In [None]:
_, brats_val_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=8, shuffle_val=True)
camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=8, shuffle_train=True, shuffle_val=True)
mnist_train_dataloader, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, batch_size=8, shuffle_train=True, shuffle_val=True,
                                                                 transform=Compose([torchvision.transforms.Resize((128, 128)),
                                                                          torchvision.transforms.ToTensor()]))

# Plot Results

In [None]:
plot_n_batches = 1

batch_generator = yield_reconstructed_batches(camcan_train_dataloader, model, residual_threshold=0.16)
plot_stacked_scan_reconstruction_batches(batch_generator, plot_n_batches, 
                                         cmap='hot', axis='off', figsize=(20, 20), save_dir_path=DATA_DIR_PATH/'reconstructions')

# Model Performance

## Segmentation Scores

In [None]:
from uncertify.evaluation.model_performance import calculate_mean_dice_scores, calculate_mean_iou_scores
from uncertify.visualization.model_performance import plot_segmentation_performance_vs_threshold

In [None]:
performance_calc_thresholds = np.linspace(0, 1, 10)
dice_scores = calculate_mean_dice_scores(brats_val_dataloader, model, residual_thresholds=performance_calc_thresholds,
                                         max_n_batches=5)
iou_scores = calculate_mean_iou_scores(brats_val_dataloader, model, residual_thresholds=performance_calc_thresholds,
                                       max_n_batches=5)

In [None]:
fig, _ = plot_segmentation_performance_vs_threshold(performance_calc_thresholds, dice_scores=dice_scores, iou_scores=iou_scores, 
                                                    train_set_threshold=0.17, figsize=(12, 6));
fig.savefig(DATA_DIR_PATH / 'plots' / 'dice_iou_vs_threshold.png')

## Pixel-wise anomaly detection / classification scores

In [None]:
from uncertify.evaluation.model_performance import calculate_confusion_matrix
from uncertify.visualization.model_performance import plot_confusion_matrix

In [None]:
confusion_matrix = calculate_confusion_matrix(brats_val_dataloader, model, residual_threshold=0.17, max_n_batches=10, normalize=None)
fig, _ = plot_confusion_matrix(confusion_matrix, categories=['normal', 'anomaly'], cbar=False, cmap='YlGn', figsize=(7, 6))
fig.savefig(DATA_DIR_PATH / 'plots' / 'confusion_matrix.png')

In [None]:
from uncertify.deploy import yield_y_true_y_pred
from scikitplot.metrics import plot_precision_recall, plot_roc
from sklearn.metrics import average_precision_score, roc_auc_score

In [None]:
y_true, y_pred = yield_y_true_y_pred(brats_val_dataloader, model, max_n_batches=5)

In [None]:
auprc = average_precision_score(y_true, y_pred[:, 1])
%time ax = plot_precision_recall(y_true, y_pred, figsize=(12, 8), classes_to_plot=[1], plot_micro=False, title=f'PR Curve Pixel-wise Anomaly Detection')
plt.savefig(DATA_DIR_PATH / 'plots' / 'precision_recall_curve.png')

In [None]:
auroc = roc_auc_score(y_true, y_pred[:, 1])
%time ax = plot_roc(y_true, y_pred, figsize=(12, 8), plot_micro=False, plot_macro=False, classes_to_plot=[1], title=f'ROC Curve Pixel-wise Anomaly Detection')
plt.savefig(DATA_DIR_PATH / 'plots' / 'roc_curve.png')

# Sample-wise Loss Term Histograms

In [None]:
from sklearn.neighbors import KernelDensity

from uncertify.visualization.histograms import plot_loss_histograms

In [None]:
max_n_batches = 5

brats_val_generator = yield_reconstructed_batches(brats_val_dataloader, model, residual_threshold=0.16, max_batches=max_n_batches)
camcan_val_generator = yield_reconstructed_batches(camcan_val_dataloader, model, residual_threshold=0.16, max_batches=max_n_batches)
mnist_val_generator = yield_reconstructed_batches(mnist_val_dataloader, model_mnist, residual_threshold=0.16, max_batches=max_n_batches, get_batch_fn=lambda batch: batch[0])
output_generators = [brats_val_generator, camcan_val_generator, mnist_val_generator]
generator_names = ['BraTS17 Val', 'CamCAN Val', 'MNIST Val']
figs_axes = plot_loss_histograms(output_generators=output_generators, names=generator_names, figsize=(12, 4), ylabel='Normalized Frequency', plot_density=True)
for idx, (fig, _) in enumerate(figs_axes):
    fig.savefig(DATA_DIR_PATH / 'plots' / f'loss_term_distributions_{idx}.png')

In [None]:
128*128

In [None]:
for batch in brats_val_dataloader:
    for sample in batch['seg']:
        print(sample.shape)
        print(sample.numel())
        print(sample > 0)
        print(torch.sum(sample > 0))
    break

# Threshold calculation

In [None]:
%%time
thresholds, camcan_false_positive_rates = threshold_vs_fpr(camcan_train_dataloader, model, thresholds=np.linspace(0, 1, 20),
                                                    use_ground_truth=False, n_batches_per_thresh=200)
thresholds, brats_false_positive_rates = threshold_vs_fpr(brats_val_dataloader, model, thresholds=np.linspace(0, 1, 20),
                                                    use_ground_truth=False, n_batches_per_thresh=200)

In [None]:
# Calculating the threshold value which secures a certain FPR on the training data
ACCEPTED_FPR = 0.05

objective = partial(calculate_fpr_minus_accepted, 
                    accepted_fpr=ACCEPTED_FPR,
                    data_loader=camcan_train_dataloader, 
                    model=model, 
                    use_ground_truth=False, 
                    n_batches_per_thresh=10)
tau = golden_section_search(objective, low=0.0, up=1.0, tolerance=0.003)
mean_tau = np.mean(tau)
print(f'Found threshold value: {mean_tau}')

In [None]:
plt.rc('font', family='serif')
matplotlib.rcParams.update({'font.size': 18})

fig, ax = setup_plt_figure(figsize=(16, 8))
ax.plot(thresholds, camcan_false_positive_rates, linewidth=4, linestyle='dashed', alpha=0.5, label='CamCAN Train')
ax.plot(thresholds, brats_false_positive_rates, linewidth=3, linestyle='solid', alpha=0.7, label='BraTS Validation')
ax.set_ylabel(f'False Positive Rate')
ax.set_xlabel(f'Threshold')

normed_diff = [abs(fpr - ACCEPTED_FPR) for fpr in camcan_false_positive_rates]
ax.plot(thresholds, normed_diff, c='green', alpha=0.7, linewidth=3, label='CamCAN FPR - Accepted FPR')
ax.plot(thresholds, [ACCEPTED_FPR] * len(thresholds), linestyle='dotted', linewidth=3, color='grey', label=f'Accepted FPR ({ACCEPTED_FPR:.2f})')
ax.plot([mean_tau, mean_tau], [-0.05, 1], linestyle='dotted', color='green', linewidth=3, label=f'Threshold through Golden Section Search ({mean_tau:.2f})')
ax.legend(frameon=False)
fig.savefig(DATA_DIR_PATH / 'plots' / 'threshold.png')