# NIH Data Split + Sanity Checks

This notebook creates patient-wise splits and runs basic sanity checks.

**Inputs to edit:**
- `CSV_PATH`: NIH `Data_Entry_2017.csv`
- `IMAGE_ROOT`: folder containing NIH images
- `OUTPUT_DIR`: where to save splits


In [None]:
import os
import pandas as pd

from src.data import SplitConfig, create_nih_splits, load_nih_metadata, save_splits

CSV_PATH = "/path/to/Data_Entry_2017.csv"
IMAGE_ROOT = "/path/to/images"
OUTPUT_DIR = "splits"

os.makedirs(OUTPUT_DIR, exist_ok=True)
print("Ready.")


In [None]:
df = load_nih_metadata(CSV_PATH, image_root=IMAGE_ROOT)
config = SplitConfig(val_fraction=0.1, test_fraction=0.2, seed=42, normal_only_train=True)
train_df, val_df, test_df = create_nih_splits(df, config)

save_splits(train_df, val_df, test_df, OUTPUT_DIR)

print("Train:", len(train_df))
print("Val:", len(val_df))
print("Test:", len(test_df))


In [None]:
# Sanity checks
print("Normal rate (train):", (train_df["is_normal"].mean()))
print("Normal rate (val):", (val_df["is_normal"].mean()))
print("Normal rate (test):", (test_df["is_normal"].mean()))

# Check for patient leakage
train_patients = set(train_df["patient_id"].unique())
val_patients = set(val_df["patient_id"].unique())
test_patients = set(test_df["patient_id"].unique())

print("Train/Val overlap:", len(train_patients & val_patients))
print("Train/Test overlap:", len(train_patients & test_patients))
print("Val/Test overlap:", len(val_patients & test_patients))


In [None]:
# Label distribution snapshot
label_counts = df.explode("finding_labels")["finding_labels"].value_counts().head(10)
label_counts
