In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from uncertify.data.dataloaders import dataloader_factory, DatasetType
from uncertify.io.models import load_ensemble_models
from uncertify.evaluation.inference import yield_inference_batches
from uncertify.visualization.reconstruction import plot_stacked_scan_reconstruction_batches
from uncertify.common import DATA_DIR_PATH, HD_DATA_PATH

In [None]:
# Load models
RUN_VERSIONS = [1, 2, 3, 4, 5]
ensemble_models = load_ensemble_models(DATA_DIR_PATH / 'masked_ensemble_models', [f'model{idx}.ckpt' for idx in RUN_VERSIONS])
model = ensemble_models[0]

In [None]:
batch_size    = 12
USE_N_BATCHES = 10
NUM_WORKERS   = 0
SHUFFLE_VAL   = True

brats_t2_path    = HD_DATA_PATH / 'processed/brats17_t2_bc_std_bv3.5.hdf5'
brats_t2_hm_path = HD_DATA_PATH / 'processed/brats17_t2_hm_bc_std_bv3.5.hdf5'
camcan_t2_val_path   = DATA_DIR_PATH  / 'processed/camcan_val_t2_hm_std_bv3.5_xe.hdf5'
camcan_t2_train_path = DATA_DIR_PATH  / 'processed/camcan_train_t2_hm_std_bv3.5_xe.hdf5'

_, brats_t2_dataloader    = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_path, shuffle_val=SHUFFLE_VAL, num_workers=NUM_WORKERS)
_, brats_val_t2_hm_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, val_set_path=brats_t2_hm_path, shuffle_val=SHUFFLE_VAL, num_workers=NUM_WORKERS)
camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=batch_size, val_set_path=camcan_t2_val_path, train_set_path=camcan_t2_train_path, shuffle_val=SHUFFLE_VAL, shuffle_train=True, num_workers=NUM_WORKERS)

dataloader_dict = {'BraTS T2 val': brats_t2_dataloader}

for name, dataloader in dataloader_dict.items(): 
    print(f'{name:15} dataloader: {len(dataloader)} batches (batch_size: {dataloader.batch_size}) -> {len(dataloader) * dataloader.batch_size} samples.')

In [None]:
plot_n_batches = 3

dataloaders = [brats_val_t2_hm_dataloader, brats_t2_dataloader, camcan_train_dataloader]

for dataloader in dataloaders:
    print(f'Dataset: {dataloader.dataset.name}')
    batch_generator = yield_inference_batches(dataloader, model, residual_threshold=0.95)
    plot_stacked_scan_reconstruction_batches(batch_generator, plot_n_batches, nrow=16, show_mask=False,
                                             cmap='gray', axis='off', figsize=(20, 15),
                                             save_dir_path=DATA_DIR_PATH/'reconstructions')