In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path

from torch.utils.data import DataLoader
import torchvision

from uncertify.models.vae import load_vae_baur_model
from uncertify.data.dataloaders import DatasetType
from uncertify.data.dataloaders import dataloader_factory
from uncertify.evaluation.evaluation_pipeline import run_evaluation_pipeline
from uncertify.evaluation.configs import EvaluationConfig, PerformanceEvaluationConfig, PixelThresholdSearchConfig
from uncertify.data.datasets import GaussianNoiseDataset
from uncertify.common import DATA_DIR_PATH

In [None]:
# Define some paths and high level parameters
CHECKPOINT_PATH = Path('/media/juniors/2TB_internal_HD/lightning_logs/train_vae/version_2/checkpoints/last.ckpt')
HDD_PROCESSED_DIR_PATH = Path('/media/juniors/2TB_internal_HD/datasets/processed/')
BATCH_SIZE = 155

# Load the model and define the evaluation config
model = load_vae_baur_model(CHECKPOINT_PATH)
eval_cfg = EvaluationConfig(
    thresh_search_config=PixelThresholdSearchConfig(
        accepted_fpr=0.05,
    ),
    performance_config=PerformanceEvaluationConfig(
        use_n_batches=None
    )
)

_, brats_val_t2_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=BATCH_SIZE, 
                                             val_set_path=HDD_PROCESSED_DIR_PATH / 'brats17_t2_hm_bc_std_bv-3.5.hdf5', shuffle_val=False)
_, brats_val_t1_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=BATCH_SIZE, 
                                             val_set_path=HDD_PROCESSED_DIR_PATH / 'brats17_t1_hm_bc_std_bv-3.5.hdf5', shuffle_val=False)

camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=BATCH_SIZE, 
                                                                    val_set_path=DATA_DIR_PATH / 'processed/camcan_val_t2_hm_std_bv3.5_xe.hdf5', 
                                                                    train_set_path=DATA_DIR_PATH / 'processed/camcan_train_t2_hm_std_bv3.5_xe.hdf5', 
                                                                    shuffle_val=False, shuffle_train=True)
noise_set = GaussianNoiseDataset()
noise_loader = DataLoader(noise_set, batch_size=BATCH_SIZE)

_, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, batch_size=BATCH_SIZE, transform=torchvision.transforms.Compose([
                                                                        torchvision.transforms.Resize((128, 128)),
                                                                        torchvision.transforms.ToTensor()
                                                                    ])
                         )

In [None]:
for dataloader in [brats_val_t2_dataloader, brats_val_t1_dataloader, noise_loader, mnist_val_dataloader]:
    run_evaluation_pipeline(model, camcan_train_dataloader, dataloader, eval_cfg)