# This notebook prepares the NUH dataset for model training. It converts 3D cardiac MRI scans into separate 2D slices. Each slice is saved together with its corresponding scar mask  and myocardium mask. These folders are then split into training, validation, and test sets.


In [None]:
!pip install SimpleITK

Collecting SimpleITK
  Downloading simpleitk-2.5.2-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.2 kB)
Downloading simpleitk-2.5.2-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (52.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.6/52.6 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.5.2


In [None]:
from sklearn.model_selection import train_test_split
from scipy.ndimage import distance_transform_edt
import matplotlib.pyplot as plt
import scipy.ndimage as ndi
import SimpleITK as sitk
import nibabel as nib
from glob import glob
import pandas as pd
import numpy as np
import cv2
import os



# Mount the google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


Mounted at /content/drive


In [None]:
def compute_total_scar(mask_volume):
    """
    Count the total number of scar pixels (label == 3) across all slices of a volume.
    """
    return np.sum(mask_volume == 3)



def load_nifti(path):
    """
    Load a NIfTI medical image file and return its data as a NumPy array.
    """
    return nib.load(path).get_fdata()


def get_new_spacing(shape):
    """
    Return new voxel spacing.
    """
    # Convert shape to (height, width)
    shape = (shape[1], shape[0])
    if shape in [(336, 336), (256, 256), (320, 320)]:
        return (0.81, 0.81, 10)
    elif shape in [(208, 256), (256, 208)]:
        return (0.84, 0.84, 10)
    elif (198 <= shape[0] <= 224 and shape[1] == 224) or shape in [(172, 192), (224, 222), (224, 198), (224, 210), (224, 216)]:
        return (0.89, 0.89, 10)
    else:
        # Unknown shape, return None
        print(f"!!!!!!!!! Warning: shape {shape}")
        return None


def plot_images(original_image, original_mask):
    """
    Plot images and masks side by side.
    """
    plt.figure(figsize=(12, 6))

    # Original image
    plt.subplot(2, 3, 1)
    plt.imshow(original_image, cmap='gray')
    plt.title('Original Image')
    plt.axis('off')

    # Mask
    plt.subplot(2, 3, 2)
    plt.imshow(original_mask, cmap='gray')
    plt.title('Original Mask')
    plt.axis('off')

    plt.tight_layout()
    plt.show()


def save_slice_as_npz(output_base, volume_id, slice_idx, image_slice, mask_slice):
    """
    Process a single slice:
      - Extract scar and myocardium masks
      - Rotate and flip to return to original orientation
      - Save as compressed .npz file
    """
    # Convert masks: scar=3, myocardium=2
    scar_mask = (mask_slice == 3).astype(np.uint8)
    myocardium_mask = (mask_slice == 2).astype(np.uint8)

    # Rotate both image and masks 90° clockwise
    rotated_image = np.rot90(image_slice, k=3)
    rotated_scar_mask = np.rot90(scar_mask, k=3)
    rotated_myo_mask = np.rot90(myocardium_mask, k=3)

    # Flip horizontally
    flipped_image = np.fliplr(rotated_image)
    flipped_scar_mask = np.fliplr(rotated_scar_mask)
    flipped_myo_mask = np.fliplr(rotated_myo_mask)

    # Save slice data into compressed .npz file
    slice_path = os.path.join(output_base, f"{volume_id.split('.')[0]}_slice_{slice_idx:03d}.npz")
    np.savez_compressed(slice_path,
                        image=flipped_image.astype(np.float32),
                        mask=flipped_scar_mask,
                        myocardium_mask=flipped_myo_mask)


def process_volume(image, mask, volume_id, output_base):
    """
    Process a 3D volume by iterating over all slices and saving each as .npz.
    """
    print(f"Scan shape is: {image.shape}")
    # Loop over slices along the Z-axis
    for i in range(image.shape[0]):
        save_slice_as_npz(output_base, volume_id, i, image[i], mask[i])

    print(f"Processed volume {volume_id}")


In [None]:
# DataFrame to store total scar counts per patient
patient_scar = pd.DataFrame(columns=['patient_id', 'total_scar'])

# Folder Setup
scans_folder = "drive/MyDrive/Training_and_Ablation_Studies/NUH"               # Original scans
output_folder = "drive/MyDrive/Training_and_Ablation_Studies/preprocessed"  # Output for preprocessed slices

os.makedirs(output_folder, exist_ok=True)

# Loop over all patient scan folders
for scan_name in sorted(os.listdir(scans_folder)):

    scan_path = os.path.join(scans_folder, scan_name)

    if os.path.isdir(scan_path):
        files = sorted(os.listdir(scan_path))

        # Load NIfTI files f
        img_nii = nib.load(os.path.join(scan_path, files[0]))
        label_nii = nib.load(os.path.join(scan_path, files[1])).get_fdata()

        # Compute total scar pixel count and save to DataFrame
        scar_pixels = compute_total_scar(label_nii)
        patient_scar.loc[len(patient_scar)] = [scan_name, scar_pixels]

        # Extract image and label paths for resampling
        img_path = os.path.join(scan_path, files[0])
        label_path = os.path.join(scan_path, files[1])

        # Read with SimpleITK for resampling
        sitk_img = sitk.ReadImage(img_path)
        sitk_label = sitk.ReadImage(label_path)

        # Get original and new spacing
        original_spacing = sitk_img.GetSpacing()
        new_spacing = get_new_spacing(sitk_img.GetSize())  # (w, h, d)

        # Resampling function
        def resample_image(itk_image, new_spacing, interpolator):
            """
            Resample image to new spacing with specified interpolator.
            """
            original_size = itk_image.GetSize()
            original_spacing = itk_image.GetSpacing()

            # Compute new size based on spacing ratio
            new_size = [
                int(round(osz * ospc / nspc))
                for osz, ospc, nspc in zip(original_size, original_spacing, new_spacing)
            ]

            resample = sitk.ResampleImageFilter()
            resample.SetOutputSpacing(new_spacing)
            resample.SetSize(new_size)
            resample.SetOutputDirection(itk_image.GetDirection())
            resample.SetOutputOrigin(itk_image.GetOrigin())
            resample.SetInterpolator(interpolator)
            resample.SetDefaultPixelValue(0)

            return resample.Execute(itk_image)

        # Resample image (linear) and label (nearest neighbor)
        resampled_img = resample_image(sitk_img, new_spacing, sitk.sitkLinear)
        resampled_label = resample_image(sitk_label, new_spacing, sitk.sitkNearestNeighbor)

        # Verify orientation to check resampling consistency
        print("Original image orientation:", sitk_img.GetDirection())
        print("Resampled image orientation:", resampled_img.GetDirection())

        # Convert to NumPy arrays (Z, Y, X)
        resampled_img_np = sitk.GetArrayFromImage(resampled_img)
        resampled_label_np = sitk.GetArrayFromImage(resampled_label)

        # Extract patient/volume ID
        volume_id = os.path.basename(img_path).split('_image')[0]

        # Process entire volume into individual 2D slices and save as .npz
        process_volume(resampled_img_np, resampled_label_np, volume_id, output_folder)


In [None]:
# Bin the scar amounts into categories
patient_scar['scar_bin'] = pd.qcut(patient_scar['total_scar'], q=3, labels=False)

# First split into train (80%) and temp (20%)
train_df, temp_df = train_test_split(
    patient_scar,
    test_size=0.2,
    # stratify=patient_scar['scar_bin'],
    random_state=42
)

# Split temp into validation and test (10% each of total)
test_df, val_df = train_test_split(
    temp_df,
    test_size=0.5,
    # stratify=temp_df['scar_bin'],
    random_state=42
)

# Drop the scar_bin column
train_df = train_df.drop(columns='scar_bin')
val_df = val_df.drop(columns='scar_bin')
test_df = test_df.drop(columns='scar_bin')

# Plot the spread of scar_bin
bin_counts = patient_scar['scar_bin'].value_counts().sort_index()

fig, ax = plt.subplots(figsize=(5, 4))
ax.bar(bin_counts.index.astype(str), bin_counts.values)

ax.set_xlabel('scar_bin label')
ax.set_ylabel('Number of subjects')
ax.set_title('Distribution of scar bins (n = {})'.format(len(patient_scar)))
ax.grid(axis='y', linestyle='--', alpha=0.4)

plt.tight_layout()
plt.show()

print(val_df.shape)
print(test_df.shape)
print(train_df.shape)

In [None]:
train_folder = "drive/MyDrive/Training_and_Ablation_Studies/NUH/preprocessed/train"
val_folder = "drive/MyDrive/Training_and_Ablation_Studies/NUH/preprocessed/val"
test_folder = "drive/MyDrive/Training_and_Ablation_Studies/NUH/preprocessed/test"

os.makedirs(train_folder, exist_ok=True)
os.makedirs(val_folder, exist_ok=True)
os.makedirs(test_folder, exist_ok=True)


# Move each patient to its right directory
for folder in os.listdir("drive/MyDrive/Training_and_Ablation_Studies/NUH/preprocessed"):
    if folder[:12] in train_df['patient_id'].values:
        os.rename(os.path.join("drive/MyDrive/Training_and_Ablation_Studies/NUH/preprocessed", folder), os.path.join(train_folder, folder))

    elif folder[:12] in val_df['patient_id'].values:
        os.rename(os.path.join("drive/MyDrive/Training_and_Ablation_Studies/NUH/preprocessed", folder), os.path.join(val_folder, folder))

    elif folder[:12] in test_df['patient_id'].values:
        os.rename(os.path.join("drive/MyDrive/Training_and_Ablation_Studies/NUH/preprocessed", folder), os.path.join(test_folder, folder))

    else:
        print(f"!!!!!!!!! Warning: {folder}")