In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

try:
    tqdm._instances.clear()
except:
    pass

In [None]:
import torch
import torchvision

from uncertify.io.models import load_ensemble_models
from uncertify.common import HD_MODELS_PATH
from uncertify.evaluation.ensembles import yield_inference_batches
from uncertify.data.default_dataloaders import default_dataloader_dict_factory
from uncertify.evaluation.ensembles import infer_ensembles, combine_ensemble_results, visualize_ensemble_predictions
from uncertify.visualization.reconstruction import plot_stacked_scan_reconstruction_batches
from uncertify.visualization.grid import imshow_grid

In [None]:
model_dir_path = HD_MODELS_PATH / 'scheduled_masked_ensembles'
file_names = [f'model_{idx}.ckpt' for idx in [0, 1, 2, 3, 4]]

ensembles = load_ensemble_models(dir_path=model_dir_path, file_names=file_names)
dataloader_dict = default_dataloader_dict_factory(batch_size=8, shuffle_val=True, num_workers=1)

print(f'Loaded {len(ensembles)} ensemble models.')
print(f'Loaded {len(dataloader_dict)} dataloaders: {dataloader_dict.keys()}')

In [None]:
dataloader = dataloader_dict['BraTS T2 val']
model_result_generators = infer_ensembles(ensembles, dataloader,
                                          use_n_batches=10, residual_threshold=0.65)

ensemble_results = combine_ensemble_results(model_result_generators)
visualize_ensemble_predictions(ensemble_results, figsize=(12, 12), cmap='viridis', axis='off')