# nnU-Net Data Preparation

## Introduction
This notebook combines all the necessary steps for preparing data for nnU-Net, including fetching and splitting the original data, downsampling images and labels, remapping label values, and creating the `dataset.json` file.

## Imports and Setup

In [10]:
import os
import random
import json
import shutil
import SimpleITK as sitk
import numpy as np

# Define variables for dataset types
ascending = "Ascending"
sinuses = "Sinuses"


## Downsampling
### Resample Images and Labels

In [6]:
# Define target spacing
target_spacing = (2.0, 2.0, 2.0)

def resample_image(image_path, output_path, is_label=False):
    image = sitk.ReadImage(image_path)
    
    original_spacing = image.GetSpacing()
    original_size = image.GetSize()
    
    new_size = [
        int(round(original_size[i] * (original_spacing[i] / target_spacing[i])))
        for i in range(3)
    ]
    
    interpolator = sitk.sitkNearestNeighbor if is_label else sitk.sitkLinear

    resampled_image = sitk.Resample(
        image,
        new_size,
        sitk.Transform(),
        interpolator,
        image.GetOrigin(),
        target_spacing,
        image.GetDirection(),
        0,
        image.GetPixelID()
    )
    
    sitk.WriteImage(resampled_image, output_path)

def downsample_images():
    images_dirs = ["../../LSS/ECE5995_AorticWall/Normal/images", "../../LSS/ECE5995_AorticWall/Diseased/images"]
    output_images_dir = "Resampled/Common_2mm_images"

    os.makedirs(output_images_dir, exist_ok=True)

    for img_dir in images_dirs:
        for image_file in sorted(os.listdir(img_dir)):
            if image_file.endswith(".nii.gz"):
                image_path = os.path.join(img_dir, image_file)
                output_image_path = os.path.join(output_images_dir, image_file)
                resample_image(image_path, output_image_path, is_label=False)

    print(f"Resampling Finished! Images are saved in {output_images_dir}")

def downsample_labels(dataset_type):
    labels_dirs = ["../../LSS/ECE5995_AorticWall/Normal/labels", "../../LSS/ECE5995_AorticWall/Diseased/labels"]
    output_labels_dir = f"Resampled/{dataset_type}_2mm_labels"

    os.makedirs(output_labels_dir, exist_ok=True)

    for lbl_dir in labels_dirs:
        label_files = sorted(os.listdir(lbl_dir))
        for label_file in label_files:
            if label_file.endswith(f"_{dataset_type}.nii.gz"):
                label_path = os.path.join(lbl_dir, label_file)
                output_label_path = os.path.join(output_labels_dir, label_file)
                resample_image(label_path, output_label_path, is_label=True)

    print(f"Resampling Finished! Labels are saved in {output_labels_dir}")

def downsample_dataset():
    # Downsample images once
    downsample_images()

    # Downsample labels for Ascending dataset
    downsample_labels(ascending)

    # Downsample labels for Sinuses dataset
    downsample_labels(sinuses)

downsample_dataset()

Resampling Finished! Images are saved in Resampled/Common_2mm_images
Resampling Finished! Labels are saved in Resampled/Ascending_2mm_labels
Resampling Finished! Labels are saved in Resampled/Sinuses_2mm_labels


## Data Preparation
### Fetch and Split Downsampled Data

In [23]:
def setup_directories(dataset_type):
    
    images_dir = f"Resampled/Common_2mm_images"
    labels_dir = f"Resampled/{dataset_type}_2mm_labels"

    print(f"dataset_type: {dataset_type}")  # Debugging print statement

    if dataset_type == ascending:
        dataset_number = "001"
    elif dataset_type == sinuses:
        dataset_number = "002"
    else:
        dataset_number = "000"  # Default or other types
    
    output_dir = f"nnUNet_raw/Dataset{dataset_number}_{dataset_type}"
    print(f"output_dir: {output_dir}")  # Debugging print statement

    images_train_dir = os.path.join(output_dir, "imagesTr")
    images_test_dir = os.path.join(output_dir, "imagesTs")
    labels_train_dir = os.path.join(output_dir, "labelsTr")
    labels_test_dir = os.path.join(output_dir, "labelsTs")
    train_test_record_file = os.path.join(output_dir, "train_test_split.json")

    os.makedirs(images_train_dir, exist_ok=True)
    os.makedirs(images_test_dir, exist_ok=True)
    os.makedirs(labels_train_dir, exist_ok=True)
    os.makedirs(labels_test_dir, exist_ok=True)

    return images_dir, labels_dir, images_train_dir, images_test_dir, labels_train_dir, labels_test_dir, train_test_record_file


In [32]:
def prepare_data(dataset_type):
    images_dir, labels_dir, images_train_dir, images_test_dir, labels_train_dir, labels_test_dir, train_test_record_file = setup_directories(dataset_type)
    
    all_images = sorted([f for f in os.listdir(images_dir) if f.endswith(".nii.gz")])

    random.seed(42)
    random.shuffle(all_images)

    dataset_usage_ratio = 0.2  
    train_test_ratio = 0.8     

    subset_count = int(len(all_images) * dataset_usage_ratio)
    subset_images = all_images[:subset_count]

    train_count = int(len(subset_images) * train_test_ratio)

    train_files = subset_images[:train_count]
    test_files = subset_images[train_count:]

    train_test_record = {
        "dataset_usage_ratio": dataset_usage_ratio,
        "train_test_ratio": train_test_ratio,
        "train_count": len(train_files),
        "test_count": len(test_files),
        "train": [],
        "test": [],
        "filename_mapping": {}  
    }

    for idx, image_filename in enumerate(train_files):
        image_path = os.path.join(images_dir, image_filename)

        new_filename = f"{dataset_type}_{idx:05d}_0000.nii.gz"
        train_test_record["train"].append(new_filename)
        train_test_record["filename_mapping"][new_filename] = image_filename  

        shutil.copy(image_path, os.path.join(images_train_dir, new_filename))

        label_filename = image_filename.replace(".nii.gz", f"_{dataset_type}.nii.gz")
        label_path = os.path.join(labels_dir, label_filename)

        if os.path.exists(label_path):
            new_label_filename = f"{dataset_type}_{idx:05d}.nii.gz"
            shutil.copy(label_path, os.path.join(labels_train_dir, new_label_filename))
            train_test_record["filename_mapping"][new_label_filename] = label_filename 

    for idx, image_filename in enumerate(test_files):
        image_path = os.path.join(images_dir, image_filename)

        new_filename = f"{dataset_type}_{idx+train_count:05d}_0000.nii.gz"
        train_test_record["test"].append(new_filename)
        train_test_record["filename_mapping"][new_filename] = image_filename

        shutil.copy(image_path, os.path.join(images_test_dir, new_filename))

        label_filename = image_filename.replace(".nii.gz", f"_{dataset_type}.nii.gz")
        label_path = os.path.join(labels_dir, label_filename)

        if os.path.exists(label_path):
            new_label_filename = f"{dataset_type}_{idx+train_count:05d}.nii.gz"
            shutil.copy(label_path, os.path.join(labels_test_dir, new_label_filename))
            train_test_record["filename_mapping"][new_label_filename] = label_filename  

    with open(train_test_record_file, "w") as f:
        json.dump(train_test_record, f, indent=4)

    print(f"Data split! Using {dataset_usage_ratio * 100}% of the original dataset")
    print(f"Training set: {len(train_files)} images and {len(os.listdir(labels_train_dir))} labels")
    print(f"Testing set: {len(test_files)} images and {len(os.listdir(labels_test_dir))} labels")
    print(f"Data report is saved at {train_test_record_file}")

# Prepare data for Ascending dataset
prepare_data(ascending)

# Prepare data for Sinuses dataset
prepare_data(sinuses)

dataset_type: Ascending
output_dir: nnUNet_raw/Dataset001_Ascending
Data split! Using 20.0% of the original dataset
Training set: 34 images and 34 labels
Testing set: 9 images and 9 labels
Data report is saved at nnUNet_raw/Dataset001_Ascending/train_test_split.json
dataset_type: Sinuses
output_dir: nnUNet_raw/Dataset002_Sinuses
Data split! Using 20.0% of the original dataset
Training set: 34 images and 34 labels
Testing set: 9 images and 9 labels
Data report is saved at nnUNet_raw/Dataset002_Sinuses/train_test_split.json


## Label Remapping
### Remap Label Values

In [33]:
def remap_labels(label_path):
    print(f"Processing {label_path}...")
    
    image = sitk.ReadImage(label_path)
    array = sitk.GetArrayFromImage(image)

    array[array == 4] = 1  

    new_image = sitk.GetImageFromArray(array)
    new_image.SetSpacing(image.GetSpacing())
    new_image.SetOrigin(image.GetOrigin())
    new_image.SetDirection(image.GetDirection())

    sitk.WriteImage(new_image, label_path)
    print(f"Updated {label_path}")

def process_labels(dataset_type):
    _, _, _, _, labels_train_dir, labels_test_dir, _ = setup_directories(dataset_type)
    
    for filename in os.listdir(labels_train_dir):
        if filename.endswith(".nii.gz"):
            remap_labels(os.path.join(labels_train_dir, filename))
    
    for filename in os.listdir(labels_test_dir):
        if filename.endswith(".nii.gz"):
            remap_labels(os.path.join(labels_test_dir, filename))
    
    print("All the label 4 are converted to label 1.")

# Process labels for Ascending dataset
process_labels(ascending)

# Process labels for Sinuses dataset
process_labels(sinuses)

dataset_type: Ascending
output_dir: nnUNet_raw/Dataset001_Ascending
Processing nnUNet_raw/Dataset001_Ascending/labelsTr/Ascending_00007.nii.gz...
Updated nnUNet_raw/Dataset001_Ascending/labelsTr/Ascending_00007.nii.gz
Processing nnUNet_raw/Dataset001_Ascending/labelsTr/Ascending_00003.nii.gz...
Updated nnUNet_raw/Dataset001_Ascending/labelsTr/Ascending_00003.nii.gz
Processing nnUNet_raw/Dataset001_Ascending/labelsTr/Ascending_00011.nii.gz...
Updated nnUNet_raw/Dataset001_Ascending/labelsTr/Ascending_00011.nii.gz
Processing nnUNet_raw/Dataset001_Ascending/labelsTr/Ascending_00032.nii.gz...
Updated nnUNet_raw/Dataset001_Ascending/labelsTr/Ascending_00032.nii.gz
Processing nnUNet_raw/Dataset001_Ascending/labelsTr/Ascending_00031.nii.gz...
Updated nnUNet_raw/Dataset001_Ascending/labelsTr/Ascending_00031.nii.gz
Processing nnUNet_raw/Dataset001_Ascending/labelsTr/Ascending_00010.nii.gz...
Updated nnUNet_raw/Dataset001_Ascending/labelsTr/Ascending_00010.nii.gz
Processing nnUNet_raw/Dataset001

## Dataset JSON Creation
### Create dataset.json

In [34]:
def create_dataset_json(dataset_type):
    _, _, images_train_dir, images_test_dir, labels_train_dir, labels_test_dir, _ = setup_directories(dataset_type)

    train_cases = sorted([f.replace(".nii.gz", "") for f in os.listdir(images_train_dir) if f.endswith(".nii.gz")])
    test_cases = sorted([f.replace(".nii.gz", "") for f in os.listdir(images_test_dir) if f.endswith(".nii.gz")])

    dataset_info = {
        "channel_names": {"0": "CT"},  
        "labels": {
            "background": 0,
            "root_1": 1,
            "root_2": 2,
            "root_3": 3
        },
        "numTraining": len(train_cases),
        "file_ending": ".nii.gz",      
        "training": [{"image": f"./imagesTr/{case}.nii.gz", "label": f"./labelsTr/{case}.nii.gz"} for case in train_cases],
        "test": [{"image": f"./imagesTs/{case}.nii.gz", "label": f"./labelsTs/{case}.nii.gz"} for case in test_cases]
    }

    dataset_dir = os.path.dirname(images_train_dir)
    json_path = os.path.join(dataset_dir, "dataset.json")
    with open(json_path, "w") as json_file:
        json.dump(dataset_info, json_file, indent=4)
    
    print(f"Dataset JSON created at {json_path}")

# Create dataset JSON for Ascending dataset
create_dataset_json(ascending)

# Create dataset JSON for Sinuses dataset
create_dataset_json(sinuses)

dataset_type: Ascending
output_dir: nnUNet_raw/Dataset001_Ascending
Dataset JSON created at nnUNet_raw/Dataset001_Ascending/dataset.json
dataset_type: Sinuses
output_dir: nnUNet_raw/Dataset002_Sinuses
Dataset JSON created at nnUNet_raw/Dataset002_Sinuses/dataset.json


## Summary and Next Steps
### In this notebook, we have prepared the data for nnU