In [1]:
import re
import os
import glob
from pathlib import Path
from typing import Tuple, List

import numpy as np
import tensorflow as tf
import SimpleITK as sitk
from scipy.ndimage import zoom

In [2]:
dataset_path = Path("data/train")

In [8]:
list_ds = tf.data.Dataset.list_files(str(dataset_path/'*GG/*/'))

In [9]:
for f in list_ds.take(2):
    print(f.numpy())

b'data/train/HGG/Brats18_TCIA02_290_1'
b'data/train/HGG/Brats18_CBICA_APR_1'


In [5]:
def resize(img: np.ndarray, shape: Tuple[int, int, int], mode: str = 'constant') -> np.ndarray:
    """
    Wrapper for scipy.ndimage.zoom suited for MRI images.
    
    Args:
        img (np.ndarray): The image to resize.
        shape (Tuple[int, int, int]): The shape to resize to.
        mode (str, optional): The mode to use while resizing. Defaults to 'constant'.
    
    Returns:
        np.ndarray: The resized image.
    """
    assert len(shape) >= 3, "Output shape cannot have more than 3 dimensions"

    orig_shape = img.shape
    factors = (
        shape[0]/orig_shape[0],
        shape[1]/orig_shape[1], 
        shape[2]/orig_shape[2]
    )
    
    # Resize to the given shape
    return zoom(img, factors, mode=mode)

def load_nii(paths: List[str]) -> np.ndarray:
    """
    Reads the .nii files specified by their paths and returns the data.
    
    Args:
        paths (List[str]): The list of paths to the files to read.
    
    Returns:
        np.ndarray: A numpy array containing the data of each individual file.
    """

    # creating the array to store the data
    data: np.ndarray = None

    # for every path in the list of given paths
    for i in range(len(paths)):

        # read the file
        d = sitk.GetArrayFromImage(sitk.ReadImage(paths[i].decode('utf-8')))

        # allocating the variable to store the data if it has not been allocated
        # already
        if data is None:
            data = np.zeros((len(paths), *d.shape), dtype=np.float32)

        # assigning the file contents to the data array
        data[i] = d

    # returning the data
    return data

def get_files(parent_dir: bytes) -> Tuple[List[str], List[str]]:
    """
    Gets the contents of the directory.
    
    Args:
        parent_dir (bytes): The directory to get the contents of.
    
    Returns:
        Tuple[List[str], List[str]]: An array of the paths to the data files, 
        and an array to the path of the segmentation mask (label).
    """

    # getting the list of files in the directory
    files = os.listdir(parent_dir)

    # filtering out the data files from the list of files in the directory
    data_files = [os.path.join(parent_dir, i) for i in files if b"seg" not in i]

    # filtering out the segmentation mask file (label) from the list of files in the directory
    seg_file = [os.path.join(parent_dir, i) for i in files if b"seg" in i]
    
    # returning the two
    return data_files, seg_file


def preprocess_data(images: np.ndarray, out_shape: Tuple[int, int, int] = None) -> np.ndarray:
    """
    Preprocess the input data.
    
    Args:
        images (np.ndarray): The list of images to preprocess
        out_shape (Tuple[int, int, int, int], optional): The output shape of the image, if resizing
        is requried. Defaults to None.
    
    Returns:
        np.ndarray: The array containing the resultant images.
    """

    # creating the array to store the resultant images
    out_imgs = np.zeros((len(images), *out_shape), dtype=np.float32)

    # for every images
    for i, img in enumerate(images):

        # if there is a need to resize the image, resize the image.
        if out_shape is not None:
            _img = resize(img, out_shape)
        else:
            _img = img

        # normalizing the image
        mean = _img.mean()
        std = _img.std()
        out_imgs[i] = (_img - mean) / std
    
    return out_imgs

def preprocess_label(seg_mask: np.ndarray, out_shape: Tuple[int, int, int] = None, mode: str = 'nearest') -> np.ndarray:
    """
    Separates out the 3 labels from the segmentation provided, namely:
    GD-enhancing tumor (ET — label 4), the peritumoral edema (ED — label 2))
    and the necrotic and non-enhancing tumor core (NCR/NET — label 1)
    
    Args:
        seg_mask (np.ndarray): The numpy array containing the data of the label.
        out_shape (Tuple[int, int, int], optional): The shape to resize to. Defaults to None.
        mode (str, optional): The mode to use while resizing. Defaults to 'nearest'.
    
    Returns:
        np.ndarray: The resultant image.
    """

    # extracting the labels from the segmentation mask
    ncr = seg_mask == 1  # Necrotic and Non-Enhancing Tumor (NCR/NET)
    ed = seg_mask == 2  # Peritumoral Edema (ED)
    et = seg_mask == 4  # GD-enhancing Tumor (ET)
    
    # resizing if required
    if out_shape is not None:
        ncr = resize(ncr, out_shape, mode=mode)
        ed = resize(ed, out_shape, mode=mode)
        et = resize(et, out_shape, mode=mode)

    return np.array([ncr, ed, et], dtype=np.uint8)

def process_paths(parent_dir: str, out_shape: Tuple[int, int, int] = None) -> Tuple[np.ndarray, np.ndarray]:
    """
    Loads the images from the parent directory, processes them and returns the processed images.
    
    Args:
        parent_dir (str): The path to the parent directory.
        out_shape (Tuple[int, int, int], optional): The shape to resize to. Defaults to None.
    
    Returns:
        Tuple[np.ndarray, np.ndarray]: First, the data, second, the labels.
    """

    data_files, seg_file = get_files(parent_dir)
    data: tf.Tensor = load_nii(data_files)
    seg: tf.Tensor = load_nii(seg_file)[0]
    
    data = preprocess_data(data, out_shape)
    label = preprocess_label(seg, out_shape)

    return data, label

In [6]:
res = list_ds.map(lambda x: tf.numpy_function(process_paths, [x, (80, 96, 64)], [tf.float32, tf.uint8]))

In [7]:
%%timeit -r1 -n1
for x, y in res:
    print(f"x: {x.shape}")
    print(f"y: {y.shape}")  
    break

x: (4, 80, 96, 64)
y: (3, 80, 96, 64)
6.46 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
