In [15]:
import os
import random
import json

import numpy as np
import nibabel as nib

from PIL import Image

In [2]:
def ct_to_slices(path):
    """
    Expects the path to a CT scan in the NIfTI format (nii.gz).
    Returns a list of CT slices for that scan in the form of numpy arrays.
    """
    data = nib.load(path).get_fdata()
    [_, _, slices] = data.shape
    
    return [data[..., slice] for slice in range(slices)]

In [3]:
def calc_hounsfield(slices):
    """
    Returns the minimum and maximum Hounsfield value for the passed CT scan slice.
    """
    hounsfield_min = min([np.min(slice) for slice in slices])
    hounsfield_max = max([np.max(slice) for slice in slices])
    
    return hounsfield_min, hounsfield_max

In [5]:
def normalize_slice(slice, hounsfield_min, hounsfield_max):
    """
    Normalizes data on the Hounsfield scale to a [0, 1] interval.
    """
    slice[slice < hounsfield_min] = hounsfield_min
    slice[slice > hounsfield_max] = hounsfield_max
    slice = (slice - hounsfield_min) / (hounsfield_max - hounsfield_min)
    
    return slice

In [7]:
def convert_ct_dataset_to_slices(dataset_dir, output_dir, negative_downsampling_rate=1, positive_downsampling_rate=1):
    """
    Converts a Medical Segmentation Decathlon CT scan training dataset to
    the sliced images form, with optional downsampling of negative (no tumors present)
    or positive (with tumors present) slices.
    
    The slices are represented as single channel grayscale .png images.
    """
    os.mkdir(output_dir)
    
    images_dir = os.path.join(output_dir, 'images')
    labels_dir = os.path.join(output_dir, 'labels')
    os.mkdir(images_dir)
    os.mkdir(labels_dir)
    
    image_paths = None
    label_paths = None
    
    with open(os.path.join(dataset_dir, 'dataset.json'), 'r') as dataset_info:
        data = json.load(dataset_info)
        image_paths = [os.path.join(dataset_dir, scan['image']) for scan in data['training']]
        label_paths = [os.path.join(dataset_dir, scan['label']) for scan in data['training']]
        
    for i in range(len(label_paths)):
        label_path = label_paths[i]
        image_path = image_paths[i]
        
        label_slices = ct_to_slices(label_path)
        image_slices = ct_to_slices(image_path)
        
        hounsfield_min, hounsfield_max = calc_hounsfield(image_slices)
        
        for j in range(len(label_slices)):
            label_slice = label_slices[j]
            image_slice = image_slices[j]
            
            # Check if the slice contains any traces of tumors
            if 1.0 not in label_slice:
                if random.random() > 1 / negative_downsampling_rate:
                    continue
            else:
                if random.random() > 1 / positive_downsampling_rate:
                    continue
                
            label_slice = label_slice * 255
            image_slice = normalize_slice(image_slice, hounsfield_min, hounsfield_max) * 255
            
            Image.fromarray(label_slice).convert('L').save(os.path.join(labels_dir, f'{i + 1}_{j + 1}.png'))
            Image.fromarray(image_slice).convert('L').save(os.path.join(images_dir, f'{i + 1}_{j + 1}.png'))

In [18]:
def train_val_test_split(sliced_dataset_dir, train_dataset_dir, val_dataset_dir, test_dataset_dir, val_split=0.2, test_split=0.2):
    """
    Splits a dataset (obtained by a call to convert_ct_dataset_to_slices) into
    a train, validation and test subset.
    
    Due to the Medical Segmentation Decathlon test dataset labels (ground truths) not being provided,
    we'll have to use part of the set they provided for training as our testing set.
    """
    images_dir = os.path.join(sliced_dataset_dir, 'images')
    labels_dir = os.path.join(sliced_dataset_dir, 'labels')
    
    image_paths = next(os.walk(images_dir))[2]
    scan_ids = set([path.split('_')[0] for path in image_paths])
    
    num_val = int(len(scan_ids) * val_split)
    num_test = int(len(scan_ids) * test_split)
    
    val_scans = set(random.sample(sorted(scan_ids), num_val))
    test_scans = set(random.sample(sorted(scan_ids.difference(val_scans)), num_test))
    
    train_images_dir = os.path.join(train_dataset_dir, 'images')
    train_labels_dir = os.path.join(train_dataset_dir, 'labels')
    
    val_images_dir = os.path.join(val_dataset_dir, 'images')
    val_labels_dir = os.path.join(val_dataset_dir, 'labels')
    
    test_images_dir = os.path.join(test_dataset_dir, 'images')
    test_labels_dir = os.path.join(test_dataset_dir, 'labels')
    
    os.makedirs(train_images_dir)
    os.makedirs(train_labels_dir)
    
    os.makedirs(val_images_dir)
    os.makedirs(val_labels_dir)
    
    os.makedirs(test_images_dir)
    os.makedirs(test_labels_dir)
    
    for path in image_paths:
        scan_id = path.split('_')[0]
        if scan_id in val_scans:
            os.rename(os.path.join(images_dir, path), os.path.join(val_images_dir, path))
            os.rename(os.path.join(labels_dir, path), os.path.join(val_labels_dir, path))
        elif scan_id in test_scans:
            os.rename(os.path.join(images_dir, path), os.path.join(test_images_dir, path))
            os.rename(os.path.join(labels_dir, path), os.path.join(test_labels_dir, path))
        else:
            os.rename(os.path.join(images_dir, path), os.path.join(train_images_dir, path))
            os.rename(os.path.join(labels_dir, path), os.path.join(train_labels_dir, path))