In [2]:
import os
import random
import json

import numpy as np
import nibabel as nib

from PIL import Image
from torch.utils.data import Dataset
from torchvision.io import read_image

In [3]:
def ct_to_slices(path):
    """
    Expects the path to a CT scan in the NIfTI format (nii.gz).
    Returns a list of CT slices for that scan in the form of numpy arrays.
    """
    data = nib.load(path).get_fdata()
    [_, _, slices] = data.shape
    
    return [data[..., slice] for slice in range(slices)]

In [4]:
def calc_hounsfield(slices):
    """
    Returns the minimum and maximum Hounsfield value for the passed CT scan slice.
    """
    hounsfield_min = min([np.min(slice) for slice in slices])
    hounsfield_max = max([np.max(slice) for slice in slices])
    
    return hounsfield_min, hounsfield_max

In [5]:
def normalize_slice(slice, hounsfield_min, hounsfield_max):
    """
    Normalizes data on the Hounsfield scale to a [0, 1] interval.
    """
    slice[slice < hounsfield_min] = hounsfield_min
    slice[slice > hounsfield_max] = hounsfield_max
    slice = (slice - hounsfield_min) / (hounsfield_max - hounsfield_min)
    
    return slice

In [6]:
def convert_ct_dataset_to_slices(dataset_dir, train_dir, val_dir, test_dir, val_split=0.2, test_split=0.2, negative_downsampling_rate=1, positive_downsampling_rate=1):
    """
    Converts a Medical Segmentation Decathlon CT scan training dataset to
    the sliced images form, with optional downsampling of negative (no tumors present)
    or positive (with tumors present) slices in the training set.
    
    The slices are represented as single channel grayscale .png images.
    
    Due to the Medical Segmentation Decathlon test dataset labels (ground truths) not being provided,
    we'll have to use parts of the set they provided for training as our validation and testing sets.
    """
    
    train_images = os.path.join(train_dir, 'images')
    train_labels = os.path.join(train_dir, 'labels')
    
    val_images = os.path.join(val_dir, 'images')
    val_labels = os.path.join(val_dir, 'labels')
    
    test_images = os.path.join(test_dir, 'images')
    test_labels = os.path.join(test_dir, 'labels')
    
    os.makedirs(train_images)
    os.makedirs(train_labels)
    
    os.makedirs(val_images)
    os.makedirs(val_labels)
    
    os.makedirs(test_images)
    os.makedirs(test_labels)
    
    image_paths = None
    label_paths = None
    
    with open(os.path.join(dataset_dir, 'dataset.json'), 'r') as dataset_info:
        data = json.load(dataset_info)
        image_paths = [os.path.join(dataset_dir, scan['image']) for scan in data['training']]
        label_paths = [os.path.join(dataset_dir, scan['label']) for scan in data['training']]
    
    num_of_scans = len(image_paths)
    num_val = int(num_of_scans * val_split)
    num_test = int(num_of_scans * test_split)
    
    scan_ids = set(range(num_of_scans))
    val_scans = set(random.sample(sorted(scan_ids), num_val))
    test_scans = set(random.sample(sorted(scan_ids.difference(val_scans)), num_test))
    
    for i in range(len(label_paths)):
        target_images_dir = None
        target_labels_dir = None
        
        is_in_train = False
        
        if i in val_scans:
            target_images_dir = val_images
            target_labels_dir = val_labels
        elif i in test_scans:
            target_images_dir = test_images
            target_labels_dir = test_labels
        else:
            target_images_dir = train_images
            target_labels_dir = train_labels
            is_in_train = True
        
        label_slices = ct_to_slices(label_paths[i])
        image_slices = ct_to_slices(image_paths[i])
        
        hounsfield_min, hounsfield_max = calc_hounsfield(image_slices)
        
        for j in range(len(label_slices)):
            image_slice = image_slices[j]
            label_slice = label_slices[j]
            
            # Downsample only for the training set
            if is_in_train:
                # Check if the slice contains any traces of tumors
                if 1.0 not in label_slice:
                    if random.random() > 1 / negative_downsampling_rate:
                        continue
                else:
                    if random.random() > 1 / positive_downsampling_rate:
                        continue
            
            image_slice = normalize_slice(image_slice, hounsfield_min, hounsfield_max) * 255
            label_slice = label_slice * 255
            
            Image.fromarray(image_slice).convert('L').save(os.path.join(target_images_dir, f'{i + 1}_{j + 1}.png'))
            Image.fromarray(label_slice).convert('L').save(os.path.join(target_labels_dir, f'{i + 1}_{j + 1}.png'))

In [7]:
class CTDataset(Dataset):
    def __init__(self, image_dir, label_dir, image_transform=None, label_transform=None):
        image_names = sorted(os.listdir(image_dir))
        self.image_paths = [os.path.join(image_dir, image_name) for image_name in image_names]
        
        label_names = sorted(os.listdir(label_dir))
        self.label_paths = [os.path.join(label_dir, label_name) for label_name in label_names]
        
        self.image_transform = image_transform
        self.label_transform = label_transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = read_image(self.image_paths[idx])
        label = read_image(self.label_paths[idx])
        
        if self.image_transform:
            image = self.image_transform(image)
        
        if self.label_transform:
            label = self.label_transform(label)
            
        return image, label

In [8]:
class CTDatasetMultiSlices(Dataset):
    def __init__(self, root_dir, image_transform=None, label_transform=None):
        self.image_transform = image_transform
        self.label_transform = label_transform

        self.images = []
        self.labels = []

        images_root = os.path.join(root_dir, 'images')
        labels_root = os.path.join(root_dir, 'labels')

        image_files = image_files = sorted(os.listdir(images_root), key=numerical_sort)
        label_files = image_files = sorted(os.listdir(labels_root), key=numerical_sort)

        # Group image and label files by scan ID
        scan_slices = {}
        for img_file, lbl_file in zip(image_files, label_files):
            scan_id = img_file.split('_')[0]
            if scan_id not in scan_slices:
                scan_slices[scan_id] = []
            scan_slices[scan_id].append((img_file, lbl_file))

        # Merge slices into 3-channel images and store them
        for scan_id, slices in scan_slices.items():
            for i in range(len(slices)):
                curr_img_file, curr_lbl_file = slices[i]

                # Determine previous and next slices using the list index
                prev_img_file = slices[i-1][0] if i > 0 else curr_img_file
                next_img_file = slices[i+1][0] if i < len(slices) - 1 else curr_img_file

                # Load images
                prev_image = read_image(os.path.join(images_root, prev_img_file))
                curr_image = read_image(os.path.join(images_root, curr_img_file))
                next_image = read_image(os.path.join(images_root, next_img_file))

                # Stack images into a 3-channel tensor
                image_3ch = torch.cat([prev_image, curr_image, next_image], dim=0)

                # Load label
                label = read_image(os.path.join(labels_root, curr_lbl_file))

                # Store the merged image and label
                self.images.append(image_3ch)
                self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.image_transform:
            image = self.image_transform(image)
        
        if self.label_transform:
            label = self.label_transform(label)

        return image, label
