In [None]:
import os
import shutil
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedGroupKFold
from collections import defaultdict

In [None]:
# Define paths
DATASET_DIR = "data/1-processed/"
OUTPUT_BASE_DIR = "data/2-splits/"
CSV_OUTPUT_PATH = os.path.join(OUTPUT_BASE_DIR, "sgkf05", "split_summary.csv")

# Settings for 5-fold CV
CONFIGS = {
    "sgkf05": dict(n_seeds=5, n_splits=5)
}

# Recommended seeds
SEED_LIST = [0, 42, 77, 123, 2025]

# Class labels
class_map = {"NRM": 0, "SJD": 1}

In [None]:
# Collect participant data
groups, labels, file_paths = [], [], []

print("Collecting dataset...")
for class_name in ["NRM", "SJD"]:
    class_dir = os.path.join(DATASET_DIR, class_name)
    if not os.path.exists(class_dir):
        continue
    for participant in os.listdir(class_dir):
        participant_path = os.path.join(class_dir, participant)
        if os.path.isdir(participant_path):
            label = class_map[class_name]
            for image_file in os.listdir(participant_path):
                if image_file.endswith(".jpg"):
                    image_path = os.path.join(participant_path, image_file)
                    file_paths.append(image_path)
                    labels.append(label)
                    groups.append(participant)

file_paths = np.array(file_paths)
labels = np.array(labels)
groups = np.array(groups)

summary_records = []

In [None]:
# Perform splitting for each config
for config_name, cfg in CONFIGS.items():
    n_seeds = cfg["n_seeds"]
    n_splits = cfg["n_splits"]

    for seed_idx in range(1, n_seeds + 1):
        seed = SEED_LIST[seed_idx - 1] if seed_idx - 1 < len(SEED_LIST) else seed_idx
        sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)

        for fold_idx, (train_val_idx, test_idx) in enumerate(sgkf.split(file_paths, labels, groups), start=1):
            train_val_files = file_paths[train_val_idx]
            train_val_labels = labels[train_val_idx]
            train_val_groups = groups[train_val_idx]

            inner_sgkf = StratifiedGroupKFold(n_splits=4, shuffle=True, random_state=seed + 99)
            train_idx, val_idx = next(inner_sgkf.split(train_val_files, train_val_labels, train_val_groups))

            train_idx = train_val_idx[train_idx]
            val_idx = train_val_idx[val_idx]

            fold_path = os.path.join(OUTPUT_BASE_DIR, config_name, f"seed{seed_idx:02d}", f"fold{fold_idx:02d}")
            os.makedirs(fold_path, exist_ok=True)

            split_dict = {"train": train_idx, "val": val_idx, "test": test_idx}

            # Ensure no subject overlap
            participant_sets = {
                split: set(groups[idx] for idx in indices) for split, indices in split_dict.items()
            }
            assert len(participant_sets["train"] & participant_sets["val"]) == 0, "Overlap between train and val!"
            assert len(participant_sets["train"] & participant_sets["test"]) == 0, "Overlap between train and test!"
            assert len(participant_sets["val"] & participant_sets["test"]) == 0, "Overlap between val and test!"

            image_counts = {split: len(indices) for split, indices in split_dict.items()}
            total_participants = sum(len(s) for s in participant_sets.values())

            print(f"[{config_name}] Seed {seed_idx:02d} - Fold {fold_idx:02d}:")
            for split in ["train", "val", "test"]:
                indices = split_dict[split]
                n_images = image_counts[split]
                n_participants = len(participant_sets[split])
                pct = (n_participants / total_participants) * 100

                split_labels = labels[indices]
                n_nrm = np.sum(split_labels == class_map["NRM"])
                n_sjd = np.sum(split_labels == class_map["SJD"])

                print(f"  {split.capitalize()}: {n_images} images, {n_participants} participants "
                      f"({pct:.2f}%) [NRM: {n_nrm}, sjd: {n_sjd}]")

                summary_records.append({
                    "config": config_name,
                    "seed": f"seed{seed_idx:02d}",
                    "fold": f"fold{fold_idx:02d}",
                    "split": split,
                    "n_images": n_images,
                    "n_participants": n_participants,
                    "pct_participants": round(pct, 2),
                    "n_nrm": int(n_nrm),
                    "n_sjd": int(n_sjd)
                })

            # Copy files to corresponding folders
            for split, indices in split_dict.items():
                split_dir = os.path.join(fold_path, split)
                os.makedirs(split_dir, exist_ok=True)
                for idx in indices:
                    src_path = file_paths[idx]
                    group_label = groups[idx].split("-")[0]  # NRM or SJD
                    dest_dir = os.path.join(split_dir, group_label)
                    os.makedirs(dest_dir, exist_ok=True)
                    shutil.copy(src_path, os.path.join(dest_dir, os.path.basename(src_path)))


In [None]:
# Save summary CSV
summary_df = pd.DataFrame(summary_records)
os.makedirs(os.path.dirname(CSV_OUTPUT_PATH), exist_ok=True)
summary_df.to_csv(CSV_OUTPUT_PATH, index=False)
print(f"\nâœ… All dataset splits created successfully!")
print(f"ðŸ“„ Summary CSV saved to: {CSV_OUTPUT_PATH}")