In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.ERROR)

In [None]:
from pathlib import Path

from torch.utils.data import DataLoader
import torchvision

from uncertify.models.vae import load_vae_baur_model
from uncertify.data.dataloaders import DatasetType
from uncertify.data.dataloaders import dataloader_factory
from uncertify.evaluation.evaluation_pipeline import run_evaluation_pipeline, print_results
from uncertify.evaluation.configs import EvaluationConfig, PerformanceEvaluationConfig, PixelThresholdSearchConfig
from uncertify.data.datasets import GaussianNoiseDataset
from uncertify.common import DATA_DIR_PATH

In [None]:
# Define some paths and high level parameters
CHECKPOINT_PATH = Path('/media/juniors/2TB_internal_HD/lightning_logs/train_vae/version_3/checkpoints/last.ckpt')
SSD_PROCESSED_DIR_PATH = Path('/home/juniors/code/uncertify/data/processed/') 
HDD_PROCESSED_DIR_PATH = Path('/media/juniors/2TB_internal_HD/datasets/processed/')
PROCESSED_DIR_PATH = SSD_PROCESSED_DIR_PATH

BATCH_SIZE = 155
USE_N_BATCHES = 100

In [None]:
# Load the model and define the evaluation config
model = load_vae_baur_model(CHECKPOINT_PATH)
eval_cfg = EvaluationConfig()
eval_cfg.do_plots = False
eval_cfg.use_n_batches = USE_N_BATCHES

PROCESSED_DIR_PATH = Path('/media/juniors/2TB_internal_HD/datasets/processed/')

brats_t2_path    = PROCESSED_DIR_PATH / 'brats17_t2_bc_std_bv3.5_l10.hdf5'
brats_t2_hm_path = PROCESSED_DIR_PATH / 'brats17_t2_hm_bc_std_bv-3.5.hdf5'
brats_t1_path    = PROCESSED_DIR_PATH / 'brats17_t1_bc_std_bv3.5_l10.hdf5'
brats_t1_hm_path = PROCESSED_DIR_PATH / 'brats17_t1_hm_bc_std_bv-3.5.hdf5'
camcan_t2_val_path   = DATA_DIR_PATH  / 'processed/camcan_val_t2_hm_std_bv3.5_xe.hdf5'
camcan_t2_train_path = DATA_DIR_PATH  / 'processed/camcan_train_t2_hm_std_bv3.5_xe.hdf5'

_, brats_val_t2_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=BATCH_SIZE, val_set_path=brats_t2_path, shuffle_val=False)
_, brats_val_t1_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=BATCH_SIZE, val_set_path=brats_t1_path, shuffle_val=False)
_, brats_val_t2_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=BATCH_SIZE, val_set_path=brats_t2_hm_path, shuffle_val=False)
_, brats_val_t1_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=BATCH_SIZE, val_set_path=brats_t1_hm_path, shuffle_val=False)

camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=BATCH_SIZE, val_set_path=camcan_t2_val_path, train_set_path=camcan_t2_train_path, shuffle_val=False, shuffle_train=True)

noise_set = GaussianNoiseDataset()
noise_loader = DataLoader(noise_set, batch_size=BATCH_SIZE)

_, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, batch_size=BATCH_SIZE, transform=torchvision.transforms.Compose([
                                                                        torchvision.transforms.Resize((128, 128)),
                                                                        torchvision.transforms.ToTensor()]))

for name, dataloader in [('BraTS T2 val', brats_val_t2_dataloader), 
                         ('BraTS T1 val', brats_val_t1_dataloader), 
                         ('BraTS T2 HM val', brats_val_t2_hm_dataloader), 
                         ('BraTS T1 HM val', brats_val_t1_hm_dataloader),
                         ('CamCAN train', camcan_train_dataloader),
                         ('Gaussian noise', noise_loader),
                         ('MNIST', mnist_val_dataloader)
                        ]: 
    print(f'{name:15} dataloader: {len(dataloader)} batches (batch_size: {dataloader.batch_size}) -> {len(dataloader) * dataloader.batch_size} samples.')

In [None]:
DO_SEGMENTATION      = True
DO_ANOMALY_DETECTION = True
DO_HISTOGRAMS        = False
DO_OOD               = False
RESIDUAL_THRESHOLD  = 1.35

results = {}

counter = 0
for dataloader, name in zip([brats_val_t2_hm_dataloader, brats_val_t1_hm_dataloader, brats_val_t2_dataloader, brats_val_t1_dataloader],
                            ['BraTS T2 HM', 'BraTS T1 HM', 'BraTS T2', 'BraTS T1']):
    LOG.info(f'Running evaluation on {name}!')
    result = run_evaluation_pipeline(model, 
                                    camcan_train_dataloader, 
                                    dataloader, 
                                    eval_cfg, 
                                    RESIDUAL_THRESHOLD,
                                    run_segmentation=DO_SEGMENTATION, 
                                    run_anomaly_detection=DO_ANOMALY_DETECTION, 
                                    run_histograms=DO_HISTOGRAMS,
                                    run_ood_detection=DO_OOD)
    results[name] = result
    counter += 1
    if counter == 2:
        pass
    

for name, result in results.items():
    print(f'\n\t{name}')
    print_results(result)      