In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os, re
from matplotlib import colors
from tqdm.notebook import tqdm
from IPython.display import clear_output
import SimpleITK as sitk

In [2]:
"""
Data load info. loaders:
- input: abs. path to dir. with data
- output: a dictionary with
    - key: img #
    - val: absolute dir. to each img
"""

# load data
def load_data_to_dictionaries(path):
    """
    output dictionaries that record the input file paths
    """
    data = {}
    for dirname, _, filenames in os.walk(path):
        for filename in filenames:
            if filename.split('.')[1] == 'csv':
                continue
            name_split = filename.split('_')
            idx = int(name_split[2])
            file_type = name_split[3].split('.')[0]
            if file_type not in data.keys():
                data[file_type] = {idx: os.path.join(dirname, filename)}
            else:
                data[file_type].update({idx: os.path.join(dirname, filename)})
    return data

In [3]:
train_data = load_data_to_dictionaries('/scratch/ec2684/cv/data/brats20/nii/train')
val_data = load_data_to_dictionaries('/scratch/ec2684/cv/data/brats20/nii/val')

In [4]:
train_data.keys()

dict_keys(['seg', 't1', 't2', 't1ce', 'flair'])

In [9]:
# check the data stored in 'seg', 't1', 't2', 't1ce', 'flair' have no missing data
# (check there exists 369 scans for each scan types)
def check_data(data):    
    all_match = True

    for key1 in tqdm(data.keys()):
        for key2 in tqdm(data.keys()):
            all_match = all_match and \
            (sorted(data[key1].keys()) == sorted(data[key2].keys()))

    print(all_match)

check_data(train_data)
check_data(val_data)

In [10]:
# check the data stored in 'seg', 't1', 't2', 't1ce', 'flair' have (1,1,1) spacing

def check_unit_spacing(data):    
    all_match = True

    for key1 in tqdm(data.keys()):
        for key2 in tqdm(data[key1].keys()):
            vol = sitk.ReadImage(data[key1][key2], sitk.sitkInt16)
            all_match = all_match and \
            (vol.GetSpacing() == (1.,1.,1.))

    print(all_match)
    
check_unit_spacing(train_data)
check_unit_spacing(val_data)

In [13]:
# check if data stored in 'seg', 't1', 't2', 't1ce', 'flair' all have (155,240,240) shape

uniform_shape = True

for key1 in tqdm(train_data.keys()):
    for key2 in tqdm(train_data[key1].keys()):
        data = sitk.ReadImage(train_data[key1][key2], sitk.sitkInt16)
        data_array = sitk.GetArrayFromImage(data)
        uniform_shape = uniform_shape and (data_array.shape == (155,240,240))

print(uniform_shape)

In [None]:
'''
##### Imaging Data Description: #####
All BraTS multimodal scans are:
- in NIfTI format (.nii.gz)
- describe:
    a) native (T1)
    b) post-contrast T1-weighted (T1Gd)
    c) T2-weighted (T2)
    d) T2 Fluid Attenuated Inversion Recovery (T2-FLAIR) volumes
- acquired with:
    - different clinical protocols
    - various scanners from multiple (n=19) institutions

- all the imaging datasets have been segmented manually, 
by one to four raters, following the same annotation protocol.
- their annotations were approved by experienced neuro-radiologists. 
- Annotations comprise:
    (NCR/NET — label 1) the necrotic and non-enhancing tumor core
    (ED — label 2) the peritumoral edema 
    (ET — label 4) the GD-enhancing tumor
    
- The provided data are distributed after their pre-processing, 
i.e., co-registered to the same anatomical template, interpolated 
to the same resolution (1 mm^3) and skull-stripped.
''';

In [14]:
# check min max voxel values for 't1', 't2', 't1ce', 'flair'
# where 'seg' values are greater than 1

def voxel_range(data):    
    max_voxel = -1E5
    min_voxel = 1E5
    ranges = {}
    for key in data.keys():
        if key == 'seg':
            continue
        ranges[key] = [min_voxel,max_voxel]
    for key1 in tqdm(data.keys()):
        if key1 == 'seg':
            continue
        for key2 in tqdm(data[key1].keys()):
            seg = sitk.ReadImage(data['seg'][key2], sitk.sitkInt8)
            seg_array = sitk.GetArrayFromImage(seg)
            vol = sitk.ReadImage(data[key1][key2], sitk.sitkInt16)
            vol_array = sitk.GetArrayFromImage(vol)
            vol_array = vol_array[seg_array > 0]
            ranges[key1][0] = min(ranges[key1][0],np.min(vol_array))
            ranges[key1][1] = max(ranges[key1][1],np.max(vol_array))
    return ranges

In [16]:
train_ranges = voxel_range(train_data)

In [None]:
train_ranges

{'t1': [0, 21113], 't2': [0, 31404], 't1ce': [0, 18011], 'flair': [0, 29422]}

In [17]:
# count voxel values for 't1', 't2', 't1ce', 'flair'
# where 'seg' values are greater than 1

def voxel_cnt(data):
    ranges = {}
    for key in data.keys():
        if key == 'seg':
            continue
        ranges[key] = np.zeros(31405)
    for key1 in tqdm(data.keys()):
        if key1 == 'seg':
            continue
        for key2 in tqdm(data[key1].keys()):
            seg = sitk.ReadImage(data['seg'][key2], sitk.sitkInt8)
            seg_array = sitk.GetArrayFromImage(seg)
            vol = sitk.ReadImage(data[key1][key2], sitk.sitkInt16)
            vol_array = sitk.GetArrayFromImage(vol)
            vol_array = vol_array[seg_array > 0]
            unique, counts = np.unique(vol_array, return_counts=True)
            for idx, num in enumerate(unique):
                ranges[key1][num] = ranges[key1][num] + counts[idx]
    return ranges

In [19]:
train_cnt = voxel_cnt(train_data)

In [None]:
# get the 95% voxel values range

for key in train_ranges.keys():
    N = np.sum(train_cnt[key])
    a = 0.95
    max_range = a*N
    print(key)
    print(f'mean: {np.dot(np.arange(len(train_cnt[key])),train_cnt[key])/N}')
    print(f'95% voxel: {np.where(np.cumsum(train_cnt[key]) < max_range)[0][-1]}')
    print()

t1
mean: 694.837275449542
95% voxel: 2197

t2
mean: 1181.4516712400637
95% voxel: 2858

t1ce
mean: 830.4526936526679
95% voxel: 2969

flair
mean: 815.7834883957967
95% voxel: 1333



In [None]:
# decided to clip each scans by the 95% voxel values above before normalizing

In [20]:
!pwd

/scratch/ec2684/cv/brats20
