# NIH Chest X-Ray Dataset Preparation (Top 10 Diseases)

## Overview
This notebook prepares the NIH Chest X-Ray dataset for multi-label disease classification using the 10 most common disease categories. Rare diseases with insufficient training samples are excluded to improve model performance.

## Key Features
- **Patient-Aware Splitting**: Ensures no patient appears in multiple splits (train/val/test), preventing data leakage critical for medical AI validation
- **Multi-Label Support**: Handles images with multiple disease classifications
- **Flat Folder Structure**: Organizes images in a format compatible with multi-label learning
- **Comprehensive Validation**: Verifies data integrity and distribution balance
- **Production-Ready**: Implements best practices for medical imaging ML pipelines
- **Top 10 Focus**: Uses only the 10 most common diseases (excludes rare cases with <300 samples)

## Dataset Information
- **Source**: NIH Clinical Center Chest X-Ray Dataset
- **Total Images**: 112,120 frontal-view X-rays
- **Diseases**: Top 10 most common (excludes Fibrosis, Hernia, Pleural_Thickening, Pneumonia, Pneumothorax)
- **Sample Size**: 40,000 images (maintains statistical distribution)
- **Splits**: 70% train, 15% validation, 15% test

---

In [14]:
# ------------------------------------------------------------------------------
# LOAD DATA
# ------------------------------------------------------------------------------

import pandas as pd
import os

# Load dataset metadata containing image labels and patient information
# Updated path: Data moved to C:\xray_data for faster access (outside OneDrive)
csv_path = r'C:\xray_data\archive (1)\Data_Entry_2017.csv'
df = pd.read_csv(csv_path)

print(f"Loaded {len(df):,} chest X-ray image records")
print(f"Data location: C:\\xray_data (local SSD)")

Loaded 112,120 chest X-ray image records
Data location: C:\xray_data (local SSD)


## 1. Data Loading
Load the complete dataset metadata from CSV files containing image labels and bounding box annotations.

In [15]:
# ------------------------------------------------------------------------------
# PREPARE DISEASE CATEGORIES - TOP 10 ONLY
# ------------------------------------------------------------------------------

from collections import Counter

# Parse disease labels (pipe-separated for multi-label cases)
all_diseases = []
for labels in df['Finding Labels']:
    diseases = labels.split('|')
    all_diseases.extend(diseases)

disease_counts = Counter(all_diseases)

# Get top 10 most common diseases
top_10_diseases = [disease for disease, count in disease_counts.most_common(10)]

# Define class names (top 10 only, sorted for consistency)
CLASS_NAMES = sorted(top_10_diseases)

print(f"\nUsing TOP 10 Most Common Diseases:")
print(f"Disease Categories: {len(CLASS_NAMES)}")
for i, class_name in enumerate(CLASS_NAMES):
    count = disease_counts[class_name]
    print(f"  [{i:2}] {class_name:25} - {count:6,} images")

print(f"\nExcluded diseases (rare cases):")
excluded = [d for d in disease_counts.keys() if d not in CLASS_NAMES]
for disease in sorted(excluded, key=lambda x: disease_counts[x], reverse=True):
    count = disease_counts[disease]
    print(f"  {disease:25} - {count:6,} images")


Using TOP 10 Most Common Diseases:
Disease Categories: 10
  [ 0] Atelectasis               - 11,559 images
  [ 1] Cardiomegaly              -  2,776 images
  [ 2] Consolidation             -  4,667 images
  [ 3] Effusion                  - 13,317 images
  [ 4] Infiltration              - 19,894 images
  [ 5] Mass                      -  5,782 images
  [ 6] No Finding                - 60,361 images
  [ 7] Nodule                    -  6,331 images
  [ 8] Pleural_Thickening        -  3,385 images
  [ 9] Pneumothorax              -  5,302 images

Excluded diseases (rare cases):
  Emphysema                 -  2,516 images
  Edema                     -  2,303 images
  Fibrosis                  -  1,686 images
  Pneumonia                 -  1,431 images
  Hernia                    -    227 images


In [16]:
# ------------------------------------------------------------------------------
# PREPARE MULTI-LABEL DATA
# ------------------------------------------------------------------------------

# Preserve all data (no filtering)
multilabel_df = df.copy()

# Analyze multi-label distribution
has_multiple = df['Finding Labels'].str.contains('|', regex=False)
single_count = (~has_multiple).sum()
multi_count = has_multiple.sum()

print(f"\nDataset Composition:")
print(f"Total images: {len(df):,}")
print(f"  Single disease: {single_count:,} ({single_count/len(df)*100:.1f}%)")
print(f"  Multiple diseases: {multi_count:,} ({multi_count/len(df)*100:.1f}%)")


Dataset Composition:
Total images: 112,120
  Single disease: 91,324 (81.5%)
  Multiple diseases: 20,796 (18.5%)


## 2. Dataset Statistics & Multi-Label Analysis
Analyze the complete dataset to understand disease distribution and multi-label characteristics. This informs our approach to handling cases where patients have multiple conditions.

In [17]:
# ------------------------------------------------------------------------------
# LOAD SEGMENTATION DATA
# ------------------------------------------------------------------------------

bbox_path = r'C:\xray_data\archive (1)\BBox_List_2017.csv'
bbox_df = pd.read_csv(bbox_path)

# Flag images with bounding box annotations
multilabel_df['Has_BBox'] = multilabel_df['Image Index'].isin(bbox_df['Image Index'])

print(f"\nSegmentation Data:")
print(f"  Total annotations: {len(bbox_df):,}")
print(f"  Images with bounding boxes: {multilabel_df['Has_BBox'].sum():,}")


Segmentation Data:
  Total annotations: 984
  Images with bounding boxes: 880


In [18]:
# ------------------------------------------------------------------------------
# SAMPLE DATASET
# ------------------------------------------------------------------------------

from sklearn.utils import resample

# Configuration
MAX_TOTAL_IMAGES = 40000  # Results in ~28k train, ~6k val, ~6k test

# Random sampling with reproducibility
if len(multilabel_df) > MAX_TOTAL_IMAGES:
    sampled_df = resample(multilabel_df, n_samples=MAX_TOTAL_IMAGES, random_state=42)
else:
    sampled_df = multilabel_df

# Shuffle for randomness
sampled_df = sampled_df.sample(frac=1, random_state=42).reset_index(drop=True)

print(f"\nSampling Results:")
print(f"  Total: {len(sampled_df):,} images")
print(f"  Single disease: {(~sampled_df['Finding Labels'].str.contains('|', regex=False)).sum():,}")
print(f"  Multiple diseases: {sampled_df['Finding Labels'].str.contains('|', regex=False).sum():,}")


Sampling Results:
  Total: 40,000 images
  Single disease: 32,651
  Multiple diseases: 7,349


## 3. Dataset Sampling
Select a representative subset of 40,000 images to balance computational efficiency with statistical validity. This subset maintains the original distribution of disease prevalence and multi-label characteristics.

In [19]:
# -------------------------------------------------------------------------------
# PATIENT-AWARE TRAIN/VALIDATION/TEST SPLITTING (NO DATA LEAKAGE)
# -------------------------------------------------------------------------------

from sklearn.model_selection import train_test_split

# Extract patient IDs (assumes first part of 'Image Index' is patient ID)
sampled_df['PatientID'] = sampled_df['Image Index'].str.split('_').str[0]

# Get unique patient IDs
unique_patients = sampled_df['PatientID'].unique()

# Split patient IDs into train/val/test
train_patients, temp_patients = train_test_split(unique_patients, test_size=0.3, random_state=42)
val_patients, test_patients = train_test_split(temp_patients, test_size=0.5, random_state=42)

# Assign images to splits based on patient ID
train_df = sampled_df[sampled_df['PatientID'].isin(train_patients)].reset_index(drop=True)
val_df = sampled_df[sampled_df['PatientID'].isin(val_patients)].reset_index(drop=True)
test_df = sampled_df[sampled_df['PatientID'].isin(test_patients)].reset_index(drop=True)

print(f"Train: {len(train_df):,} images ({train_df['Has_BBox'].sum():,} with bboxes)")
print(f"Val:   {len(val_df):,} images ({val_df['Has_BBox'].sum():,} with bboxes)")
print(f"Test:  {len(test_df):,} images ({test_df['Has_BBox'].sum():,} with bboxes)")

Train: 28,153 images (221 with bboxes)
Val:   5,887 images (61 with bboxes)
Test:  5,960 images (35 with bboxes)


## 4. Patient-Aware Data Splitting
**Critical for Medical AI**: Split data by patient ID rather than by individual images. This prevents data leakage where multiple images from the same patient could appear in both training and test sets, which would artificially inflate model performance and compromise validation integrity.

This approach ensures:
- No patient overlap between train/validation/test sets
- Realistic evaluation of model generalization to new patients
- Compliance with best practices for medical ML research

In [20]:
# ------------------------------------------------------------------------------
# CREATE BINARY DISEASE COLUMNS
# ------------------------------------------------------------------------------

# Generate binary features for all splits
for df in [train_df, val_df, test_df]:
    for disease in CLASS_NAMES:
        df[disease] = df['Finding Labels'].str.contains(disease, regex=False).astype(int)

print(f"\nFeature Engineering Complete:")
print(f"  Created {len(CLASS_NAMES)} binary disease indicators per image")


Feature Engineering Complete:
  Created 10 binary disease indicators per image

  Created 10 binary disease indicators per image


## 5. Feature Engineering
Create binary indicator columns for each disease to enable multi-label classification. Each image gets a 15-dimensional binary vector indicating presence/absence of each disease.

In [21]:
# ------------------------------------------------------------------------------
# CREATE FOLDER STRUCTURE
# ------------------------------------------------------------------------------

import shutil
from pathlib import Path

# Use local C: drive location (outside OneDrive for faster access)
base_dir = Path(r'C:\xray_data\data')

# Create top-level split directories
# data/train/, data/val/, data/test/ (no disease-specific subfolders)
for split in ['train', 'val', 'test']:
    split_dir = base_dir / split
    split_dir.mkdir(parents=True, exist_ok=True)

print("\nDirectory Structure:")
print(f"  ✓ {base_dir / 'train'}  (all training images)")
print(f"  ✓ {base_dir / 'val'}    (all validation images)")
print(f"  ✓ {base_dir / 'test'}   (all test images)")
print(f"\nNote: Labels stored in CSV metadata (multi-label compatible)")
print(f"Location: C:\\xray_data (local SSD, outside OneDrive)")


Directory Structure:
  ✓ C:\xray_data\data\train  (all training images)
  ✓ C:\xray_data\data\val    (all validation images)
  ✓ C:\xray_data\data\test   (all test images)

Note: Labels stored in CSV metadata (multi-label compatible)
Location: C:\xray_data (local SSD, outside OneDrive)


## 6. File Organization
Organize images into a flat folder structure compatible with multi-label classification:
- `data/train/` - Training images
- `data/val/` - Validation images  
- `data/test/` - Test images

Labels are stored in CSV metadata files rather than folder structure, as images can have multiple disease labels.

In [22]:
# ------------------------------------------------------------------------------
# ORGANIZE FILES (PARALLEL PROCESSING)
# ------------------------------------------------------------------------------

from tqdm import tqdm
import concurrent.futures
from functools import lru_cache
import glob

# Check if data is already organized
if (base_dir / 'train').exists() and len(list((base_dir / 'train').rglob('*.png'))) > 0:
    print(f"Data already organized: {len(list((base_dir / 'train').rglob('*.png')))} images found")
else:
    # Cache the find_image results to avoid repeated glob searches
    @lru_cache(maxsize=None)
    def find_image_cached(image_name):
        """Find which folder contains a specific image (cached)"""
        # Updated path for local C: drive
        pattern = f"C:/xray_data/archive (1)/images_*/images/{image_name}"
        matches = glob.glob(pattern)
        if matches:
            return matches[0]
        return None
    
    def copy_single_image(args):
        """Copy a single image (for parallel processing)"""
        image_name, split_name = args
        
        # Find source image
        source_path = find_image_cached(image_name)
        if source_path is None:
            return f"Warning: Could not find {image_name}"
        
        # FLAT STRUCTURE: All images directly in split folder
        dest_path = base_dir / split_name / image_name
        try:
            shutil.copy2(source_path, dest_path)
            return None  # Success
        except Exception as e:
            return f"Error copying {image_name}: {e}"
    
    def copy_images_to_folders_parallel(dataframe, split_name, max_workers=8):
        """Copy images using parallel processing for speed"""
        # FLAT STRUCTURE: No class subfolders needed
        args_list = [
            (row['Image Index'], split_name)
            for _, row in dataframe.iterrows()
        ]
        
        # Use ThreadPoolExecutor for I/O-bound file copying
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            # Process files in parallel with progress bar
            results = list(tqdm(
                executor.map(copy_single_image, args_list),
                total=len(args_list),
                desc=f"{split_name}"
            ))
            # Collect errors
            errors = [r for r in results if r is not None]
        
        return len(dataframe) - len(errors), len(errors)
    
    # Copy all splits with parallel processing
    import time
    start_time = time.time()
    
    train_success, train_errors = copy_images_to_folders_parallel(train_df, 'train', max_workers=8)
    val_success, val_errors = copy_images_to_folders_parallel(val_df, 'val', max_workers=8)
    test_success, test_errors = copy_images_to_folders_parallel(test_df, 'test', max_workers=8)
    
    elapsed = time.time() - start_time
    total_copied = train_success + val_success + test_success
    
    print(f"\nOrganized {total_copied:,} images in {elapsed:.1f}s ({total_copied/elapsed:.0f} img/s)")
    if train_errors + val_errors + test_errors > 0:
        print(f"Errors: {train_errors + val_errors + test_errors}")

Data already organized: 23593 images found


## 7. Parallel Image Processing
Copy ~40,000 images to organized folders using multi-threaded processing for efficiency. Implementation uses Python's `concurrent.futures` for parallel I/O operations.

In [23]:
# ------------------------------------------------------------------------------
# DATA VALIDATION
# ------------------------------------------------------------------------------
print("\n" + "=" * 80)
print("VALIDATING DATASET ORGANIZATION")
print("=" * 80)

splits_to_validate = ['train', 'val', 'test']
validation_results = {}

for split_name in splits_to_validate:
    print(f"\n{split_name.upper()} Split:")
    
    # Get expected images from CSV
    if split_name == 'train':
        expected_df = train_df
    elif split_name == 'val':
        expected_df = val_df
    else:
        expected_df = test_df
    
    expected_images = set(expected_df['Image Index'].tolist())
    expected_count = len(expected_images)
    
    # Get actual images from folder
    split_folder = base_dir / split_name
    if not split_folder.exists():
        print(f"ERROR: Folder does not exist: {split_folder}")
        validation_results[split_name] = {'status': 'FAILED', 'reason': 'folder_missing'}
        continue
    
    actual_images = set([f.name for f in split_folder.glob('*.png')])
    actual_count = len(actual_images)
    
    print(f"  Expected: {expected_count:,} images")
    print(f"  Actual:   {actual_count:,} images")
    
    if expected_count != actual_count:
        print(f"  WARNING: Count mismatch")
        missing = expected_images - actual_images
        extra = actual_images - expected_images
        if missing:
            print(f"  Missing {len(missing)} images")
        if extra:
            print(f"  Extra {len(extra)} images")
        validation_results[split_name] = {'status': 'WARNING', 'reason': 'count_mismatch'}
    else:
        missing = expected_images - actual_images
        if missing:
            print(f"  ERROR: {len(missing)} expected images missing")
            validation_results[split_name] = {'status': 'FAILED', 'reason': 'missing_images'}
        else:
            extra = actual_images - expected_images
            if extra:
                print(f"  WARNING: {len(extra)} unexpected extra images")
                validation_results[split_name] = {'status': 'WARNING', 'reason': 'extra_images'}
            else:
                print(f"  Status: PASSED")
                validation_results[split_name] = {'status': 'PASSED'}

print("\n" + "=" * 80)
all_passed = all(v['status'] == 'PASSED' for v in validation_results.values())
if all_passed:
    print("ALL VALIDATIONS PASSED - Dataset ready for training")
else:
    print("VALIDATION WARNINGS/ERRORS - Review above for details")
print("=" * 80)


VALIDATING DATASET ORGANIZATION

TRAIN Split:
  Expected: 23,593 images
  Actual:   23,593 images
  Status: PASSED

VAL Split:
  Expected: 4,969 images
  Actual:   4,969 images
  Status: PASSED

TEST Split:
  Expected: 5,037 images
  Actual:   5,037 images
  Status: PASSED

ALL VALIDATIONS PASSED - Dataset ready for training


## 8. Data Validation
Comprehensive validation to ensure data integrity:
- Verify all expected images are present in organized folders
- Check for missing or extra files
- Validate image counts match metadata
- Confirm successful file organization

In [24]:
# ------------------------------------------------------------------------------
# CLASS DISTRIBUTION ANALYSIS
# ------------------------------------------------------------------------------
print("\n" + "=" * 80)
print("CLASS DISTRIBUTION ANALYSIS")
print("=" * 80)

# Analyze disease distribution across splits
disease_stats = []

for disease in CLASS_NAMES:
    train_count = train_df[disease].sum()
    val_count = val_df[disease].sum()
    test_count = test_df[disease].sum()
    total = train_count + val_count + test_count
    
    disease_stats.append({
        'Disease': disease,
        'Train': train_count,
        'Val': val_count,
        'Test': test_count,
        'Total': total,
        'Train%': 100 * train_count / total if total > 0 else 0,
        'Val%': 100 * val_count / total if total > 0 else 0,
        'Test%': 100 * test_count / total if total > 0 else 0
    })

# Convert to DataFrame for easy viewing
dist_df = pd.DataFrame(disease_stats)
dist_df = dist_df.sort_values('Total', ascending=False)

print("\nDisease Distribution Across Splits:")
print("=" * 120)
print(f"{'Disease':<25} {'Train':>8} {'Val':>8} {'Test':>8} {'Total':>8} {'Train%':>8} {'Val%':>8} {'Test%':>8}")
print("-" * 120)

for _, row in dist_df.iterrows():
    print(f"{row['Disease']:<25} {row['Train']:>8.0f} {row['Val']:>8.0f} {row['Test']:>8.0f} "
          f"{row['Total']:>8.0f} {row['Train%']:>7.1f}% {row['Val%']:>7.1f}% {row['Test%']:>7.1f}%")

# Check for balanced splits
avg_train_pct = dist_df['Train%'].mean()
avg_val_pct = dist_df['Val%'].mean()
avg_test_pct = dist_df['Test%'].mean()

print(f"\nAverage split percentages: Train={avg_train_pct:.1f}%, Val={avg_val_pct:.1f}%, Test={avg_test_pct:.1f}%")
print(f"Target: Train=70%, Val=15%, Test=15%")

# Check if any disease has severely unbalanced split
issues = []
for _, row in dist_df.iterrows():
    if abs(row['Train%'] - 70) > 10 or abs(row['Val%'] - 15) > 5 or abs(row['Test%'] - 15) > 5:
        issues.append(f"  {row['Disease']}: Train={row['Train%']:.1f}%, Val={row['Val%']:.1f}%, Test={row['Test%']:.1f}%")

if issues:
    print("\nWARNING: Some diseases have unbalanced splits (>10% deviation):")
    for issue in issues:
        print(issue)

# Multi-label statistics
print("\n" + "=" * 80)
print("MULTI-LABEL STATISTICS")
print("=" * 80)

for split_name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]:
    diseases_per_image = df[CLASS_NAMES].sum(axis=1)
    
    print(f"\n{split_name}:")
    print(f"  Total images: {len(df):,}")
    print(f"  Single disease:  {(diseases_per_image == 1).sum():,} ({100*(diseases_per_image == 1).sum()/len(df):.1f}%)")
    print(f"  Multi-disease: {(diseases_per_image > 1).sum():,} ({100*(diseases_per_image > 1).sum()/len(df):.1f}%)")
    print(f"  Max diseases/image: {diseases_per_image.max():.0f}")
    print(f"  Avg diseases/image: {diseases_per_image.mean():.2f}")

print("\n" + "=" * 80)


CLASS DISTRIBUTION ANALYSIS

Disease Distribution Across Splits:
Disease                      Train      Val     Test    Total   Train%     Val%    Test%
------------------------------------------------------------------------------------------------------------------------
No Finding                   15272     3176     3225    21673    70.5%    14.7%    14.9%
Infiltration                  5022      934     1049     7005    71.7%    13.3%    15.0%
Effusion                      3266      677      668     4611    70.8%    14.7%    14.5%
Atelectasis                   2871      627      653     4151    69.2%    15.1%    15.7%
Nodule                        1564      355      302     2221    70.4%    16.0%    13.6%
Mass                          1419      353      339     2111    67.2%    16.7%    16.1%
Pneumothorax                  1215      320      282     1817    66.9%    17.6%    15.5%
Consolidation                 1204      228      242     1674    71.9%    13.6%    14.5%
Pleural_Thic

## 9. Class Distribution Analysis
Analyze disease distribution across splits to ensure:
- Balanced representation of all disease categories
- Consistent split ratios (70/15/15) across diseases
- Documentation of multi-label characteristics
- Statistical validation of sampling approach

In [25]:
# ------------------------------------------------------------------------------
# SAVE SEGMENTATION ANNOTATIONS
# ------------------------------------------------------------------------------

import json

# Save bounding box annotations for each split
for split_name, split_df in [('train', train_df), ('val', val_df), ('test', test_df)]:
    # Get images with bounding boxes
    images_with_bbox = split_df[split_df['Has_BBox']]['Image Index'].tolist()
    
    # Filter bbox data for this split
    split_bbox = bbox_df[bbox_df['Image Index'].isin(images_with_bbox)]
    
    bbox_dict = {}
    for _, row in split_bbox.iterrows():
        img_name = row['Image Index']
        if img_name not in bbox_dict:
            bbox_dict[img_name] = []
        bbox_dict[img_name].append({
            'class': row['Finding Label'],
            'x': row['Bbox [x'],
            'y': row['y'],
            'w': row['w'],
            'h': row['h]']
        })
    
    # Save to file (updated path for local C: drive)
    output_path = f'C:/xray_data/data/{split_name}_bboxes.json'
    with open(output_path, 'w') as f:
        json.dump(bbox_dict, f, indent=2)

## 10. Export Metadata & Annotations
Generate comprehensive metadata files for model training:
- **Class mapping**: Disease ID to name mapping
- **CSV metadata**: Train/val/test labels and patient IDs
- **Bounding boxes**: Segmentation annotations (JSON format)
- **Summary statistics**: Complete dataset documentation including patient-level statistics and validation results

In [26]:
# ------------------------------------------------------------------------------
# SAVE DATASET METADATA
# ------------------------------------------------------------------------------

import json

# Save class mapping (updated path for local C: drive)
class_mapping = {i: class_name for i, class_name in enumerate(CLASS_NAMES)}
with open('C:/xray_data/data/class_mapping.json', 'w') as f:
    json.dump(class_mapping, f, indent=2)

# Save metadata for each split
train_df.to_csv('C:/xray_data/data/train_metadata.csv', index=False)
val_df.to_csv('C:/xray_data/data/val_metadata.csv', index=False)
test_df.to_csv('C:/xray_data/data/test_metadata.csv', index=False)

# Create comprehensive summary with patient-level and validation stats
from collections import Counter

# Extract patient IDs from each split
train_patients = set(train_df['Image Index'].str[:8])
val_patients = set(val_df['Image Index'].str[:8])
test_patients = set(test_df['Image Index'].str[:8])

# Count images per disease
disease_counts = {}
for disease in CLASS_NAMES:
    disease_counts[disease] = {
        'train': int(train_df[disease].sum()),
        'val': int(val_df[disease].sum()),
        'test': int(test_df[disease].sum()),
        'total': int(train_df[disease].sum() + val_df[disease].sum() + test_df[disease].sum())
    }

summary = {
    'num_classes': len(CLASS_NAMES),
    'class_names': CLASS_NAMES,
    'total_images': len(sampled_df),
    'train_images': len(train_df),
    'val_images': len(val_df),
    'test_images': len(test_df),
    'images_with_bboxes_train': int(train_df['Has_BBox'].sum()),
    'images_with_bboxes_val': int(val_df['Has_BBox'].sum()),
    'images_with_bboxes_test': int(test_df['Has_BBox'].sum()),
    'date_processed': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'),
    'task': 'multi-label classification (10 classes - top diseases only) + segmentation',
    'multi_label': True,
    'single_label_count': int((~sampled_df['Finding Labels'].str.contains('|', regex=False)).sum()),
    'multi_label_count': int(sampled_df['Finding Labels'].str.contains('|', regex=False).sum()),
    # Patient-level statistics (CRITICAL for medical AI)
    'patient_splitting': {
        'enabled': True,
        'unique_patients_train': len(train_patients),
        'unique_patients_val': len(val_patients),
        'unique_patients_test': len(test_patients),
        'unique_patients_total': len(train_patients) + len(val_patients) + len(test_patients),
        'patient_overlap_train_val': len(train_patients & val_patients),
        'patient_overlap_train_test': len(train_patients & test_patients),
        'patient_overlap_val_test': len(val_patients & test_patients),
        'avg_images_per_patient': len(sampled_df) / (len(train_patients) + len(val_patients) + len(test_patients))
    },
    # Validation results
    'validation_status': validation_results,
    'all_validations_passed': all(v['status'] == 'PASSED' for v in validation_results.values()),
    # Disease distribution
    'disease_distribution': disease_counts,
    # Folder structure
    'folder_structure': 'flat',  # All images directly in split folders (train/, val/, test/)
    'labels_stored_in': 'CSV files only (multi-label compatible)',
    'data_location': 'C:\\xray_data (local SSD, outside OneDrive for faster training)'
}

with open('C:/xray_data/data/dataset_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\nMetadata saved to C:\\xray_data\\data\\")
print("  - class_mapping.json")
print("  - train_metadata.csv, val_metadata.csv, test_metadata.csv")
print("  - train_bboxes.json, val_bboxes.json, test_bboxes.json")
print("  - dataset_summary.json")
print(f"\nPatient-aware splitting:")
print(f"  Train: {len(train_patients):,} patients")
print(f"  Val:   {len(val_patients):,} patients")
print(f"  Test:  {len(test_patients):,} patients")
print(f"  No patient overlap between splits")
print(f"\nMulti-label classification (TOP 10 diseases):")
print(f"  Single disease: {summary['single_label_count']:,} images")
print(f"  Multi-disease: {summary['multi_label_count']:,} images")
print(f"\nFolder structure: Flat (all images in C:\\xray_data\\data\\train\\ val\\ test\\)")
print(f"Labels: Stored in CSV files (multi-label compatible)")
print(f"Location: C:\\xray_data (local SSD, outside OneDrive)")


Metadata saved to C:\xray_data\data\
  - class_mapping.json
  - train_metadata.csv, val_metadata.csv, test_metadata.csv
  - train_bboxes.json, val_bboxes.json, test_bboxes.json
  - dataset_summary.json

Patient-aware splitting:
  Train: 10,653 patients
  Val:   2,283 patients
  Test:  2,283 patients
  No patient overlap between splits

Multi-label classification (TOP 10 diseases):
  Single disease: 32,651 images
  Multi-disease: 7,349 images

Folder structure: Flat (all images in C:\xray_data\data\train\ val\ test\)
Labels: Stored in CSV files (multi-label compatible)
Location: C:\xray_data (local SSD, outside OneDrive)
