In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
import itertools
from functools import partial

import torch
import torchvision
import seaborn as sns
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from uncertify.models.vae import VariationalAutoEncoder
from uncertify.models.encoder_decoder_baur2020 import BaurDecoder, BaurEncoder
from uncertify.data.dataloaders import dataloader_factory, DatasetType
from uncertify.visualization.reconstruction import plot_stacked_scan_reconstruction_batches
from uncertify.deploy import yield_reconstructed_batches
from uncertify.visualization.grid import imshow_grid
from uncertify.visualization.plotting import setup_plt_figure
from uncertify.evaluation.thresholding import threshold_vs_fpr
from uncertify.algorithms.golden_section_search import golden_section_search
from uncertify.evaluation.thresholding import calculate_fpr_minus_accepted
from uncertify.common import DATA_DIR_PATH

In [None]:
model = VariationalAutoEncoder(BaurEncoder(), BaurDecoder(), get_batch_fn=lambda batch: batch['scan'])

In [None]:
CHECKPOINT_PATH = DATA_DIR_PATH / 'lightning_logs/train_vae/version_1/checkpoints/epoch=261.ckpt'
assert CHECKPOINT_PATH.exists(), f'Model checkpoint does not exist!'

checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['state_dict'])

In [None]:
_, brats_val_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=8, shuffle_val=True)
camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=8, shuffle_train=True, shuffle_val=True)

In [None]:
plot_n_batches = 1

batch_generator = yield_reconstructed_batches(camcan_train_dataloader, model, residual_threshold=0.16)
plot_stacked_scan_reconstruction_batches(batch_generator, plot_n_batches, 
                                         cmap='hot', axis='off', figsize=(20, 20), save_dir_path=DATA_DIR_PATH/'reconstructions')

In [None]:
thresholds, camcan_false_positive_rates = threshold_vs_fpr(camcan_train_dataloader, model, thresholds=np.linspace(0, 1, 30),
                                                    use_ground_truth=False, n_batches_per_thresh=20)
thresholds, brats_false_positive_rates = threshold_vs_fpr(brats_val_dataloader, model, thresholds=np.linspace(0, 1, 30),
                                                    use_ground_truth=True, n_batches_per_thresh=20)

In [None]:
# Calculating the threshold value which secures a certain FPR on the training data
ACCEPTED_FPR = 0.07

objective = partial(calculate_fpr_minus_accepted, 
                    accepted_fpr=ACCEPTED_FPR,
                    data_loader=camcan_train_dataloader, 
                    model=model, 
                    use_ground_truth=False, 
                    n_batches_per_thresh=20)
tau = golden_section_search(objective, low=0.0, up=1.0, tolerance=0.01)
mean_tau = np.mean(tau)
print(f'Found threshold value: {mean_tau}')

In [None]:
plt.rc('font', family='serif')

matplotlib.rcParams.update({'font.size': 18})
fig, ax = setup_plt_figure(figsize=(16, 8))
ax.plot(thresholds, camcan_false_positive_rates, linewidth=4, linestyle='dashed', alpha=0.5, label='CamCAN Train')
ax.plot(thresholds, brats_false_positive_rates, linewidth=3, linestyle='solid', alpha=0.7, label='BraTS Validation')
ax.set_ylabel(f'False Positive Rate')
ax.set_xlabel(f'Threshold')

normed_diff = [abs(fpr - ACCEPTED_FPR) for fpr in camcan_false_positive_rates]
ax.plot(thresholds, normed_diff, c='green', alpha=0.7, linewidth=3, label='CamCAN FPR - Accepted FPR')
ax.plot(thresholds, [ACCEPTED_FPR] * len(thresholds), linestyle='dotted', linewidth=3, color='grey', label=f'Accepted FPR ({ACCEPTED_FPR:.2f})')
ax.plot([mean_tau, mean_tau], [-0.05, 1], linestyle='dotted', color='green', linewidth=3, label=f'Threshold through Golden Section Search ({mean_tau:.2f})')
ax.legend(frameon=False)
fig.savefig(DATA_DIR_PATH / 'plots' / 'threshold.png')