<a href="https://colab.research.google.com/github/dylanbforde/Cardiac-MRI-Segmentation/blob/DataLoading/Load_and_split_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import nibabel as nib
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
import shutil

In [2]:
# Short axis configuration - having an external config file might have been messy so I kept it in here
RAW_DATA_DIR = Path(r'/content/drive/Shareddrives/Segmentation Group Assignment/Data/MnM2/dataset')
OUTPUT_DIR = Path('./processed_data')
MODALITY = 'SA'
TEST_SPLIT = 0.2
VAL_SPLIT = 0.4
SEED = 42


In [3]:
for subset in ['train', 'val', 'test']:
    for dtype in ['images', 'labels']:
        path = OUTPUT_DIR / subset / MODALITY / dtype
        path.mkdir(parents=True, exist_ok=True)


In [4]:
# Puts all the files together
all_pairs = []

for patient_folder in RAW_DATA_DIR.iterdir():
    if not patient_folder.is_dir():
        continue

    for file in os.listdir(patient_folder):
        if f"_{MODALITY}_" in file and not file.endswith('CINE.nii.gz'):
            if 'gt' not in file:
                image_path = patient_folder / file
                label_file = file.replace('.nii.gz', '_gt.nii.gz')
                label_path = patient_folder / label_file
                if label_path.exists():
                    all_pairs.append((image_path, label_path))

In [None]:
train_val, test = train_test_split(all_pairs, test_size=TEST_SPLIT, random_state=SEED)
train, val = train_test_split(train_val, test_size=VAL_SPLIT / (1 - TEST_SPLIT), random_state=SEED)

splits = {'train': train, 'val': val, 'test': test}

# Saves files out, this should work now?
for split_name, pairs in splits.items():
    for image_path, label_path in pairs:
        image = nib.load(str(image_path))
        label = nib.load(str(label_path))

        for i in range(image.shape[2]):
            img_slice = image.get_fdata()[:, :, i]
            lbl_slice = label.get_fdata()[:, :, i]

            affine = image.affine
            header = image.header

            img_out = nib.Nifti1Image(img_slice[..., np.newaxis], affine, header)
            lbl_out = nib.Nifti1Image(lbl_slice[..., np.newaxis], affine, header)

            base_name = image_path.stem.replace('.nii', '') + f'_slice{i:03}.nii.gz'

            nib.save(img_out, OUTPUT_DIR / split_name / MODALITY / 'images' / base_name)
            nib.save(lbl_out, OUTPUT_DIR / split_name / MODALITY / 'labels' / base_name)

print("Done")

In [None]:
# Long Axis
RAW_DATA_DIR = Path(r'/content/drive/Shareddrives/Segmentation Group Assignment/Data/MnM2/dataset')
OUTPUT_DIR = Path('./processed_data')
MODALITY = 'LA'
TEST_SPLIT = 0.2
VAL_SPLIT = 0.4
SEED = 42


In [None]:

for subset in ['train', 'val', 'test']:
    for dtype in ['images', 'labels']:
        path = OUTPUT_DIR / subset / MODALITY / dtype
        path.mkdir(parents=True, exist_ok=True)


In [None]:
all_pairs = []

for patient_folder in RAW_DATA_DIR.iterdir():
    if not patient_folder.is_dir():
        continue

    for file in os.listdir(patient_folder):
        if f"_{MODALITY}_" in file and not file.endswith('CINE.nii.gz'):
            if 'gt' not in file:
                image_path = patient_folder / file
                label_file = file.replace('.nii.gz', '_gt.nii.gz')
                label_path = patient_folder / label_file
                if label_path.exists():
                    all_pairs.append((image_path, label_path))

In [None]:

train_val, test = train_test_split(all_pairs, test_size=TEST_SPLIT, random_state=SEED)
train, val = train_test_split(train_val, test_size=VAL_SPLIT / (1 - TEST_SPLIT), random_state=SEED)

splits = {'train': train, 'val': val, 'test': test}

for split_name, pairs in splits.items():
    for image_path, label_path in pairs:
        image = nib.load(str(image_path))
        label = nib.load(str(label_path))

        affine = image.affine
        header = image.header

        img_out = nib.Nifti1Image(image.get_fdata(), affine, header)
        lbl_out = nib.Nifti1Image(label.get_fdata(), affine, header)

        base_name = image_path.stem.replace('.nii', '') + '.nii.gz'

        nib.save(img_out, OUTPUT_DIR / split_name / MODALITY / 'images' / base_name)
        nib.save(lbl_out, OUTPUT_DIR / split_name / MODALITY / 'labels' / base_name)

print("Done")