# Train/Val/Test Split Generation

This notebook creates train/validation and test splits from a metadata CSV.

- **Input**: A single metadata CSV containing columns such as `filepath`, `patient_id`, `pathology`, `region`, and `depth`.
- **Output**: CSV files for a holdout test set and cross-validation train/val splits.

The utility functions come from `src.data.utils.create_splits_from_single_csv`.

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
from pathlib import Path
from src.data.utils import create_splits_from_single_csv, analyze_splits

In [None]:
# Path to the metadata CSV
csv_path = 'path/to/metadata.csv'

# Directory where split CSVs will be saved
output_dir = Path('data_splits')
output_dir.mkdir(exist_ok=True)

In [None]:
# Create splits
folds, test_df = create_splits_from_single_csv(csv_path, n_splits=5, holdout_frac=0.2, seed=42)

# Save test set
test_path = output_dir / 'test.csv'
test_df.to_csv(test_path, index=False)
print(f'Saved test set to {test_path} ({len(test_df)} samples)')

# Save each fold
for i, (train_df, val_df) in enumerate(folds):
    train_path = output_dir / f'fold_{i}_train.csv'
    val_path = output_dir / f'fold_{i}_val.csv'
    train_df.to_csv(train_path, index=False)
    val_df.to_csv(val_path, index=False)
    print(f'Fold {i}: Train={len(train_df)} -> {train_path}')
    print(f'Fold {i}: Val={len(val_df)} -> {val_path}')

analysis = analyze_splits(folds, test_df)

In [None]:
# Plot class distributions across splits
label_cols = ['pathology', 'region', 'depth']

def plot_distribution(label):
    records = []
    counts = analysis['test_analysis']['class_distributions'].get(label, {}).get('counts', {})
    for cls, cnt in counts.items():
        records.append({'split': 'test', 'class': cls, 'count': cnt})
    for i, (train_a, val_a) in enumerate(analysis['fold_analyses']):
        tcounts = train_a['class_distributions'].get(label, {}).get('counts', {})
        for cls, cnt in tcounts.items():
            records.append({'split': f'fold{i}_train', 'class': cls, 'count': cnt})
        vcounts = val_a['class_distributions'].get(label, {}).get('counts', {})
        for cls, cnt in vcounts.items():
            records.append({'split': f'fold{i}_val', 'class': cls, 'count': cnt})
    if not records:
        print(f'No data for {label}')
        return
    dist_df = pd.DataFrame(records)
    plt.figure(figsize=(8,4))
    sns.barplot(data=dist_df, x='class', y='count', hue='split')
    plt.title(f'{label} distribution across splits')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
    plt.show()

for lbl in label_cols:
    plot_distribution(lbl)