<a href="https://colab.research.google.com/github/dylanbforde/Cardiac-MRI-Segmentation/blob/main/Augmented_data_loaders.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchio

In [None]:
from pathlib import Path
from torch.utils.data import DataLoader
import torchio as tio
import numpy as np
import torch


In [None]:
def get_subjects(image_dir, label_dir):
    image_paths = sorted(image_dir.glob('*.nii.gz'))
    label_paths = sorted(label_dir.glob('*.nii.gz'))
    assert len(image_paths) == len(label_paths),

    subjects = []
    for img_path, lbl_path in zip(image_paths, label_paths):
        subject = tio.Subject(
            mri=tio.ScalarImage(img_path),
            heart=tio.LabelMap(lbl_path)
        )
        subjects.append(subject)
    return subjects

In [None]:
def get_loaders(modality='SA', base_path='./processed_data', batch_size=4, num_workers=4,
                patch_size=(256, 256, 1), augment=True):

    assert modality in ['SA', 'LA'], "Modality must be 'SA' or 'LA'"

    base_path = Path(base_path)
    paths = {
        'train': {
            'images': base_path / 'train' / modality / 'images',
            'labels': base_path / 'train' / modality / 'labels',
        },
        'val': {
            'images': base_path / 'val' / modality / 'images',
            'labels': base_path / 'val' / modality / 'labels',
        },
        'test': {
            'images': base_path / 'test' / modality / 'images',
            'labels': base_path / 'test' / modality / 'labels',
        }
    }

    # Torchio has some swanky looking transforms - these ones look good can add more?
    # Can find more here: https://docs.torchio.org/transforms/transforms.html
    common_transforms = [
        tio.CropOrPad(patch_size),
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
        tio.RescaleIntensity(out_min_max=(0, 1)),
    ]

    augment_transforms = [
        tio.RandomFlip(axes=(0, 1), flip_probability=0.5),
        tio.RandomNoise(mean=0.0, std=0.01),
        tio.RandomBlur(std=(0.1, 1.0)),
    ] if augment else []

    training_transform = tio.Compose(common_transforms + augment_transforms)
    validation_transform = tio.Compose(common_transforms)


    train_subjects = get_subjects(paths['train']['images'], paths['train']['labels'])
    val_subjects = get_subjects(paths['val']['images'], paths['val']['labels'])
    test_subjects = get_subjects(paths['test']['images'], paths['test']['labels'])

    train_dataset = tio.SubjectsDataset(train_subjects, transform=training_transform)
    val_dataset = tio.SubjectsDataset(val_subjects, transform=validation_transform)
    test_dataset = tio.SubjectsDataset(test_subjects, transform=validation_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    print("Done")

    return train_loader, val_loader, test_loader

In [None]:
train_loader, val_loader, test_loader = get_loaders(
    modality='SA',
    base_path='./processed_data',
    batch_size=2,
    num_workers=0,
    patch_size=(256, 256,1),
    augment=True
)


In [None]:
# This just tests it on a single batch to make sure everthing works

import matplotlib.pyplot as plt

batch = next(iter(train_loader))
images = batch['mri'][tio.DATA]
labels = batch['heart'][tio.DATA]

print(f"Images shape: {images.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Subject keys: {batch.keys()}")

img_slice = images[0, 0, :, :, 0].numpy()
lbl_slice = labels[0, 0, :, :, 0].numpy()

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(img_slice, cmap='gray')
plt.title("Data auugmented Image")

plt.subplot(1, 2, 2)
plt.imshow(lbl_slice, cmap='Reds', alpha=0.7)
plt.title("Label overlay")

plt.tight_layout()
plt.show()