# 04 — Longitudinal EMR Generator (V2)

**Purpose**: Generate realistic, multi-year patient charts from a structured RUQ pain dataset.

| Component | Engine | Notes |
|-----------|--------|-------|
| Patient Hx parsing | Regex | Extracts demographics, vitals, labs, PE, PMHx from structured text |
| Timeline spine | Rule-based | 3–7 yr history with disease progression |
| Prior encounter notes | LLM (MedGemma 27B) | Full SOAP-format clinical notes |
| Current encounter | LLM (MedGemma 27B) | HPI, PE, A&P, nursing notes |
| Imaging order | From dataset | Order placed (status "Ordered") — no results yet |
| Labs & vitals | Dataset + Rule-based | Dataset values merged with trending history |
| Export | JSON | Individual EMR files + summary CSV |

**Key design**: The provider does **not** know the actual diagnosis — the EMR reflects
diagnostic uncertainty. The actual diagnosis is used behind the scenes for clinical
consistency (labs, vitals, PE findings), but provider-facing notes present the case
as a workup. The EMR ends with the provider ordering the imaging protocol from the dataset.

## 1. Setup and Install Dependencies

In [None]:
import os, sys, torch

# Mount Google Drive to access project files
import subprocess
from google.colab import drive
drive.mount('/content/drive')

# Auto-find project root on Drive
_result = subprocess.run(['find', '/content/drive', '-maxdepth', '5', '-name', 'generated_emrs_27b', '-type', 'd'],
                        capture_output=True, text=True, timeout=30)
_found = [p.replace('/data/generated_emrs_27b', '') for p in _result.stdout.strip().split('\n') if p]
REPO_ROOT = _found[0] if _found else "/content/drive/MyDrive/medgemma-protocol-generator"

os.chdir(REPO_ROOT)
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)

print(f"Project root: {REPO_ROOT}")
print(f"CWD set to:   {os.getcwd()}")
print(f"src exists:   {os.path.isdir('src')}")
print(f"torch:        {torch.__version__}")
print(f"CUDA:         {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU:          {torch.cuda.get_device_name(0)}")

# Install dependencies
!pip -q install -r requirements.txt

In [None]:
from huggingface_hub import login

# Authenticate with HuggingFace
# Uses Colab Secrets (key icon in left sidebar) — add HF_TOKEN there
try:
    from google.colab import userdata
    login(token=userdata.get('HF_TOKEN'))
    print("\u2713 Authenticated via Colab Secret (HF_TOKEN)")
except Exception:
    # Fallback: non-interactive login with token string
    # Replace with your token from https://huggingface.co/settings/tokens
    login(token="YOUR_TOKEN_HERE", add_to_git_credential=False)
    print("\u2713 Authenticated via token string")

## 2. Import the `src/` Package

All V2 models and generators live in the `src/` package.

In [None]:
# REPO_ROOT was set in the setup cell; add to path for src/ imports
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)

print("Project root:", REPO_ROOT)
print("src exists:", os.path.isdir("src"))

In [None]:
import json
import re
import importlib
import pandas as pd

# Force-reload src modules so edits on disk are picked up without runtime restart
import src.emr_models
import src.clinical_knowledge
import src.emr_narrative
import src.vignette_parser
import src.longitudinal_generator

importlib.reload(src.emr_models)
importlib.reload(src.clinical_knowledge)
importlib.reload(src.emr_narrative)
importlib.reload(src.vignette_parser)
importlib.reload(src.longitudinal_generator)

# Core V2 imports (now from freshly-reloaded modules)
from src.emr_models import (
    ClinicalVignette,
    ParsedVignette,
    LongitudinalEMR,
    EncounterRecord,
    ClinicalNote,
    ProblemListEntry,
    MedicationChange,
)

from src.emr_narrative import MedGemmaBackend
from src.vignette_parser import (
    VignetteParser,
    _extract_modality,
    _extract_body_region,
    _extract_contrast,
)
from src.longitudinal_generator import LongitudinalEMRGenerator

print("\u2713 All V2 modules imported successfully (force-reloaded)")

## 3. Load MedGemma 27B Model

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

MODEL_ID = "google/medgemma-27b-text-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

print(f"Loading MedGemma 27B model ({MODEL_ID})...")
print(f"  Note: Requires ~16 GB VRAM with 4-bit quantization (A100/L4 recommended)")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, quantization_config=bnb_config, device_map="auto"
)
processor = AutoTokenizer.from_pretrained(MODEL_ID)

print(f"\u2713 MedGemma 27B loaded!")
if torch.cuda.is_available():
    print(f"  GPU memory used: {torch.cuda.memory_allocated() / 1e9:.1f} GB")

In [None]:
# Create MedGemma backend for narrative generation
backend = MedGemmaBackend(model=model, processor=processor)

# Create the longitudinal EMR generator
# MedGemma is used for ALL encounter notes (prior and current)
generator = LongitudinalEMRGenerator(
    narrative_backend=backend,
)

print("\u2713 Longitudinal EMR Generator initialized")
print(f"  Backend: {backend.name}")
print(f"  All encounters use MedGemma for clinical notes")

## 4. Load RUQ Pain Dataset

Load the structured dataset with 92 RUQ pain cases. Each row provides:
- **Setting**: Clinical setting (ED, Inpatient, Outpatient Clinic, ICU, etc.)
- **Actual Diagnosis**: Ground truth (provider does NOT know this)
- **Patient Hx**: Objective chart data (demographics, vitals, labs, PE, PMHx)
- **Protocol Ordered**: The imaging protocol the provider orders

In [None]:
# Load the RUQ pain dataset
dataset_path = os.path.join(REPO_ROOT, "data", "ruq_pain_dataset_2.xlsx")
df = pd.read_excel(dataset_path)

print(f"\u2713 Loaded {len(df)} cases from ruq_pain_dataset_2.xlsx")
print(f"  Columns used: {['Setting', 'Actual Diagnosis', 'Patient Hx (Objective Chart Data)', 'Protocol Ordered']}")
print(f"\n  Settings:   {df['Setting'].nunique()} unique")
print(f"  Diagnoses:  {df['Actual Diagnosis'].nunique()} unique")
print(f"  Protocols:  {df['Protocol Ordered'].nunique()} unique")
print(f"\nSample rows:")
df[['Setting', 'Actual Diagnosis', 'Patient Hx (Objective Chart Data)', 'Protocol Ordered']].head(5)

## 5. Parse Patient Hx into Structured Data

Regex-based parsers extract demographics, vitals, labs, PE findings, and PMHx from
the "Patient Hx (Objective Chart Data)" column, then build `ParsedVignette` objects
that feed directly into the EMR generator (bypassing MedGemma-based vignette extraction).

In [None]:
# ── Patient Hx Parsing Functions ──────────────────────────────────────────────

def parse_demographics(hx_text: str):
    """Extract age and sex from text like '45F.' or '52M.'"""
    m = re.match(r'(\d+)\s*([MF])', hx_text.strip())
    if m:
        age = int(m.group(1))
        sex = "Female" if m.group(2) == "F" else "Male"
        return age, sex
    return None, None


def parse_vitals(hx_text: str) -> dict:
    """Extract vitals: T, HR, BP, RR, SpO2."""
    vitals = {}
    m = re.search(r'T\s+([\d.]+)\s*(?:°F|F)?', hx_text)
    if m:
        vitals['temperature_f'] = float(m.group(1).rstrip('.'))
    m = re.search(r'HR\s+(\d+)', hx_text)
    if m:
        vitals['heart_rate'] = int(m.group(1))
    m = re.search(r'BP\s+(\d+)/(\d+)', hx_text)
    if m:
        vitals['blood_pressure_systolic'] = int(m.group(1))
        vitals['blood_pressure_diastolic'] = int(m.group(2))
    m = re.search(r'RR\s+(\d+)', hx_text)
    if m:
        vitals['respiratory_rate'] = int(m.group(1))
    m = re.search(r'SpO2\s+([\d.]+)', hx_text)
    if m:
        vitals['oxygen_saturation'] = float(m.group(1).rstrip('.'))
    return vitals


def parse_labs(hx_text: str) -> dict:
    """Extract lab values from Patient Hx text.

    Note: Values like 'T.bili 1.1.' have a trailing sentence period that the
    regex [\\d.]+ captures. We rstrip('.') on all captured values before float().
    """
    labs = {}

    # WBC: handle "14,000", "14K", "14,200", "14.2K", "14.2"
    m = re.search(r'WBC\s+([\d.,]+)\s*K?', hx_text, re.IGNORECASE)
    if m:
        val_str = m.group(1).replace(',', '').rstrip('.')
        val = float(val_str)
        if val > 100:       # raw count like 14000
            val = val / 1000
        labs['wbc'] = val

    # Standard numeric labs — rstrip('.') handles trailing sentence periods
    lab_patterns = {
        'ast':              r'AST\s+([\d.]+)',
        'alt':              r'ALT\s+([\d.]+)',
        'alp':              r'(?:ALP|Alk\s*Phos)\s+([\d.]+)',
        'bilirubin_total':  r'T\.?\s*bili\s+([\d.]+)',
        'bilirubin_direct': r'D\.?\s*bili\s+([\d.]+)',
        'lipase':           r'[Ll]ipase\s+([\d.]+)',
        'creatinine':       r'Cr\s+([\d.]+)',
        'glucose':          r'[Gg]lucose\s+([\d.]+)',
        'hemoglobin':       r'(?:Hgb|Hb|Hemoglobin)\s+([\d.]+)',
        'platelets':        r'(?:Plt|Platelets)\s+([\d.]+)',
        'inr':              r'INR\s+([\d.]+)',
        'lactate':          r'[Ll]actate\s+([\d.]+)',
        'albumin':          r'[Aa]lbumin\s+([\d.]+)',
        'troponin_i':       r'[Tt]roponin\s+([\d.]+)',
    }
    for key, pattern in lab_patterns.items():
        m = re.search(pattern, hx_text)
        if m:
            labs[key] = float(m.group(1).rstrip('.'))

    # CRP (can be very high)
    m = re.search(r'CRP\s+([\d.]+)', hx_text)
    if m:
        labs['crp'] = float(m.group(1).rstrip('.'))

    # GFR (may have > prefix)
    m = re.search(r'(?:GFR|eGFR)\s+>?\s*([\d.]+)', hx_text, re.IGNORECASE)
    if m:
        labs['gfr'] = float(m.group(1).rstrip('.'))

    # Amylase
    m = re.search(r'[Aa]mylase\s+([\d.]+)', hx_text)
    if m:
        labs['amylase'] = float(m.group(1).rstrip('.'))

    return labs


def parse_pe_findings(hx_text: str) -> str:
    """Extract PE findings section."""
    m = re.search(r'PE:\s*(.*?)(?:PMHx:|$)', hx_text, re.IGNORECASE | re.DOTALL)
    if m:
        return m.group(1).strip().rstrip('.')
    return ""


def parse_pmhx(hx_text: str) -> list:
    """Extract past medical history conditions."""
    m = re.search(r'PMHx:\s*(.*?)(?:\.\s*$|$)', hx_text, re.IGNORECASE | re.DOTALL)
    if m:
        text = m.group(1).strip().rstrip('.')
        conditions = [c.strip() for c in text.split(',') if c.strip()]
        return conditions
    return []


def parse_surgical_history(hx_text: str, pmhx_conditions: list) -> list:
    """Extract surgical history from Patient Hx text and PMHx conditions.

    Looks for surgical procedures in the full text and PMHx list.
    Returns list of procedure strings (e.g., ["cholecystectomy (2018)", "appendectomy"]).
    """
    surgical_procedures = []
    combined = (hx_text + " " + " ".join(pmhx_conditions)).lower()

    # Surgical procedure patterns with optional year
    surgery_patterns = [
        (r'cholecystectomy\s*(?:\((\d{4})\))?', 'Cholecystectomy'),
        (r'appendectomy\s*(?:\((\d{4})\))?', 'Appendectomy'),
        (r'hysterectomy\s*(?:\((\d{4})\))?', 'Hysterectomy'),
        (r'c[\-\s]?section\s*(?:\((\d{4})\))?', 'Cesarean section'),
        (r'cabg\s*(?:\((\d{4})\))?', 'CABG'),
        (r'coronary\s+(?:artery\s+)?bypass\s*(?:\((\d{4})\))?', 'CABG'),
        (r'mastectomy\s*(?:\((\d{4})\))?', 'Mastectomy'),
        (r'(?:aaa|aortic\s+aneurysm)\s+repair\s*(?:\((\d{4})\))?', 'AAA repair'),
        (r's/p\s+stent\s*(?:\((\d{4})\))?', 'Coronary stent placement'),
        (r'stent\s+(?:placement|to)\s*(?:\((\d{4})\))?', 'Coronary stent placement'),
        (r'colectomy\s*(?:\((\d{4})\))?', 'Colectomy'),
        (r'(?:hip|knee)\s+replacement\s*(?:\((\d{4})\))?', None),  # special handling
        (r'gastric\s+(?:bypass|sleeve|band)\s*(?:\((\d{4})\))?', 'Bariatric surgery'),
        (r'ercp\s+with\s+(?:sphincterotomy|stent)\s*(?:\((\d{4})\))?', 'ERCP with sphincterotomy'),
        (r'prior\s+ercp\s*(?:\((\d{4})\))?', 'ERCP'),
        (r'splenectomy\s*(?:\((\d{4})\))?', 'Splenectomy'),
        (r'whipple\s*(?:\((\d{4})\))?', 'Whipple procedure'),
        (r'nephrectomy\s*(?:\((\d{4})\))?', 'Nephrectomy'),
        (r'liver\s+transplant\s*(?:\((\d{4})\))?', 'Liver transplant'),
    ]

    found_procedures = set()
    for pattern, proc_name in surgery_patterns:
        m = re.search(pattern, combined)
        if m:
            if proc_name is None:
                # Hip/knee replacement — capture the joint
                joint_match = re.search(r'(hip|knee)\s+replacement', combined)
                proc_name = f"{joint_match.group(1).title()} replacement" if joint_match else "Joint replacement"
            year = m.group(1) if m.lastindex and m.group(1) else None
            entry = f"{proc_name} ({year})" if year else proc_name
            if proc_name not in found_procedures:
                found_procedures.add(proc_name)
                surgical_procedures.append(entry)

    return surgical_procedures


def detect_contrast_allergy_flags(hx_text: str, pmhx_conditions: list) -> list:
    """Detect contrast allergy or related safety flags from Patient Hx.

    Returns list of safety flag strings for ParsedVignette.safety_flags_mentioned.
    """
    flags = []
    combined = (hx_text + " " + " ".join(pmhx_conditions)).lower()

    allergy_patterns = [
        (r'(?:severe\s+)?contrast\s+allergy', 'contrast allergy'),
        (r'iodinated\s+contrast\s+(?:allergy|reaction)', 'iodinated contrast allergy'),
        (r'gadolinium\s+(?:allergy|reaction|sensitivity)', 'gadolinium allergy'),
        (r'shellfish\s+allergy', 'shellfish allergy'),
        (r'(?:previous|prior|h/o)\s+contrast\s+reaction', 'prior contrast reaction'),
        (r'contrast[\-\s]?induced\s+nephropathy', 'contrast-induced nephropathy risk'),
        (r'nsf\s+risk', 'gadolinium NSF risk'),
    ]

    for pattern, flag_name in allergy_patterns:
        if re.search(pattern, combined):
            flags.append(flag_name)

    return flags


def derive_chief_complaint(pe_findings: str, hx_text: str) -> str:
    """Derive chief complaint from PE findings or full text."""
    combined = (pe_findings + " " + hx_text).lower()
    if 'ruq' in combined or 'right upper quadrant' in combined:
        return "Right upper quadrant pain"
    if 'epigastric' in combined:
        return "Epigastric pain"
    if 'jaundice' in combined:
        return "Jaundice with abdominal pain"
    return "Abdominal pain"


def derive_acuity(vitals: dict, labs: dict, setting: str) -> str:
    """Derive clinical acuity from vitals, labs, and setting."""
    s = setting.lower()
    temp = vitals.get('temperature_f', 98.6)
    hr = vitals.get('heart_rate', 80)
    wbc = labs.get('wbc', 8.0)
    sbp = vitals.get('blood_pressure_systolic', 120)

    if 'icu' in s or 'sicu' in s or 'micu' in s:
        return 'septic'
    if sbp < 90 or (temp >= 101.0 and wbc > 15 and hr > 100):
        return 'shock' if sbp < 90 else 'septic'
    if temp >= 100.4 or wbc > 15:
        return 'febrile'
    if hr > 90 or wbc > 11:
        return 'mild'
    return 'stable'


# Surgical procedure terms to remove from PMHx (avoid duplicate representation)
SURGICAL_KEYWORDS = {
    'cholecystectomy', 'appendectomy', 'hysterectomy', 'c-section',
    'cabg', 'mastectomy', 'colectomy', 'splenectomy', 'nephrectomy',
    'gastric bypass', 'gastric sleeve', 'whipple', 'liver transplant',
}


def parse_dataset_row(row, row_index: int):
    """
    Parse one row from the Excel dataset into a ParsedVignette + metadata.

    Returns:
        (ParsedVignette, metadata_dict)
    """
    setting = str(row['Setting'])
    actual_diagnosis = str(row['Actual Diagnosis'])
    patient_hx = str(row['Patient Hx (Objective Chart Data)'])
    protocol_ordered = str(row['Protocol Ordered'])

    age, sex = parse_demographics(patient_hx)
    vitals = parse_vitals(patient_hx)
    labs = parse_labs(patient_hx)
    pe_findings = parse_pe_findings(patient_hx)
    pmhx = parse_pmhx(patient_hx)
    chief_complaint = derive_chief_complaint(pe_findings, patient_hx)
    acuity = derive_acuity(vitals, labs, setting)

    # Extract surgical history and contrast allergy flags
    surgical_history = parse_surgical_history(patient_hx, pmhx)
    safety_flags = detect_contrast_allergy_flags(patient_hx, pmhx)

    # Remove surgical procedures from PMHx conditions to avoid duplicate
    # representation (they'll appear in surgical_history instead)
    filtered_pmhx = []
    for cond in pmhx:
        cond_lower = cond.lower().strip()
        is_surgical = any(kw in cond_lower for kw in SURGICAL_KEYWORDS)
        # Also check for "s/p" prefix which often indicates prior surgery
        if cond_lower.startswith('s/p '):
            is_surgical = True
        if not is_surgical:
            filtered_pmhx.append(cond)

    parsed = ParsedVignette(
        age=age,
        sex=sex,
        diagnosis=actual_diagnosis,
        chief_complaint=chief_complaint,
        history_conditions=filtered_pmhx,
        surgical_history=surgical_history,
        vitals_mentioned=vitals,
        labs_mentioned=labs,
        exam_findings=pe_findings,
        ordered_study=protocol_ordered,
        imaging_modality=_extract_modality(protocol_ordered),
        imaging_body_region=_extract_body_region(protocol_ordered),
        imaging_contrast=_extract_contrast(protocol_ordered),
        clinical_setting=setting,
        acuity=acuity,
        extraction_confidence=1.0,
    )

    # Detect special populations
    if 'pregnan' in patient_hx.lower():
        parsed.special_populations.append('pregnant')
    if 'immunocompromised' in patient_hx.lower() or 'transplant' in patient_hx.lower():
        parsed.special_populations.append('immunocompromised')

    # Set safety flags from contrast allergy detection
    parsed.safety_flags_mentioned = safety_flags

    metadata = {
        'case_id': f'RUQ-{row_index + 1:03d}',
        'setting': setting,
        'actual_diagnosis': actual_diagnosis,
        'protocol_ordered': protocol_ordered,
        'raw_patient_hx': patient_hx,
        'pe_findings': pe_findings,
    }
    return parsed, metadata


print("\u2713 Patient Hx parsing functions defined")
print(f"  Includes: surgical history parser, contrast allergy detector")


In [None]:
# Parse all rows into ParsedVignettes
parsed_cases = []
parse_errors = []

for idx, row in df.iterrows():
    try:
        parsed, metadata = parse_dataset_row(row, idx)
        parsed_cases.append((parsed, metadata))
    except Exception as e:
        parse_errors.append((idx, str(e)))

print(f"\u2713 Parsed {len(parsed_cases)}/{len(df)} cases successfully")
if parse_errors:
    print(f"\u2717 Parse errors: {len(parse_errors)}")
    for idx, err in parse_errors:
        print(f"  Row {idx}: {err}")

# Validation summary
ages_found = sum(1 for p, m in parsed_cases if p.age is not None)
vitals_found = sum(1 for p, m in parsed_cases if len(p.vitals_mentioned) > 0)
labs_found = sum(1 for p, m in parsed_cases if len(p.labs_mentioned) > 0)
pmhx_found = sum(1 for p, m in parsed_cases if len(p.history_conditions) > 0)
dx_documented = sum(1 for p, m in parsed_cases if m['actual_diagnosis'] != '(Not documented)')
surg_found = sum(1 for p, m in parsed_cases if len(p.surgical_history) > 0)
safety_found = sum(1 for p, m in parsed_cases if len(p.safety_flags_mentioned) > 0)

print(f"\nData completeness:")
print(f"  Age/Sex extracted:       {ages_found}/{len(parsed_cases)}")
print(f"  Vitals extracted:        {vitals_found}/{len(parsed_cases)}")
print(f"  Labs extracted:          {labs_found}/{len(parsed_cases)}")
print(f"  PMHx extracted:          {pmhx_found}/{len(parsed_cases)}")
print(f"  Diagnosis documented:    {dx_documented}/{len(parsed_cases)}")
print(f"  Surgical history found:  {surg_found}/{len(parsed_cases)}")
print(f"  Safety flags detected:   {safety_found}/{len(parsed_cases)}")

# Show surgical history details
if surg_found > 0:
    print(f"\nSurgical procedures found:")
    for p, m in parsed_cases:
        if p.surgical_history:
            print(f"  {m['case_id']}: {', '.join(p.surgical_history)}")

if safety_found > 0:
    print(f"\nSafety flags detected:")
    for p, m in parsed_cases:
        if p.safety_flags_mentioned:
            print(f"  {m['case_id']}: {', '.join(p.safety_flags_mentioned)}")

# Show summary table
summary_rows = []
for p, m in parsed_cases[:10]:
    summary_rows.append({
        'Case ID': m['case_id'],
        'Setting': m['setting'][:20],
        'Age': p.age,
        'Sex': (p.sex or '?')[0],
        'Vitals': len(p.vitals_mentioned),
        'Labs': len(p.labs_mentioned),
        'PMHx': len(p.history_conditions),
        'Surg': len(p.surgical_history),
        'Acuity': p.acuity,
        'Protocol': m['protocol_ordered'][:25],
        'Diagnosis': m['actual_diagnosis'][:30],
    })
print(f"\nFirst 10 cases:")
pd.DataFrame(summary_rows)


## 6. Generate a Single Longitudinal EMR (Demo)

Generate one EMR to verify the pipeline. The provider does **not** know the actual
diagnosis — notes reflect diagnostic uncertainty. The EMR ends with the imaging order.

In [None]:
# Pick the first case for demo
parsed, metadata = parsed_cases[0]

print(f"Case: {metadata['case_id']}")
print(f"Setting: {metadata['setting']}")
print(f"Actual Diagnosis (ground truth): {metadata['actual_diagnosis']}")
print(f"Protocol Ordered: {metadata['protocol_ordered']}")
print(f"Patient Hx: {metadata['raw_patient_hx'][:120]}...")
print(f"\nParsed: age={parsed.age}, sex={parsed.sex}, acuity={parsed.acuity}")
print(f"  Vitals: {parsed.vitals_mentioned}")
print(f"  Labs: {parsed.labs_mentioned}")
print(f"  PMHx: {parsed.history_conditions}")
print(f"  Chief complaint: {parsed.chief_complaint}")
print(f"  Imaging: {parsed.imaging_modality} {parsed.imaging_body_region} {parsed.imaging_contrast}")

# Create a ClinicalVignette for metadata tracking
vignette = ClinicalVignette(
    vignette_text=metadata['raw_patient_hx'],
    case_id=metadata['case_id'],
    case_name=f"{metadata['setting']} - {metadata['actual_diagnosis']}",
    ground_truth_diagnosis=metadata['actual_diagnosis'],
    ground_truth_appropriate_study=metadata['protocol_ordered'],
    seed=42,
)

print(f"\n{'='*70}")
print("Generating longitudinal EMR (diagnosis masked in narratives)...\n")

emr = generator.generate_from_parsed(
    parsed,
    vignette=vignette,
    seed=42,
    mask_diagnosis_in_narratives=True,
)

print(f"\n\u2713 Longitudinal EMR generated!")
print(f"  Patient: {emr.patient.first_name} {emr.patient.last_name}")
print(f"  Age: {emr.patient.age} | Sex: {emr.patient.sex}")
print(f"  Total encounters: {len(emr.encounter_history)}")
print(f"  Problem list entries: {len(emr.problem_list)}")
print(f"  Medication changes: {len(emr.medication_history)}")

In [None]:
# Display the encounter timeline
print("=" * 80)
print("ENCOUNTER TIMELINE")
print("=" * 80)

for i, record in enumerate(emr.encounter_history):
    enc = record.encounter
    is_current = (i == len(emr.encounter_history) - 1)
    marker = ">>> CURRENT <<<" if is_current else ""

    print(f"\n{'='*60}")
    print(f"Encounter {i+1}: {enc.encounter_type} | {enc.admission_datetime[:10]}  {marker}")
    print(f"  Facility: {enc.facility}")
    print(f"  Department: {enc.department}")
    print(f"  Provider: {enc.attending_provider}")
    print(f"  Chief Complaint: {enc.chief_complaint}")

    if record.diagnoses:
        print(f"  Diagnoses: {', '.join(record.diagnoses)}")

    if record.vital_signs:
        v = record.vital_signs[0]
        parts = []
        if v.temperature_f: parts.append(f"T {v.temperature_f}\u00b0F")
        if v.heart_rate: parts.append(f"HR {v.heart_rate}")
        if v.blood_pressure_systolic: parts.append(f"BP {v.blood_pressure_systolic}/{v.blood_pressure_diastolic}")
        if parts:
            print(f"  Vitals: {' | '.join(parts)}")

    if record.lab_results:
        total_labs = sum(len(p.results) for p in record.lab_results)
        abnormal = [r for p in record.lab_results for r in p.results if r.flag]
        print(f"  Labs: {total_labs} results ({len(abnormal)} abnormal)")

    if record.imaging_orders:
        for order in record.imaging_orders:
            print(f"  Imaging Order: {order.modality} {order.body_region} {order.contrast} [Status: {order.status}]")

    if record.clinical_notes:
        note_types = [n.note_type for n in record.clinical_notes]
        print(f"  Notes: {', '.join(note_types)}")

In [None]:
# Problem List and Medications
print("=" * 70)
print("PROBLEM LIST")
print("=" * 70)
print(f"{'Condition':<35} {'ICD-10':<10} {'Status':<10} {'Date Added':<12}")
print("-" * 70)
for entry in emr.problem_list:
    icd = entry.icd10 or ""
    date = entry.date_added[:10] if entry.date_added else ""
    print(f"{entry.condition[:33]:<35} {icd:<10} {entry.status:<10} {date:<12}")

print(f"\n{'='*70}")
print("CURRENT MEDICATIONS")
print("=" * 70)
print("\nHome Medications:")
for m in emr.current_medications.home_medications:
    print(f"  - {m.name} {m.dose} {m.route} {m.frequency} ({m.indication or ''})")
if emr.current_medications.inpatient_medications:
    print("\nInpatient Medications:")
    for m in emr.current_medications.inpatient_medications:
        print(f"  - {m.name} {m.dose} {m.route} {m.frequency} ({m.indication or ''})")

In [None]:
current = emr.current_encounter

# Demographics
print("=" * 70)
print("PATIENT DEMOGRAPHICS")
print("=" * 70)
p = emr.patient
print(f"Name: {p.first_name} {p.last_name}")
print(f"MRN: {p.mrn}")
print(f"DOB: {p.date_of_birth} (Age: {p.age})")
print(f"Sex: {p.sex}")
print(f"Insurance: {p.insurance}")

# Vital Signs
print("\n" + "=" * 70)
print("VITAL SIGNS (Current Encounter)")
print("=" * 70)
for v in current.vital_signs:
    print(f"\n[{v.source}] {v.timestamp}")
    print(f"  Temp: {v.temperature_f}\u00b0F | HR: {v.heart_rate} | BP: {v.blood_pressure_systolic}/{v.blood_pressure_diastolic}")
    print(f"  RR: {v.respiratory_rate} | SpO2: {v.oxygen_saturation}% | Pain: {v.pain_scale}/10")

In [None]:
# Lab Results
print("=" * 70)
print("LABORATORY RESULTS (Current Encounter)")
print("=" * 70)
for panel in current.lab_results:
    print(f"\n--- {panel.panel_name} ({panel.timestamp}) ---")
    print(f"{'Test':<22} {'Value':>8} {'Unit':<15} {'Ref Range':<15} {'Flag':<8}")
    print("-" * 70)
    for r in panel.results:
        ref_str = f"{r.reference_low}-{r.reference_high}" if r.reference_low is not None else ""
        flag_str = r.flag or ""
        print(f"{r.test_name:<22} {r.value:>8.1f} {r.unit:<15} {ref_str:<15} {flag_str:<8}")

In [None]:
# Clinical Notes
print("=" * 70)
print("CLINICAL NOTES (Current Encounter)")
print("=" * 70)

for note in current.clinical_notes:
    print(f"\n{'='*60}")
    print(f"{note.note_type.upper()} | {note.timestamp} | {note.author}")
    print(f"{'='*60}")
    print(note.note_text)

In [None]:
# Imaging Orders (Current Encounter) — should be the final action
print("=" * 70)
print("IMAGING ORDERS (Current Encounter)")
print("=" * 70)
if current.imaging_orders:
    for order in current.imaging_orders:
        print(f"\nOrder: {order.modality} {order.body_region} {order.contrast}")
        print(f"  ID: {order.order_id} | Urgency: {order.urgency} | Status: {order.status}")
        print(f"  Indication: {order.indication}")
        print(f"  Ordering Provider: {order.ordering_provider}")
    print(f"\n  Note: Imaging ordered but not yet performed.")
    print(f"  This will be evaluated for appropriateness in the next notebook.")
else:
    print("  No imaging orders placed.")

# Patient History
print("\n" + "=" * 70)
print("PATIENT HISTORY")
print("=" * 70)

print("\nSurgical History:")
for s in emr.surgical_history:
    print(f"  - {s.procedure} ({s.year})")
if not emr.surgical_history:
    print("  None")

print("\nAllergies:")
for a in emr.allergies:
    print(f"  - {a.allergen} ({a.allergy_type}) - {a.reaction}, {a.severity}")
if not emr.allergies:
    print("  NKDA")

print(f"\nSocial History:")
sh = emr.social_history
print(f"  Smoking: {sh.smoking_status}")
print(f"  Alcohol: {sh.alcohol_use}")
print(f"  Occupation: {sh.occupation or 'Not specified'}")

print(f"\nFamily History:")
for fh in emr.family_history:
    print(f"  - {fh}")

## 7. Verify Diagnosis Masking & Output Structure

Confirm that provider-facing notes do NOT reveal the actual diagnosis, and that the
EMR ends with the imaging order (no results).

In [None]:
# Verify EMR output structure
print(f"Actual diagnosis (ground truth): {metadata['actual_diagnosis']}")
print(f"Protocol ordered: {metadata['protocol_ordered']}")
print()

# Verify imaging order is present and status is "Ordered" (no results)
if current.imaging_orders:
    order = current.imaging_orders[0]
    print(f"  ✓ Imaging order present — {order.modality} {order.body_region} [{order.status}]")
else:
    print(f"  ⚠ WARNING: No imaging order found!")

if current.imaging_reports:
    print(f"  ⚠ WARNING: Imaging reports present (should be empty — order only)")
else:
    print(f"  ✓ No imaging results (order only — awaiting evaluation)")

# Verify current encounter diagnoses use working description
print(f"\n  Current encounter diagnoses: {current.diagnoses}")


## 8. Batch Generation — All 92 Cases (with Checkpoint/Resume)

Generate longitudinal EMRs for all cases in the dataset. Each EMR is **saved to disk
immediately** after generation, so if the runtime disconnects or the cell is interrupted,
all completed cases are preserved. Re-running this cell **skips already-generated cases**
and picks up where it left off.

- Output directory: `data/generated_emrs_27b/`
- Each case saved as `{case_id}.json`
- Summary CSV updated after every case
- Previous 4B-generated EMRs preserved in `data/generated_emrs/`

In [None]:
import time

output_dir = os.path.join(REPO_ROOT, "data", "generated_emrs_27b")
os.makedirs(output_dir, exist_ok=True)


def _update_summary_csv(out_dir, meta, emr_r):
    """Append or update the summary CSV with one new case."""
    csv_path = os.path.join(out_dir, "generation_summary.csv")

    new_row = {
        'case_id': meta['case_id'],
        'setting': meta['setting'],
        'actual_diagnosis': meta['actual_diagnosis'],
        'protocol_ordered': meta['protocol_ordered'],
        'patient_name': f"{emr_r.patient.first_name} {emr_r.patient.last_name}",
        'age': emr_r.patient.age,
        'sex': emr_r.patient.sex,
        'encounters': len(emr_r.encounter_history),
        'problems': len(emr_r.problem_list),
        'med_changes': len(emr_r.medication_history),
        'emr_file': f"{meta['case_id']}.json",
    }

    if os.path.exists(csv_path):
        existing = pd.read_csv(csv_path)
        # Remove old row for this case_id if re-generating
        existing = existing[existing['case_id'] != meta['case_id']]
        updated = pd.concat([existing, pd.DataFrame([new_row])], ignore_index=True)
    else:
        updated = pd.DataFrame([new_row])

    updated.to_csv(csv_path, index=False)


# ── Checkpoint/Resume: detect already-completed cases ────────────────────────
completed_ids = set()
for fname in os.listdir(output_dir):
    if fname.startswith("RUQ-") and fname.endswith(".json"):
        completed_ids.add(fname.replace(".json", ""))

if completed_ids:
    print(f"▶ Resuming: found {len(completed_ids)} already-generated EMRs on disk")
    print(f"  Completed: {sorted(completed_ids)[:5]}{'...' if len(completed_ids) > 5 else ''}")
else:
    print("▶ Starting fresh batch generation")

print(f"  Output: {output_dir}")

# ── Batch generation loop ────────────────────────────────────────────────────
NUM_CASES = len(parsed_cases)
results = []       # Accumulates (meta, emr) for in-memory summary
errors = []
skipped = 0
start_time = time.time()

for idx, (parsed_v, meta) in enumerate(parsed_cases[:NUM_CASES]):
    case_id = meta['case_id']

    # Skip if already generated
    if case_id in completed_ids:
        skipped += 1
        continue

    print(f"\n{'='*70}")
    print(f"Case {idx+1}/{NUM_CASES}: {case_id} | {meta['setting']} | {meta['actual_diagnosis'][:40]}")
    print(f"{'='*70}")

    vig = ClinicalVignette(
        vignette_text=meta['raw_patient_hx'],
        case_id=case_id,
        case_name=f"{meta['setting']} - {meta['actual_diagnosis']}",
        ground_truth_diagnosis=meta['actual_diagnosis'],
        ground_truth_appropriate_study=meta['protocol_ordered'],
        seed=42 + idx,
    )

    try:
        emr_result = generator.generate_from_parsed(
            parsed_v,
            vignette=vig,
            seed=42 + idx,
            mask_diagnosis_in_narratives=True,
        )

        # ── Save immediately to disk ─────────────────────────────────────
        emr_path = os.path.join(output_dir, f"{case_id}.json")
        with open(emr_path, 'w') as f:
            f.write(emr_result.to_json())

        results.append((meta, emr_result))
        completed_ids.add(case_id)

        print(f"  ✓ {emr_result.patient.first_name} {emr_result.patient.last_name}, "
              f"{emr_result.patient.age}yo {emr_result.patient.sex}")
        print(f"    Encounters: {len(emr_result.encounter_history)} | "
              f"Problems: {len(emr_result.problem_list)} | "
              f"Protocol: {meta['protocol_ordered'][:30]}")
        print(f"    Saved → {case_id}.json")

        # ── Update summary CSV after every case (incremental checkpoint) ─
        _update_summary_csv(output_dir, meta, emr_result)

    except Exception as e:
        errors.append((case_id, str(e)))
        print(f"  ✗ ERROR: {e}")
        import traceback
        traceback.print_exc()

elapsed = time.time() - start_time
print(f"\n{'='*70}")
print(f"BATCH COMPLETE")
print(f"{'='*70}")
print(f"  Generated this run:  {len(results)}")
print(f"  Skipped (resumed):   {skipped}")
print(f"  Errors:              {len(errors)}")
print(f"  Total on disk:       {len(completed_ids)}/{NUM_CASES}")
print(f"  Time this run:       {elapsed/60:.1f} minutes")
if errors:
    print(f"\nErrors:")
    for cid, err in errors:
        print(f"  {cid}: {err}")

In [None]:
# Verify saved EMRs on disk (works even after runtime restart)
output_dir = os.path.join(REPO_ROOT, "data", "generated_emrs_27b")

json_files = sorted(f for f in os.listdir(output_dir) if f.endswith(".json"))
csv_path = os.path.join(output_dir, "generation_summary.csv")

print(f"✓ {len(json_files)} EMR JSON files in {output_dir}/")
if os.path.exists(csv_path):
    summary_df = pd.read_csv(csv_path)
    print(f"✓ Summary CSV: {len(summary_df)} rows")
else:
    print("⚠ No summary CSV found — re-run batch cell to regenerate")

print(f"\nFiles (first 10):")
for f in json_files[:10]:
    size_kb = os.path.getsize(os.path.join(output_dir, f)) / 1024
    print(f"  {f}  ({size_kb:.0f} KB)")
if len(json_files) > 10:
    print(f"  ... and {len(json_files) - 10} more")

# Check for any missing cases
all_case_ids = {f"RUQ-{i+1:03d}" for i in range(len(parsed_cases))}
saved_case_ids = {f.replace(".json", "") for f in json_files}
missing = sorted(all_case_ids - saved_case_ids)
if missing:
    print(f"\n⚠ Missing cases ({len(missing)}): {missing}")
    print("  Re-run the batch cell above to generate these.")
else:
    print(f"\n✓ All {len(all_case_ids)} cases generated!")

## 9. Generation Summary

In [None]:
# Summary table — reads from summary CSV so it works after resume/restart
output_dir = os.path.join(REPO_ROOT, "data", "generated_emrs_27b")
csv_path = os.path.join(output_dir, "generation_summary.csv")

if os.path.exists(csv_path):
    summary_df = pd.read_csv(csv_path).sort_values('case_id')

    print(f"{'='*100}")
    print("GENERATION SUMMARY")
    print(f"{'='*100}")
    print(f"{'ID':<10} {'Patient':<20} {'Age':>4} {'Sex':<4} {'Setting':<18} {'Enc':>4} {'Protocol':<25}")
    print("-" * 100)
    for _, row in summary_df.iterrows():
        print(f"{row['case_id']:<10} "
              f"{str(row['patient_name'])[:18]:<20} "
              f"{int(row['age']):>4} "
              f"{str(row['sex'])[0]:<4} "
              f"{str(row['setting'])[:16]:<18} "
              f"{int(row['encounters']):>4} "
              f"{str(row['protocol_ordered'])[:23]:<25}")

    print(f"\n{'='*100}")
    print(f"Total cases: {len(summary_df)}")
    print(f"Unique settings: {summary_df['setting'].nunique()}")
    print(f"Unique diagnoses: {summary_df['actual_diagnosis'].nunique()}")
    print(f"Unique protocols: {summary_df['protocol_ordered'].nunique()}")
else:
    print("No summary CSV found. Run the batch generation cell first.")

## Architecture Summary

### Pipeline
```
ruq_pain_dataset_2.xlsx
  -> parse_dataset_row()          [Regex: demographics, vitals, labs, PE, PMHx]
  -> ParsedVignette               [Structured data, bypasses MedGemma extraction]
  -> generate_from_parsed()       [mask_diagnosis_in_narratives=True]
     -> _build_timeline()         [Condition progression from PMHx]
     -> _generate_prior_encounter()  [SOAP notes via MedGemma]
     -> _generate_current_encounter()
        -> generate_clinical_data()  [Uses REAL diagnosis for consistency]
        -> generate_hpi/pe/ap()      [Uses MASKED diagnosis for uncertainty]
        -> ImagingOrder              [From dataset, status="Ordered", no results]
  -> LongitudinalEMR              [Multi-year chart ending at imaging order]
     -> JSON export               [For downstream protocol evaluation]
```

### Key Design Decisions
- **Diagnosis masking**: Real diagnosis drives clinical data generation (behind the scenes), but provider-facing notes reflect diagnostic uncertainty
- **No imaging results**: EMR stops at the imaging order — evaluation happens in the next notebook
- **Direct parsing**: Regex-based Patient Hx parsing bypasses MedGemma extraction for speed and accuracy
- **Clinical setting**: Respects the 28 different settings from the dataset (ED, Inpatient, Outpatient, ICU, etc.)
