In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.ERROR)

In [None]:
from pathlib import Path

from torch.utils.data import DataLoader
import torchvision
from tqdm import tqdm
try:
    tqdm._instances.clear()
except:
    pass

from uncertify.io.models import load_ensemble_models, load_vae_baur_model
from uncertify.data.dataloaders import DatasetType
from uncertify.data.dataloaders import dataloader_factory
from uncertify.evaluation.evaluation_pipeline import run_evaluation_pipeline, print_results
from uncertify.evaluation.configs import EvaluationConfig, PerformanceEvaluationConfig, PixelThresholdSearchConfig
from uncertify.data.datasets import GaussianNoiseDataset
from uncertify.data.default_dataloaders import default_dataloader_dict_factory, filter_dataloader_dict, print_dataloader_dict

from uncertify.common import DATA_DIR_PATH, HD_DATA_PATH, HD_MODELS_PATH

In [None]:
# Load (ensemble) models
RUN_VERSIONS = [0, 1, 2, 3, 4]
ensemble_models = load_ensemble_models(HD_MODELS_PATH / 'scheduled_masked_ensembles', [f'model_{idx}.ckpt' for idx in RUN_VERSIONS])

# Load stand-alone model
masked_model = load_vae_baur_model(Path('/mnt/2TB_internal_HD/lightning_logs/schedule_mask/version_3/checkpoints/last.ckpt'))
non_masked_model = load_vae_baur_model(Path('/mnt/2TB_internal_HD/lightning_logs/schedule_mask/version_5/checkpoints/last.ckpt'))

model = non_masked_model
#model = load_vae_baur_model(Path('/mnt/2TB_internal_HD/lightning_logs/beta_test/version_2/checkpoints/last.ckpt'))
#model = load_vae_baur_model(Path('/media/1TB_SSD/lightning_logs/camcan_beta/version_0/checkpoints/last.ckpt'))

In [None]:
dataloader_dict = default_dataloader_dict_factory(batch_size=10, num_workers=12, shuffle_val=True)
# filtered_dataloader_dict = filter_dataloader_dict(dataloader_dict, contains=['val', 'BraTS T2 HM'], exclude=['Flip', 'CamCAN val', 'T1'])
filtered_dataloader_dict = {name: dataloader_dict[name] for name in ['BraTS T2']}  #,
                                                                     #'BraTS T2 HM',
                                                                     #'BraTS T1',
                                                                     #'BraTS T1 HM',
                                                                     #'CamCAN T2 lesion',
                                                                     #'MNIST',
                                                                     #'Gaussian noise']}
print_dataloader_dict(filtered_dataloader_dict)

In [None]:
DO_SEGMENTATION      = False
DO_ANOMALY_DETECTION = False
DO_LOSS_HISTOGRAMS   = False
DO_OOD               = True
RESIDUAL_THRESHOLD   = 0.69
DO_EXAMPLE_IMGS      = False
DO_PLOTS             = True

eval_cfg = EvaluationConfig()
USE_N_BATCHES = 5
TRAIN_LOADER_NAME = 'CamCAN T2'
TRAIN_DATALOADER = dataloader_dict[TRAIN_LOADER_NAME]
eval_cfg.do_plots = DO_PLOTS
eval_cfg.use_n_batches = USE_N_BATCHES
eval_cfg.ood_config.metrics = ('dose', )  # ('waic', 'dose')
eval_cfg.ood_config.dose_statistics = ('rec_err', 'kl_div', 'elbo', 'entropy')  # ('rec_err', 'kl_div', 'elbo', 'entropy')

results = {}
counter = 0
for val_set_name, dataloader  in filtered_dataloader_dict.items():
    LOG.info(f'Running evaluation on {val_set_name}!')
    result = run_evaluation_pipeline(model, 
                                     TRAIN_DATALOADER, 
                                     TRAIN_LOADER_NAME,
                                     dataloader,
                                     val_set_name,
                                     eval_cfg, 
                                     RESIDUAL_THRESHOLD,
                                     run_segmentation=DO_SEGMENTATION, 
                                     run_anomaly_detection=DO_ANOMALY_DETECTION, 
                                     run_loss_histograms=DO_LOSS_HISTOGRAMS,
                                     run_ood_detection=DO_OOD,
                                     do_example_imgs=DO_EXAMPLE_IMGS,
                                     ensemble_models=ensemble_models)
    results[val_set_name] = result
    counter += 1
    if counter == 2:
        pass
    

for name, result in results.items():
    print(f'\n\t{name}')
    print_results(result)      