# Creation of train/val/test datasets used in `tutorial-segmentation.ipynb`

# 0. Imports and function definitions


In [48]:
import numpy as np
import matplotlib.pyplot as plt
import imageio
import os
import re
import glob
from tqdm import tqdm_notebook as tqdm
 
%matplotlib inline

def get_section_no(name):
    if isinstance(name, list):
        res = []
        for n in name:
            res.append(get_section_no(n))
        return res
    else:
        return re.findall(r'\d\d\d\d', name)[-1]

In [49]:
# randomly crop patches from train and val sections
def crop_patches(imgs, masks, coords, patch_size=(256,256), label_map=None):
    """Randomly crop patches of patch_size from imgs. 
    Ensures that entire patch is inside image
    
    Args:
        imgs, masks: np.array of shape (num_imgs, h, w) containing images and masks to be cropped
        coords: np.array of shape (num, 3), upper left coordinate of patches that should be cropped
        patch_size: shape of resulting patches
        label_map: dictionary mapping values in masks to certain labels 
    Returns:
        patches, labels: np.array of shape (num_patches, patch_size)"""
    
    # crop patches from imgs and masks for each selected coordinate
    patches = []
    labels = []
    for coord in tqdm(coords):
        slice_h = slice(coord[1], coord[1]+patch_size[0])
        slice_w = slice(coord[2], coord[2]+patch_size[1])
        patches.append(imgs[coord[0],slice_h,slice_w])
        if label_map is not None:
            label = np.zeros(patch_size, dtype=np.uint8)
            for k,v in label_map.items():
                label[masks[coord[0],slice_h,slice_w]==k] = v
        else:
            label = masks[coord[0],slice_h, slice_w]
        labels.append(label)
    return patches, labels

def sample_coordinates(imgs, num, patch_size=(256,256), p=None):
    """
    Sample num coordinates from imgs
    Coordinates represent upper left corner of patch of size patch_size. Ensures that entire patch is inside imgs
    Args:
        imgs: np.array of shape (num_imgs, h, w)
        num: int, number of coordinates that should be sampled
        p: array of shape (num_imgs, h, w), default is None. Sampling probability for each coordinate
    Returns:
        coords: np.array of shape (num, 3), sampled coordinates
    """
    # select coordinates of upper left corner of patch
    # entire patch should fit in image
    shape = [imgs.shape[0], imgs.shape[1]-patch_size[0], imgs.shape[2]-patch_size[1]] 
    if p is not None:
        # remove borders from p array, such that elements correspond to sampling probability of center pixel of patch
        slice_h = slice(patch_size[0]//2, p.shape[1]-patch_size[0]//2)
        slice_w = slice(patch_size[1]//2, p.shape[2]-patch_size[1]//2)
        p = p[:,slice_h,slice_w]
        coords = weighted_random_choice(num=num, weights=p) # uses np.random.choice
    else: # random sampling of coordinates
        indices = np.random.randint(low=0, high=shape[0]*shape[1]*shape[2], size=num)
        coords = np.unravel_index(indices, shape)
        coords = np.array(coords).T  # coords has shape num, 3
    return coords

def weighted_random_choice(num, weights):
    """
    Sample num coordinates, respecting the given sampling weights
    """
    # First, choose how many values to sample along axis 1
    i1 = np.random.choice(range(0, weights.shape[0]), size=num, p=np.sum(weights, axis=(1,2))/np.sum(weights))
    inds, nums = np.unique(i1, return_counts=True)
    coords = []
    for i, num in tqdm(zip(inds, nums)):
        #print('Sampling', num, 'from axis', i)
        # choose num coordinates from weights[i]
        ind = np.random.choice(range(0, weights.shape[1]*weights.shape[2]), size=num, p=weights[i].flatten()/np.sum(weights[i]))
        for c in np.array(np.unravel_index(ind, weights.shape[1:])).T:
            coords.append([i, c[0], c[1]])
    # shuffle coords
    np.random.shuffle(coords)
    return np.array(coords)
    
def get_sampling_weights(masks, label_vals):
    """
    Return matrix with sampling probability for each element in masks, 
    such that each label in masks will be sampled with equal probability
    Args:
        masks: array containing labels
        label_vals: list of label values that should be considered in masks
    """
    # weight matrix, weighing gm/wm/background equally (to ensure equal sampling)
    weights = np.zeros(masks.shape)
    for l in label_vals:
        binary_mask = masks == l
        p = 1./np.sum(binary_mask)
        weights[binary_mask] = p
    weights = weights / np.sum(weights)
    return weights

def calculate_patch_dataset(sections, sections_fname, masks_fname, num_patches, label_map, patch_size):
    # load images
    print("Loading images")
    masks = np.array([imageio.imread(masks_fname.format(sec)) for sec in sections])
    imgs = np.array([imageio.imread(sections_fname.format(sec)) for sec in sections])
    # ensure equal sampling of all classes
    print("Calculating sampling_weights")
    sampling_weights = get_sampling_weights(masks, label_vals=label_map.keys()) 
    #sampling_weights = None
    # randomly sampled coordinates
    print("Sampling coordinates")
    coords = sample_coordinates(imgs, num_patches, patch_size=patch_size, p=sampling_weights)
    # crop patches
    print("Cropping patches from coordinates")
    patches, labels = crop_patches(imgs, masks, coords, patch_size=patch_size, label_map=label_map)

    # prepare for use as training dataset
    X = np.expand_dims(np.array(patches), 3)
    Y = np.expand_dims(np.array(labels), 3)
    return X, Y

# 1. Download data

In [50]:
data_dir = 'data/raw'

# download V1 masks
masks_v1_url = 'https://object.cscs.ch/v1/AUTH_227176556f3c4bb38df9feea4b91200c/hbp-d002272_BigBrainCytoMapping-v1-v2_pub/ReferenceDelineations/v1/2019_05_01_v1.zip'
masks_v1_archive = os.path.join(data_dir, 'masks_v1.zip')
masks_v1_dir = os.path.join(data_dir, 'masks_v1')
!mkdir -p {data_dir}
!wget -q -nc {masks_v1_url} -O {masks_v1_archive}
!unzip -qo {masks_v1_archive} -d {masks_v1_dir}
!mv {masks_v1_dir}/v1/* {masks_v1_dir}/
!rmdir {masks_v1_dir}/v1
print('V1 masks downloaded to', masks_v1_dir)

V1 masks downloaded to data/raw/masks_v1


In [51]:
# download BigBrain sections for every mask
sections_url = 'ftp://bigbrain.loris.ca/BigBrainRelease.2015/2D_Final_Sections/Coronal/Png/Full_Resolution/pm{}o.png'
sections_dir = os.path.join(data_dir, 'sections')
sections_fname = os.path.join(sections_dir, 'B20_{}.png')
!mkdir -p {sections_dir}
num = 0
for f in glob.glob(os.path.join(masks_v1_dir, '*')):
    section_no = get_section_no(f)
    !wget -q -nc {sections_url.format(section_no)} -O {sections_fname.format(section_no)}
    num += 1
print(num, 'BigBrain sections downloaded to', sections_dir)

39 BigBrain sections downloaded to data/raw/sections


In [52]:
# download gm/wm segmentations (sliced from segmented volume)
masks_gmwm_url = 'https://fz-juelich.sciebo.de/s/YwSi8flpJVtVsKD/download'
masks_gmwm_archive = os.path.join(data_dir, 'masks_gmwm.zip')
masks_gmwm_dir = os.path.join(data_dir, 'masks_gmwm')
!mkdir -p {data_dir}
!wget -q -nc {masks_gmwm_url} -O {masks_gmwm_archive}
!unzip -qo {masks_gmwm_archive} -d {masks_gmwm_dir} 
!mv {masks_gmwm_dir}/masks_gmwm/* {masks_gmwm_dir}/
!rmdir {masks_gmwm_dir}/masks_gmwm
print('Gray/white matter masks downloaded to', masks_gmwm_dir)

Gray/white matter masks downloaded to data/raw/masks_gmwm


In [53]:
# look at structure of data directory
!ls -R {data_dir}

sections_fname = os.path.join(data_dir, 'sections/B20_{}.png')
masks_v1_fname = os.path.join(data_dir, 'masks_v1/B20_{}_v1.png')
masks_gmwm_fname = os.path.join(data_dir, 'masks_gmwm/B20_{}_gmwm.png')

data/raw:
masks_gmwm  masks_gmwm.zip  masks_v1  masks_v1.zip  sections

data/raw/masks_gmwm:
B20_0061_gmwm.png  B20_0661_gmwm.png  B20_1261_gmwm.png  B20_1861_gmwm.png
B20_0121_gmwm.png  B20_0721_gmwm.png  B20_1321_gmwm.png  B20_1921_gmwm.png
B20_0181_gmwm.png  B20_0781_gmwm.png  B20_1381_gmwm.png  B20_1980_gmwm.png
B20_0241_gmwm.png  B20_0841_gmwm.png  B20_1441_gmwm.png  B20_2041_gmwm.png
B20_0301_gmwm.png  B20_0901_gmwm.png  B20_1501_gmwm.png  B20_2101_gmwm.png
B20_0361_gmwm.png  B20_0961_gmwm.png  B20_1561_gmwm.png  B20_2161_gmwm.png
B20_0421_gmwm.png  B20_1021_gmwm.png  B20_1621_gmwm.png  B20_2221_gmwm.png
B20_0481_gmwm.png  B20_1081_gmwm.png  B20_1681_gmwm.png  B20_2281_gmwm.png
B20_0541_gmwm.png  B20_1141_gmwm.png  B20_1741_gmwm.png  B20_2341_gmwm.png
B20_0601_gmwm.png  B20_1201_gmwm.png  B20_1801_gmwm.png

data/raw/masks_v1:
B20_0061_v1.png  B20_0661_v1.png  B20_1261_v1.png  B20_1861_v1.png
B20_0121_v1.png  B20_0721_v1.png  B20_1321_v1.png  B20_1921_v1.png
B20_

# 2. Sample train/val patches from sections

In [54]:
# split available sections in train/val and test sections
imgs = sorted(glob.glob(sections_fname.format('*')))
train_imgs = []
test_imgs = []
val_imgs = []
for i, img in enumerate(imgs):
    if i%6 in (0,1,3,4):
        train_imgs.append(img)
    elif i%6 == 2:
        val_imgs.append(img)
    else:
        test_imgs.append(img)
print('Number of train/val/test images: {}/{}/{}'.format(len(train_imgs),len(val_imgs),len(test_imgs)))

train_sections = get_section_no(train_imgs)
val_sections = get_section_no(val_imgs)
test_sections = get_section_no(test_imgs)
print('Train sections', train_sections)
print('Val sections', val_sections)
print('Test sections', test_sections)

Number of train/val/test images: 26/7/6
Train sections ['0061', '0121', '0241', '0301', '0421', '0481', '0601', '0661', '0781', '0841', '0961', '1021', '1141', '1201', '1321', '1381', '1501', '1561', '1681', '1741', '1861', '1921', '2041', '2101', '2221', '2281']
Val sections ['0181', '0541', '0901', '1261', '1621', '1980', '2341']
Test sections ['0361', '0721', '1081', '1441', '1801', '2161']


In [42]:
# Save test sections in extra folder
test_dir = 'data/test'
!mkdir -p {test_dir}
for s in test_sections:
    !cp {sections_fname.format(s)} {os.path.join(test_dir, os.path.basename(sections_fname.format(s)))}
    !cp {masks_gmwm_fname.format(s)} {os.path.join(test_dir, os.path.basename(masks_gmwm_fname.format(s)))}
    !cp {masks_v1_fname.format(s)} {os.path.join(test_dir, os.path.basename(masks_v1_fname.format(s)))}
!tar -cf data/test.tar {test_dir}

In [44]:
# -- caution, calculating weights and datasets needs around 20GB of RAM
# -- if this step runs out of memory, set weights = None (random sampling of pixels), or use precomputed patches
np.random.seed(42) # get reproducible datasets
# sample training dataset for gray/white matter segmentation
num_train = 500
X_train, Y_train = calculate_patch_dataset(sections=train_sections, sections_fname=sections_fname, masks_fname=masks_gmwm_fname,
                       num_patches=num_train, label_map={0:0, 128:1, 255:2}, patch_size=(268,268))
np.savez("data/train_gmwm.npz", X=X_train, Y=Y_train)

# sample validation dataset
num_val = 20
X_val, Y_val = calculate_patch_dataset(sections=val_sections, sections_fname=sections_fname, masks_fname=masks_gmwm_fname,
                       num_patches=num_val, label_map={0:0, 128:1, 255:2}, patch_size=(268,268))
np.savez("data/val_gmwm.npz", X=X_val, Y=Y_val)

Loading images
Calculating sampling_weights
Sampling coordinates


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Cropping patches from coordinates


HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

Loading images
Calculating sampling_weights
Sampling coordinates


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Cropping patches from coordinates


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

In [55]:
# sample training dataset for v1 segmentation
num_train = 500
X_train, Y_train = calculate_patch_dataset(sections=train_sections, sections_fname=sections_fname, masks_fname=masks_v1_fname,
                       num_patches=num_train, label_map={0:0, 255:1}, patch_size=(268,268))
np.savez("data/train_v1.npz", X=X_train, Y=Y_train)

# sample validation dataset
num_val = 20
X_val, Y_val = calculate_patch_dataset(sections=val_sections, sections_fname=sections_fname, masks_fname=masks_v1_fname,
                       num_patches=num_val, label_map={0:0, 255:1}, patch_size=(268,268))
np.savez("data/val_v1.npz", X=X_val, Y=Y_val)

Loading images
Calculating sampling_weights
Sampling coordinates


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Cropping patches from coordinates


HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

Loading images
Calculating sampling_weights
Sampling coordinates


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Cropping patches from coordinates


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

In [47]:
# sample training dataset for v1 segmentation with 4 classes: bg, gm, wm, v1
num_patches = 500
label_map = {0:0, 128:1, 255:2, 64:3}
patch_size = (268,268)

# load images
print("Loading images")
masks = np.array([imageio.imread(masks_gmwm_fname.format(sec)) for sec in train_sections])
masks_v1 = np.array([imageio.imread(masks_v1_fname.format(sec)) for sec in train_sections])
masks[masks_v1==255] = 64
imgs = np.array([imageio.imread(sections_fname.format(sec)) for sec in train_sections])
# ensure equal sampling of all classes
print("Calculating sampling_weights")
sampling_weights = get_sampling_weights(masks, label_vals=label_map.keys()) 
# randomly sampled coordinates
print("Sampling coordinates")
coords = sample_coordinates(imgs, num_patches, patch_size=patch_size, p=sampling_weights)
# crop patches
print("Cropping patches from coordinates")
patches, labels = crop_patches(imgs, masks, coords, patch_size=patch_size, label_map=label_map)

# prepare for use as training dataset
X_train = np.expand_dims(np.array(patches), 3)
Y_train = np.expand_dims(np.array(labels), 3)

np.savez("data/train_v1gmwm.npz", X=X_train, Y=Y_train)

# sample validation dataset
num_patches = 20
# load images
print("Loading images")
masks = np.array([imageio.imread(masks_gmwm_fname.format(sec)) for sec in val_sections])
masks_v1 = np.array([imageio.imread(masks_v1_fname.format(sec)) for sec in val_sections])
masks[masks_v1==255] = 64
imgs = np.array([imageio.imread(sections_fname.format(sec)) for sec in val_sections])
# ensure equal sampling of all classes
print("Calculating sampling_weights")
sampling_weights = get_sampling_weights(masks, label_vals=label_map.keys()) 
# randomly sampled coordinates
print("Sampling coordinates")
coords = sample_coordinates(imgs, num_patches, patch_size=patch_size, p=sampling_weights)
# crop patches
print("Cropping patches from coordinates")
patches, labels = crop_patches(imgs, masks, coords, patch_size=patch_size, label_map=label_map)

# prepare for use as training dataset
X_val = np.expand_dims(np.array(patches), 3)
Y_val = np.expand_dims(np.array(labels), 3)

np.savez("data/val_v1gmwm.npz", X=X_val, Y=Y_val)

Loading images
Calculating sampling_weights
Sampling coordinates


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Cropping patches from coordinates


HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

Loading images
Calculating sampling_weights
Sampling coordinates


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Cropping patches from coordinates


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))