# Data Augmentation Pipeline v2.0

**For Research Paper**: Complete pipeline for augmenting the symptom-disease dataset.

## Pipeline Overview

| Stage | Description | Output |
|-------|-------------|--------|
| **0** | Expand symptom vocabulary with Mayo Clinic symptoms | `symptom_columns.json` (updated) |
| **1** | Generate synthetic samples for rare diseases (<20 samples) | `symptoms_augmented_no_demographics.csv` |
| **2** | Add demographic variables (age, sex) | `symptoms_augmented_with_demographics.csv` |

## Requirements
- `data/rare_diseases_symptoms_template.json` - Filled with Mayo Clinic symptoms
- `data/final_disease_demographics.json` - Demographics from ChatGPT + synthetic rules

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
import json
import re
import random
import sys
import gc
from collections import Counter

# Add project root to path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

from utils.symptom_normalizer import normalize_symptom

# Paths
# Input files
data_path = project_root / "data" / "processed" / "symptoms" / "symptoms_to_disease_cleaned.csv"
symptom_cols_path = project_root / "data" / "symptom_vocabulary.json"
template_path = project_root / "data" / "rare_diseases_symptoms_template.json"
category_map_path = project_root / "data" / "disease_mapping.json"
demographics_path = project_root / "data" / "final_disease_demographics.json"

# Output files - vocabulary saved to SAME file (overwrites original)
expanded_vocab_path = symptom_cols_path  # Overwrites original
output_no_demo_path = project_root / "data" / "processed" / "symptoms" / "symptoms_augmented_no_demographics.csv"
output_with_demo_path = project_root / "data" / "processed" / "symptoms" / "symptoms_augmented_with_demographics.csv"

print(f"Project root: {project_root}")
print(f"\nInput files:")
print(f"  Data: {data_path.exists()} - {data_path}")
print(f"  Vocab: {symptom_cols_path.exists()} - {symptom_cols_path}")
print(f"  Template: {template_path.exists()} - {template_path}")
print(f"  Demographics: {demographics_path.exists()} - {demographics_path}")


Project root: c:\Users\henry\Desktop\Programming\Python\Multimodal_Diagnosis

Input files:
  Data: True - c:\Users\henry\Desktop\Programming\Python\Multimodal_Diagnosis\data\processed\symptoms\symptoms_to_disease_cleaned.csv
  Vocab: True - c:\Users\henry\Desktop\Programming\Python\Multimodal_Diagnosis\data\symptom_vocabulary.json
  Template: True - c:\Users\henry\Desktop\Programming\Python\Multimodal_Diagnosis\data\rare_diseases_symptoms_template.json
  Demographics: True - c:\Users\henry\Desktop\Programming\Python\Multimodal_Diagnosis\data\final_disease_demographics.json


---
# Stage 0: Expand Symptom Vocabulary

The current vocabulary has 377 symptoms. Many Mayo Clinic symptoms don't have exact matches.

**Strategy**: Add clinically important new symptoms to the vocabulary (appearing in 5+ diseases).

In [2]:
# Load current vocabulary
with open(symptom_cols_path) as f:
    ORIGINAL_VOCAB = json.load(f)
ORIGINAL_SET = set(s.lower() for s in ORIGINAL_VOCAB)

print(f"Original vocabulary: {len(ORIGINAL_VOCAB)} symptoms")

# Load template with Mayo Clinic symptoms
with open(template_path) as f:
    template = json.load(f)

print(f"Diseases in template: {len(template)}")

Original vocabulary: 458 symptoms
Diseases in template: 135


In [3]:
# Extract all unique symptoms from template
all_mayo_symptoms = set()
symptom_counts = Counter()

print("Extracting and normalizing symptoms from template...")
for disease, info in template.items():
    mayo = info.get("mayo_clinic_symptoms", [])
    for sym in mayo:
        # Normalize symptom using project specific normalizer
        sym_norm = normalize_symptom(sym)
        
        if sym_norm and not sym_norm.startswith('"notes"'):
            all_mayo_symptoms.add(sym_norm)
            symptom_counts[sym_norm] += 1

print(f"Total unique mayo symptoms: {len(all_mayo_symptoms)}")

# Find symptoms NOT in current vocabulary
new_symptoms = [s for s in all_mayo_symptoms if s not in ORIGINAL_SET]
print(f"New symptoms (not in vocabulary): {len(new_symptoms)}")


Extracting and normalizing symptoms from template...
Total unique mayo symptoms: 690
New symptoms (not in vocabulary): 525


In [4]:
# Show most common new symptoms
new_symptom_counts = [(s, symptom_counts[s]) for s in new_symptoms]
new_symptom_counts.sort(key=lambda x: -x[1])

print("Most common new symptoms (top 50):")
print("-" * 60)
for sym, count in new_symptom_counts[:50]:
    print(f"  [{count:2d}x] {sym}")

Most common new symptoms (top 50):
------------------------------------------------------------
  [ 1x] enlargement of the breast tissue
  [ 1x] heavy menstrual bleeding
  [ 1x] black stool
  [ 1x] lines of rash
  [ 1x] language difficulty
  [ 1x] swelling of ankles and legs
  [ 1x] poor balance or coordination
  [ 1x] bowel obstruction
  [ 1x] discoloration
  [ 1x] neck stiffness
  [ 1x] red discolored skin
  [ 1x] urinary tract infections
  [ 1x] blisters on chest
  [ 1x] heart disorders
  [ 1x] long-lasting cough with thick mucus
  [ 1x] red streaks on skin
  [ 1x] issues with cognitive development
  [ 1x] bowed or bent bones
  [ 1x] uterine tenderness
  [ 1x] coordination problems
  [ 1x] iris that jiggles
  [ 1x] cloudy urine
  [ 1x] trouble with speech
  [ 1x] musculoskeletal pain
  [ 1x] webbed nect
  [ 1x] changes in alertness
  [ 1x] increased sensitivity to cold
  [ 1x] memory fog
  [ 1x] enlarged head in infants
  [ 1x] fluid buildup around the lungs
  [ 1x] involuntary spas

In [5]:
# CONFIGURATION: Filter new symptoms
# Only add symptoms that appear in at least MIN_DISEASE_COUNT diseases
MIN_DISEASE_COUNT = 2  # Only add symptoms appearing in 5+ diseases

symptoms_to_add = [s for s, count in new_symptom_counts if count >= MIN_DISEASE_COUNT]

print(f"Symptoms appearing in >= {MIN_DISEASE_COUNT} diseases: {len(symptoms_to_add)}")
print("\nSymptoms to add:")
for s in symptoms_to_add:
    print(f"  - {s}")

Symptoms appearing in >= 2 diseases: 0

Symptoms to add:


In [6]:
# Create expanded vocabulary
EXPANDED_VOCAB = ORIGINAL_VOCAB + symptoms_to_add
EXPANDED_SET = set(s.lower() for s in EXPANDED_VOCAB)

print(f"Original vocabulary: {len(ORIGINAL_VOCAB)} symptoms")
print(f"Expanded vocabulary: {len(EXPANDED_VOCAB)} symptoms (+{len(symptoms_to_add)})")

# Save expanded vocabulary (overwrites original)
with open(expanded_vocab_path, 'w') as f:
    json.dump(EXPANDED_VOCAB, f, indent=2)

print(f"\nSaved expanded vocabulary to: {expanded_vocab_path}")
print("NOTE: Original vocabulary file has been updated with new symptoms.")

Original vocabulary: 458 symptoms
Expanded vocabulary: 458 symptoms (+0)

Saved expanded vocabulary to: c:\Users\henry\Desktop\Programming\Python\Multimodal_Diagnosis\data\symptom_vocabulary.json
NOTE: Original vocabulary file has been updated with new symptoms.


---
# Stage 1: Symptom Mapping & Synthetic Data Generation

1. Map Mayo Clinic symptoms to expanded vocabulary
2. Generate synthetic samples for diseases with <20 samples

In [7]:
# Load original dataset
df = pd.read_csv(data_path)
print(f"Original dataset: {len(df):,} rows, {df['diseases'].nunique()} diseases")

# Get disease counts
counts = df['diseases'].value_counts()
rare_diseases = counts[counts < 20]
print(f"\nDiseases with <20 samples: {len(rare_diseases)}")
print(f"Total samples in rare diseases: {rare_diseases.sum():,}")

Original dataset: 206,267 rows, 627 diseases

Diseases with <20 samples: 128
Total samples in rare diseases: 1,085


In [8]:
# Load category mapping
with open(category_map_path) as f:
    category_map = json.load(f)

# Create disease -> category lookup
disease_to_category = {}
for cat, diseases in category_map.items():
    for d in diseases:
        disease_to_category[d] = cat

print(f"Loaded {len(disease_to_category)} disease -> category mappings")

Loaded 542 disease -> category mappings


In [9]:
def map_symptoms_to_vocab(symptoms_list, vocab_set):
    """
    Map a list of symptoms to the vocabulary.
    Returns symptoms that exist in the vocabulary.
    """
    mapped = []
    for sym in symptoms_list:
        sym_clean = sym.strip().lower()
        if sym_clean in vocab_set:
            mapped.append(sym_clean)
    return list(set(mapped))  # Remove duplicates


def generate_synthetic_samples(disease: str, symptoms: list, n_samples: int,
                               all_symptoms: list, min_sym: int = 4, max_sym: int = 8) -> list:
    """
    Generate synthetic samples for a disease.
    Each sample has random 4-8 symptoms selected from the symptom list.
    """
    samples = []
    category = disease_to_category.get(disease, "Unknown Type")
    
    for _ in range(n_samples):
        # Select random symptoms
        n_sym = random.randint(min_sym, min(max_sym, len(symptoms)))
        selected = random.sample(symptoms, n_sym)
        
        # Create row with all symptoms as 0
        row = {col: 0 for col in all_symptoms}
        
        # Set selected symptoms to 1
        for sym in selected:
            if sym in row:
                row[sym] = 1
        
        row['diseases'] = disease
        row['disease_category'] = category
        row['symptoms'] = ", ".join(selected)
        
        samples.append(row)
    
    return samples

print("Defined mapping and generation functions")

Defined mapping and generation functions


In [10]:
# Map symptoms for each disease in template
disease_mapped_symptoms = {}
mapping_stats = {'total': 0, 'mapped': 0, 'diseases_ready': 0}

for disease, info in template.items():
    mayo = info.get("mayo_clinic_symptoms", [])
    if not mayo:
        continue
    
    mapped = map_symptoms_to_vocab(mayo, EXPANDED_SET)
    mapping_stats['total'] += len(mayo)
    mapping_stats['mapped'] += len(mapped)
    
    if len(mapped) >= 4:  # Minimum for synthetic generation
        disease_mapped_symptoms[disease] = mapped
        mapping_stats['diseases_ready'] += 1

print(f"Symptom mapping results:")
print(f"  Total mayo symptoms: {mapping_stats['total']}")
print(f"  Mapped to vocabulary: {mapping_stats['mapped']} ({100*mapping_stats['mapped']/mapping_stats['total']:.1f}%)")
print(f"  Diseases ready for synthesis (>=4 symptoms): {mapping_stats['diseases_ready']}")

Symptom mapping results:
  Total mayo symptoms: 1255
  Mapped to vocabulary: 578 (46.1%)
  Diseases ready for synthesis (>=4 symptoms): 72


In [11]:
# Generate synthetic samples
random.seed(42)
TARGET_SAMPLES = 25  # Minimum samples per disease

all_synthetic = []
generation_log = []

for disease, symptoms in disease_mapped_symptoms.items():
    current_count = counts.get(disease, 0)
    
    if current_count >= TARGET_SAMPLES:
        continue
    
    n_new = TARGET_SAMPLES - current_count
    samples = generate_synthetic_samples(disease, symptoms, n_new, EXPANDED_VOCAB)
    all_synthetic.extend(samples)
    
    generation_log.append({
        'disease': disease,
        'original': current_count,
        'added': n_new,
        'symptoms_available': len(symptoms)
    })

print(f"Generated {len(all_synthetic):,} synthetic samples for {len(generation_log)} diseases")
print("\nGeneration details:")
for log in generation_log[:20]:
    print(f"  {log['disease']}: {log['original']} -> {log['original'] + log['added']} (+{log['added']}, {log['symptoms_available']} symptoms available)")
if len(generation_log) > 20:
    print(f"  ... and {len(generation_log) - 20} more diseases")

Generated 1,154 synthetic samples for 72 diseases

Generation details:
  rocky mountain spotted fever: 1 -> 25 (+24, 9 symptoms available)
  myocarditis: 1 -> 25 (+24, 5 symptoms available)
  kaposi sarcoma: 1 -> 25 (+24, 5 symptoms available)
  chronic ulcer: 1 -> 25 (+24, 5 symptoms available)
  gas gangrene: 1 -> 25 (+24, 5 symptoms available)
  thalassemia: 1 -> 25 (+24, 4 symptoms available)
  typhoid fever: 1 -> 25 (+24, 5 symptoms available)
  rheumatic fever: 2 -> 25 (+23, 4 symptoms available)
  human immunodeficiency virus infection (hiv): 2 -> 25 (+23, 8 symptoms available)
  hashimoto thyroiditis: 2 -> 25 (+23, 9 symptoms available)
  sporotrichosis: 3 -> 25 (+22, 6 symptoms available)
  cat scratch disease: 3 -> 25 (+22, 6 symptoms available)
  dengue fever: 3 -> 25 (+22, 7 symptoms available)
  adrenal cancer: 3 -> 25 (+22, 6 symptoms available)
  necrotizing fasciitis: 3 -> 25 (+22, 5 symptoms available)
  connective tissue disorder: 3 -> 25 (+22, 6 symptoms available)
 

In [12]:
# Create expanded base dataset with new symptom columns
# Optimization: Use int8 and avoiding fragmentation by using concat instead of loop insert

# Identify new columns to add
new_cols_to_add = [sym for sym in symptoms_to_add if sym not in df.columns]

if new_cols_to_add:
    print(f"Adding {len(new_cols_to_add)} new symptom columns (int8)...")
    # Create a separate DataFrame for new columns
    new_data = pd.DataFrame(0, index=df.index, columns=new_cols_to_add, dtype='int8')
    
    # Concatenate once
    df = pd.concat([df, new_data], axis=1)
else:
    print("No new columns to add.")

print(f"Expanded original dataset to {len(df.columns)} columns")


No new columns to add.
Expanded original dataset to 377 columns


In [13]:
# Combine original + synthetic
if all_synthetic:
    df_synthetic = pd.DataFrame(all_synthetic)
    
    # Ensure synthetic dataframe has all columns
    # Optimization: Use reindex which is faster/cleaner
    df_synthetic = df_synthetic.reindex(columns=df.columns, fill_value=0).astype(df.dtypes)
    
    # Concatenate
    df_augmented = pd.concat([df, df_synthetic], ignore_index=True)
    
    print(f"Original samples: {len(df):,}")
    print(f"Synthetic samples: {len(df_synthetic):,}")
    print(f"Total augmented: {len(df_augmented):,}")
    
    # cleanup
    del df_synthetic
else:
    df_augmented = df
    print("No synthetic samples generated")

# Free up memory
del df
gc.collect()
print("Memory cleanup: deleted original df")


Original samples: 206,267
Synthetic samples: 1,154
Total augmented: 207,421
Memory cleanup: deleted original df


In [14]:
# Verify rare disease counts improved
new_counts = df_augmented['diseases'].value_counts()
new_rare = new_counts[new_counts < 20]

print(f"Before augmentation: {len(rare_diseases)} diseases with <20 samples")
print(f"After augmentation: {len(new_rare)} diseases with <20 samples")
print(f"\nDiseases still below 20 samples:")
for d, c in new_rare.items():
    print(f"  {c:2d}  {d}")

Before augmentation: 128 diseases with <20 samples
After augmentation: 59 diseases with <20 samples

Diseases still below 20 samples:
  19  otosclerosis
  18  cyst of the eyelid
  14  fibrocystic breast disease
  14  pneumoconiosis
  13  congenital malformation syndrome
  13  hpv
  13  factitious disorder
  12  raynaud disease
  12  moyamoya disease
  11  galactorrhea of unknown cause
  11  zenker diverticulum
  11  myoclonus
  11  pulmonic valve disease
  10  testicular cancer
  10  vesicoureteral reflux
  10  avascular necrosis
  10  reactive arthritis
  10  decubitus ulcer
  10  optic neuritis
  10  granuloma inguinale
   9  vacterl syndrome
   9  lichen planus
   9  hemarthrosis
   9  placenta previa
   9  aphakia
   9  thyroid cancer
   8  hypercholesterolemia
   8  spinocerebellar ataxia
   7  omphalitis
   7  pemphigus
   6  empyema
   6  priapism
   6  vulvar cancer
   6  breast cancer
   6  hypertrophic obstructive cardiomyopathy (hocm)
   6  tuberous sclerosis
   6  g6pd enzy

In [15]:
# Save dataset WITHOUT demographics
df_augmented.to_csv(output_no_demo_path, index=False)

print(f"Saved dataset WITHOUT demographics:")
print(f"  Path: {output_no_demo_path}")
print(f"  Size: {output_no_demo_path.stat().st_size / 1024 / 1024:.1f} MB")
print(f"  Rows: {len(df_augmented):,}")
print(f"  Columns: {len(df_augmented.columns)}")

Saved dataset WITHOUT demographics:
  Path: c:\Users\henry\Desktop\Programming\Python\Multimodal_Diagnosis\data\processed\symptoms\symptoms_augmented_no_demographics.csv
  Size: 156.8 MB
  Rows: 207,421
  Columns: 377


---
# Stage 2: Add Demographics (Age, Sex)

Using merged demographics from ChatGPT + synthetic rules.

In [16]:
# Load demographics
with open(demographics_path) as f:
    demographics = json.load(f)

print(f"Loaded demographics for {len(demographics)} diseases")

# Check coverage
all_diseases = set(df_augmented['diseases'].unique())
demo_diseases = set(demographics.keys())
covered = all_diseases & demo_diseases
missing = all_diseases - demo_diseases

print(f"\nDemographic coverage:")
print(f"  Total diseases in dataset: {len(all_diseases)}")
print(f"  Covered by demographics: {len(covered)} ({100*len(covered)/len(all_diseases):.1f}%)")
print(f"  Missing (will use defaults): {len(missing)}")

Loaded demographics for 667 diseases

Demographic coverage:
  Total diseases in dataset: 630
  Covered by demographics: 623 (98.9%)
  Missing (will use defaults): 7


In [17]:
# Default demographics for missing diseases
DEFAULT_DEMO = {
    "age_min": 10,
    "age_max": 80,
    "age_peak": 45,
    "male_pct": 50
}

def sample_age(demo: dict) -> int:
    """Sample age from triangular distribution."""
    age_min = demo.get('age_min', 10)
    age_max = demo.get('age_max', 80)
    age_peak = demo.get('age_peak', 45)
    
    # Handle edge cases
    if age_min == age_max:
        return int(age_min)
    
    age_peak = max(age_min, min(age_peak, age_max))
    
    if age_min == age_peak or age_peak == age_max:
        age = np.random.uniform(age_min, age_max)
    else:
        age = np.random.triangular(age_min, age_peak, age_max)
    
    return int(np.clip(age, 0, 100))


def sample_sex(demo: dict) -> str:
    """Sample sex from Bernoulli distribution."""
    male_pct = demo.get('male_pct', 50)
    return 'M' if np.random.random() * 100 < male_pct else 'F'

print("Defined demographic sampling functions")

Defined demographic sampling functions


In [18]:
# Generate demographics for all rows
np.random.seed(42)
ages = []
sexes = []

for idx, row in df_augmented.iterrows():
    disease = row['diseases']
    demo = demographics.get(disease, DEFAULT_DEMO)
    
    ages.append(sample_age(demo))
    sexes.append(sample_sex(demo))
    
    if idx % 50000 == 0:
        print(f"Processed {idx:,} rows...")

df_augmented['age'] = ages
df_augmented['sex'] = sexes

print(f"\nGenerated demographics for {len(df_augmented):,} rows")

Processed 0 rows...
Processed 50,000 rows...
Processed 100,000 rows...
Processed 150,000 rows...
Processed 200,000 rows...

Generated demographics for 207,421 rows


  df_augmented['age'] = ages
  df_augmented['sex'] = sexes


In [19]:
# Summary statistics
print("Demographics Summary:")
print("=" * 50)
print(f"Age: min={df_augmented['age'].min()}, max={df_augmented['age'].max()}, mean={df_augmented['age'].mean():.1f}")
print(f"Sex: {(df_augmented['sex'] == 'M').mean() * 100:.1f}% male, {(df_augmented['sex'] == 'F').mean() * 100:.1f}% female")
print(f"\n{df_augmented['sex'].value_counts()}")

Demographics Summary:
Age: min=0, max=99, mean=43.4
Sex: 47.2% male, 52.8% female

sex
F    109440
M     97981
Name: count, dtype: int64


In [20]:
# Verify key diseases
print("Verification - Sample diseases:")
print("=" * 80)

verify_diseases = ["prostate cancer", "preeclampsia", "migraine", "pyloric stenosis", "diabetes"]

for disease in verify_diseases:
    subset = df_augmented[df_augmented['diseases'] == disease]
    if len(subset) > 0:
        male_pct = (subset['sex'] == 'M').mean() * 100
        mean_age = subset['age'].mean()
        expected = demographics.get(disease, DEFAULT_DEMO)
        
        print(f"{disease:25} Age: {mean_age:5.1f} (exp: {expected.get('age_peak', '?'):>3}), "
              f"Male: {male_pct:5.1f}% (exp: {expected.get('male_pct', '?'):>3}%), "
              f"n={len(subset)}")

Verification - Sample diseases:
prostate cancer           Age:  70.2 (exp:  70), Male: 100.0% (exp: 100%), n=135
preeclampsia              Age:  30.3 (exp:  30), Male:   0.0% (exp:   0%), n=217
migraine                  Age:  32.0 (exp:  30), Male:  30.3% (exp:  30%), n=221
pyloric stenosis          Age:   0.0 (exp:   0), Male:  64.0% (exp:  80%), n=25
diabetes                  Age:  43.0 (exp:  55), Male: 100.0% (exp:  52%), n=1


In [21]:
# Reorder columns: diseases, category, age, sex, then symptoms
cols = df_augmented.columns.tolist()

# Move key columns to front
key_cols = ['diseases', 'disease_category', 'age', 'sex']
symptom_cols = [c for c in cols if c not in key_cols + ['symptoms']]
final_order = key_cols + symptom_cols + ['symptoms']

# Only include columns that exist
final_order = [c for c in final_order if c in df_augmented.columns]

df_final = df_augmented[final_order]
print(f"Reordered columns: {len(df_final.columns)} total")
print(f"First 10: {df_final.columns[:10].tolist()}")

Reordered columns: 379 total
First 10: ['diseases', 'disease_category', 'age', 'sex', 'anxiety and nervousness', 'depression', 'shortness of breath', 'depressive or psychotic symptoms', 'sharp chest pain', 'dizziness']


In [22]:
# Save dataset WITH demographics
df_final.to_csv(output_with_demo_path, index=False)

print(f"Saved dataset WITH demographics:")
print(f"  Path: {output_with_demo_path}")
print(f"  Size: {output_with_demo_path.stat().st_size / 1024 / 1024:.1f} MB")
print(f"  Rows: {len(df_final):,}")
print(f"  Columns: {len(df_final.columns)}")

Saved dataset WITH demographics:
  Path: c:\Users\henry\Desktop\Programming\Python\Multimodal_Diagnosis\data\processed\symptoms\symptoms_augmented_with_demographics.csv
  Size: 157.8 MB
  Rows: 207,421
  Columns: 379


---
# Summary

## Output Files

In [23]:
print("=" * 80)
print("DATA AUGMENTATION COMPLETE")
print("=" * 80)

print("\n1. EXPANDED VOCABULARY:")
print(f"   {expanded_vocab_path}")
print(f"   Original: {len(ORIGINAL_VOCAB)} symptoms")
print(f"   Expanded: {len(EXPANDED_VOCAB)} symptoms (+{len(symptoms_to_add)})")

print("\n2. DATASET WITHOUT DEMOGRAPHICS:")
print(f"   {output_no_demo_path}")
if output_no_demo_path.exists():
    df_check = pd.read_csv(output_no_demo_path, nrows=1)
    print(f"   Rows: {len(df_augmented):,}, Columns: {len(df_check.columns)}")
    print(f"   Size: {output_no_demo_path.stat().st_size / 1024 / 1024:.1f} MB")

print("\n3. DATASET WITH DEMOGRAPHICS:")
print(f"   {output_with_demo_path}")
if output_with_demo_path.exists():
    df_check = pd.read_csv(output_with_demo_path, nrows=1)
    print(f"   Rows: {len(df_final):,}, Columns: {len(df_check.columns)}")
    print(f"   Size: {output_with_demo_path.stat().st_size / 1024 / 1024:.1f} MB")
    print(f"   Includes: age, sex columns")

DATA AUGMENTATION COMPLETE

1. EXPANDED VOCABULARY:
   c:\Users\henry\Desktop\Programming\Python\Multimodal_Diagnosis\data\symptom_vocabulary.json
   Original: 458 symptoms
   Expanded: 458 symptoms (+0)

2. DATASET WITHOUT DEMOGRAPHICS:
   c:\Users\henry\Desktop\Programming\Python\Multimodal_Diagnosis\data\processed\symptoms\symptoms_augmented_no_demographics.csv
   Rows: 207,421, Columns: 377
   Size: 156.8 MB

3. DATASET WITH DEMOGRAPHICS:
   c:\Users\henry\Desktop\Programming\Python\Multimodal_Diagnosis\data\processed\symptoms\symptoms_augmented_with_demographics.csv
   Rows: 207,421, Columns: 379
   Size: 157.8 MB
   Includes: age, sex columns


---
## Documentation for Research Paper

> **Data Augmentation Pipeline**
>
> **Stage 0 - Vocabulary Expansion:**
> 1. Collected symptom lists from Mayo Clinic and Cleveland Clinic for 135 rare diseases
> 2. Identified symptoms appearing in >=5 diseases not in original vocabulary
> 3. Expanded vocabulary from 377 to N symptoms (updated in place)
>
> **Stage 1 - Synthetic Symptom Data:**
> 1. Mapped Mayo Clinic symptoms to expanded vocabulary
> 2. For diseases with <20 training samples, generated synthetic samples
> 3. Each synthetic sample: random 4-8 symptom subset from disease's symptom profile
> 4. Increased rare disease representation to minimum 25 samples per disease
>
> **Stage 2 - Demographic Variables (Age/Sex):**
> 1. Collected epidemiological demographics via GPT-4 queries
> 2. Applied category-level defaults with keyword-based overrides for sex-specific diseases
> 3. Age sampled from triangular distribution (min, peak, max)
> 4. Sex sampled from Bernoulli distribution based on disease-specific male percentage
>
> Two output datasets were created:
> - `symptoms_augmented_no_demographics.csv`: For symptom-only models
> - `symptoms_augmented_with_demographics.csv`: For multimodal models incorporating age/sex