In [None]:
%load_ext autoreload
%autoreload 2
    
from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from itertools import islice

import h5py
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

from uncertify.common import DATA_DIR_PATH

In [None]:
BRATS_DATA_DIR = DATA_DIR_PATH / 'brats'
CAMCAN_DATA_DIR = DATA_DIR_PATH / 'camcan'

# Explore dataset
Constist of two major types of objects: Datasets and Groups.

Datasets: multidimensional arrays of homogeneous types.
Groups: Hierarchical structures (file system-like).

In [None]:
brats_val_path = BRATS_DATA_DIR / 'brats_all_val.hdf5'
camcan_val_path = CAMCAN_DATA_DIR / 'camcan_t2_val_set.hdf5'
camcan_train_path = CAMCAN_DATA_DIR / 'camcan_t2_train_set.hdf5'

for path in [brats_val_path, camcan_val_path, camcan_train_path]:
    print(f'{path} does{" not " if not path.exists() else " "}exist!')

In [None]:
def print_datasets_info(h5py_file: h5py.File) -> None:
    for dataset_name, dataset in h5py_file.items():
        print(dataset)

for path in [brats_val_path, camcan_val_path, camcan_train_path]:
    name = path.name
    h5py_file = h5py.File(path, 'r')
    print(f'\n --- {name} ---')
    print_datasets_info(h5py_file)

In [None]:
from uncertify.visualization.datasets import plot_samples

n_samples = 1
plot_samples(h5py.File(brats_val_path, 'r'), n_samples=n_samples, cmap='hot')
plot_samples(h5py.File(camcan_val_path, 'r'), n_samples=n_samples, cmap='hot')

# Explore Dataset using Dataset and Dataloader in PyTorch

In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision

from uncertify.data.datasets import Brats2017HDF5Dataset, CamCanHDF5Dataset
from uncertify.data.dataloaders import dataloader_factory, DatasetType
from uncertify.visualization.datasets import plot_brats_batches, plot_camcan_batches
from uncertify.visualization.datasets import plot_fraction_of_abnormal_pixels

In [None]:
batch_size = 4
do_shuffle = True

brats_val_dataset = Brats2017HDF5Dataset(hdf5_file_path=brats_val_path)
brats_val_dataloader = DataLoader(brats_val_dataset, batch_size=batch_size, shuffle=do_shuffle)

camcan_train_dataset = CamCanHDF5Dataset(hdf5_file_path=camcan_train_path)
camcan_train_dataloader = DataLoader(camcan_train_dataset, batch_size=batch_size, shuffle=do_shuffle)

## Using the dataset factory

In [None]:
_, brats_val_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=4, shuffle_val=True)
camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=4)
plot_n_batches = 1
plot_brats_batches(brats_val_dataloader, plot_n_batches, cmap='hot')
plot_camcan_batches(camcan_train_dataloader, plot_n_batches)
plot_camcan_batches(camcan_val_dataloader, plot_n_batches)

# Analyze normal / abnormal pixel distribution

In [None]:
from uncertify.evaluation.datasets import get_samples_without_lesions
from uncertify.visualization.datasets import plot_fraction_of_abnormal_pixels
from uncertify.visualization.datasets import plot_abnormal_pixel_distribution
from uncertify.visualization.datasets import boxplot_abnormal_pixel_fraction

In [None]:
fig, _ = plot_abnormal_pixel_distribution(brats_val_dataloader, figsize=(12, 5), 
                                          hist_kwargs=dict(bins=30, density=True))
fig.savefig(DATA_DIR_PATH / 'plots' / 'normal_abnormal_n_pixel_distribution.png')

In [None]:
fig, _ = plot_fraction_of_abnormal_pixels(brats_val_dataloader, figsize=(12, 5), 
                                          hist_kwargs=dict(bins=30, density=True))
fig.savefig(DATA_DIR_PATH / 'plots' / 'abnormal_pixel_fraction.png')

In [None]:
fig, ax = boxplot_abnormal_pixel_fraction(data_loader=brats_val_dataloader, figsize=(2.5, 5))
fig.savefig(DATA_DIR_PATH / 'plots' / 'boxplot_abnormal_pixel_fraction.png')

In [None]:
n_samples_without_lesions, n_samples_total = get_samples_without_lesions(brats_val_dataloader)
print(f'{n_samples_without_lesions} / {n_samples_total} samples have lesional pixels.')