In [None]:
%load_ext autoreload
%autoreload 2
    
from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from itertools import islice
from pathlib import Path

import h5py
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

from uncertify.common import DATA_DIR_PATH

from typing import List

In [None]:
BRATS_DATA_DIR = DATA_DIR_PATH / 'brats' # Path('/scratch/maheer/datasets/processed/')  # 
CAMCAN_DATA_DIR = DATA_DIR_PATH / 'camcan'

# Explore HDF5 dataset
**As created by scripts/preprocess_brats.py**  
Constist of two major types of objects: Datasets and Groups.

Datasets: multidimensional arrays of homogeneous types.
Groups: Hierarchical structures (file system-like).

In [None]:
brats_val_path = BRATS_DATA_DIR / 'brats17_t1_hm_bc_scale_l4.hdf5'
#brats_t1_path = BRATS_DATA_DIR / 'brats_t1_no_hm_unbiased.hdf5'
#camcan_val_path = CAMCAN_DATA_DIR / 'camcan_t2_val_set.hdf5'
#camcan_train_path = CAMCAN_DATA_DIR / 'camcan_t2_train_set.hdf5'

def print_dataset_information(dataset_paths: List[Path]) -> None:
    for path in dataset_paths:
        print(f'{path} does{" not " if not path.exists() else " "}exist!')
    
    def print_datasets_info(h5py_file: h5py.File) -> None:
        for dataset_name, dataset in h5py_file.items():
            print(dataset)

    for path in dataset_paths:
        name = path.name
        h5py_file = h5py.File(path, 'r')
        print(f'\n --- {name} ---')
        print_datasets_info(h5py_file)
    print('Metadata:')
    for key, val in h5py_file.attrs.items():
        print(f'\t{key:30}: {val}')
    return h5py_file

h5py_file = print_dataset_information(dataset_paths=[brats_val_path]) # , camcan_val_path, camcan_train_path, brats_t1_path])

In [None]:
from uncertify.visualization.datasets import plot_samples

n_samples = 50
plot_samples(h5py.File(brats_val_path, 'r'), n_samples=n_samples, cmap='Greys_r', dataset_length=464, vmin=0, vmax=1)
#plot_samples(h5py.File(brats_t1_path, 'r'), n_samples=n_samples, dataset_length=310, cmap='Greys')
#plot_samples(h5py.File(camcan_val_path, 'r'), n_samples=n_samples, cmap='Greys')

# Explore Dataset using Dataset and Dataloader in PyTorch

In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision

from uncertify.data.datasets import Brats2017HDF5Dataset, CamCanHDF5Dataset
from uncertify.data.dataloaders import dataloader_factory, DatasetType
from uncertify.visualization.datasets import plot_brats_batches, plot_camcan_batches
from uncertify.visualization.datasets import plot_fraction_of_abnormal_pixels

## Using the dataset factory

In [None]:
batch_size = 155
brats_val_path = DATA_DIR_PATH / 'brats' # Path('/scratch/maheer/datasets/processed/')
brats_val_path = BRATS_DATA_DIR / 'brats17_t1_hm_bc_scale_l4.hdf5'
_, brats_val_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, path=brats_val_path, shuffle_val=False)

stop = 10
for idx, batch in enumerate(brats_val_dataloader):
    mask = batch['mask'].cpu().detach().numpy()
    scan = batch['scan'].cpu().detach().numpy()
    plt.hist(scan[mask!=0].flatten(), bins=30)
    plt.show()
    if idx == stop:
        break

In [None]:
batch_size = 155
brats_val_path = DATA_DIR_PATH / 'brats' # Path('/scratch/maheer/datasets/processed/')
brats_val_path = BRATS_DATA_DIR / 'brats17_t1_hm_bc_scale_l4.hdf5'
_, brats_val_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, path=brats_val_path, shuffle_val=False)
#_, brats_val_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, path=DATA_DIR_PATH / 'brats/brats_all_val.hdf5', shuffle_val=True)
#_, brats_t1_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, path=DATA_DIR_PATH / 'brats/brats_t1_no_hm_unbiased.hdf5', shuffle_val=True)
#camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=batch_size, shuffle_val=True)
plot_n_batches = 100
plot_brats_batches(brats_val_dataloader, plot_n_batches, cmap='hot', figsize=(18, 12), nrow=16, vmin=0)
#plot_brats_batches(brats_t1_dataloader, plot_n_batches, cmap='Greys_r', vmax=4, vmin=-3.5, figsize=(12, 12))
#plot_camcan_batches(camcan_train_dataloader, plot_n_batches)
#plot_camcan_batches(camcan_val_dataloader, plot_n_batches)

In [None]:
for item in brats_val_dataloader:
    print(item)
    break

# Analyze normal / abnormal pixel distribution

In [None]:
from uncertify.evaluation.datasets import get_samples_without_lesions
from uncertify.visualization.datasets import plot_fraction_of_abnormal_pixels
from uncertify.visualization.datasets import plot_abnormal_pixel_distribution
from uncertify.visualization.datasets import boxplot_abnormal_pixel_fraction

In [None]:
fig, _ = plot_abnormal_pixel_distribution(brats_val_dataloader, figsize=(12, 5), 
                                          hist_kwargs=dict(bins=30, density=True))
fig.savefig(DATA_DIR_PATH / 'plots' / 'normal_abnormal_n_pixel_distribution.png')

In [None]:
fig, _ = plot_fraction_of_abnormal_pixels(brats_val_dataloader, figsize=(12, 5), 
                                          hist_kwargs=dict(bins=80, density=True))
fig.savefig(DATA_DIR_PATH / 'plots' / 'abnormal_pixel_fraction.png')

In [None]:
fig, ax = boxplot_abnormal_pixel_fraction(data_loader=brats_val_dataloader, figsize=(2.5, 5))
fig.savefig(DATA_DIR_PATH / 'plots' / 'boxplot_abnormal_pixel_fraction.png')

In [None]:
n_samples_without_lesions, n_higher_ratio_threshold, n_samples_total = get_samples_without_lesions(brats_val_dataloader, pixel_ratio_threshold=0.01)
print(f'{n_samples_without_lesions} / {n_samples_total} samples have lesional pixels. {n_higher_ratio_threshold} / {n_samples_total} to have more than 0.01% of lesional pixels within the brain mask.')