# Data Splitting (Train/Val/Test)

In [1]:
import os
from brain_tumor_segmentation.config import PROCESSED_DATA_DIR

# Define base directory
base_dir = os.path.join(PROCESSED_DATA_DIR, "training_data")

[32m2025-02-16 18:38:27.128[0m | [1mINFO    [0m | [36mbrain_tumor_segmentation.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /home/cepa995/workspace/brain-tumor-segmentation[0m


### Strategy 1: 70/15/15

In [2]:
import os
import json
import random
from sklearn.model_selection import train_test_split

# Configuration
train_ratio = 0.7  # 70% Train
val_ratio = 0.15   # 15% Validation
test_ratio = 0.15  # 15% Test

# List all subjects (directories)
subjects = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])

# Store only valid subjects (those with all required files)
valid_subjects = []

# Check for missing files
for subject in subjects:
    subject_path = os.path.join(base_dir, subject)

    # Define expected image paths
    t1ce = os.path.join(subject_path, f"{subject}_t1wce.nii.gz")
    flair = os.path.join(subject_path, f"{subject}_flair.nii.gz")
    label = os.path.join(subject_path, f"{subject}_seg.nii.gz")

    # Ensure all required files exist
    if all(os.path.exists(f) for f in [t1ce, flair, label]):
        valid_subjects.append(subject)
    else:
        print(f"⚠️ Skipping {subject}: Missing one or more image files")

# Check if we have enough subjects
if len(valid_subjects) < 10:  # Arbitrary small number to ensure the split makes sense
    raise ValueError(f"Not enough valid subjects ({len(valid_subjects)}) to perform a meaningful split.")

# Set seed for reproducibility
random.seed(42)

# Split dataset
train_subjects, temp_subjects = train_test_split(valid_subjects, test_size=(val_ratio + test_ratio), random_state=42)  # 70% Train
val_subjects, test_subjects = train_test_split(temp_subjects, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=42)  # 15% Val, 15% Test

# Organize data in the correct format
splits = {
    "train": [],
    "validation": [],
    "test": []
}

# Function to format subject data
def format_subject(subject):
    subject_path = os.path.join(base_dir, subject)
    return {
        "subject": subject,
        "images": [
            os.path.join(subject_path, f"{subject}_t1wce.nii.gz"),
            os.path.join(subject_path, f"{subject}_flair.nii.gz")
        ],
        "label": os.path.join(subject_path, f"{subject}_seg.nii.gz")
    }

# Store formatted data
splits["train"] = [format_subject(subj) for subj in train_subjects]
splits["validation"] = [format_subject(subj) for subj in val_subjects]
splits["test"] = [format_subject(subj) for subj in test_subjects]

# Save split info
output_file = "dataset-split-strategies/dataset_splits_strategy-1_v1.json"
with open(output_file, "w") as f:
    json.dump(splits, f, indent=4)

print(f"Dataset split saved in {output_file}")
print(f"Total valid subjects used: {len(valid_subjects)} (out of {len(subjects)})")
print(f"Train: {len(train_subjects)}, Validation: {len(val_subjects)}, Test: {len(test_subjects)}")


⚠️ Skipping BraTS2021_UPENN-GBM_00157: Missing one or more image files
⚠️ Skipping BraTS2021_UPENN-GBM_00170: Missing one or more image files
⚠️ Skipping BraTS2021_UPENN-GBM_00186: Missing one or more image files
⚠️ Skipping BraTS2021_UPENN-GBM_00337: Missing one or more image files
⚠️ Skipping BraTS2021_UPENN-GBM_00414: Missing one or more image files
⚠️ Skipping BraTS2021_UPENN-GBM_00834: Missing one or more image files
⚠️ Skipping BraTS2021_UPENN-GBM_00839: Missing one or more image files
Dataset split saved in dataset-split-strategies/dataset_splits_strategy-1_v1.json
Total valid subjects used: 553 (out of 560)
Train: 387, Validation: 83, Test: 83


#### Strategy 2: K-Fold Split (K=5)

In [3]:
from sklearn.model_selection import KFold

# Configuration
n_folds = 5  # Number of folds

# List all subjects (directories)
subjects = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])

# Shuffle subjects for randomness
random.seed(42)
random.shuffle(subjects)

# Initialize K-Fold
kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)

# Prepare dictionary to store folds
cross_val_splits = {"folds": []}

valid_subjects = []  # Store only subjects with all required files

# Check for missing files before splitting
for subject in subjects:
    subject_path = os.path.join(base_dir, subject)

    # Define expected image paths
    t1 = os.path.join(subject_path, f"{subject}_t1w.nii.gz")
    t1ce = os.path.join(subject_path, f"{subject}_t1wce.nii.gz")
    t2 = os.path.join(subject_path, f"{subject}_t2w.nii.gz")
    flair = os.path.join(subject_path, f"{subject}_flair.nii.gz")
    label = os.path.join(subject_path, f"{subject}_seg.nii.gz")

    # Ensure all required files exist
    if all(os.path.exists(f) for f in [t1, t1ce, t2, flair, label]):
        valid_subjects.append(subject)
    else:
        print(f"⚠️ Skipping {subject}: Missing one or more image files")

# Check if we have enough valid subjects for K-Fold
if len(valid_subjects) < n_folds:
    raise ValueError(f"Not enough valid subjects ({len(valid_subjects)}) to create {n_folds} folds.")

# Assign subjects to folds
for fold_idx, (train_idx, val_idx) in enumerate(kf.split(valid_subjects)):
    fold_data = {
        "fold": fold_idx,
        "train": [],
        "val": []
    }
    
    # Assign training subjects
    for idx in train_idx:
        subject = valid_subjects[idx]
        subject_path = os.path.join(base_dir, subject)

        # Image and label paths
        t1 = os.path.join(subject_path, f"{subject}_t1w.nii.gz")
        t1ce = os.path.join(subject_path, f"{subject}_t1wce.nii.gz")
        t2 = os.path.join(subject_path, f"{subject}_t2w.nii.gz")
        flair = os.path.join(subject_path, f"{subject}_flair.nii.gz")
        label = os.path.join(subject_path, f"{subject}_seg.nii.gz")

        fold_data["train"].append({
            "subject": subject,
            "images": [t1, t1ce, t2, flair],
            "label": label
        })
    
    # Assign validation subjects
    for idx in val_idx:
        subject = valid_subjects[idx]
        subject_path = os.path.join(base_dir, subject)

        # Image and label paths
        t1 = os.path.join(subject_path, f"{subject}_t1w.nii.gz")
        t1ce = os.path.join(subject_path, f"{subject}_t1wce.nii.gz")
        t2 = os.path.join(subject_path, f"{subject}_t2w.nii.gz")
        flair = os.path.join(subject_path, f"{subject}_flair.nii.gz")
        label = os.path.join(subject_path, f"{subject}_seg.nii.gz")

        fold_data["val"].append({
            "subject": subject,
            "images": [t1, t1ce, t2, flair],
            "label": label
        })

    cross_val_splits["folds"].append(fold_data)

# Save splits to JSON
output_file = f"dataset-split-strategies/dataset_{n_folds}-folds_strategy_2_v1.json"
with open(output_file, "w") as f:
    json.dump(cross_val_splits, f, indent=4)

print(f"K-Fold cross-validation dataset saved in {output_file}")
print(f"Total valid subjects used: {len(valid_subjects)} (out of {len(subjects)})")


⚠️ Skipping BraTS2021_UPENN-GBM_00414: Missing one or more image files
⚠️ Skipping BraTS2021_UPENN-GBM_00337: Missing one or more image files
⚠️ Skipping BraTS2021_UPENN-GBM_00157: Missing one or more image files
⚠️ Skipping BraTS2021_UPENN-GBM_00186: Missing one or more image files
⚠️ Skipping BraTS2021_UPENN-GBM_00834: Missing one or more image files
⚠️ Skipping BraTS2021_UPENN-GBM_00839: Missing one or more image files
⚠️ Skipping BraTS2021_UPENN-GBM_00170: Missing one or more image files
K-Fold cross-validation dataset saved in dataset-split-strategies/dataset_5-folds_strategy_2_v1.json
Total valid subjects used: 553 (out of 560)
