In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from itertools import islice

import h5py
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

from uncertify.common import DATA_DIR_PATH

In [None]:
BRATS_DATA_DIR = DATA_DIR_PATH / 'brats'
CAMCAN_DATA_DIR = DATA_DIR_PATH / 'camcan'

# Explore dataset
Constist of two major types of objects: Datasets and Groups.

Datasets: multidimensional arrays of homogeneous types.
Groups: Hierarchical structures (file system-like).

In [None]:
brats_val_path = BRATS_DATA_DIR / 'brats_all_val.hdf5'
camcan_val_path = CAMCAN_DATA_DIR / 'camcan_t2_val_set.hdf5'
camcan_train_path = CAMCAN_DATA_DIR / 'camcan_t2_train_set.hdf5'

for path in [brats_val_path, camcan_val_path, camcan_train_path]:
    print(f'{path} does{" not " if not path.exists() else " "}exist!')

In [None]:
def print_datasets_info(h5py_file: h5py.File) -> None:
    for dataset_name, dataset in h5py_file.items():
        print(dataset)

for path in [brats_val_path, camcan_val_path, camcan_train_path]:
    name = path.name
    h5py_file = h5py.File(path, 'r')
    print(f'\n --- {name} ---')
    print_datasets_info(h5py_file)

In [None]:
def plot_samples(h5py_file: h5py.File, n_samples: int = 3, dataset_length: int = 4000, cmap: str = 'Greys_r') -> None:
    sample_indices = np.random.choice(dataset_length, n_samples)
    keys = sorted(list(h5py_file.keys()))
    print(f'Columns: {"-".join(keys)}')
    for counter, idx in enumerate(sample_indices):
        fig, axes = plt.subplots(ncols=len(keys), nrows=1, figsize=(10, 10))
        for dataset_name, dataset, ax in zip(keys, [h5py_file[key] for key in keys], axes):
            ax.imshow(np.reshape(dataset[idx], (200, 200)), cmap=cmap)
            ax.axis('off')
        plt.tight_layout()
        print(stats.describe(dataset[idx]))
        plt.show()

n_samples = 3
plot_samples(h5py.File(brats_val_path, 'r'), n_samples=n_samples, cmap='hot')
plot_samples(h5py.File(camcan_val_path, 'r'), n_samples=n_samples, cmap='hot')

# Explore Dataset using Dataset and Dataloader in PyTorch

In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision

from uncertify.data.transforms import Flat2ImgTransform
from uncertify.data.datasets import Brats2017HDF5Dataset
from uncertify.visualization.grid import imshow_grid

In [None]:
transform = torchvision.transforms.Compose([Flat2ImgTransform(new_shape=(200, 200)),
                                            torchvision.transforms.ToTensor()])

brats_val_dataset = Brats2017HDF5Dataset(hdf5_file_path=brats_val_path, transform=transform)
brats_val_dataloader = DataLoader(brats_val_dataset, batch_size=4, shuffle=True)

In [None]:
for sample in islice(brats_val_dataloader, 3):
    grid = make_grid(torch.cat((sample['scan'].type(torch.FloatTensor), sample['seg'].type(torch.FloatTensor)), dim=2))
    imshow_grid(grid, one_channel=True, plt_show=True, cmap='hot', figsize=(9, 8), axis='off')