In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.ERROR)

In [None]:
from pathlib import Path

from tqdm import tqdm
try:
    tqdm._instances.clear()
except:
    pass
import seaborn as sns
# sns.set_context("poster")
import numpy as np
import matplotlib.pyplot as plt
plt.style.use({'figure.facecolor':'white'})

import torch

from uncertify.utils.python_helpers import print_dict_tree, get_idx_of_closest_value
from uncertify.evaluation.ood_experiments import run_ood_evaluations, run_ood_to_ood_dict
from uncertify.evaluation.anomaly_detection import slice_wise_lesion_detection_dose_kde, slice_wise_lesion_detection_waic
from uncertify.evaluation.model_performance import calculate_roc, calculate_prc
from uncertify.visualization.ood_scores import plot_ood_samples_over_range
from uncertify.visualization.ood_scores import plot_ood_scores, plot_most_least_ood, plot_samples_close_to_score
from uncertify.data.default_dataloaders import default_dataloader_dict_factory, filter_dataloader_dict, print_dataloader_dict
from uncertify.io.models import load_ensemble_models, load_vae_baur_model
from uncertify.data.dataloaders import print_dataloader_info
from uncertify.io.models import load_ensemble_models, load_vae_baur_model

from uncertify.common import DATA_DIR_PATH, HD_DATA_PATH, HD_MODELS_PATH

In [None]:
# Load (ensemble) models
RUN_VERSIONS = [0, 1, 2, 3, 4]
ensemble_models = load_ensemble_models(HD_MODELS_PATH / 'scheduled_masked_ensembles', [f'model_{idx}.ckpt' for idx in RUN_VERSIONS])
# Load stand-alone model
masked_model = load_vae_baur_model(Path('/mnt/2TB_internal_HD/lightning_logs/schedule_mask/version_3/checkpoints/last.ckpt'))
model = masked_model
#non_masked_model = load_vae_baur_model(Path('/mnt/2TB_internal_HD/lightning_logs/schedule_mask/version_5/checkpoints/last.ckpt'))
#model = load_vae_baur_model(Path('/mnt/2TB_internal_HD/lightning_logs/beta_test/version_2/checkpoints/last.ckpt'))

In [None]:
dataloader_dict = default_dataloader_dict_factory(batch_size=128, num_workers=0, shuffle_val=True)
train_loader_dict = filter_dataloader_dict(dataloader_dict, contains=['CamCAN'], exclude=['art'])

# OOD Detection Evaluation for different OOD datasets

In [None]:
brats_t2_dict = filter_dataloader_dict(dataloader_dict, contains=['BraTS'], exclude=['Flip', 'HM', 'T1'])
print_dataloader_dict(brats_t2_dict)

In [None]:
# Runs full OOD evaluation (stores output) all datasets and metrics in evaluation config
run_ood_evaluations(dataloader_dict['CamCAN train'],
                    brats_t2_dict,
                    [model],
                    residual_threshold=0.65,
                    max_n_batches=10)

# OOD Score

In [None]:
NUM_BACTHES = 20
OOD_METRICS = ('waic', )  # ('dose', 'waic') as a tu
DOSE_STATISTICS = ('entropy', 'rec_err', 'kl_div', 'elbo') # can use ('entropy', 'elbo', 'kl_div', 'rec_err')

# brats_t2_dict = filter_dataloader_dict(dataloader_dict, contains=['BraTS', 'T2'], exclude=['HM', 'Flip'])

ood_dataloader_dict = {name: dataloader_dict[name] for name in ['CamCAN T2', 'BraTS T2', 'CamCAN T2 lesion', 'BraTS T2 HM', 'BraTS T2 HFlip', 'BraTS T2 VFlip']}
ood_dataloader_dict = {name: dataloader_dict[name] for name in ['CamCAN T2', 'BraTS T2', ]}


print_dataloader_dict(ood_dataloader_dict)
metrics_ood_dict = run_ood_to_ood_dict(test_dataloader_dict=ood_dataloader_dict,
                                       ensemble_models=ensemble_models,
                                       train_dataloader=train_loader_dict['CamCAN T2'],
                                       num_batches=NUM_BACTHES,
                                       ood_metrics=OOD_METRICS,
                                       dose_statistics=DOSE_STATISTICS)
print_dict_tree(metrics_ood_dict)

In [None]:
plot_ood_samples_over_range(metrics_ood_dict, 'BraTS T2', mode='waic', stat_type='entropy', start_val=0.1, end_val=0.8, n_values=16,
                                axis='off', add_colorbar=False, figsize=(12, 12), cmap='gray', nrow=16)

In [None]:
# Check for different DoSE KDE stats whether they can distinguish between healthy and unhealthy
print('DoSE Lesion detection')
slice_wise_lesion_detection_dose_kde(ood_dataloader_dict, 'CamCAN T2', metrics_ood_dict, DOSE_STATISTICS, predict_mode='kde',
                                     do_plots=True, show_title=False, show_legend=True)

In [None]:
print('WAIC Lesion detection')
slice_wise_lesion_detection_waic(ood_dataloader_dict, 'CamCAN T2', metrics_ood_dict,
                                     do_plots=True, show_title=False, show_legend=True)

In [None]:
for metric in OOD_METRICS:
    n = 16
    ood_dict = metrics_ood_dict[metric]
    plot_ood_scores(ood_dict, score_label=metric, dataset_name_filters=[], modes_to_include=['healthy', 'lesional'])

    plot_most_least_ood(ood_dict, 'BraTS T2 val', n_most=n)
    #plot_most_least_ood(ood_dict, 'CamCAN train', do_lesional=False, n_most=n)

    plot_samples_close_to_score(ood_dict, 'BraTS T2 val', min_score=0, max_score=25, n=n, show_ground_truth=False, print_score=True)
    #plot_samples_close_to_score(ood_dict, 'CamCAN train', do_lesional=False, show_ground_truth=False, min_score=0, max_score=25, n=n, print_score=True)


# Slice-Wise Statistics

In [None]:
import pandas as pd
import seaborn as sns

from uncertify.evaluation.dose import full_pipeline_slice_wise_dose_scores
from uncertify.evaluation.statistics import fit_statistics, aggregate_slice_wise_statistics
from uncertify.visualization.dose import do_pair_plot_statistics

In [None]:
# Plot the fitted training statistics as a lower triangular pair-plot
DOSE_STATISTICS = ['elbo', 'rec_err', 'kl_div', 'entropy']
MAX_N_BATCHES = 10


stat_dataloader_dict = {name: dataloader_dict[name] for name in ['BraTS T2 val', 'CamCAN train']}
stat_frames = {}

for name, dataloader in stat_dataloader_dict.items():
    hue = 'is_lesional' if dataloader is not dataloader_dict['CamCAN train'] else None
    statistics_dict = aggregate_slice_wise_statistics(model, dataloader,
                                                      DOSE_STATISTICS, max_n_batches=MAX_N_BATCHES)
    stat_df = pd.DataFrame(statistics_dict)
    do_pair_plot_statistics(statistics_dict, DOSE_STATISTICS, dataloader.dataset.name, hue=hue)
    stat_frames[name] = stat_df

    plt.figure()
    sns.histplot(stat_df, x='entropy', hue=hue, kde=True)
    ax = plt.gca()
    ax.set_xlim([0, 0.16])

In [None]:
for name, stat_df in stat_frames.items():
    plt.figure()
    sns.histplot(stat_df, x='rec_err', hue='is_lesional' if 'Cam' not in name else None, kde=True)
    ax = plt.gca()
    #ax.set_xlim([0.7, 1.0])

In [None]:
from uncertify.utils.python_helpers import get_idx_of_closest_value

# Plot lesional and healthy samples which are close to some score (entropy)

ref_values = np.linspace(0.05, 0.25, 40)

brats_stats = stat_frames['BraTS T2 val']
is_lesional = brats_stats['is_lesional']
lesional_scans = list(brats_stats['scans'][is_lesional])
healthy_scans = list(brats_stats['scans'][np.invert(is_lesional)])
lesional_entropy = list(brats_stats['entropy'][is_lesional])
healthy_entropy = list(brats_stats['entropy'][np.invert(is_lesional)])


for ref_val in ref_values:
    healthy_ids = get_idx_of_closest_value(healthy_entropy, ref_val)
    lesional_ids = get_idx_of_closest_value(lesional_entropy, ref_val)
    
    plt.subplots(figsize=(2, 2))
    plt.imshow(lesional_scans[lesional_ids][0].numpy())
    plt.title(f'[lesional]\n entropy={lesional_entropy[lesional_ids]:.3f}')
    plt.axis('off')
    plt.show()
    
    plt.subplots(figsize=(2, 2))
    plt.imshow(healthy_scans[healthy_ids][0].numpy())
    plt.title(f'[healthy]\n entropy={healthy_entropy[healthy_ids]:.3f}')
    plt.axis('off')
    plt.show()

In [None]:
for stat_df, hue in zip(stat_frames.values(), ['is_lesional', None]):
    plt.figure()
    sns.histplot(stat_df, x='entropy', hue=hue, kde=True)
    ax = plt.gca()
    ax.set_xlim([0, 0.26])

In [None]:
# Plot the DoSE KDE scores on a fitted training distribution
MAX_N_BATCHES = 5
test_dataloaders = [camcan_train_dataloader, brats_val_t2_dataloader] #[brats_val_t2_dataloader, brats_val_t1_dataloader, camcan_train_dataloader]
has_lesional_data = [False, True] # [True, True, False]

kde_func_dict = fit_statistics(statistics_dict)
for dataloader, has_lesions in zip(test_dataloaders, has_lesional_data):
    dose_scores, dose_kde_dict = full_pipeline_slice_wise_dose_scores(camcan_train_dataloader, dataloader, model,
                                                                      DOSE_STATISTICS, MAX_N_BATCHES, kde_func_dict)
    do_pair_plot_statistics(dose_kde_dict, DOSE_STATISTICS, dataloader.dataset.name, 'is_lesional' if has_lesions else None)

# Plotting Final Dose Statistics

In [None]:
# Fitting statistics on training data
statistics_dict = aggregate_slice_wise_statistics(model, camcan_train_dataloader, DOSE_STATISTICS, max_n_batches=50)
kde_func_dict = fit_statistics(statistics_dict)

# Computing dose scores on OOD dataset
dose_scores, dose_kde_dict = full_pipeline_slice_wise_dose_scores(camcan_train_dataloader, brats_val_t2_dataloader, model, 
                                                                    DOSE_STATISTICS, MAX_N_BATCHES, kde_func_dict)
final_dose_df = pd.DataFrame({'DoSE': dose_scores, 'is_lesional': dose_kde_dict['is_lesional']})
sns.kdeplot(final_dose_df.DoSE, hue=final_dose_df.is_lesional)

# Computing dose scores on ID dataset
dose_scores, dose_kde_dict = full_pipeline_slice_wise_dose_scores(camcan_train_dataloader, camcan_train_dataloader, model, 
                                                                    DOSE_STATISTICS, 3, kde_func_dict)
final_dose_df = pd.DataFrame({'DoSE': dose_scores, 'is_lesional': dose_kde_dict['is_lesional']})
sns.kdeplot(final_dose_df.DoSE, hue=final_dose_df.is_lesional)

# Plot Entropy Values over Range

In [None]:
from uncertify.visualization.entropy_experiments import plot_image_and_entropy, plot_images_and_entropy
from uncertify.evaluation.inference import yield_inference_batches, residual_l1, residual_l1_max
from uncertify.visualization.entropy_experiments import plot_image_and_entropy
from uncertify.evaluation.statistics import rec_error_entropy_batch_stat

for batch in yield_inference_batches(dataloader_dict['BraTS T2 HM val'], model, residual_fn=residual_l1_max, residual_threshold=0.65, manual_seed_val=15,
                                    max_batches=3):
    entropy_batch = rec_error_entropy_batch_stat(batch)
    plot_images_and_entropy(batch.residual.squeeze(1), entropy_array=entropy_batch, figsize=(13, 13))