In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from tqdm import tqdm
try:
    tqdm._instances.clear()
except:
    pass
import seaborn as sns
# sns.set_context("poster")
import numpy as np
import matplotlib.pyplot as plt
plt.style.use({'figure.facecolor':'white'})

from uncertify.utils.python_helpers import print_dict_tree, get_idx_of_closest_value

from uncertify.evaluation.ood_experiments import run_ood_evaluations, run_ood_to_ood_dict
from uncertify.evaluation.anomaly_detection import slice_wise_lesion_detection_dose_kde
from uncertify.evaluation.model_performance import calculate_roc, calculate_prc

from uncertify.visualization.ood_scores import plot_ood_scores, plot_most_least_ood, plot_samples_close_to_score


from uncertify.visualization.ood_scores import plot_ood_scores
from uncertify.visualization.entropy import plot_entropy_samples_over_range
from uncertify.data.default_dataloaders import default_dataloader_dict_factory, filter_dataloader_dict, print_dataloader_dict
from uncertify.io.models import load_ensemble_models, load_vae_baur_model
from uncertify.data.dataloaders import print_dataloader_info

from uncertify.common import DATA_DIR_PATH, HD_DATA_PATH

In [None]:
# Load (ensemble) models
RUN_VERSIONS = [1, 2, 3, 4, 5]
ensemble_models = load_ensemble_models(DATA_DIR_PATH / 'masked_ensemble_models', [f'model{idx}.ckpt' for idx in RUN_VERSIONS])

# Load stand-alone model
model = load_vae_baur_model(HD_DATA_PATH.parent / 'lightning_logs/schedule_mask/version_0/checkpoints/last.ckpt')

In [None]:
dataloader_dict = default_dataloader_dict_factory()
print_dataloader_dict(dataloader_dict)
train_loader_dict = filter_dataloader_dict(dataloader_dict, contains=['CamCAN', 'train'], exclude=['art'])
print_dataloader_dict(train_loader_dict)
brats_t2_dict = filter_dataloader_dict(dataloader_dict, contains=['BraTS', 'T2'], exclude=['Flip'])
print_dataloader_dict(brats_t2_dict)

# OOD Detection Evaluation for different OOD datasets

In [None]:
# Runs full OOD evaluation (stores output) all datasets and metrics in evaluation config

run_ood_evaluations(dataloader_dict['CamCAN train'],
                    brats_t2_dict,
                    [model],
                    residual_threshold=0.65,
                    max_n_batches=2)

# OOD Score

In [None]:
NUM_BACTHES = 15
OOD_METRICS = ('dose', )  # 'waic', 
DOSE_STATISTICS = ('entropy', 'elbo', 'kl_div', 'rec_err') # can use ('entropy', 'elbo', 'kl_div', 'rec_err')

metrics_ood_dict = run_ood_to_ood_dict(test_dataloader_dict=brats_t2_dict,
                                       ensemble_models=[model],
                                       train_dataloader=train_loader_dict['CamCAN train'],
                                       num_batches=NUM_BACTHES,
                                       ood_metrics=OOD_METRICS,
                                       dose_statistics=DOSE_STATISTICS)
print_dict_tree(metrics_ood_dict)

In [None]:
print(max(metrics_ood_dict['dose']['BraTS T2 val']['lesional']))
print(min(metrics_ood_dict['dose']['BraTS T2 val']['lesional']))

print(max(metrics_ood_dict['dose']['BraTS T2 val']['dose_kde_lesional']['entropy']))
print(min(metrics_ood_dict['dose']['BraTS T2 val']['dose_kde_lesional']['entropy']))

In [None]:
plot_entropy_samples_over_range(metrics_ood_dict, 'BraTS T2 val', start_val=0, end_val=6, n_values=64,
                                axis='off', add_colorbar=False, figsize=(12, 12), cmap='gray')

In [None]:
# Check for different DoSE KDE stats whether they can distinguish between healthy and unhealthy
slice_wise_lesion_detection_dose_kde(brats_t2_dict, metrics_ood_dict, DOSE_STATISTICS, do_plots=True)

In [None]:
for metric in OOD_METRICS:
    n = 8
    ood_dict = metrics_ood_dict[metric]
    plot_ood_scores(ood_dict, score_label=metric, dataset_name_filters=[], modes_to_include=['healthy', 'lesional'])

    #plot_most_least_ood(ood_dict, 'BraTS T2', n_most=n)
    #plot_most_least_ood(ood_dict, 'CamCAN train', do_lesional=False, n_most=n)

    #plot_samples_close_to_score(ood_dict, 'BraTS T2', min_score=0, max_score=25, n=n, show_ground_truth=False, print_score=True)
    #plot_samples_close_to_score(ood_dict, 'CamCAN train', do_lesional=False, show_ground_truth=False, min_score=0, max_score=25, n=n, print_score=True)


# Slice-Wise Statistics

In [None]:
import pandas as pd
import seaborn as sns

from uncertify.evaluation.dose import full_pipeline_slice_wise_dose_scores
from uncertify.evaluation.statistics import fit_statistics, aggregate_slice_wise_statistics
from uncertify.visualization.dose import do_pair_plot_statistics

In [None]:
# Plot the fitted training statistics as a lower triangular pair-plot
DOSE_STATISTICS = ['elbo', 'rec_err', 'kl_div', 'entropy']
MAX_N_BATCHES = 10


stat_dataloader_dict = {'BraTS T2': brats_val_t2_dataloader,
                   #'BraTS T1': brats_val_t1_dataloader,
                   #'BraTS T2 HM': brats_val_t2_hm_dataloader,
                   #'BraTS T1 HM': brats_val_t1_hm_dataloader,
                   'CamCAN train': camcan_train_dataloader,
                   #'Gaussian noise': noise_loader,
                   #'MNIST': mnist_val_dataloader
}

stat_frames = {}

for name, dataloader in stat_dataloader_dict.items():
    hue = 'is_lesional' if dataloader is not camcan_train_dataloader else None
    statistics_dict = aggregate_slice_wise_statistics(model, dataloader,
                                                      DOSE_STATISTICS, max_n_batches=MAX_N_BATCHES)
    stat_df = pd.DataFrame(statistics_dict)
    do_pair_plot_statistics(statistics_dict, DOSE_STATISTICS, dataloader.dataset.name, hue=hue)
    stat_frames[name] = stat_df

    plt.figure()
    sns.histplot(stat_df, x='entropy', hue=hue, kde=True)
    ax = plt.gca()
    ax.set_xlim([0, 0.16])

In [None]:
for name, stat_df in stat_frames.items():
    plt.figure()
    sns.histplot(stat_df, x='entropy', hue='is_lesional' if 'Cam' not in name else None, kde=True)
    ax = plt.gca()
    ax.set_xlim([0, 0.26])

In [None]:
from uncertify.utils.python_helpers import get_idx_of_closest_value

# Plot lesional and healthy samples which are close to some score (entropy)

ref_values = np.linspace(0.05, 0.25, 40)

brats_stats = stat_frames['BraTS T2']
is_lesional = brats_stats['is_lesional']
lesional_scans = list(brats_stats['scans'][is_lesional])
healthy_scans = list(brats_stats['scans'][np.invert(is_lesional)])
lesional_entropy = list(brats_stats['entropy'][is_lesional])
healthy_entropy = list(brats_stats['entropy'][np.invert(is_lesional)])


for ref_val in ref_values:
    healthy_ids = get_idx_of_closest_value(healthy_entropy, ref_val)
    lesional_ids = get_idx_of_closest_value(lesional_entropy, ref_val)
    
    plt.subplots(figsize=(2, 2))
    plt.imshow(lesional_scans[lesional_ids][0].numpy())
    plt.title(f'[lesional]\n entropy={lesional_entropy[lesional_ids]:.3f}')
    plt.axis('off')
    plt.show()
    
    plt.subplots(figsize=(2, 2))
    plt.imshow(healthy_scans[healthy_ids][0].numpy())
    plt.title(f'[healthy]\n entropy={healthy_entropy[healthy_ids]:.3f}')
    plt.axis('off')
    plt.show()

In [None]:
for stat_df, hue in zip(stat_frames.values(), ['is_lesional', None]):
    plt.figure()
    sns.histplot(stat_df, x='entropy', hue=hue, kde=True)
    ax = plt.gca()
    ax.set_xlim([0, 0.26])

In [None]:
# Plot the DoSE KDE scores on a fitted training distribution
MAX_N_BATCHES = 5
test_dataloaders = [camcan_train_dataloader, brats_val_t2_dataloader] #[brats_val_t2_dataloader, brats_val_t1_dataloader, camcan_train_dataloader]
has_lesional_data = [False, True] # [True, True, False]

kde_func_dict = fit_statistics(statistics_dict)
for dataloader, has_lesions in zip(test_dataloaders, has_lesional_data):
    dose_scores, dose_kde_dict = full_pipeline_slice_wise_dose_scores(camcan_train_dataloader, dataloader, model,
                                                                      DOSE_STATISTICS, MAX_N_BATCHES, kde_func_dict)
    do_pair_plot_statistics(dose_kde_dict, DOSE_STATISTICS, dataloader.dataset.name, 'is_lesional' if has_lesions else None)

# Plotting Final Dose Statistics

In [None]:
# Fitting statistics on training data
statistics_dict = aggregate_slice_wise_statistics(model, camcan_train_dataloader, DOSE_STATISTICS, max_n_batches=50)
kde_func_dict = fit_statistics(statistics_dict)

# Computing dose scores on OOD dataset
dose_scores, dose_kde_dict = full_pipeline_slice_wise_dose_scores(camcan_train_dataloader, brats_val_t2_dataloader, model, 
                                                                    DOSE_STATISTICS, MAX_N_BATCHES, kde_func_dict)
final_dose_df = pd.DataFrame({'DoSE': dose_scores, 'is_lesional': dose_kde_dict['is_lesional']})
sns.kdeplot(final_dose_df.DoSE, hue=final_dose_df.is_lesional)

# Computing dose scores on ID dataset
dose_scores, dose_kde_dict = full_pipeline_slice_wise_dose_scores(camcan_train_dataloader, camcan_train_dataloader, model, 
                                                                    DOSE_STATISTICS, 3, kde_func_dict)
final_dose_df = pd.DataFrame({'DoSE': dose_scores, 'is_lesional': dose_kde_dict['is_lesional']})
sns.kdeplot(final_dose_df.DoSE, hue=final_dose_df.is_lesional)