# Preparing Training Data for nnU-Net Segmentation

## Overview
This script prepares training datasets for nnU-Net segmentation tasks by:
1. Verifying that the dimensions and spacing of image and segmentation files match.
2. Extracting and organizing the image and segmentation files into case-specific folders following the nnU-Net naming conventions.

## Workflow
### **Cell 1: Loading File Paths**
- Reads all `.nii.gz` images and corresponding segmentations from specified directories.
- Verifies the number of image and segmentation files is consistent.

### **Cell 2: Validation of Image and Segmentation Files**
- Uses the SimpleITK library to:
  - Read `.nii.gz` files into arrays.
  - Validate that image and segmentation shapes and spacings match.
- If mismatches are found, an exception is raised.

### **Cell 3: Extracting Case Identifiers**
- Extracts unique case identifiers from filenames to organize them into case-specific directories.

### **Cell 4: Organizing Files**
- Creates a target directory structure where each case has its own folder.
- Copies and renames the images as `imaging.nii.gz` and segmentations as `segmentation.nii.gz` for nnU-Net compatibility.

## Required Libraries
- `os` and `glob` for file and directory management.
- `shutil` for file operations.
- `SimpleITK` for medical image processing.

## Output
- Case-specific directories containing properly named image and segmentation files.

## Note
- Ensure the input directories and target directory are correctly specified.
- Verify that the `SimpleITK` library is installed before running this script.


In [None]:
import os, glob, random, shutil

# Directories containing training images and segmentations
img_dir_train = 'nnUNet_raw_data_base/nnUNet_train_data_raw/img_in_nii_L2-L5/'
seg_dir_train = 'nnUNet_raw_data_base/nnUNet_train_data_raw/seg_in_nii_L2-L5/'

# Retrieve sorted lists of image and segmentation files
train_images = sorted(glob.glob(os.path.join(img_dir_train, '*.nii.gz')))
train_labels = sorted(glob.glob(os.path.join(seg_dir_train, '*.nii.gz')))

# Print the number of images and labels, and an example file for verification
print(f"Number of Images: {len(train_images)}, Number of Labels: {len(train_labels)}")
print(f"Example Image: {train_images[-1]}, Example Label: {train_labels[-1]}")


In [None]:
import SimpleITK as sitk  

# Validate that image and segmentation dimensions and spacing match
for idx in range(len(train_images)):
    # Load image and segmentation using SimpleITK
    img_sitk = sitk.ReadImage(train_images[idx])
    img_array = sitk.GetArrayFromImage(img_sitk)
    
    seg_sitk = sitk.ReadImage(train_labels[idx])
    seg_array = sitk.GetArrayFromImage(seg_sitk)
    
    # Print the shapes and spacings for debugging
    print(f"{idx}: Image Shape: {img_array.shape}, Segmentation Shape: {seg_array.shape}")
    print(f"Image Spacing: {img_sitk.GetSpacing()}, Segmentation Spacing: {seg_sitk.GetSpacing()}")
    
    # Raise exceptions if shapes or spacings do not match
    if img_array.shape != seg_array.shape:
        raise Exception(f"Shape Mismatch: {train_images[idx]}")
    if img_sitk.GetSpacing() != seg_sitk.GetSpacing():
        raise Exception(f"Spacing Mismatch: {train_images[idx]}")


In [None]:
# Extract unique case identifiers from filenames
for idx, path in enumerate(train_images):
    img_name = os.path.basename(path)  # Extract filename
    indices = [i for i, c in enumerate(img_name) if c == '_']  # Locate underscores
    
    # Determine case ID based on the naming structure
    case_id = img_name[:indices[1]]
    if len(indices) > 2 and '_3d' not in img_name:
        case_id = img_name[:indices[2]]
    
    # Print case ID and corresponding file paths for debugging
    print(f"Case ID: {case_id}")
    print(f"Image Path: {train_images[idx]}, Segmentation Path: {train_labels[idx]}")


In [None]:
# Target directory for organized training data
targ_dir = 'nnUNet_raw_data_base/nnUNet_train_data/Task515_muscle_raw/'

# Organize files into case-specific directories
for idx, path in enumerate(train_images):
    img_name = os.path.basename(path)
    indices = [i for i, c in enumerate(img_name) if c == '_']
    
    # Extract case ID
    case_id = img_name[:indices[1]]
    if len(indices) > 2 and '_3d' not in img_name:
        case_id = img_name[:indices[2]]
    
    # Create case-specific directory if it doesn't exist
    case_dir = os.path.join(targ_dir, case_id)
    if not os.path.exists(case_dir):
        os.mkdir(case_dir)
    
    # Define target paths for image and segmentation files
    targ_img_path = os.path.join(case_dir, 'imaging.nii.gz')
    targ_seg_path = os.path.join(case_dir, 'segmentation.nii.gz')
    
    # Copy and rename files into the target directory
    shutil.copy(train_images[idx], targ_img_path)
    shutil.copy(train_labels[idx], targ_seg_path)
    
    # Print paths for debugging
    print(f"Copied Image: {targ_img_path}, Copied Segmentation: {targ_seg_path}")
