# RUQ Pain Imaging Protocol Triage — Batch EMR Evaluation

This notebook evaluates **generated longitudinal EMR JSONs** (from notebook 04) against **ACR Appropriateness Criteria** using MedGemma 27B for clinical variant classification.

## Pipeline
1. Load generated EMR JSON files from `data/generated_emrs_27b/`
2. **Deterministically extract** structured `PatientContext` from each EMR (no LLM needed — data is already structured)
3. Load `Protocol Ordered` and `Provider Indication` from the dataset
4. Use **MedGemma 27B** to classify into 1 of 5 ACR clinical variants
5. Perform **deterministic ACR lookup** for appropriateness rating (1-9)
6. Apply **safety checks** (contrast allergy, renal function, pregnancy)
7. Triage into GREEN / YELLOW / RED / PURPLE
8. **Verify** results against dataset verification columns

## Prerequisites
- HuggingFace account with MedGemma 27B access
- Colab with GPU enabled (L4 or higher)
- Generated EMR JSONs in `data/generated_emrs_27b/` (from notebook 04)
- `ruq_pain_dataset_2 (1).xlsx` with extended columns


In [1]:
# Install dependencies
!pip install -q transformers>=4.50.0 accelerate bitsandbytes torch pillow huggingface_hub openpyxl pandas

## 1. Setup and Install Dependencies

In [2]:
from huggingface_hub import login

# Authenticate with HuggingFace
# Get your token at: https://huggingface.co/settings/tokens
# Make sure you've accepted terms at: https://huggingface.co/google/medgemma-27b-text-it
login()


In [None]:
import torch

# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {total_gb:.1f} GB")
    if total_gb < 20:
        print("WARNING: GPU has limited VRAM. MedGemma 27B quantized needs ~16 GB. "
              "Go to Runtime > Change runtime type > L4 GPU or higher")
    else:
        print("GPU has sufficient VRAM for MedGemma 27B quantized (~16 GB).")
else:
    print("WARNING: No GPU detected! Go to Runtime > Change runtime type > L4 GPU")


## 2. Define ACR Criteria and Data Structures

This section defines all the necessary classes and ACR criteria inline so the notebook works standalone in Colab.

In [None]:
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
from datetime import datetime
import json
import re

# =============================================================================
# ENUMS
# =============================================================================

class TriagePriority(Enum):
    """Triage output categories for radiologist workflow."""
    GREEN = "likely_appropriate"
    YELLOW = "possibly_inappropriate"
    RED = "definitely_inappropriate"
    PURPLE = "insufficient_information"

class AppropriatenessCategory(Enum):
    """ACR Appropriateness rating categories."""
    USUALLY_APPROPRIATE = "usually_appropriate"
    MAY_BE_APPROPRIATE = "may_be_appropriate"
    USUALLY_NOT_APPROPRIATE = "usually_not_appropriate"

class Modality(Enum):
    """Imaging modalities from ACR Appropriateness Criteria for RUQ Pain."""
    US_ABDOMEN = "us_abdomen"
    CT_ABDOMEN_WITH_CONTRAST = "ct_abdomen_with_iv_contrast"
    CT_ABDOMEN_WITHOUT_CONTRAST = "ct_abdomen_without_contrast"
    CT_ABDOMEN_WITHOUT_AND_WITH_CONTRAST = "ct_abdomen_without_and_with_contrast"
    MRI_ABDOMEN_WITHOUT_CONTRAST = "mri_abdomen_without_contrast"
    MRI_ABDOMEN_WITHOUT_CONTRAST_WITH_MRCP = "mri_abdomen_without_contrast_with_mrcp"
    MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP = "mri_abdomen_without_and_with_contrast_with_mrcp"
    NUCLEAR_HIDA = "nuclear_medicine_hepatobiliary_scan"
    RADIOGRAPHY_ABDOMEN = "radiography_abdomen"
    IMAGE_GUIDED_CHOLECYSTOSTOMY = "image_guided_cholecystostomy"

class RUQVariant(Enum):
    """ACR RUQ Pain Clinical Variants."""
    VARIANT_1 = "unknown_etiology_initial"
    VARIANT_2 = "suspected_biliary_initial"
    VARIANT_3 = "suspected_biliary_afebrile_negative_us"
    VARIANT_4 = "suspected_biliary_febrile_negative_us"
    VARIANT_5 = "suspected_acalculous_cholecystitis"

VARIANT_DESCRIPTIONS = {
    RUQVariant.VARIANT_1: "Right upper quadrant pain, unknown etiology, initial imaging",
    RUQVariant.VARIANT_2: "Right upper quadrant pain, suspected biliary disease, initial imaging",
    RUQVariant.VARIANT_3: "Right upper quadrant pain, suspected biliary disease, no fever and normal WBC, negative or equivocal ultrasound",
    RUQVariant.VARIANT_4: "Right upper quadrant pain, suspected biliary disease, fever and/or elevated WBC, negative or equivocal ultrasound",
    RUQVariant.VARIANT_5: "Right upper quadrant pain, suspected acalculous cholecystitis, negative or equivocal ultrasound",
}

print("✓ Enums defined")

In [None]:
# =============================================================================
# ACR CRITERIA DATA STRUCTURES
# =============================================================================

@dataclass
class ModalityRating:
    """ACR appropriateness rating for a specific imaging modality."""
    modality: Modality
    rating: int
    relative_radiation: int = 0
    notes: Optional[str] = None

    @property
    def category(self) -> AppropriatenessCategory:
        if self.rating >= 7:
            return AppropriatenessCategory.USUALLY_APPROPRIATE
        elif self.rating >= 4:
            return AppropriatenessCategory.MAY_BE_APPROPRIATE
        else:
            return AppropriatenessCategory.USUALLY_NOT_APPROPRIATE

@dataclass
class ClinicalVariantCriteria:
    """Complete ACR criteria for a specific clinical variant."""
    variant: RUQVariant
    description: str
    ratings: list = field(default_factory=list)

    @property
    def usually_appropriate(self):
        return [r for r in self.ratings if r.rating >= 7]

    @property
    def may_be_appropriate(self):
        return [r for r in self.ratings if 4 <= r.rating <= 6]

    @property
    def usually_not_appropriate(self):
        return [r for r in self.ratings if r.rating <= 3]

    def get_rating(self, modality: Modality):
        for r in self.ratings:
            if r.modality == modality:
                return r
        return None

# =============================================================================
# ACR CRITERIA DATABASE
# =============================================================================

ACR_RUQ_CRITERIA = {
    RUQVariant.VARIANT_1: ClinicalVariantCriteria(
        variant=RUQVariant.VARIANT_1,
        description=VARIANT_DESCRIPTIONS[RUQVariant.VARIANT_1],
        ratings=[
            ModalityRating(Modality.US_ABDOMEN, 9),
            ModalityRating(Modality.CT_ABDOMEN_WITH_CONTRAST, 8),
            ModalityRating(Modality.MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP, 6),
            ModalityRating(Modality.MRI_ABDOMEN_WITHOUT_CONTRAST_WITH_MRCP, 5),
            ModalityRating(Modality.RADIOGRAPHY_ABDOMEN, 5),
            ModalityRating(Modality.CT_ABDOMEN_WITHOUT_CONTRAST, 5),
            ModalityRating(Modality.NUCLEAR_HIDA, 3),
            ModalityRating(Modality.CT_ABDOMEN_WITHOUT_AND_WITH_CONTRAST, 2),
        ],
    ),
    RUQVariant.VARIANT_2: ClinicalVariantCriteria(
        variant=RUQVariant.VARIANT_2,
        description=VARIANT_DESCRIPTIONS[RUQVariant.VARIANT_2],
        ratings=[
            ModalityRating(Modality.US_ABDOMEN, 9),
            ModalityRating(Modality.CT_ABDOMEN_WITH_CONTRAST, 6),
            ModalityRating(Modality.MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP, 6),
            ModalityRating(Modality.MRI_ABDOMEN_WITHOUT_CONTRAST_WITH_MRCP, 5),
            ModalityRating(Modality.NUCLEAR_HIDA, 5),
            ModalityRating(Modality.CT_ABDOMEN_WITHOUT_CONTRAST, 5),
            ModalityRating(Modality.CT_ABDOMEN_WITHOUT_AND_WITH_CONTRAST, 3),
        ],
    ),
    RUQVariant.VARIANT_3: ClinicalVariantCriteria(
        variant=RUQVariant.VARIANT_3,
        description=VARIANT_DESCRIPTIONS[RUQVariant.VARIANT_3],
        ratings=[
            ModalityRating(Modality.MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP, 8),
            ModalityRating(Modality.MRI_ABDOMEN_WITHOUT_CONTRAST_WITH_MRCP, 7),
            ModalityRating(Modality.CT_ABDOMEN_WITH_CONTRAST, 7),
            ModalityRating(Modality.NUCLEAR_HIDA, 5),
            ModalityRating(Modality.CT_ABDOMEN_WITHOUT_CONTRAST, 4),
            ModalityRating(Modality.CT_ABDOMEN_WITHOUT_AND_WITH_CONTRAST, 3),
        ],
    ),
    RUQVariant.VARIANT_4: ClinicalVariantCriteria(
        variant=RUQVariant.VARIANT_4,
        description=VARIANT_DESCRIPTIONS[RUQVariant.VARIANT_4],
        ratings=[
            ModalityRating(Modality.MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP, 8),
            ModalityRating(Modality.NUCLEAR_HIDA, 7),
            ModalityRating(Modality.CT_ABDOMEN_WITH_CONTRAST, 7),
            ModalityRating(Modality.MRI_ABDOMEN_WITHOUT_CONTRAST_WITH_MRCP, 6),
            ModalityRating(Modality.CT_ABDOMEN_WITHOUT_CONTRAST, 5),
            ModalityRating(Modality.CT_ABDOMEN_WITHOUT_AND_WITH_CONTRAST, 3),
        ],
    ),
    RUQVariant.VARIANT_5: ClinicalVariantCriteria(
        variant=RUQVariant.VARIANT_5,
        description=VARIANT_DESCRIPTIONS[RUQVariant.VARIANT_5],
        ratings=[
            ModalityRating(Modality.NUCLEAR_HIDA, 8),
            ModalityRating(Modality.IMAGE_GUIDED_CHOLECYSTOSTOMY, 6),
            ModalityRating(Modality.MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP, 6),
            ModalityRating(Modality.CT_ABDOMEN_WITH_CONTRAST, 5),
            ModalityRating(Modality.MRI_ABDOMEN_WITHOUT_CONTRAST_WITH_MRCP, 5),
            ModalityRating(Modality.CT_ABDOMEN_WITHOUT_CONTRAST, 4),
            ModalityRating(Modality.CT_ABDOMEN_WITHOUT_AND_WITH_CONTRAST, 3),
        ],
    ),
}

print("✓ ACR Criteria database loaded")

In [None]:
# =============================================================================
# DATA MODELS
# =============================================================================

@dataclass
class OrderedProtocol:
    """Represents the imaging order from the primary physician."""
    modality: Modality
    order_id: Optional[str] = None
    ordering_provider: Optional[str] = None
    ordering_provider_specialty: Optional[str] = None
    urgency: Optional[str] = None  # "Stat", "Urgent", "Routine"

    def requires_contrast(self) -> bool:
        contrast_modalities = {
            Modality.CT_ABDOMEN_WITH_CONTRAST,
            Modality.CT_ABDOMEN_WITHOUT_AND_WITH_CONTRAST,
            Modality.MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP,
        }
        return self.modality in contrast_modalities

    def requires_gadolinium(self) -> bool:
        return self.modality in {
            Modality.MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP,
        }

    def requires_iodinated_contrast(self) -> bool:
        return self.modality in {
            Modality.CT_ABDOMEN_WITH_CONTRAST,
            Modality.CT_ABDOMEN_WITHOUT_AND_WITH_CONTRAST,
        }

    def involves_radiation(self) -> bool:
        return self.modality in {
            Modality.CT_ABDOMEN_WITH_CONTRAST,
            Modality.CT_ABDOMEN_WITHOUT_CONTRAST,
            Modality.CT_ABDOMEN_WITHOUT_AND_WITH_CONTRAST,
            Modality.NUCLEAR_HIDA,
            Modality.RADIOGRAPHY_ABDOMEN,
        }


@dataclass
class AllergyDetail:
    """Structured allergy information."""
    allergen: str
    severity: Optional[str] = None  # "Mild", "Moderate", "Severe"
    reaction: Optional[str] = None  # "Anaphylaxis/urticaria", "Rash", etc.
    allergy_type: Optional[str] = None


@dataclass
class PriorImagingStudy:
    """Represents a prior imaging study from the patient's history."""
    modality: str  # e.g., "Ultrasound Abdomen", "CT Abdomen and Pelvis"
    date: Optional[str] = None
    result: Optional[str] = None  # "negative", "equivocal", "positive"
    impression: Optional[str] = None
    findings: Optional[str] = None


@dataclass
class PatientContext:
    """Patient EHR data relevant for RUQ pain triage."""
    patient_id: Optional[str] = None
    age: Optional[int] = None
    sex: Optional[str] = None
    weight_kg: Optional[float] = None
    height_cm: Optional[float] = None
    bmi: Optional[float] = None

    # --- Core labs ---
    wbc: Optional[float] = None
    wbc_elevated: Optional[bool] = None
    ast: Optional[float] = None
    alt: Optional[float] = None
    alp: Optional[float] = None
    bilirubin_total: Optional[float] = None
    bilirubin_direct: Optional[float] = None
    lipase: Optional[float] = None
    creatinine: Optional[float] = None
    gfr: Optional[float] = None

    # --- Coagulation labs ---
    platelets: Optional[float] = None
    inr: Optional[float] = None
    pt: Optional[float] = None
    ptt: Optional[float] = None

    # --- Additional labs ---
    hemoglobin: Optional[float] = None
    glucose: Optional[float] = None
    tsh: Optional[float] = None
    crp: Optional[float] = None
    d_dimer: Optional[float] = None

    # --- Vitals ---
    temperature: Optional[float] = None
    has_fever: Optional[bool] = None
    heart_rate: Optional[int] = None
    blood_pressure_systolic: Optional[int] = None
    blood_pressure_diastolic: Optional[int] = None
    respiratory_rate: Optional[int] = None
    spo2: Optional[float] = None

    # --- Prior imaging (all modalities) ---
    prior_us_performed: bool = False
    prior_us_result: Optional[str] = None   # "negative", "equivocal", "positive"
    prior_us_findings: Optional[str] = None
    prior_ct_performed: bool = False
    prior_ct_result: Optional[str] = None
    prior_ct_findings: Optional[str] = None
    prior_mri_performed: bool = False
    prior_mri_result: Optional[str] = None
    prior_mri_findings: Optional[str] = None
    prior_hida_performed: bool = False
    prior_hida_result: Optional[str] = None
    prior_hida_findings: Optional[str] = None
    prior_imaging_studies: list = field(default_factory=list)  # List[PriorImagingStudy]

    # --- Allergies (structured) ---
    has_contrast_allergy: Optional[bool] = None
    iodinated_contrast_allergy: Optional[bool] = None
    gadolinium_allergy: Optional[bool] = None
    contrast_allergy_severity: Optional[str] = None  # "Mild", "Moderate", "Severe"
    contrast_allergy_reaction: Optional[str] = None   # "Anaphylaxis", "Rash", etc.
    shellfish_allergy: Optional[bool] = None
    allergy_list: list = field(default_factory=list)  # List[AllergyDetail]

    # --- Medical history (structured) ---
    medical_history: list = field(default_factory=list)       # condition names
    medical_history_icd10: list = field(default_factory=list)  # (condition, icd10) tuples
    surgical_history: list = field(default_factory=list)       # procedure names

    # --- Pregnancy ---
    is_pregnant: Optional[bool] = None
    pregnancy_trimester: Optional[str] = None  # "first", "second", "third"

    # --- Medications ---
    home_medications: list = field(default_factory=list)  # med name strings
    inpatient_medications: list = field(default_factory=list)
    on_anticoagulation: Optional[bool] = None
    anticoagulant_name: Optional[str] = None
    on_metformin: Optional[bool] = None
    on_nephrotoxic_drugs: Optional[bool] = None
    nephrotoxic_drug_names: list = field(default_factory=list)

    # --- Implants & devices ---
    has_cardiac_device: Optional[bool] = None  # pacemaker, ICD, etc.
    has_metallic_implants: Optional[bool] = None
    device_details: list = field(default_factory=list)

    # --- Clinical state ---
    is_icu_patient: Optional[bool] = None
    is_critically_ill: Optional[bool] = None
    is_post_operative: Optional[bool] = None
    days_post_op: Optional[int] = None
    on_tpn: Optional[bool] = None
    is_septic: Optional[bool] = None
    clinical_setting: Optional[str] = None  # "ed", "icu", "inpatient", "outpatient"

    # --- Notes ---
    clinical_notes: Optional[str] = None


@dataclass
class TriageInput:
    """Complete input for the RUQ protocol reviewer."""
    ordered_protocol: OrderedProtocol
    clinical_indication: str
    patient_context: PatientContext

@dataclass
class RecommendedProtocol:
    """A recommended alternative imaging protocol."""
    modality: Modality
    acr_rating: int
    rationale: str
    is_top_recommendation: bool = False

@dataclass
class IndicationQuality:
    """Assessment of clinical indication quality and coherence with EMR data."""
    quality_tier: str            # "GOOD", "QUESTIONABLE", "POOR", "JUNK"
    quality_score: float         # 0.0 (junk) to 1.0 (excellent)
    is_junk: bool = False
    junk_reasons: list = field(default_factory=list)
    coherence_score: float = 1.0
    coherence_flags: list = field(default_factory=list)
    llm_indication_assessment: Optional[str] = None
    llm_suggested_indication: Optional[str] = None
    summary: str = ""

    @property
    def requires_attention(self) -> bool:
        return self.quality_tier in ("QUESTIONABLE", "POOR", "JUNK")

    def to_dict(self) -> dict:
        return {
            "quality_tier": self.quality_tier,
            "quality_score": self.quality_score,
            "is_junk": self.is_junk,
            "junk_reasons": self.junk_reasons,
            "coherence_score": self.coherence_score,
            "coherence_flags": self.coherence_flags,
            "llm_indication_assessment": self.llm_indication_assessment,
            "llm_suggested_indication": self.llm_suggested_indication,
            "summary": self.summary,
            "requires_attention": self.requires_attention,
        }

@dataclass
class TriageOutput:
    """Complete output from the RUQ protocol reviewer."""
    priority: TriagePriority
    confidence: float
    classified_variant: RUQVariant
    variant_confidence: float
    variant_reasoning: str
    ordered_protocol_rating: Optional[int]
    ordered_protocol_category: Optional[AppropriatenessCategory]
    appropriateness_assessment: str
    recommended_protocols: list = field(default_factory=list)
    has_safety_concern: bool = False
    contrast_contraindicated: bool = False
    contrast_warning: Optional[str] = None
    renal_function_concern: bool = False
    renal_warning: Optional[str] = None
    pregnancy_concern: bool = False
    pregnancy_warning: Optional[str] = None
    coagulation_concern: bool = False
    coagulation_warning: Optional[str] = None
    mri_safety_concern: bool = False
    mri_safety_warning: Optional[str] = None
    metformin_warning: Optional[str] = None
    llm_reasoning: Optional[str] = None
    insufficient_information: bool = False
    missing_fields: list = field(default_factory=list)
    safety_warnings: list = field(default_factory=list)  # All safety warnings collected
    indication_quality: Optional[IndicationQuality] = None

    @property
    def priority_display(self) -> str:
        display_map = {
            TriagePriority.GREEN: "\U0001f7e2 LIKELY APPROPRIATE",
            TriagePriority.YELLOW: "\U0001f7e1 REQUIRES REVIEW",
            TriagePriority.RED: "\U0001f534 INAPPROPRIATE - ACTION REQUIRED",
            TriagePriority.PURPLE: "\U0001f7e3 INSUFFICIENT INFORMATION",
        }
        return display_map.get(self.priority, str(self.priority))

    @property
    def action_required(self) -> str:
        action_map = {
            TriagePriority.GREEN: "Low priority for manual review. Order appears appropriate per ACR criteria.",
            TriagePriority.YELLOW: "Manual review recommended. Order may be appropriate but requires radiologist assessment.",
            TriagePriority.RED: "Immediate action required. Order is inappropriate per ACR criteria. Contact ordering provider or change protocol.",
            TriagePriority.PURPLE: "Cannot evaluate appropriateness. Critical clinical information is missing from the EMR. Request additional clinical data from ordering provider before proceeding.",
        }
        return action_map.get(self.priority, "Review required")

    def to_radiologist_summary(self) -> str:
        lines = [
            "=" * 60,
            f"TRIAGE RESULT: {self.priority_display}",
            f"Confidence: {self.confidence:.0%}",
            "=" * 60,
        ]

        if self.insufficient_information:
            lines.append("")
            lines.append("\u26a0\ufe0f  MISSING CRITICAL INFORMATION:")
            for f in self.missing_fields:
                lines.append(f"  \u2022 {f}")
            lines.append("")
            lines.append(f"ACTION: {self.action_required}")
            lines.append("")
            lines.append("-" * 60)
            lines.append("Reference: ACR Appropriateness Criteria: Right Upper Quadrant Pain (2022)")
            return "\n".join(lines)

        lines.extend([
            "",
            f"CLINICAL VARIANT: {self.classified_variant.name}",
            f"  {VARIANT_DESCRIPTIONS.get(self.classified_variant, '')}",
            f"  Classification confidence: {self.variant_confidence:.0%}",
            "",
            "ORDERED PROTOCOL ASSESSMENT:",
        ])
        if self.ordered_protocol_rating:
            cat_str = self.ordered_protocol_category.value if self.ordered_protocol_category else 'N/A'
            lines.append(f"  ACR Rating: {self.ordered_protocol_rating}/9 ({cat_str})")
        lines.append(f"  {self.appropriateness_assessment}")

        # Show ALL safety warnings
        all_warnings = self.safety_warnings or []
        if not all_warnings:
            # Backward compat: collect individual warnings
            if self.contrast_warning:
                all_warnings.append(self.contrast_warning)
            if self.renal_warning:
                all_warnings.append(self.renal_warning)
            if self.pregnancy_warning:
                all_warnings.append(self.pregnancy_warning)
            if self.coagulation_warning:
                all_warnings.append(self.coagulation_warning)
            if self.mri_safety_warning:
                all_warnings.append(self.mri_safety_warning)
            if self.metformin_warning:
                all_warnings.append(self.metformin_warning)

        if all_warnings:
            lines.append("")
            lines.append("\u26a0\ufe0f  SAFETY CONCERNS:")
            for w in all_warnings:
                lines.append(f"  \u2022 {w}")

        if self.indication_quality and self.indication_quality.requires_attention:
            lines.append("")
            iq = self.indication_quality
            tier_marker = {"JUNK": "\U0001f6ab", "POOR": "\u26a0\ufe0f", "QUESTIONABLE": "\u2753"}.get(iq.quality_tier, "")
            lines.append(f"{tier_marker}  INDICATION QUALITY: {iq.quality_tier} ({iq.quality_score:.0%})")
            if iq.coherence_flags:
                for flag in iq.coherence_flags[:3]:
                    lines.append(f"  \u2022 {flag}")
            if iq.junk_reasons:
                for reason in iq.junk_reasons[:2]:
                    lines.append(f"  \u2022 {reason}")
            if iq.llm_suggested_indication:
                lines.append(f"  Suggested indication: {iq.llm_suggested_indication}")
            lines.append(f"  {iq.summary}")

        if self.recommended_protocols:
            lines.append("")
            lines.append("RECOMMENDED ALTERNATIVES:")
            for rec in self.recommended_protocols:
                marker = "\u2192" if rec.is_top_recommendation else "\u2022"
                lines.append(f"  {marker} {rec.modality.value} (ACR {rec.acr_rating}/9)")
                lines.append(f"    {rec.rationale}")

        lines.append("")
        lines.append(f"ACTION: {self.action_required}")
        lines.append("")
        lines.append("-" * 60)
        lines.append("Reference: ACR Appropriateness Criteria: Right Upper Quadrant Pain (2022)")

        return "\n".join(lines)

print("\u2713 Data models defined (expanded PatientContext with coag, meds, prior imaging, pregnancy, devices)")

In [None]:
# =============================================================================
# SAFETY CHECKS (Expanded)
# =============================================================================

@dataclass
class SafetyFlags:
    contrast_ordered: bool = False
    contrast_contraindicated: bool = False
    contrast_warning: Optional[str] = None
    renal_concern: bool = False
    renal_warning: Optional[str] = None
    pregnancy_concern: bool = False
    pregnancy_warning: Optional[str] = None
    coagulation_concern: bool = False
    coagulation_warning: Optional[str] = None
    mri_safety_concern: bool = False
    mri_safety_warning: Optional[str] = None
    metformin_warning: Optional[str] = None
    has_any_concern: bool = False
    all_warnings: list = field(default_factory=list)

def assess_safety(ordered_protocol: OrderedProtocol, patient_context: PatientContext) -> SafetyFlags:
    """Assess safety concerns for the ordered protocol."""
    flags = SafetyFlags()

    flags.contrast_ordered = ordered_protocol.requires_contrast()
    is_mri = ordered_protocol.modality in {
        Modality.MRI_ABDOMEN_WITHOUT_CONTRAST,
        Modality.MRI_ABDOMEN_WITHOUT_CONTRAST_WITH_MRCP,
        Modality.MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP,
    }
    involves_radiation = ordered_protocol.involves_radiation()

    # --- 1. Contrast allergy checks ---
    if flags.contrast_ordered:
        if ordered_protocol.requires_iodinated_contrast():
            if patient_context.iodinated_contrast_allergy:
                flags.contrast_contraindicated = True
                sev = f" ({patient_context.contrast_allergy_severity})" if patient_context.contrast_allergy_severity else ""
                rxn = f" — prior reaction: {patient_context.contrast_allergy_reaction}" if patient_context.contrast_allergy_reaction else ""
                flags.contrast_warning = f"CONTRAINDICATED: Iodinated contrast allergy{sev}{rxn}"
                flags.all_warnings.append(flags.contrast_warning)
            elif patient_context.shellfish_allergy:
                flags.contrast_warning = "Shellfish allergy noted — verify iodinated contrast tolerance before proceeding"
                flags.all_warnings.append(flags.contrast_warning)

        if ordered_protocol.requires_gadolinium():
            if patient_context.gadolinium_allergy:
                flags.contrast_contraindicated = True
                flags.contrast_warning = "CONTRAINDICATED: Gadolinium allergy/prior reaction"
                flags.all_warnings.append(flags.contrast_warning)

        if patient_context.has_contrast_allergy and not flags.contrast_contraindicated:
            if not flags.contrast_warning:
                flags.contrast_warning = "Patient reports contrast allergy — verify type before proceeding"
                flags.all_warnings.append(flags.contrast_warning)

    # --- 2. Renal function checks ---
    if flags.contrast_ordered or patient_context.on_nephrotoxic_drugs:
        gfr = patient_context.gfr
        cr = patient_context.creatinine
        if gfr is not None:
            if gfr < 30:
                flags.renal_concern = True
                flags.renal_warning = f"HIGH RISK: GFR {gfr:.0f} — significant renal impairment"
                if ordered_protocol.requires_gadolinium():
                    flags.renal_warning += ". NSF risk with gadolinium"
                flags.all_warnings.append(flags.renal_warning)
            elif gfr < 60:
                flags.renal_concern = True
                flags.renal_warning = f"MODERATE RISK: GFR {gfr:.0f} — reduced renal function"
                flags.all_warnings.append(flags.renal_warning)
        elif cr is not None and cr > 1.5:
            flags.renal_concern = True
            flags.renal_warning = f"Elevated creatinine ({cr:.1f}) — check GFR before contrast"
            flags.all_warnings.append(flags.renal_warning)

    # --- 3. Pregnancy checks ---
    if patient_context.is_pregnant:
        flags.pregnancy_concern = True
        trimester_str = f" ({patient_context.pregnancy_trimester} trimester)" if patient_context.pregnancy_trimester else ""
        if involves_radiation:
            flags.pregnancy_warning = (
                f"SAFETY: Patient is pregnant{trimester_str}. "
                f"Ordered protocol ({ordered_protocol.modality.value}) involves ionizing radiation. "
                f"Consider US or MRI without gadolinium."
            )
        elif ordered_protocol.requires_gadolinium():
            flags.pregnancy_warning = (
                f"CAUTION: Patient is pregnant{trimester_str}. "
                f"Gadolinium crosses the placenta — use only if benefit outweighs risk."
            )
        else:
            flags.pregnancy_warning = f"Note: Patient is pregnant{trimester_str}. US preferred when possible."
        flags.all_warnings.append(flags.pregnancy_warning)

    # --- 4. Coagulation checks (relevant for invasive procedures) ---
    if ordered_protocol.modality == Modality.IMAGE_GUIDED_CHOLECYSTOSTOMY:
        if patient_context.inr is not None and patient_context.inr > 1.5:
            flags.coagulation_concern = True
            flags.coagulation_warning = f"Elevated INR ({patient_context.inr:.1f}) — correct coagulopathy before cholecystostomy"
            flags.all_warnings.append(flags.coagulation_warning)
        if patient_context.platelets is not None and patient_context.platelets < 50:
            flags.coagulation_concern = True
            warning = f"Thrombocytopenia (platelets {patient_context.platelets:.0f}K) — transfuse before cholecystostomy"
            if flags.coagulation_warning:
                flags.coagulation_warning += f"; {warning}"
            else:
                flags.coagulation_warning = warning
            flags.all_warnings.append(warning)
        if patient_context.on_anticoagulation:
            flags.coagulation_concern = True
            ac_name = f" ({patient_context.anticoagulant_name})" if patient_context.anticoagulant_name else ""
            warning = f"Patient on anticoagulation{ac_name} — hold/reverse before invasive procedure"
            if flags.coagulation_warning:
                flags.coagulation_warning += f"; {warning}"
            else:
                flags.coagulation_warning = warning
            flags.all_warnings.append(warning)

    # --- 5. MRI safety checks (metallic implants, cardiac devices) ---
    if is_mri:
        if patient_context.has_cardiac_device:
            flags.mri_safety_concern = True
            details = f": {', '.join(patient_context.device_details)}" if patient_context.device_details else ""
            flags.mri_safety_warning = f"Cardiac device present{details} — verify MRI compatibility"
            flags.all_warnings.append(flags.mri_safety_warning)
        elif patient_context.has_metallic_implants:
            flags.mri_safety_concern = True
            flags.mri_safety_warning = "Metallic implants noted — verify MRI compatibility"
            flags.all_warnings.append(flags.mri_safety_warning)

    # --- 6. Metformin + contrast interaction ---
    if flags.contrast_ordered and patient_context.on_metformin:
        if patient_context.gfr is not None and patient_context.gfr < 60:
            flags.metformin_warning = (
                f"Patient on metformin with GFR {patient_context.gfr:.0f} — "
                "hold metformin 48h post-contrast, recheck renal function"
            )
        else:
            flags.metformin_warning = "Patient on metformin — standard post-contrast monitoring"
        flags.all_warnings.append(flags.metformin_warning)

    flags.has_any_concern = (
        flags.contrast_contraindicated or flags.renal_concern or
        flags.pregnancy_concern or flags.coagulation_concern or
        flags.mri_safety_concern or bool(flags.metformin_warning)
    )
    return flags

# =============================================================================
# INFORMATION COMPLETENESS CHECK
# =============================================================================

def assess_information_completeness(patient_context: PatientContext, clinical_indication: str = "") -> tuple:
    """
    Check whether the PatientContext has enough critical information for
    reliable ACR variant classification and appropriateness triage.

    Returns:
        Tuple of (is_sufficient: bool, missing_fields: list[str], missing_critical_count: int)
    """
    missing = []

    # --- Clinical indication adequacy check ---
    # Vague indications (e.g., "Pain", "Eval", "Imaging requested") provide
    # no diagnostic guidance and make variant classification unreliable.
    VAGUE_INDICATIONS = {
        "pain", "discomfort", "eval", "evaluation",
        "imaging requested", "doctor requested", "physician requested",
        "screening", "pcp referral", "referral",
        "follow-up", "follow-up imaging", "followup",
        "outside hospital transfer", "osh transfer",
        "ed workup", "workup",
        "altered mental status",
        "not documented", "(not documented)",
    }
    RUQ_CLINICAL_KEYWORDS = [
        "ruq", "right upper quadrant", "gallbladder", "gallstone", "cholecyst",
        "biliary", "bile", "hepat", "liver", "pancrea", "jaundice",
        "bilirubin", "choledocho", "cholangi", "murphy",
        "r/o", "rule out", "suspect", "concern for",
        "mass", "lesion", "abscess", "obstruction", "dilation",
        "fatty liver", "steatosis", "cirrhosis", "nausea", "vomiting",
        "cbd", "hcc", "pud", "cholelithiasis", "stone", "colic", "epigastric",
    ]

    indication_lower = clinical_indication.strip().lower()
    indication_is_vague = False

    if not indication_lower or indication_lower in VAGUE_INDICATIONS:
        indication_is_vague = True
    elif len(indication_lower.split()) < 4:
        # Short indications without any RUQ-relevant clinical term
        if not any(kw in indication_lower for kw in RUQ_CLINICAL_KEYWORDS):
            indication_is_vague = True

    if indication_is_vague:
        missing.append("Clinical indication (vague or non-specific)")

    # --- PatientContext field checks ---
    # Must-have fields: without ANY of these, we cannot meaningfully triage
    if patient_context.age is None:
        missing.append("Patient age")
    if patient_context.sex is None:
        missing.append("Patient sex")

    # Fever/temperature status -- critical for distinguishing Variants 3 vs 4
    has_fever_info = (patient_context.has_fever is not None or
                      patient_context.temperature is not None)
    if not has_fever_info:
        missing.append("Fever status or temperature")

    # WBC status -- critical for distinguishing Variants 3 vs 4
    has_wbc_info = (patient_context.wbc is not None or
                    patient_context.wbc_elevated is not None)
    if not has_wbc_info:
        missing.append("WBC count or leukocytosis status")

    # Decision-critical fields (need at least some of these)
    decision_missing = 0
    decision_fields = []

    if patient_context.is_icu_patient is None:
        decision_missing += 1
        decision_fields.append("ICU status")

    if patient_context.is_post_operative is None:
        decision_missing += 1
        decision_fields.append("Post-operative status")

    if patient_context.gfr is None and patient_context.creatinine is None:
        decision_missing += 1
        decision_fields.append("Renal function (GFR or creatinine)")

    # If 3+ decision-critical fields are missing, flag them too
    if decision_missing >= 3:
        missing.extend(decision_fields)

    # Must-have threshold: any must-have missing -> insufficient
    must_have_missing = sum([
        patient_context.age is None,
        patient_context.sex is None,
        not has_fever_info,
        not has_wbc_info,
        indication_is_vague,
    ])

    is_sufficient = must_have_missing == 0 and decision_missing < 3

    return (is_sufficient, missing, must_have_missing + decision_missing)



# =============================================================================
# INDICATION QUALITY ASSESSMENT (Tiers 1-3)
# =============================================================================

import re as _re_iq

# Medical terminology reference set for junk detection
MEDICAL_TERM_SET = frozenset([
    "ruq", "right upper quadrant", "gallbladder", "gallstone", "cholecyst",
    "biliary", "bile", "hepat", "liver", "pancrea", "jaundice",
    "bilirubin", "choledocho", "cholangi", "murphy",
    "r/o", "rule out", "suspect", "concern for",
    "mass", "lesion", "abscess", "obstruction", "dilation",
    "fatty liver", "steatosis", "cirrhosis", "nausea", "vomiting",
    "cbd", "hcc", "pud", "cholelithiasis", "stone", "colic", "epigastric",
    "pain", "acute", "chronic", "fever", "tenderness", "elevated",
    "abnormal", "workup", "eval", "imaging", "follow", "screening",
    "abdominal", "abdomen", "chest", "pelvis", "flank",
    "cancer", "tumor", "metasta", "carcinoma", "lymphoma",
    "infection", "sepsis", "inflam", "edema", "effusion",
    "fracture", "trauma", "injury", "bleed", "hemorrh",
    "obstruct", "stricture", "hernia", "perforation",
    "renal", "kidney", "ureter", "bladder",
    "ct", "mri", "ultrasound", "hida", "mrcp",
    "post-op", "pre-op", "status post", "s/p",
    "hx", "pmh", "history", "dx", "diagnosis",
])


def detect_junk_indication(indication):
    """Tier 1: Detect whether a clinical indication is junk/garbage text."""
    reasons = []
    text = indication.strip()
    if not text:
        return True, ["Empty indication"]
    if len(text) <= 2:
        return True, [f"Indication too short ({len(text)} characters)"]

    alpha_chars = sum(1 for c in text if c.isalpha())
    total_chars = len(text.replace(" ", ""))
    if total_chars > 0 and alpha_chars / total_chars < 0.5:
        reasons.append(f"High non-alphabetic character ratio ({alpha_chars}/{total_chars})")

    if _re_iq.search(r'(.)\1{4,}', text):
        reasons.append("Repeated character pattern detected")

    if _re_iq.search(r'https?://|www\.|\.com|\.org|@.*\.\w+|\\\\|//\w', text):
        reasons.append("Contains URL/email/file path")

    if _re_iq.match(r'^[\d\s\-\.]+$', text):
        reasons.append("Numeric-only text — likely MRN or order number")

    text_lower = text.lower()
    has_medical_term = any(term in text_lower for term in MEDICAL_TERM_SET)

    if not has_medical_term and len(text.split()) >= 2:
        reasons.append("No recognizable medical terminology")

    is_junk = len(reasons) >= 1 and not has_medical_term
    return is_junk, reasons


def _match_term(text, terms):
    for t in terms:
        if t in text:
            return t
    return "unknown"


def assess_indication_coherence(indication, patient_context):
    """Tier 2: Cross-reference indication against EMR data for clinical coherence."""
    flags = []
    indication_lower = indication.lower().strip()

    # RULE 1: Acute infection but labs/vitals disagree
    ACUTE_INFECTION_TERMS = ["cholecystitis", "cholangitis", "appendicitis", "abscess", "sepsis", "infected", "acute infection"]
    claims_infection = any(t in indication_lower for t in ACUTE_INFECTION_TERMS)
    if claims_infection:
        has_fever = patient_context.has_fever
        has_elevated_wbc = patient_context.wbc_elevated
        if has_fever is False and has_elevated_wbc is False:
            matched = _match_term(indication_lower, ACUTE_INFECTION_TERMS)
            flags.append(f"Indication suggests acute infection ('{matched}') but patient is afebrile with normal WBC")
        elif has_fever is False and has_elevated_wbc is None:
            matched = _match_term(indication_lower, ACUTE_INFECTION_TERMS)
            flags.append(f"Indication suggests infection ('{matched}') but patient is afebrile — verify WBC")

    # RULE 2: Jaundice/obstruction but bilirubin normal
    OBSTRUCTIVE_TERMS = ["jaundice", "obstructive", "obstruction", "choledocholithiasis", "cbd stone", "common bile duct", "biliary obstruction"]
    if any(t in indication_lower for t in OBSTRUCTIVE_TERMS):
        bili = patient_context.bilirubin_total
        alp = patient_context.alp
        if bili is not None and bili < 1.2:
            flags.append(f"Indication mentions jaundice/obstruction but total bilirubin is {bili} mg/dL (normal)")
        if alp is not None and alp < 120 and bili is not None and bili < 1.2:
            flags.append(f"Indication suggests biliary obstruction but ALP ({alp}) and bilirubin ({bili}) both normal")

    # RULE 3: Pancreatitis but lipase normal
    if "pancreatitis" in indication_lower:
        lipase = patient_context.lipase
        if lipase is not None and lipase < 60:
            flags.append(f"Indication mentions pancreatitis but lipase is {lipase} U/L (normal)")

    # RULE 4: Hepatitis/liver disease but LFTs normal
    HEPATIC_TERMS = ["hepatitis", "liver disease", "cirrhosis", "liver failure"]
    if any(t in indication_lower for t in HEPATIC_TERMS):
        ast, alt = patient_context.ast, patient_context.alt
        if ast is not None and alt is not None and ast < 40 and alt < 40:
            flags.append(f"Indication suggests hepatic disease but AST ({ast}) and ALT ({alt}) both normal")

    # RULE 5: Gallbladder indication in post-cholecystectomy patient
    CHOLECYSTITIS_TERMS = ["cholecystitis", "gallbladder", "gallstone", "cholelithiasis"]
    if any(t in indication_lower for t in CHOLECYSTITIS_TERMS):
        surgical_hx_lower = " ".join(patient_context.surgical_history).lower()
        if "cholecystectomy" in surgical_hx_lower:
            flags.append("Indication references gallbladder/cholecystitis but patient has prior cholecystectomy — anatomically inaccurate")

    # RULE 6: Non-RUQ pathology
    NON_RUQ_TERMS = {"appendicitis": "RLQ pathology", "renal colic": "flank/urological", "ovarian": "pelvic", "pulmonary embolism": "chest", "pneumonia": "chest", "cardiac": "chest", "headache": "neurological", "stroke": "neurological"}
    for term, explanation in NON_RUQ_TERMS.items():
        if term in indication_lower:
            flags.append(f"Indication mentions '{term}' ({explanation}) — unlikely for RUQ imaging")

    # RULE 7: Generic indication despite rich EMR data
    GENERIC_TERMS = {"abdominal pain", "abd pain", "belly pain", "stomach pain"}
    is_generic = indication_lower in GENERIC_TERMS or (len(indication_lower.split()) <= 3 and any(g in indication_lower for g in GENERIC_TERMS))
    if is_generic:
        findings = []
        if patient_context.has_fever: findings.append("fever")
        if patient_context.wbc_elevated: findings.append("elevated WBC")
        if patient_context.bilirubin_total and patient_context.bilirubin_total > 1.2: findings.append(f"elevated bilirubin ({patient_context.bilirubin_total})")
        if patient_context.lipase and patient_context.lipase > 60: findings.append(f"elevated lipase ({patient_context.lipase})")
        if patient_context.alt and patient_context.alt > 40: findings.append(f"elevated ALT ({patient_context.alt})")
        if len(findings) >= 2:
            flags.append(f"Indication is generic ('{indication.strip()}') but EMR shows {', '.join(findings)} — more specific indication would improve accuracy")

    # Compute coherence score
    if not flags:
        coherence_score = 1.0
    else:
        major = sum(1 for f in flags if "inaccurate" in f or "not RUQ" in f or "anatomically" in f)
        minor = len(flags) - major
        coherence_score = max(0.0, 1.0 - (major * 0.3) - (minor * 0.15))
    return coherence_score, flags


def assess_indication_quality(indication, patient_context, llm_assessment=None):
    """Comprehensive indication quality assessment combining all three tiers."""
    is_junk, junk_reasons = detect_junk_indication(indication)
    if is_junk:
        return IndicationQuality(quality_tier="JUNK", quality_score=0.0, is_junk=True, junk_reasons=junk_reasons, coherence_score=0.0, summary=f"Non-clinical text: {'; '.join(junk_reasons)}")

    coherence_score, coherence_flags = assess_indication_coherence(indication, patient_context)

    llm_quality, llm_explanation, llm_suggested = None, None, None
    if llm_assessment:
        llm_quality = llm_assessment.get("quality", "GOOD")
        llm_explanation = llm_assessment.get("explanation")
        llm_suggested = llm_assessment.get("suggested_indication")

    if coherence_score >= 0.8 and (llm_quality in (None, "GOOD")):
        quality_tier = "GOOD" if not coherence_flags else "QUESTIONABLE"
        quality_score = min(1.0, coherence_score)
    elif coherence_score >= 0.5 or llm_quality == "QUESTIONABLE":
        quality_tier, quality_score = "QUESTIONABLE", coherence_score
    else:
        quality_tier, quality_score = "POOR", coherence_score

    if llm_quality == "POOR" and quality_tier not in ("JUNK", "POOR"):
        quality_tier, quality_score = "POOR", min(quality_score, 0.3)

    parts = []
    if coherence_flags: parts.append(f"Clinical coherence: {'; '.join(coherence_flags[:2])}")
    if junk_reasons: parts.append(f"Format issues: {'; '.join(junk_reasons)}")
    if llm_explanation: parts.append(f"LLM assessment: {llm_explanation}")
    summary = " | ".join(parts) if parts else "Indication appears clinically appropriate"

    return IndicationQuality(quality_tier=quality_tier, quality_score=quality_score, is_junk=False, junk_reasons=junk_reasons, coherence_score=coherence_score, coherence_flags=coherence_flags, llm_indication_assessment=llm_explanation, llm_suggested_indication=llm_suggested, summary=summary)

print("\u2713 Safety checks, information completeness, and indication quality assessment defined")

## 3. Load MedGemma 27B Model

We load MedGemma 27B text instruction-tuned (`google/medgemma-27b-text-it`) with 4-bit NF4 quantization. The quantized model uses approximately 16 GB VRAM and requires an L4 GPU (24 GB) or higher.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

MODEL_ID = "google/medgemma-27b-text-it"

# 4-bit NF4 quantization -- reduces 27B model to ~16 GB VRAM
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

print(f"Loading MedGemma 27B model ({MODEL_ID})...")
print(f"  Note: Requires ~16 GB VRAM with 4-bit quantization (A100/L4 recommended)")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
)
processor = AutoTokenizer.from_pretrained(MODEL_ID)

print(f"\u2713 MedGemma 27B loaded!")
if torch.cuda.is_available():
    print(f"  GPU memory used: {torch.cuda.memory_allocated() / 1e9:.1f} GB")


## 4. Define RUQ Protocol Reviewer

In [None]:
class RUQProtocolReviewer:
    """
    Main class for RUQ pain imaging order triage.
    Uses MedGemma to classify clinical variants and assess appropriateness.
    """

    def __init__(self, model, processor):
        self.model = model
        self.processor = processor
        self.acr_criteria = ACR_RUQ_CRITERIA

    def review(self, triage_input: TriageInput) -> TriageOutput:
        """Main entry point for protocol review."""
        # Step 0: Check information completeness
        completeness = assess_information_completeness(
            triage_input.patient_context,
            clinical_indication=triage_input.clinical_indication,
        )
        is_sufficient, missing_fields, missing_count = completeness

        if not is_sufficient:
            if any("indication" in f.lower() for f in missing_fields):
                reasoning = (
                    "Clinical indication is too vague for reliable variant classification. "
                    "Request specific diagnosis, symptom, or clinical question from ordering provider."
                )
                assessment = "Cannot assess - clinical indication insufficient."
            else:
                reasoning = "Insufficient clinical information to classify variant."
                assessment = "Cannot assess - insufficient clinical information."
            return TriageOutput(
                priority=TriagePriority.PURPLE,
                confidence=0.0,
                classified_variant=RUQVariant.VARIANT_1,
                variant_confidence=0.0,
                variant_reasoning=reasoning,
                ordered_protocol_rating=None,
                ordered_protocol_category=None,
                appropriateness_assessment=assessment,
                insufficient_information=True,
                missing_fields=missing_fields,
            )

        # Step 0.5: Early junk detection (Tier 1)
        is_junk, junk_reasons = detect_junk_indication(triage_input.clinical_indication)
        if is_junk:
            junk_quality = IndicationQuality(
                quality_tier="JUNK", quality_score=0.0,
                is_junk=True, junk_reasons=junk_reasons,
                coherence_score=0.0,
                summary=f"Non-clinical indication text: {'; '.join(junk_reasons)}"
            )
            return TriageOutput(
                priority=TriagePriority.PURPLE,
                confidence=0.0,
                classified_variant=RUQVariant.VARIANT_1,
                variant_confidence=0.0,
                variant_reasoning="Indication is not valid clinical text — cannot classify variant.",
                ordered_protocol_rating=None,
                ordered_protocol_category=None,
                appropriateness_assessment="Cannot assess — indication is not valid clinical text.",
                insufficient_information=True,
                missing_fields=["Clinical indication (junk/non-clinical text detected)"],
                indication_quality=junk_quality,
            )

        # Step 1: Safety assessment
        safety_flags = assess_safety(triage_input.ordered_protocol, triage_input.patient_context)

        # Step 2: Classify variant using LLM (now includes indication assessment)
        variant_result = self._classify_variant(triage_input)

        # Enforce hard ACR decision tree rules
        variant_result = self._enforce_variant_rules(
            variant_result, triage_input.patient_context
        )

        # Step 3: Indication quality assessment (Tiers 1-3 combined)
        indication_quality = assess_indication_quality(
            triage_input.clinical_indication,
            triage_input.patient_context,
            llm_assessment=variant_result.get("indication_assessment"),
        )

        # Step 4: ACR lookup
        acr_assessment = self._assess_appropriateness(
            variant_result["variant"],
            triage_input.ordered_protocol.modality
        )

        # Step 5: Determine priority (quality-aware)
        priority, confidence = self._determine_priority(
            acr_assessment, safety_flags, variant_result["confidence"],
            indication_quality=indication_quality,
        )

        # Step 6: Generate recommendations if needed
        recommendations = []
        if priority != TriagePriority.GREEN:
            recommendations = self._generate_recommendations(
                variant_result["variant"],
                triage_input.ordered_protocol.modality
            )

        return TriageOutput(
            priority=priority,
            confidence=confidence,
            classified_variant=variant_result["variant"],
            variant_confidence=variant_result["confidence"],
            variant_reasoning=variant_result["reasoning"],
            ordered_protocol_rating=acr_assessment["rating"],
            ordered_protocol_category=acr_assessment["category"],
            appropriateness_assessment=acr_assessment["assessment"],
            recommended_protocols=recommendations,
            has_safety_concern=safety_flags.has_any_concern,
            contrast_contraindicated=safety_flags.contrast_contraindicated,
            contrast_warning=safety_flags.contrast_warning,
            renal_function_concern=safety_flags.renal_concern,
            renal_warning=safety_flags.renal_warning,
            pregnancy_concern=safety_flags.pregnancy_concern,
            pregnancy_warning=safety_flags.pregnancy_warning,
            coagulation_concern=safety_flags.coagulation_concern,
            coagulation_warning=safety_flags.coagulation_warning,
            mri_safety_concern=safety_flags.mri_safety_concern,
            mri_safety_warning=safety_flags.mri_safety_warning,
            metformin_warning=safety_flags.metformin_warning,
            safety_warnings=list(safety_flags.all_warnings),
            indication_quality=indication_quality,
            llm_reasoning=variant_result.get("llm_reasoning"),
        )


    def _enforce_variant_rules(self, variant_result: dict, ctx) -> dict:
        """
        Override LLM variant classification when it violates hard ACR rules.

        The ACR decision tree has deterministic gates that the LLM sometimes
        ignores (e.g., classifying as Variant 4 when no prior US exists).
        This method enforces those rules after the LLM call.
        """
        variant = variant_result["variant"]
        original = variant

        # RULE 1: No prior US → CANNOT be Variant 3, 4, or 5
        # These variants require a prior negative/equivocal US
        if not ctx.prior_us_performed:
            if variant in (RUQVariant.VARIANT_3, RUQVariant.VARIANT_4, RUQVariant.VARIANT_5):
                # Determine if biliary suspected from LLM reasoning or indication
                reasoning = variant_result.get("reasoning", "").lower()
                is_biliary = any(kw in reasoning for kw in [
                    "biliary", "cholecyst", "gallbladder", "gallstone",
                    "choledocho", "bile duct",
                ])
                # Downgrade to Variant 2 (suspected biliary) or 1 (unknown)
                if is_biliary or variant in (RUQVariant.VARIANT_3, RUQVariant.VARIANT_4):
                    variant_result["variant"] = RUQVariant.VARIANT_2
                else:
                    variant_result["variant"] = RUQVariant.VARIANT_1
                variant_result["reasoning"] = (
                    f"[Rule override: LLM classified as {original.name} but no prior US "
                    f"performed → downgraded to {variant_result['variant'].name}] "
                    + variant_result.get("reasoning", "")
                )
                # Lower confidence since we had to override
                variant_result["confidence"] = min(variant_result["confidence"], 0.75)

        # RULE 2: Variant 5 requires ICU/critically ill
        if variant_result["variant"] == RUQVariant.VARIANT_5:
            if not ctx.is_icu_patient and not ctx.is_critically_ill:
                variant_result["variant"] = RUQVariant.VARIANT_4 if ctx.prior_us_performed else RUQVariant.VARIANT_2
                variant_result["reasoning"] = (
                    f"[Rule override: LLM classified as VARIANT_5 but patient is not "
                    f"ICU/critically ill → reclassified to {variant_result['variant'].name}] "
                    + variant_result.get("reasoning", "")
                )
                variant_result["confidence"] = min(variant_result["confidence"], 0.75)

        # RULE 3: Prior US + fever/elevated WBC → Variant 4 (not 3)
        if ctx.prior_us_performed and variant_result["variant"] == RUQVariant.VARIANT_3:
            if ctx.has_fever or ctx.wbc_elevated:
                variant_result["variant"] = RUQVariant.VARIANT_4
                variant_result["reasoning"] = (
                    f"[Rule override: LLM classified as VARIANT_3 but patient has "
                    f"fever/elevated WBC → upgraded to VARIANT_4] "
                    + variant_result.get("reasoning", "")
                )

        # RULE 4: Prior US + afebrile + normal WBC → Variant 3 (not 4)
        if ctx.prior_us_performed and variant_result["variant"] == RUQVariant.VARIANT_4:
            if ctx.has_fever is False and ctx.wbc_elevated is False:
                variant_result["variant"] = RUQVariant.VARIANT_3
                variant_result["reasoning"] = (
                    f"[Rule override: LLM classified as VARIANT_4 but patient is "
                    f"afebrile with normal WBC → downgraded to VARIANT_3] "
                    + variant_result.get("reasoning", "")
                )

        # RULE 5: Prior US with POSITIVE result → Variant 2 (not 3/4)
        # Variants 3 and 4 assume a non-diagnostic (negative/equivocal) prior US.
        # If the US was positive (e.g., confirmed cholelithiasis), the diagnosis
        # is already established and the patient belongs in Variant 2 (initial
        # biliary workup), not the "further workup after non-diagnostic US" track.
        if ctx.prior_us_performed and ctx.prior_us_result == "positive":
            if variant_result["variant"] in (RUQVariant.VARIANT_3, RUQVariant.VARIANT_4):
                original_v = variant_result["variant"]
                variant_result["variant"] = RUQVariant.VARIANT_2
                variant_result["reasoning"] = (
                    f"[Rule override: LLM classified as {original_v.name} but prior US "
                    f"was POSITIVE (not negative/equivocal) → reclassified to VARIANT_2. "
                    f"Variants 3/4 require a non-diagnostic prior US.] "
                    + variant_result.get("reasoning", "")
                )
                variant_result["confidence"] = min(variant_result["confidence"], 0.80)

        return variant_result

    def _classify_variant(self, triage_input: TriageInput) -> dict:
        """Classify clinical variant using MedGemma."""
        prompt = self._build_classification_prompt(triage_input)

        messages = [
            {"role": "user", "content": prompt}
        ]

        inputs = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        ).to(self.model.device)

        input_len = inputs["input_ids"].shape[-1]

        with torch.inference_mode():
            generation = self.model.generate(
                **inputs,
                max_new_tokens=500,
                do_sample=False,
            )
            output_tokens = generation[0][input_len:]

        response = self.processor.decode(output_tokens, skip_special_tokens=True)
        return self._parse_variant_response(response)

    def _build_classification_prompt(self, triage_input: TriageInput) -> str:
        ctx = triage_input.patient_context

        prompt = f"""You are an expert radiologist. Classify this RUQ pain case into one of 5 ACR variants.

CLINICAL INDICATION: {triage_input.clinical_indication}

PATIENT:
- Age: {ctx.age or 'Unknown'}, Sex: {ctx.sex or 'Unknown'}
- Fever: {'Yes' if ctx.has_fever else 'No' if ctx.has_fever is False else 'Unknown'}
- WBC: {ctx.wbc or 'Unknown'} ({'Elevated' if ctx.wbc_elevated else 'Normal' if ctx.wbc_elevated is False else 'Unknown'})
- Prior US: {'Yes - ' + (ctx.prior_us_result or 'performed') if ctx.prior_us_performed else 'No'}
- ICU patient: {'Yes' if ctx.is_icu_patient else 'No'}
- Post-operative: {'Yes' if ctx.is_post_operative else 'No'}
- On TPN: {'Yes' if ctx.on_tpn else 'No'}
- Septic: {'Yes' if ctx.is_septic else 'No'}

ACR VARIANTS:
VARIANT_1: Unknown etiology, initial imaging (no prior US, unclear cause)
VARIANT_2: Suspected biliary disease, initial imaging (biliary symptoms, NO prior US)
VARIANT_3: Suspected biliary, afebrile/normal WBC, negative/equivocal US
VARIANT_4: Suspected biliary, fever/elevated WBC, negative/equivocal US
VARIANT_5: Suspected acalculous cholecystitis (ICU/critically ill, negative US)

RULES:
- No prior US → Cannot be Variant 3, 4, or 5
- Prior US + fever/elevated WBC → Variant 4
- Prior US + afebrile/normal WBC → Variant 3
- ICU/critically ill + prior negative US → Variant 5

CRITICAL V1 vs V2 DISTINCTION (when no prior US):
- VARIANT_2 if indication mentions: gallbladder, gallstone, cholecystitis, biliary,
  bile duct, choledocholithiasis, cholangitis, Murphy's sign, postprandial RUQ pain,
  fatty food intolerance, jaundice, biliary colic, or gallstone history
- VARIANT_1 if indication is non-specific (abdominal pain, hepatitis workup,
  liver mass, pancreatitis, or pain without biliary localizing features)
- When uncertain between V1 and V2, default to VARIANT_1

Also assess whether the CLINICAL INDICATION above is accurate and consistent with the patient data.

Respond ONLY with JSON:
{{"variant": "VARIANT_X", "confidence": 0.XX, "reasoning": "brief explanation", "indication_assessment": {{"quality": "GOOD or QUESTIONABLE or POOR", "explanation": "1 sentence on whether indication matches clinical picture", "suggested_indication": "what the indication should say, or null if adequate"}}}}"""
        return prompt

    def _parse_variant_response(self, response: str) -> dict:
        result = {
            "variant": RUQVariant.VARIANT_1,
            "confidence": 0.5,
            "reasoning": "Unable to parse response",
            "indication_assessment": None,
        }

        # Use balanced-brace approach for nested JSON
        parsed = None
        brace_depth = 0
        json_start = None
        for i, ch in enumerate(response):
            if ch == '{':
                if brace_depth == 0:
                    json_start = i
                brace_depth += 1
            elif ch == '}':
                brace_depth -= 1
                if brace_depth == 0 and json_start is not None:
                    candidate = response[json_start:i + 1]
                    if '"variant"' in candidate:
                        try:
                            parsed = json.loads(candidate)
                            break
                        except json.JSONDecodeError:
                            json_start = None

        # Fallback: flat regex
        if parsed is None:
            json_match = re.search(r'\{[^{}]*"variant"[^{}]*\}', response, re.DOTALL)
            if json_match:
                try:
                    parsed = json.loads(json_match.group())
                except:
                    pass

        if parsed:
            variant_str = parsed.get("variant", "VARIANT_1").upper()
            variant_map = {
                "VARIANT_1": RUQVariant.VARIANT_1,
                "VARIANT_2": RUQVariant.VARIANT_2,
                "VARIANT_3": RUQVariant.VARIANT_3,
                "VARIANT_4": RUQVariant.VARIANT_4,
                "VARIANT_5": RUQVariant.VARIANT_5,
            }
            result["variant"] = variant_map.get(variant_str, RUQVariant.VARIANT_1)
            result["confidence"] = min(max(float(parsed.get("confidence", 0.7)), 0.0), 1.0)
            result["reasoning"] = parsed.get("reasoning", "")

            # Extract indication assessment (Tier 3)
            ia = parsed.get("indication_assessment")
            if isinstance(ia, dict):
                result["indication_assessment"] = {
                    "quality": ia.get("quality", "GOOD"),
                    "explanation": ia.get("explanation", ""),
                    "suggested_indication": ia.get("suggested_indication"),
                }

        return result

    def _assess_appropriateness(self, variant: RUQVariant, modality: Modality) -> dict:
        criteria = self.acr_criteria[variant]
        rating_obj = criteria.get_rating(modality)

        if rating_obj is None:
            return {
                "rating": None,
                "category": None,
                "assessment": f"Modality {modality.value} not rated for this variant. Manual review needed."
            }

        category = rating_obj.category
        rating = rating_obj.rating

        if category == AppropriatenessCategory.USUALLY_APPROPRIATE:
            assessment = f"Usually Appropriate (ACR {rating}/9) for {VARIANT_DESCRIPTIONS[variant]}."
        elif category == AppropriatenessCategory.MAY_BE_APPROPRIATE:
            assessment = f"May Be Appropriate (ACR {rating}/9). Consider higher-rated alternatives."
        else:
            appropriate = criteria.usually_appropriate
            alts = [r.modality.value for r in appropriate[:2]]
            assessment = f"Usually Not Appropriate (ACR {rating}/9). Recommended: {', '.join(alts)}."

        return {"rating": rating, "category": category, "assessment": assessment}

    def _determine_priority(self, acr_assessment, safety_flags, variant_confidence, indication_quality=None):
        # ANY safety concern → RED (safety always overrides ACR rating)
        if safety_flags.has_any_concern:
            return TriagePriority.RED, 0.90

        # Normal ACR-based logic (no safety concerns)
        rating = acr_assessment.get("rating")
        if rating is None:
            return TriagePriority.YELLOW, 0.5

        # Apply indication quality confidence modifier
        effective_confidence = variant_confidence
        if indication_quality is not None:
            if indication_quality.quality_tier == "POOR":
                effective_confidence = min(effective_confidence, 0.60)
            elif indication_quality.quality_tier == "QUESTIONABLE":
                effective_confidence *= 0.85

        if rating <= 3:
            return TriagePriority.RED, min(0.90, effective_confidence + 0.1)
        if rating <= 6:
            return TriagePriority.YELLOW, min(0.80, effective_confidence)

        # Low confidence → YELLOW even if ACR rating is high
        if effective_confidence < 0.7:
            return TriagePriority.YELLOW, min(0.75, effective_confidence)

        return TriagePriority.GREEN, min(effective_confidence, 0.95)

    def _generate_recommendations(self, variant: RUQVariant, ordered_modality: Modality):
        criteria = self.acr_criteria[variant]
        recommendations = []

        for i, rating_obj in enumerate(criteria.usually_appropriate):
            if rating_obj.modality != ordered_modality:
                recommendations.append(RecommendedProtocol(
                    modality=rating_obj.modality,
                    acr_rating=rating_obj.rating,
                    rationale=f"ACR Usually Appropriate for {VARIANT_DESCRIPTIONS[variant]}",
                    is_top_recommendation=(i == 0),
                ))
            if len(recommendations) >= 2:
                break

        return recommendations

# Initialize reviewer
reviewer = RUQProtocolReviewer(model=model, processor=processor)
print("✓ RUQ Protocol Reviewer initialized")

## 5. EMR-to-PatientContext Mapping

Since the generated EMR JSONs already contain fully structured data (demographics, vitals, labs, allergies, problem lists), we **skip the LLM-based EMR extraction step** and build `PatientContext` via deterministic field mapping. This is faster, lossless, and more reliable than an additional LLM call.

In [None]:
# =============================================================================
# BUILD PATIENT CONTEXT FROM GENERATED EMR JSON (Expanded)
# =============================================================================
# Extracts ALL clinically relevant fields from the generated EMR JSON:
# - Demographics (age, sex, BMI)
# - Vitals (full set including RR, SpO2)
# - Labs (core + coagulation + additional)
# - Allergies (with severity, reaction, type details)
# - Medications (with anticoagulant, metformin, nephrotoxic flags)
# - Prior imaging (US, CT, MRI, HIDA — all modalities)
# - Pregnancy (from special_populations + problem_list)
# - Implants/devices (from surgical history)
# - Clinical state (ICU, post-op, TPN, sepsis)
# - ICD-10 codes from problem_list
# - Clinical notes for LLM context

# Keywords for medication classification
ANTICOAGULANT_KEYWORDS = [
    "warfarin", "coumadin", "heparin", "enoxaparin", "lovenox",
    "apixaban", "eliquis", "rivaroxaban", "xarelto", "dabigatran",
    "pradaxa", "edoxaban", "fondaparinux",
]
NEPHROTOXIC_KEYWORDS = [
    "toradol", "ketorolac", "ibuprofen", "naproxen", "meloxicam",
    "diclofenac", "indomethacin", "celecoxib",
    "gentamicin", "tobramycin", "vancomycin", "amphotericin",
    "cisplatin", "cyclosporine", "tacrolimus",
]
CARDIAC_DEVICE_KEYWORDS = [
    "pacemaker", "icd", "defibrillator", "aicd", "crt",
    "cardiac resynchronization", "loop recorder",
]
METALLIC_IMPLANT_KEYWORDS = [
    "stent", "joint replacement", "hip replacement", "knee replacement",
    "spinal hardware", "rod", "plate", "screw", "clip", "coil",
    "aaa repair", "aortic graft", "vascular graft",
]


def _classify_imaging_result(impression: str) -> str:
    """Classify imaging impression as negative/equivocal/positive for biliary disease."""
    imp_lower = (impression or "").lower()

    # Explicit negative markers
    NEGATIVE_KEYWORDS = [
        "negative", "no evidence", "normal", "unremarkable", "no acute",
        "within normal limits", "wnl", "no significant abnormality",
        "no abnormality", "no abnormalities",
    ]
    if any(kw in imp_lower for kw in NEGATIVE_KEYWORDS):
        return "negative"

    # Equivocal markers
    EQUIVOCAL_KEYWORDS = [
        "equivocal", "inconclusive", "indeterminate", "cannot exclude",
        "cannot rule out", "limited study", "suboptimal",
    ]
    if any(kw in imp_lower for kw in EQUIVOCAL_KEYWORDS):
        return "equivocal"

    # Biliary-specific: if no biliary pathology mentioned, treat as negative
    BILIARY_POSITIVE_MARKERS = [
        "cholelithiasis", "gallstone", "cholecystitis", "gallbladder wall",
        "pericholecystic", "bile duct", "choledocholithiasis", "dilat",
        "obstruct", "common duct", "cbd", "sludge", "stone",
        "pancreatitis", "mass", "abscess", "collection",
    ]
    if not any(kw in imp_lower for kw in BILIARY_POSITIVE_MARKERS):
        return "negative"

    return "positive"


def build_patient_context_from_emr(emr_json: dict) -> PatientContext:
    """Build PatientContext by deterministic mapping from generated EMR JSON."""
    ctx = PatientContext()

    # --- Demographics ---
    patient = emr_json.get("patient", {})
    ctx.patient_id = patient.get("mrn")
    ctx.age = patient.get("age")
    ctx.sex = patient.get("sex")

    # --- Current encounter (last in encounter_history) ---
    encounters = emr_json.get("encounter_history", [])
    if not encounters:
        return ctx
    current = encounters[-1]

    # --- Clinical setting ---
    parsed = emr_json.get("parsed_vignette", {})
    ctx.clinical_setting = parsed.get("clinical_setting", "").lower() or None

    # --- Vital signs (last set in current encounter) ---
    vital_sets = current.get("vital_signs", [])
    if vital_sets:
        latest = vital_sets[-1]
        ctx.temperature = latest.get("temperature_f")
        if ctx.temperature is not None:
            ctx.has_fever = ctx.temperature > 100.4
        ctx.heart_rate = latest.get("heart_rate")
        ctx.blood_pressure_systolic = latest.get("blood_pressure_systolic")
        ctx.blood_pressure_diastolic = latest.get("blood_pressure_diastolic")
        ctx.respiratory_rate = latest.get("respiratory_rate")
        ctx.spo2 = latest.get("spo2")
        # BMI from vitals if available
        bmi_val = latest.get("bmi")
        if bmi_val:
            ctx.bmi = bmi_val

    # --- Labs (flatten all panels in current encounter) ---
    lab_map = {}
    for panel in current.get("lab_results", []):
        for result in panel.get("results", []):
            test_name = result.get("test_name", "").lower().strip()
            val = result.get("value")
            if val is not None:
                lab_map[test_name] = val

    # Core labs
    ctx.wbc = lab_map.get("wbc")
    if ctx.wbc is not None:
        ctx.wbc_elevated = ctx.wbc > 11.0

    ctx.ast = lab_map.get("ast")
    ctx.alt = lab_map.get("alt")
    ctx.alp = lab_map.get("alp") or lab_map.get("alkaline phosphatase")
    ctx.bilirubin_total = (
        lab_map.get("bilirubin, total")
        or lab_map.get("total bilirubin")
        or lab_map.get("t.bili")
        or lab_map.get("bilirubin total")
    )
    ctx.bilirubin_direct = (
        lab_map.get("bilirubin, direct")
        or lab_map.get("direct bilirubin")
        or lab_map.get("d.bili")
        or lab_map.get("bilirubin direct")
    )
    ctx.lipase = lab_map.get("lipase")
    ctx.creatinine = lab_map.get("creatinine") or lab_map.get("cr")
    ctx.gfr = lab_map.get("gfr") or lab_map.get("egfr")

    # Coagulation labs
    ctx.platelets = lab_map.get("platelets") or lab_map.get("platelet count")
    ctx.inr = lab_map.get("inr")
    ctx.pt = lab_map.get("pt") or lab_map.get("prothrombin time")
    ctx.ptt = lab_map.get("ptt") or lab_map.get("aptt") or lab_map.get("partial thromboplastin time")

    # Additional labs
    ctx.hemoglobin = lab_map.get("hemoglobin") or lab_map.get("hgb")
    ctx.glucose = lab_map.get("glucose")
    ctx.tsh = lab_map.get("tsh")
    ctx.crp = lab_map.get("crp") or lab_map.get("c-reactive protein")
    ctx.d_dimer = lab_map.get("d-dimer")

    # --- Allergies (structured with severity/reaction) ---
    for allergy in emr_json.get("allergies", []):
        allergen = allergy.get("allergen", "")
        allergen_lower = allergen.lower()
        severity = allergy.get("severity")
        reaction = allergy.get("reaction")

        ctx.allergy_list.append(AllergyDetail(
            allergen=allergen,
            severity=severity,
            reaction=reaction,
            allergy_type=allergy.get("allergy_type"),
        ))

        if "iodinated" in allergen_lower or ("contrast" in allergen_lower and "gadolinium" not in allergen_lower):
            ctx.has_contrast_allergy = True
            ctx.iodinated_contrast_allergy = True
            ctx.contrast_allergy_severity = severity
            ctx.contrast_allergy_reaction = reaction
        if "gadolinium" in allergen_lower:
            ctx.has_contrast_allergy = True
            ctx.gadolinium_allergy = True
            if not ctx.contrast_allergy_severity:
                ctx.contrast_allergy_severity = severity
                ctx.contrast_allergy_reaction = reaction
        if "shellfish" in allergen_lower:
            ctx.shellfish_allergy = True

    # Also check parsed_vignette.safety_flags_mentioned
    for flag in parsed.get("safety_flags_mentioned", []):
        flag_lower = flag.lower()
        if "iodinated contrast allergy" in flag_lower:
            ctx.iodinated_contrast_allergy = True
            ctx.has_contrast_allergy = True
        if "gadolinium" in flag_lower and "allergy" in flag_lower:
            ctx.gadolinium_allergy = True
            ctx.has_contrast_allergy = True
        if "contrast allergy" in flag_lower and not ctx.has_contrast_allergy:
            ctx.has_contrast_allergy = True
        if "shellfish" in flag_lower:
            ctx.shellfish_allergy = True

    if ctx.has_contrast_allergy is None:
        ctx.has_contrast_allergy = False

    # --- Medical history (with ICD-10 codes) ---
    for entry in emr_json.get("problem_list", []):
        condition = entry.get("condition", "")
        icd10 = entry.get("icd10")
        status = entry.get("status")
        if status in ("Active", "Chronic", None):
            ctx.medical_history.append(condition)
            if icd10:
                ctx.medical_history_icd10.append((condition, icd10))

    # --- Surgical history ---
    ctx.surgical_history = [
        s.get("procedure", "") for s in emr_json.get("surgical_history", [])
    ]

    # --- Medications (with classification flags) ---
    meds_data = emr_json.get("current_medications", {})
    home_meds = meds_data.get("home_medications", []) if isinstance(meds_data, dict) else []
    inpatient_meds = meds_data.get("inpatient_medications", []) if isinstance(meds_data, dict) else []

    ctx.home_medications = [m.get("name", "") for m in home_meds]
    ctx.inpatient_medications = [m.get("name", "") for m in inpatient_meds]

    all_med_names = [n.lower() for n in ctx.home_medications + ctx.inpatient_medications]

    # Anticoagulant detection
    for med_name in all_med_names:
        if any(kw in med_name for kw in ANTICOAGULANT_KEYWORDS):
            ctx.on_anticoagulation = True
            # Find the original-case name
            for orig in ctx.home_medications + ctx.inpatient_medications:
                if any(kw in orig.lower() for kw in ANTICOAGULANT_KEYWORDS):
                    ctx.anticoagulant_name = orig
                    break
            break

    # Metformin detection
    ctx.on_metformin = any("metformin" in n for n in all_med_names)

    # Nephrotoxic drug detection
    for med_name in all_med_names:
        if any(kw in med_name for kw in NEPHROTOXIC_KEYWORDS):
            ctx.on_nephrotoxic_drugs = True
            for orig in ctx.home_medications + ctx.inpatient_medications:
                if any(kw in orig.lower() for kw in NEPHROTOXIC_KEYWORDS):
                    if orig not in ctx.nephrotoxic_drug_names:
                        ctx.nephrotoxic_drug_names.append(orig)

    # --- Pregnancy detection ---
    special_pops = parsed.get("special_populations", [])
    if any("pregnant" in sp.lower() for sp in special_pops):
        ctx.is_pregnant = True

    # Also check problem list for pregnancy ICD codes / conditions
    for cond, icd in ctx.medical_history_icd10:
        cond_lower = cond.lower()
        if "pregnancy" in cond_lower or (icd and icd.startswith("Z34")):
            ctx.is_pregnant = True
            if "first" in cond_lower:
                ctx.pregnancy_trimester = "first"
            elif "second" in cond_lower:
                ctx.pregnancy_trimester = "second"
            elif "third" in cond_lower:
                ctx.pregnancy_trimester = "third"

    # --- Implants & devices (from surgical history) ---
    for proc in ctx.surgical_history:
        proc_lower = proc.lower()
        if any(kw in proc_lower for kw in CARDIAC_DEVICE_KEYWORDS):
            ctx.has_cardiac_device = True
            ctx.device_details.append(proc)
        if any(kw in proc_lower for kw in METALLIC_IMPLANT_KEYWORDS):
            ctx.has_metallic_implants = True
            if proc not in ctx.device_details:
                ctx.device_details.append(proc)

    # --- Prior imaging (ALL modalities, not just US) ---
    for prior_enc in encounters[:-1]:
        enc_date = prior_enc.get("encounter", {}).get("admission_date")

        for report in prior_enc.get("imaging_reports", []):
            modality_str = report.get("modality", "").lower()
            impression = report.get("impression", "")
            findings = report.get("findings", "")
            result = _classify_imaging_result(impression)

            study = PriorImagingStudy(
                modality=report.get("modality", ""),
                date=enc_date,
                result=result,
                impression=impression,
                findings=findings,
            )
            ctx.prior_imaging_studies.append(study)

            # Classify by modality type
            # Only count abdominal/RUQ ultrasound as "prior US" for ACR variant gating.
            # Exclude non-abdominal US (DVT, carotid, thyroid, etc.) which are irrelevant.
            _NON_ABDOMINAL_US = [
                "lower extremity", "upper extremity", "leg", "arm",
                "dvt", "venous", "vascular", "carotid", "thyroid",
                "breast", "scrotal", "testicular", "pelvic", "transvaginal",
                "obstetric", "ob ", "fetal",
            ]
            if "us" in modality_str or "ultrasound" in modality_str:
                _is_abdominal = not any(kw in modality_str for kw in _NON_ABDOMINAL_US)
                if _is_abdominal:
                    ctx.prior_us_performed = True
                    ctx.prior_us_result = result
                    ctx.prior_us_findings = findings
            elif "ct" in modality_str:
                ctx.prior_ct_performed = True
                ctx.prior_ct_result = result
                ctx.prior_ct_findings = findings
            elif "mri" in modality_str or "mr " in modality_str:
                ctx.prior_mri_performed = True
                ctx.prior_mri_result = result
                ctx.prior_mri_findings = findings
            elif "hida" in modality_str or "hepatobiliary" in modality_str:
                ctx.prior_hida_performed = True
                ctx.prior_hida_result = result
                ctx.prior_hida_findings = findings

    # --- Clinical state ---
    clinical_setting = parsed.get("clinical_setting", "").lower()
    dept = current.get("encounter", {}).get("department", "").lower()

    if any(term in clinical_setting for term in ["icu", "sicu", "micu"]):
        ctx.is_icu_patient = True
        ctx.is_critically_ill = True
    elif "icu" in dept:
        ctx.is_icu_patient = True
        ctx.is_critically_ill = True

    # Scan history_conditions and clinical notes for post-op, TPN, sepsis
    for cond in parsed.get("history_conditions", []):
        cond_lower = cond.lower()
        if "s/p" in cond_lower or "post-op" in cond_lower or "pod" in cond_lower:
            ctx.is_post_operative = True
        if "tpn" in cond_lower:
            ctx.on_tpn = True
        if "septic" in cond_lower or "sepsis" in cond_lower:
            ctx.is_septic = True

    for note in current.get("clinical_notes", []):
        note_text = note.get("note_text", "").lower()
        if "post-operative" in note_text or "pod #" in note_text:
            ctx.is_post_operative = True
        if "tpn" in note_text or "parenteral nutrition" in note_text:
            ctx.on_tpn = True
        if "septic" in note_text or "sepsis" in note_text:
            ctx.is_septic = True
        if "mechanical ventilation" in note_text or "mechanically ventilated" in note_text:
            ctx.is_critically_ill = True

    # --- Clinical notes (HPI for variant classification context) ---
    for note in current.get("clinical_notes", []):
        if note.get("note_type") in ("HPI", "ED Note"):
            ctx.clinical_notes = note.get("note_text")
            break

    return ctx


print("\u2713 build_patient_context_from_emr() defined (expanded: coag, meds, all prior imaging, pregnancy, devices, ICD-10)")

In [None]:
# =============================================================================
# PARSE PROTOCOL STRING TO MODALITY ENUM
# =============================================================================

def parse_protocol_string_to_modality(protocol_str: str):
    """
    Map free-text protocol string from dataset to Modality enum.
    Returns None for protocols not in ACR criteria (ERCP, PET/CT, etc.)

    Ordering matters: check more specific strings before generic ones.
    """
    if not protocol_str:
        return None

    s = protocol_str.lower().strip()

    # --- Normalize common abbreviations ---
    # "w/o" = "without", "w/" = "with" (must replace w/o BEFORE w/)
    s = s.replace("w/o ", "without ").replace(" w/o", " without")
    s = s.replace("w/ ", "with ").replace(" w/", " with")

    # --- Handle compound/sequential protocols ---
    # Rate the primary (first) modality only
    SEQUENTIAL_SEPARATORS = [" then ", " followed by ", " and then ", " -> ", " prior to "]
    for sep in SEQUENTIAL_SEPARATORS:
        if sep in s:
            primary = s.split(sep)[0].strip()
            return parse_protocol_string_to_modality(primary)

    # --- Non-ACR procedures -> None ---
    if "ercp" in s:
        return None
    if "pet" in s:
        return None
    if "laparoscopy" in s or "laparoscopic" in s:
        return None
    if "eus" in s or "endoscopic ultrasound" in s:
        return None
    if "endoscopy" in s:
        return None
    if "cholecystostomy" in s:
        return Modality.IMAGE_GUIDED_CHOLECYSTOSTOMY

    # --- HIDA ---
    if "hida" in s or "hepatobiliary" in s or "cholescintigraphy" in s:
        return Modality.NUCLEAR_HIDA

    # --- MRCP (check before MRI) ---
    if "mrcp" in s:
        if "gadolinium" in s or "with contrast" in s or "with iv" in s:
            return Modality.MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP
        return Modality.MRI_ABDOMEN_WITHOUT_CONTRAST_WITH_MRCP

    # --- MRI ---
    if "mri" in s or "mr " in s or "multiphase liver" in s:
        if "without and with" in s or "with and without" in s:
            return Modality.MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP
        if "without contrast" in s or "without iv" in s or "non-contrast" in s:
            return Modality.MRI_ABDOMEN_WITHOUT_CONTRAST
        if "with contrast" in s or "with iv" in s or "with gadolinium" in s:
            return Modality.MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP
        # Default MRI for RUQ = MRCP
        return Modality.MRI_ABDOMEN_WITHOUT_CONTRAST_WITH_MRCP

    # --- CT ---
    if "ct" in s:
        if "without and with" in s or "with and without" in s:
            return Modality.CT_ABDOMEN_WITHOUT_AND_WITH_CONTRAST
        if "without contrast" in s or "without iv" in s or "non-contrast" in s:
            return Modality.CT_ABDOMEN_WITHOUT_CONTRAST
        if "cholangiography" in s:
            return Modality.CT_ABDOMEN_WITH_CONTRAST
        # Default CT = with contrast
        return Modality.CT_ABDOMEN_WITH_CONTRAST

    # --- Ultrasound ---
    if "ultrasound" in s or "us " in s or s.startswith("us") or "ruq u" in s or s == "us":
        return Modality.US_ABDOMEN

    # --- X-ray ---
    if "x-ray" in s or "xray" in s or "radiograph" in s or "kub" in s:
        return Modality.RADIOGRAPHY_ABDOMEN

    return None  # Unknown


# Quick validation: test a few protocol strings
_test_cases = [
    ("RUQ ultrasound", Modality.US_ABDOMEN),
    ("CT abdomen with IV contrast", Modality.CT_ABDOMEN_WITH_CONTRAST),
    ("MRCP", Modality.MRI_ABDOMEN_WITHOUT_CONTRAST_WITH_MRCP),
    ("ERCP", None),
    ("HIDA scan", Modality.NUCLEAR_HIDA),
    ("MRI abdomen with contrast", Modality.MRI_ABDOMEN_WITHOUT_AND_WITH_CONTRAST_WITH_MRCP),
    ("CT abdomen without contrast", Modality.CT_ABDOMEN_WITHOUT_CONTRAST),
    ("CT abdomen w/o contrast", Modality.CT_ABDOMEN_WITHOUT_CONTRAST),
    ("CT abdomen w/ contrast", Modality.CT_ABDOMEN_WITH_CONTRAST),
    ("PET/CT", None),
]
_passed = 0
for proto_str, expected in _test_cases:
    result = parse_protocol_string_to_modality(proto_str)
    status = "\u2713" if result == expected else f"\u2717 got {result}"
    if result == expected:
        _passed += 1
    print(f"  {status}  '{proto_str}' \u2192 {result}")

print(f"\n\u2713 parse_protocol_string_to_modality() defined ({_passed}/{len(_test_cases)} tests passed)")

## 6. Load Dataset and Generated EMRs

In [None]:
import pandas as pd
import os
from pathlib import Path

# Mount Google Drive (Colab) and set project root
import subprocess
try:
    from google.colab import drive
    drive.mount("/content/drive")
    # Auto-find project root on Drive
    result = subprocess.run(['find', '/content/drive', '-maxdepth', '5', '-name', 'generated_emrs_27b', '-type', 'd'],
                           capture_output=True, text=True, timeout=30)
    _found = [p.replace('/data/generated_emrs_27b', '') for p in result.stdout.strip().split('\n') if p]
    BASE_DIR = _found[0] if _found else ".."
    print(f"Project root: {BASE_DIR}")
except ImportError:
    BASE_DIR = ".."  # Local development

# Load extended dataset (with verification columns)
DATASET_PATH = os.path.join(BASE_DIR, "ruq_pain_dataset_2 (1).xlsx")
if not os.path.exists(DATASET_PATH):
    # Fallback to original dataset in data/
    DATASET_PATH = os.path.join(BASE_DIR, "data", "ruq_pain_dataset_2.xlsx")
    if not os.path.exists(DATASET_PATH):
        raise FileNotFoundError(
            f"Dataset not found. Searched:\n"
            f"  1. {os.path.join(BASE_DIR, 'ruq_pain_dataset_2 (1).xlsx')}\n"
            f"  2. {DATASET_PATH}\n"
            f"Check your Google Drive sync path."
        )
    print(f"NOTE: Using original dataset (fewer columns) at {DATASET_PATH}")

df = pd.read_excel(DATASET_PATH)
df.columns = [
    "case_num", "setting", "pattern_type", "actual_diagnosis",
    "patient_hx", "provider_indication", "protocol_ordered",
    "ai_recommended_protocol", "issue_notes",
][:len(df.columns)]  # Handle if dataset has fewer columns

# Load generated EMR JSONs
EMR_DIR = os.path.join(BASE_DIR, "data", "generated_emrs_27b")
emr_data = {}
if os.path.exists(EMR_DIR):
    for fname in sorted(os.listdir(EMR_DIR)):
        if fname.startswith("RUQ-") and fname.endswith(".json"):
            case_num = int(fname.replace("RUQ-", "").replace(".json", ""))
            with open(os.path.join(EMR_DIR, fname)) as f:
                emr_data[case_num] = json.load(f)

print(f"Dataset: {len(df)} rows, {len(df.columns)} columns")
print(f"Generated EMRs: {len(emr_data)}/{len(df)} loaded")

if emr_data:
    missing = set(range(1, len(df) + 1)) - set(emr_data.keys())
    if missing:
        print(f"Missing EMRs for cases: {sorted(missing)[:10]}{'...' if len(missing) > 10 else ''}")
    else:
        print("All cases have generated EMRs \u2713")
else:
    print(f"WARNING: No EMR files found at {EMR_DIR}")
    print("Run notebook 04 first to generate EMRs.")

# Show pattern type distribution
if "pattern_type" in df.columns:
    print(f"\nPattern Type Distribution:")
    for pt, count in df["pattern_type"].value_counts().items():
        print(f"  {count:2d} \u00d7 {pt}")


In [None]:
# =============================================================================
# SINGLE CASE DEMO
# =============================================================================

# Pick the first available case
demo_case_num = sorted(emr_data.keys())[1] if emr_data else None

if demo_case_num:
    demo_emr = emr_data[demo_case_num]
    demo_row = df[df["case_num"] == demo_case_num].iloc[0]

    print(f"DEMO: Case {demo_case_num}")
    print(f"  Setting: {demo_row['setting']}")
    print(f"  Actual Diagnosis: {demo_row['actual_diagnosis']}")
    print(f"  Protocol Ordered: {demo_row['protocol_ordered']}")
    print(f"  Provider Indication: {demo_row['provider_indication']}")
    print()

    # Step 1: Build PatientContext from EMR JSON (deterministic)
    demo_ctx = build_patient_context_from_emr(demo_emr)
    print("=" * 50)
    print("EXTRACTED PATIENT CONTEXT")
    print("=" * 50)

    print("\n--- Demographics ---")
    print(f"  Age: {demo_ctx.age}, Sex: {demo_ctx.sex}, BMI: {demo_ctx.bmi}")

    print("\n--- Vitals ---")
    print(f"  Temp: {demo_ctx.temperature}\u00b0F (fever: {demo_ctx.has_fever})")
    print(f"  HR: {demo_ctx.heart_rate}, BP: {demo_ctx.blood_pressure_systolic}/{demo_ctx.blood_pressure_diastolic}")
    print(f"  RR: {demo_ctx.respiratory_rate}, SpO2: {demo_ctx.spo2}")

    print("\n--- Core Labs ---")
    print(f"  WBC: {demo_ctx.wbc} (elevated: {demo_ctx.wbc_elevated})")
    print(f"  AST: {demo_ctx.ast}, ALT: {demo_ctx.alt}, ALP: {demo_ctx.alp}")
    print(f"  T.bili: {demo_ctx.bilirubin_total}, D.bili: {demo_ctx.bilirubin_direct}")
    print(f"  Lipase: {demo_ctx.lipase}, Cr: {demo_ctx.creatinine}, GFR: {demo_ctx.gfr}")

    print("\n--- Coagulation Labs ---")
    print(f"  Platelets: {demo_ctx.platelets}, INR: {demo_ctx.inr}, PT: {demo_ctx.pt}, PTT: {demo_ctx.ptt}")

    print("\n--- Additional Labs ---")
    print(f"  Hgb: {demo_ctx.hemoglobin}, Glucose: {demo_ctx.glucose}, CRP: {demo_ctx.crp}")

    print("\n--- Allergies ---")
    print(f"  Contrast allergy: {demo_ctx.has_contrast_allergy}")
    print(f"  Iodinated: {demo_ctx.iodinated_contrast_allergy}, Gadolinium: {demo_ctx.gadolinium_allergy}")
    print(f"  Shellfish: {demo_ctx.shellfish_allergy}")
    if demo_ctx.allergy_list:
        for a in demo_ctx.allergy_list:
            print(f"    - {a.allergen} | severity={a.severity} | reaction={a.reaction}")

    print("\n--- Medications ---")
    print(f"  Home: {demo_ctx.home_medications}")
    print(f"  Inpatient: {demo_ctx.inpatient_medications}")
    print(f"  Anticoagulation: {demo_ctx.on_anticoagulation} ({demo_ctx.anticoagulant_name})")
    print(f"  Metformin: {demo_ctx.on_metformin}")
    print(f"  Nephrotoxic drugs: {demo_ctx.on_nephrotoxic_drugs} ({demo_ctx.nephrotoxic_drug_names})")

    print("\n--- Prior Imaging ---")
    print(f"  US: {demo_ctx.prior_us_performed} ({demo_ctx.prior_us_result})")
    print(f"  CT: {demo_ctx.prior_ct_performed} ({demo_ctx.prior_ct_result})")
    print(f"  MRI: {demo_ctx.prior_mri_performed} ({demo_ctx.prior_mri_result})")
    print(f"  HIDA: {demo_ctx.prior_hida_performed} ({demo_ctx.prior_hida_result})")
    if demo_ctx.prior_imaging_studies:
        for s in demo_ctx.prior_imaging_studies:
            print(f"    - {s.modality} ({s.date}) \u2192 {s.result}")

    print("\n--- Pregnancy ---")
    print(f"  Pregnant: {demo_ctx.is_pregnant}, Trimester: {demo_ctx.pregnancy_trimester}")

    print("\n--- Implants/Devices ---")
    print(f"  Cardiac device: {demo_ctx.has_cardiac_device}")
    print(f"  Metallic implants: {demo_ctx.has_metallic_implants}")
    if demo_ctx.device_details:
        print(f"  Details: {demo_ctx.device_details}")

    print("\n--- Clinical State ---")
    print(f"  Setting: {demo_ctx.clinical_setting}")
    print(f"  ICU: {demo_ctx.is_icu_patient}, Critically ill: {demo_ctx.is_critically_ill}")
    print(f"  Post-op: {demo_ctx.is_post_operative}, Septic: {demo_ctx.is_septic}, TPN: {demo_ctx.on_tpn}")

    print("\n--- Medical History (with ICD-10) ---")
    for cond, icd in demo_ctx.medical_history_icd10[:10]:
        print(f"  {cond} [{icd}]")

    print("\n--- Surgical History ---")
    for proc in demo_ctx.surgical_history:
        print(f"  {proc}")
    print()

    # Step 2: Parse protocol string
    demo_modality = parse_protocol_string_to_modality(demo_row["protocol_ordered"])
    print(f"Protocol mapping: '{demo_row['protocol_ordered']}' \u2192 {demo_modality}")

    if demo_modality is not None:
        # Step 3: Build OrderedProtocol with metadata from EMR
        current_enc = demo_emr.get("encounter_history", [{}])[-1]
        imaging_orders = current_enc.get("imaging_orders", [])
        order_provider = imaging_orders[0].get("ordering_provider") if imaging_orders else None
        order_urgency = imaging_orders[0].get("urgency") if imaging_orders else None

        demo_ordered = OrderedProtocol(
            modality=demo_modality,
            ordering_provider=order_provider,
            urgency=order_urgency,
        )
        demo_input = TriageInput(
            ordered_protocol=demo_ordered,
            clinical_indication=demo_row["provider_indication"],
            patient_context=demo_ctx,
        )

        print(f"\nOrdering provider: {order_provider}")
        print(f"Urgency: {order_urgency}")
        print("\nRunning triage (MedGemma variant classification)...")
        demo_result = reviewer.review(demo_input)

        print("\n" + demo_result.to_radiologist_summary())

        # Show verification columns
        if "ai_recommended_protocol" in demo_row.index:
            print(f"\n--- Dataset verification ---")
            print(f"  Pattern Type: {demo_row.get('pattern_type', 'N/A')}")
            print(f"  AI Recommended: {demo_row.get('ai_recommended_protocol', 'N/A')}")
            print(f"  Issue/Notes: {demo_row.get('issue_notes', 'N/A')}")
    else:
        print(f"\n\u26a0 Protocol '{demo_row['protocol_ordered']}' not in ACR criteria")
else:
    print("No EMR data loaded.")

## 7. Batch Triage — All Cases

Run triage on all available generated EMRs. Cases without EMR files are skipped. Cases with protocols not in ACR criteria (ERCP, PET/CT, etc.) are recorded as YELLOW with an explanation.

In [None]:
# =============================================================================
# BATCH TRIAGE (with expanded safety checks)
# =============================================================================
import time

results = []
errors = []
start_time = time.time()

for idx, row in df.iterrows():
    case_num = int(row["case_num"])

    if case_num not in emr_data:
        continue

    try:
        emr_json = emr_data[case_num]

        # Step 1: Build PatientContext from EMR JSON (deterministic, no LLM)
        patient_ctx = build_patient_context_from_emr(emr_json)

        # Step 2: Parse protocol string to Modality
        modality = parse_protocol_string_to_modality(row["protocol_ordered"])

        if modality is None:
            # Protocol not in ACR criteria (ERCP, PET/CT, etc.)
            results.append({
                "case_num": case_num,
                "setting": row.get("setting"),
                "pattern_type": row.get("pattern_type"),
                "actual_diagnosis": row.get("actual_diagnosis"),
                "provider_indication": row.get("provider_indication"),
                "protocol_ordered": row["protocol_ordered"],
                "ai_recommended": row.get("ai_recommended_protocol"),
                "issue_notes": row.get("issue_notes"),
                "triage_priority": "RED",
                "triage_confidence": 0.0,
                "classified_variant": "N/A",
                "variant_confidence": 0.0,
                "variant_reasoning": f"Protocol '{row['protocol_ordered']}' not in ACR RUQ criteria",
                "acr_rating": None,
                "appropriateness": "Protocol not in ACR Appropriateness Criteria for RUQ Pain. Consider ACR-rated alternative.",
                "has_safety_concern": False,
                "safety_warnings": [],
                "contrast_warning": None,
                "renal_warning": None,
                "pregnancy_warning": None,
                "coagulation_warning": None,
                "mri_safety_warning": None,
                "metformin_warning": None,
                "recommendations": [],
                "modality_unmapped": True,
                "insufficient_information": False,
                # Expanded fields for analysis
                "on_anticoagulation": patient_ctx.on_anticoagulation,
                "on_metformin": patient_ctx.on_metformin,
                "is_pregnant": patient_ctx.is_pregnant,
                "prior_us": patient_ctx.prior_us_performed,
                "prior_ct": patient_ctx.prior_ct_performed,
                "prior_mri": patient_ctx.prior_mri_performed,
                "gfr": patient_ctx.gfr,
                "creatinine": patient_ctx.creatinine,
                "inr": patient_ctx.inr,
                "platelets": patient_ctx.platelets,
            })
            print(f"  Case {case_num}: RED (protocol not in ACR: {row['protocol_ordered']})")
            continue

        # Step 3: Build OrderedProtocol with metadata from EMR
        current_enc = emr_json.get("encounter_history", [{}])[-1]
        imaging_orders = current_enc.get("imaging_orders", [])
        order_provider = imaging_orders[0].get("ordering_provider") if imaging_orders else None
        order_urgency = imaging_orders[0].get("urgency") if imaging_orders else None

        ordered = OrderedProtocol(
            modality=modality,
            ordering_provider=order_provider,
            urgency=order_urgency,
        )
        triage_input = TriageInput(
            ordered_protocol=ordered,
            clinical_indication=row["provider_indication"],
            patient_context=patient_ctx,
        )

        # Step 4: Run triage (single MedGemma call for variant classification)
        # Safety checks (contrast, renal, pregnancy, coag, MRI, metformin) are
        # now handled inside reviewer.review() via expanded assess_safety()
        result = reviewer.review(triage_input)
        triage_priority = result.priority.name

        results.append({
            "case_num": case_num,
            "setting": row.get("setting"),
            "pattern_type": row.get("pattern_type"),
            "actual_diagnosis": row.get("actual_diagnosis"),
            "provider_indication": row.get("provider_indication"),
            "protocol_ordered": row["protocol_ordered"],
            "ai_recommended": row.get("ai_recommended_protocol"),
            "issue_notes": row.get("issue_notes"),
            "triage_priority": triage_priority,
            "triage_confidence": result.confidence,
            "classified_variant": result.classified_variant.name,
            "variant_confidence": result.variant_confidence,
            "variant_reasoning": result.variant_reasoning,
            "acr_rating": result.ordered_protocol_rating,
            "appropriateness": result.appropriateness_assessment,
            "has_safety_concern": result.has_safety_concern,
            "safety_warnings": result.safety_warnings,
            "contrast_warning": result.contrast_warning,
            "renal_warning": result.renal_warning,
            "pregnancy_warning": result.pregnancy_warning,
            "coagulation_warning": result.coagulation_warning,
            "mri_safety_warning": result.mri_safety_warning,
            "metformin_warning": result.metformin_warning,
            "recommendations": [r.modality.value for r in result.recommended_protocols],
            "modality_unmapped": False,
            "insufficient_information": result.insufficient_information,
            # Indication quality fields
            "indication_quality_tier": result.indication_quality.quality_tier if result.indication_quality else "N/A",
            "indication_quality_score": result.indication_quality.quality_score if result.indication_quality else None,
            "indication_coherence_flags": "; ".join(result.indication_quality.coherence_flags) if result.indication_quality and result.indication_quality.coherence_flags else "",
            "indication_quality_summary": result.indication_quality.summary if result.indication_quality else "",
            # Expanded fields for analysis
            "on_anticoagulation": patient_ctx.on_anticoagulation,
            "on_metformin": patient_ctx.on_metformin,
            "is_pregnant": patient_ctx.is_pregnant,
            "prior_us": patient_ctx.prior_us_performed,
            "prior_ct": patient_ctx.prior_ct_performed,
            "prior_mri": patient_ctx.prior_mri_performed,
            "gfr": patient_ctx.gfr,
            "creatinine": patient_ctx.creatinine,
            "inr": patient_ctx.inr,
            "platelets": patient_ctx.platelets,
        })

        emoji = {"GREEN": "\U0001f7e2", "YELLOW": "\U0001f7e1", "RED": "\U0001f534", "PURPLE": "\U0001f7e3"}.get(triage_priority, "?")
        rating_str = f"{result.ordered_protocol_rating}/9" if result.ordered_protocol_rating else "N/A"
        safety_str = f" \u26a0\ufe0f ({len(result.safety_warnings)} warnings)" if result.has_safety_concern else ""
        print(f"  Case {case_num}: {emoji} {triage_priority} (ACR {rating_str}, {result.classified_variant.name}){safety_str}")

    except Exception as e:
        import traceback
        errors.append({"case_num": case_num, "error": str(e), "traceback": traceback.format_exc()})
        print(f"  Case {case_num}: ERROR \u2014 {e}")

    # Progress update every 10 cases
    if len(results) % 10 == 0 and len(results) > 0:
        elapsed = time.time() - start_time
        rate = len(results) / elapsed
        remaining = (len(emr_data) - len(results) - len(errors)) / max(rate, 0.01)
        print(f"  --- {len(results)} cases done ({elapsed:.0f}s, {rate:.1f}/sec, ~{remaining:.0f}s remaining) ---")

results_df = pd.DataFrame(results)
elapsed = time.time() - start_time
print(f"\n{'='*60}")
print(f"Completed: {len(results)} cases in {elapsed:.0f}s ({len(errors)} errors)")
if errors:
    print(f"\nErrors:")
    for e in errors:
        print(f"  Case {e['case_num']}: {e['error']}")

## 8. Verification Against Dataset

Compare triage results to the 6 Pattern Types and verification columns from the dataset.

In [None]:
# =============================================================================
# VERIFICATION BY PATTERN TYPE
# =============================================================================

PATTERN_EXPECTATIONS = {
    "Correct diagnosis // Incorrect protocol": {
        "expected_priorities": ["RED", "YELLOW"],
        "metric_name": "Incorrect protocol flagged",
        "description": "Protocol is wrong \u2014 system should flag RED or YELLOW",
    },
    "Incorrect diagnosis (from hx) // Correct protocol": {
        "expected_priorities": ["GREEN", "YELLOW"],
        "metric_name": "Correct protocol recognized despite wrong dx",
        "description": "Protocol is correct for the clinical picture \u2014 may be GREEN or YELLOW",
    },
    "Incorrect diagnosis (from hx) // Incorrect protocol": {
        "expected_priorities": ["RED", "YELLOW"],
        "metric_name": "Incorrect protocol flagged (wrong dx context)",
        "description": "Protocol is wrong \u2014 system should flag RED or YELLOW",
    },
    "Correct diagnosis // Correct protocol // Safety issue": {
        "expected_priorities": ["YELLOW", "RED"],
        "check_safety": True,
        "metric_name": "Safety issue detected",
        "description": "Protocol correct but safety concern exists \u2014 should flag with safety warning",
    },
    "No diagnosis or ddx in indication // Unclear appropriate protocol": {
        "expected_priorities": ["PURPLE", "YELLOW"],
        "metric_name": "Insufficient info / uncertain",
        "description": "Vague indication \u2014 system should flag PURPLE or YELLOW",
    },
    "Correct diagnosis // Correct protocol // Safe": {
        "expected_priorities": ["GREEN"],
        "metric_name": "Appropriate protocol confirmed",
        "description": "Everything is correct \u2014 should be GREEN",
    },
}

if len(results_df) > 0 and "pattern_type" in results_df.columns:
    print("VERIFICATION BY PATTERN TYPE")
    print("=" * 80)

    total_correct = 0
    total_cases = 0

    for pattern_type, expectations in PATTERN_EXPECTATIONS.items():
        pattern_results = results_df[results_df["pattern_type"] == pattern_type]
        if len(pattern_results) == 0:
            continue

        expected = expectations["expected_priorities"]
        correct = pattern_results["triage_priority"].isin(expected).sum()
        total = len(pattern_results)
        total_correct += correct
        total_cases += total

        # Count unmapped protocols
        unmapped = pattern_results["modality_unmapped"].sum()

        print(f"\n{pattern_type}")
        print(f"  {expectations['description']}")
        print(f"  {expectations['metric_name']}: {correct}/{total} ({correct/total:.0%})")
        print(f"  Expected: {expected}  |  Got: {dict(pattern_results['triage_priority'].value_counts())}")
        if unmapped > 0:
            print(f"  ({unmapped} cases had non-ACR protocols \u2192 auto YELLOW)")

        if expectations.get("check_safety"):
            safety_detected = pattern_results["has_safety_concern"].sum()
            print(f"  Safety concern detected: {safety_detected}/{total} ({safety_detected/total:.0%})")

    if total_cases > 0:
        print(f"\n{'='*80}")
        print(f"OVERALL: {total_correct}/{total_cases} ({total_correct/total_cases:.0%}) cases matched expected triage")
else:
    print("No results to verify. Run batch triage first.")

In [None]:
# =============================================================================
# SAFETY ISSUE DETECTION REPORT (Pattern Type 4 deep-dive)
# =============================================================================

if len(results_df) > 0 and "pattern_type" in results_df.columns:
    safety_pattern = "Correct diagnosis // Correct protocol // Safety issue"
    safety_cases = results_df[results_df["pattern_type"] == safety_pattern]

    if len(safety_cases) > 0:
        print("SAFETY ISSUE DETECTION REPORT")
        print("=" * 80)
        print(f"Pattern: {safety_pattern}")
        print(f"Cases: {len(safety_cases)}")
        print()

        detected = 0
        detection_details = {
            "contrast_allergy": 0,
            "renal": 0,
            "pregnancy": 0,
            "coagulation": 0,
            "mri_safety": 0,
            "metformin": 0,
        }

        for _, row in safety_cases.iterrows():
            case_num = int(row["case_num"])

            is_detected = row["has_safety_concern"]
            if is_detected:
                detected += 1

            marker = "\u2713" if is_detected else "\u2717"
            print(f"{marker} Case {case_num}: {row['actual_diagnosis']}")
            print(f"    Protocol: {row['protocol_ordered']} \u2192 {row['triage_priority']}")
            print(f"    Expected issue: {str(row.get('issue_notes', 'N/A'))[:120]}")

            # Show all safety warnings
            warnings = row.get("safety_warnings", [])
            if warnings:
                for w in warnings:
                    print(f"    \u26a0 {w}")
            else:
                # Check individual fields
                for field, label in [
                    ("contrast_warning", "Contrast"),
                    ("renal_warning", "Renal"),
                    ("pregnancy_warning", "Pregnancy"),
                    ("coagulation_warning", "Coagulation"),
                    ("mri_safety_warning", "MRI Safety"),
                    ("metformin_warning", "Metformin"),
                ]:
                    val = row.get(field)
                    if val:
                        print(f"    \u26a0 {label}: {val}")

            # Track detection types
            if row.get("contrast_warning"):
                detection_details["contrast_allergy"] += 1
            if row.get("renal_warning"):
                detection_details["renal"] += 1
            if row.get("pregnancy_warning"):
                detection_details["pregnancy"] += 1
            if row.get("coagulation_warning"):
                detection_details["coagulation"] += 1
            if row.get("mri_safety_warning"):
                detection_details["mri_safety"] += 1
            if row.get("metformin_warning"):
                detection_details["metformin"] += 1

            # Show relevant patient data
            extras = []
            if row.get("gfr") is not None:
                extras.append(f"GFR={row['gfr']:.0f}")
            if row.get("creatinine") is not None:
                extras.append(f"Cr={row['creatinine']:.1f}")
            if row.get("inr") is not None:
                extras.append(f"INR={row['inr']:.1f}")
            if row.get("platelets") is not None:
                extras.append(f"Plt={row['platelets']:.0f}")
            if row.get("is_pregnant"):
                extras.append("PREGNANT")
            if row.get("on_anticoagulation"):
                extras.append("on anticoag")
            if row.get("on_metformin"):
                extras.append("on metformin")
            if extras:
                print(f"    Labs/Flags: {', '.join(extras)}")
            print()

        print(f"\nSAFETY DETECTION SUMMARY")
        print(f"  Detected: {detected}/{len(safety_cases)} ({detected/len(safety_cases):.0%})")
        print(f"\n  By type:")
        for stype, count in detection_details.items():
            if count > 0:
                print(f"    {stype}: {count}")
    else:
        print("No safety pattern cases found.")
else:
    print("No results to analyze.")

In [None]:
# =============================================================================
# SUMMARY TABLES
# =============================================================================

if len(results_df) > 0:
    print("TRIAGE RESULTS BY PATTERN TYPE")
    print("=" * 100)

    pivot = pd.crosstab(
        results_df["pattern_type"],
        results_df["triage_priority"],
        margins=True,
    )
    print(pivot.to_string())

    print(f"\n\nOVERALL TRIAGE DISTRIBUTION")
    print("=" * 40)
    for priority in ["GREEN", "YELLOW", "RED", "PURPLE"]:
        count = (results_df["triage_priority"] == priority).sum()
        emoji = {"GREEN": "\U0001f7e2", "YELLOW": "\U0001f7e1", "RED": "\U0001f534", "PURPLE": "\U0001f7e3"}.get(priority, "")
        print(f"  {emoji} {priority:8s}: {count:3d} ({count/len(results_df):.0%})")

    print(f"\n  Total:     {len(results_df)}")
    unmapped = results_df["modality_unmapped"].sum()
    if unmapped:
        print(f"  Non-ACR protocols (auto YELLOW): {int(unmapped)}")

    # Safety overview
    safety_count = results_df["has_safety_concern"].sum()
    print(f"\n  Safety concerns detected: {int(safety_count)}")
    if safety_count > 0:
        for field, label in [
            ("contrast_warning", "Contrast allergy"),
            ("renal_warning", "Renal impairment"),
            ("pregnancy_warning", "Pregnancy"),
            ("coagulation_warning", "Coagulation"),
            ("mri_safety_warning", "MRI safety"),
            ("metformin_warning", "Metformin interaction"),
        ]:
            if field in results_df.columns:
                n = results_df[field].notna().sum()
                if n > 0:
                    print(f"    {label}: {int(n)}")

    print(f"\n\nVARIANT CLASSIFICATION DISTRIBUTION")
    print("=" * 40)
    variant_counts = results_df[results_df["classified_variant"] != "N/A"]["classified_variant"].value_counts()
    for variant, count in variant_counts.items():
        print(f"  {variant}: {count}")

    # Medication/clinical context overview
    print(f"\n\nCLINICAL CONTEXT OVERVIEW")
    print("=" * 40)
    for col, label in [
        ("on_anticoagulation", "On anticoagulation"),
        ("on_metformin", "On metformin"),
        ("is_pregnant", "Pregnant"),
        ("prior_us", "Prior US"),
        ("prior_ct", "Prior CT"),
        ("prior_mri", "Prior MRI"),
    ]:
        if col in results_df.columns:
            n_true = results_df[col].fillna(False).astype(bool).sum()
            if n_true > 0:
                print(f"  {label}: {int(n_true)} cases")

    # GFR distribution for renal concerns
    gfr_vals = results_df["gfr"].dropna() if "gfr" in results_df.columns else pd.Series()
    if len(gfr_vals) > 0:
        low_gfr = (gfr_vals < 60).sum()
        very_low_gfr = (gfr_vals < 30).sum()
        print(f"  GFR < 60: {int(low_gfr)} cases")
        print(f"  GFR < 30: {int(very_low_gfr)} cases")
else:
    print("No results to summarize.")

In [None]:
# =============================================================================
# DETAILED CASE REVIEW — MISMATCHES
# =============================================================================

if len(results_df) > 0 and "pattern_type" in results_df.columns:
    print("MISMATCHED CASES (triage doesn't match expected)")
    print("=" * 80)

    mismatch_count = 0
    for _, row in results_df.iterrows():
        pattern = row.get("pattern_type")
        if pattern not in PATTERN_EXPECTATIONS:
            continue

        expected = PATTERN_EXPECTATIONS[pattern]["expected_priorities"]
        if row["triage_priority"] not in expected:
            mismatch_count += 1
            print(f"\nCase {int(row['case_num'])}: {row['actual_diagnosis']}")
            print(f"  Pattern: {pattern}")
            print(f"  Protocol: {row['protocol_ordered']} \u2192 {row['triage_priority']} (expected {expected})")
            print(f"  Variant: {row['classified_variant']} (conf {row['variant_confidence']:.0%})")
            print(f"  ACR Rating: {row['acr_rating']}")
            print(f"  Reasoning: {str(row['variant_reasoning'])[:150]}")
            print(f"  Dataset AI Rec: {row.get('ai_recommended', 'N/A')}")
            print(f"  Dataset Notes: {str(row.get('issue_notes', 'N/A'))[:150]}")

            # Show safety warnings if any
            warnings = row.get("safety_warnings", [])
            if warnings:
                print(f"  Safety warnings: {warnings}")

            # Show key clinical context
            extras = []
            if row.get("prior_us"):
                extras.append("prior US")
            if row.get("prior_ct"):
                extras.append("prior CT")
            if row.get("on_anticoagulation"):
                extras.append("on anticoag")
            if row.get("on_metformin"):
                extras.append("on metformin")
            if row.get("is_pregnant"):
                extras.append("PREGNANT")
            if row.get("gfr") is not None and row["gfr"] < 60:
                extras.append(f"GFR={row['gfr']:.0f}")
            if extras:
                print(f"  Clinical context: {', '.join(extras)}")

    if mismatch_count == 0:
        print("  No mismatches found! All cases matched expected triage.")
    else:
        match_rate = 1 - mismatch_count / len(results_df)
        print(f"\n{mismatch_count} mismatched cases out of {len(results_df)} total ({match_rate:.0%} match rate)")
else:
    print("No results to analyze.")

## 9. Export Results

In [None]:
# =============================================================================
# EXPORT RESULTS TO CSV
# =============================================================================

if len(results_df) > 0:
    output_dir = os.path.join(BASE_DIR, "data", "generated_emrs_27b")
    os.makedirs(output_dir, exist_ok=True)
    csv_path = os.path.join(output_dir, "triage_results.csv")

    # Select columns for export (expanded)
    export_cols = [
        "case_num", "setting", "pattern_type", "actual_diagnosis",
        "provider_indication", "protocol_ordered",
        "triage_priority", "triage_confidence",
        "classified_variant", "variant_confidence", "variant_reasoning",
        "acr_rating", "appropriateness",
        "has_safety_concern",
        "contrast_warning", "renal_warning", "pregnancy_warning",
        "coagulation_warning", "mri_safety_warning", "metformin_warning",
        "recommendations", "modality_unmapped", "insufficient_information",
        "ai_recommended", "issue_notes",
        # Expanded clinical context
        "on_anticoagulation", "on_metformin", "is_pregnant",
        "prior_us", "prior_ct", "prior_mri",
        "gfr", "creatinine", "inr", "platelets",
    ]
    export_df = results_df[[c for c in export_cols if c in results_df.columns]]
    export_df.to_csv(csv_path, index=False)
    print(f"\u2713 Results exported to {csv_path}")
    print(f"  {len(export_df)} rows \u00d7 {len(export_df.columns)} columns")

    # Also save detailed safety report
    safety_csv = os.path.join(output_dir, "triage_safety_report.csv")
    safety_cols = [
        "case_num", "pattern_type", "protocol_ordered", "triage_priority",
        "has_safety_concern", "safety_warnings",
        "contrast_warning", "renal_warning", "pregnancy_warning",
        "coagulation_warning", "mri_safety_warning", "metformin_warning",
        "on_anticoagulation", "on_metformin", "is_pregnant",
        "gfr", "creatinine", "inr", "platelets",
    ]
    safety_df = results_df[results_df["has_safety_concern"] == True]
    if len(safety_df) > 0:
        safety_export = safety_df[[c for c in safety_cols if c in safety_df.columns]]
        safety_export.to_csv(safety_csv, index=False)
        print(f"\u2713 Safety report exported to {safety_csv}")
        print(f"  {len(safety_export)} safety cases")
    else:
        print("  No safety concerns to export separately")
else:
    print("No results to export.")

## Notes

### Pipeline Architecture
This notebook uses a **1-LLM-call pipeline** per case:
1. **Deterministic mapping** (`build_patient_context_from_emr`): EMR JSON → PatientContext (no LLM)
2. **MedGemma variant classification** (`RUQProtocolReviewer._classify_variant`): PatientContext + Provider Indication → ACR Variant
3. **Deterministic rule enforcement** (`_enforce_variant_rules`): Overrides LLM when it violates hard ACR decision tree gates
4. **Deterministic ACR lookup**: Variant + Modality → Rating (1-9) → Triage Flag
5. **Deterministic safety checks** (`assess_safety`): Contrast allergy, renal, pregnancy, coagulation, MRI safety, metformin

The EMRExtractor step (which was in the original notebook) is intentionally skipped because the generated EMR JSONs already contain fully structured data.

### Priority Categories
- 🟢 **GREEN**: Likely Appropriate — Low priority for radiologist review
- 🟡 **YELLOW**: Possibly Inappropriate — Requires manual review
- 🔴 **RED**: Definitely Inappropriate — Must change or contact ordering provider
- 🟣 **PURPLE**: Insufficient Information — Cannot evaluate, request more clinical data

### Safety Checks (All Deterministic — No LLM)
- **Contrast allergy**: Iodinated/gadolinium allergy with severity and reaction type → RED if contrast ordered
- **Renal function**: GFR < 30 = HIGH RISK, GFR < 60 = MODERATE RISK; includes NSF risk warning for gadolinium
- **Pregnancy**: Detected from special_populations + problem_list ICD codes → YELLOW if radiation modality ordered; caution for gadolinium
- **Coagulation**: INR > 1.5, platelets < 50K, anticoagulation status → warning for invasive procedures (cholecystostomy)
- **MRI safety**: Cardiac devices (pacemaker, ICD) and metallic implants from surgical history → verify compatibility
- **Metformin + contrast**: Hold metformin 48h post-contrast, especially with reduced GFR

### Variant Classification Rules (Hard Overrides)
- **Rule 1**: No prior US → Cannot be Variant 3, 4, or 5 (must be 1 or 2)
- **Rule 2**: Variant 5 requires ICU/critically ill patient
- **Rule 3**: Prior US + fever/elevated WBC → Variant 4 (not 3)
- **Rule 4**: Prior US + afebrile + normal WBC → Variant 3 (not 4)

### PatientContext Fields Extracted
**Demographics**: age, sex, BMI
**Vitals**: temperature, HR, BP, RR, SpO2
**Core labs**: WBC, AST, ALT, ALP, bilirubin (total/direct), lipase, creatinine, GFR
**Coagulation**: platelets, INR, PT, PTT
**Additional labs**: hemoglobin, glucose, TSH, CRP, D-dimer
**Allergies**: allergen, severity, reaction type; contrast/gadolinium/shellfish flags
**Medications**: home + inpatient; anticoagulant, metformin, nephrotoxic drug flags
**Prior imaging**: US, CT, MRI, HIDA — all with result classification (negative/equivocal/positive)
**Pregnancy**: from special_populations + problem_list ICD codes + trimester detection
**Devices/implants**: cardiac devices, metallic implants from surgical history
**Clinical state**: ICU, post-operative, TPN, septic, clinical setting
**Medical history**: conditions with ICD-10 codes

### Verification Columns
- `AI Recommended Protocol`: What the AI system recommends as the correct protocol
- `Issue/Notes`: Clinical rationale with ACR ratings explaining why the recommendation differs
- `Pattern Type`: 6 balanced categories (16 cases each) testing different failure modes