In [None]:
# CELL 1: Setup and Dependencies for Synapse v2.1
# Run this cell first to install packages and set up environment

# Install required packages quietly
!pip install openai ipywidgets matplotlib seaborn pandas -q

# Enable widgets in Colab (required for interactive dashboard)
from google.colab import output
output.enable_custom_widget_manager()

# Import all required libraries
import re
import time
import traceback
import json
import logging
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass, field
from enum import Enum
import os
from openai import OpenAI
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
import pandas as pd
import numpy as np
import sys

# Set up OpenAI API key
import getpass

# Check if API key exists and show its status
existing_key = os.environ.get('OPENAI_API_KEY', '')
if existing_key:
    # Mask the key for security (show first 8 chars only)
    masked_key = existing_key[:8] + '...' + existing_key[-4:] if len(existing_key) > 12 else 'Key too short'
    print(f"🔑 Found existing OpenAI API key: {masked_key}")
    print("Do you want to use this existing key? (y/n): ", end='')
    use_existing = input().lower().strip()

    if use_existing != 'y':
        os.environ['OPENAI_API_KEY'] = getpass.getpass('🔑 Enter new OpenAI API key: ')
        print("✅ New API key set!")
else:
    os.environ['OPENAI_API_KEY'] = getpass.getpass('🔑 Enter your OpenAI API key: ')
    print("✅ API key set!")

# Verify setup
print("\n✅ Setup complete!")
print(f"📚 Libraries loaded: re, time, json, logging, dataclasses, openai, matplotlib, seaborn, pandas, numpy")
print(f"🔑 OpenAI API key: {'Set' if os.environ.get('OPENAI_API_KEY') else 'Not set'}")
print(f"📊 Widgets enabled: Yes")
print("\n" + "="*60)
print("✨ Ready for Cell 2: Load your original Synapse v2.0 code")
print("="*60)

In [None]:
import re
import time
import traceback
import json
import logging
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass, field
from enum import Enum
import os
from openai import OpenAI

# === Configuration Management ===
class Config:
    """Configuration management for Synapse"""
    def __init__(self):
        # Set your actual OpenAI API key here or use environment variable
        self.openai_api_key = os.getenv('OPENAI_API_KEY')
        self.log_level = os.getenv('LOG_LEVEL', 'WARNING')  # Less verbose by default
        self.enable_gpt_summary = os.getenv('ENABLE_GPT_SUMMARY', 'true').lower() == 'true'
        self.confidence_threshold = float(os.getenv('CONFIDENCE_THRESHOLD', '0.7'))
        self.enable_stroke_detection = os.getenv('ENABLE_STROKE_DETECTION', 'true').lower() == 'true'
        self.enable_seizure_detection = os.getenv('ENABLE_SEIZURE_DETECTION', 'true').lower() == 'true'

config = Config()

# === Severity Levels ===
class Severity(Enum):
    CRITICAL = "🚨"
    WARNING = "⚠️"
    MOTOR = "🦴"
    NEURO = "🧠"
    MONITORING = "🔍"
    NORMAL = "✅"

# === Structured Clinical Findings ===
@dataclass
class ClinicalFinding:
    """Structured representation of a clinical finding"""
    message: str
    severity: Severity
    confidence: float
    category: str
    negated: bool = False
    baseline: bool = False
    timestamp: datetime = field(default_factory=datetime.now)

@dataclass
class ClinicalFindings:
    """Collection of all clinical findings"""
    critical_flags: List[ClinicalFinding] = field(default_factory=list)
    warning_flags: List[ClinicalFinding] = field(default_factory=list)
    motor_flags: List[ClinicalFinding] = field(default_factory=list)
    neuro_flags: List[ClinicalFinding] = field(default_factory=list)
    monitoring_flags: List[ClinicalFinding] = field(default_factory=list)
    normal_findings: List[ClinicalFinding] = field(default_factory=list)
    overall_confidence: float = 0.0
    processing_time_ms: int = 0

    def get_all_findings(self) -> List[ClinicalFinding]:
        """Get all findings sorted by severity"""
        all_findings = (self.critical_flags + self.warning_flags +
                       self.motor_flags + self.neuro_flags +
                       self.monitoring_flags + self.normal_findings)
        return sorted(all_findings, key=lambda x: list(Severity).index(x.severity))

# === Compiled Patterns for Performance ===
class CompiledPatterns:
    """Pre-compiled regex patterns for better performance"""
    def __init__(self):
        # Consciousness patterns
        self.gcs_pattern = re.compile(r'gcs\s*(?:of|is|at|score)?\s*(\d+)', re.IGNORECASE)
        self.gcs15_pattern = re.compile(r'gcs\s*(?:of|is|at|score)?\s*15|glasgow\s+(?:coma\s+)?(?:scale|score)\s*(?:of|is|at)?\s*15', re.IGNORECASE)

        # Motor strength patterns - Enhanced for both abbreviated and expanded forms
        self.detailed_strength_pattern = re.compile(r'(rue|lue|rle|lle|right\s+upper\s+extremity|left\s+upper\s+extremity|right\s+lower\s+extremity|left\s+lower\s+extremity)\s+(\d+[-+]?/\d+[-+]?/\d+[-+]?/\d+[-+]?/\d+[-+]?)', re.IGNORECASE)
        self.generalized_strength_pattern = re.compile(r'(rue|lue|rle|lle|right\s+upper\s+extremity|left\s+upper\s+extremity|right\s+lower\s+extremity|left\s+lower\s+extremity)\s+(\d+[-+]?)\s+throughout', re.IGNORECASE)

        # Sensory level pattern - more flexible
        self.sensory_level_pattern = re.compile(r'~?([tl]\d+)\s+sensory\s+level', re.IGNORECASE)

        # Clonus pattern - more flexible
        self.clonus_pattern = re.compile(r'(rt|lt|right|left|r|l)?\s*clonus', re.IGNORECASE)

        # Babinski pattern
        self.babinski_pattern = re.compile(r'(rt|lt|right|left|r|l)?\s*babinski', re.IGNORECASE)

        # Hyperreflexia patterns
        self.hyperreflexia_pattern = re.compile(r'(rue|lue|rle|lle|right|left|bilateral|bl)?\s*(?:dtrs?|reflexes?)\s*(\d+\+)', re.IGNORECASE)

        # Saddle anesthesia pattern
        self.saddle_anesthesia_pattern = re.compile(r'saddle\s+anesthesia', re.IGNORECASE)

        # Sensory deficit patterns
        self.sensory_deficit_pattern = re.compile(r'(?:decreased|diminished|reduced|impaired|absent)\s+sensation', re.IGNORECASE)
        self.bilateral_sensory_pattern = re.compile(r'(?:bilateral|bl|both)\s+(?:feet|soles|lower\s+extremit)', re.IGNORECASE)

        # Reflex patterns
        self.absent_reflexes_pattern = re.compile(r'absent\s+(?:patellar|knee|dtr|reflex)', re.IGNORECASE)
        self.hyporeflexia_pattern = re.compile(r'(?:hyporeflexia|areflexia|absent.*reflex)', re.IGNORECASE)

        # Spinal tumor patterns
        self.spinal_tumor_pattern = re.compile(r'(?:plasmacytoma|myeloma|metastas|tumor).*(?:spine|spinal|vertebr)', re.IGNORECASE)

# Global compiled patterns instance
patterns = CompiledPatterns()

# === Enhanced Abbreviation Dictionary ===
abbreviation_map = {
    # Common medical prefixes
    "pmh": "past medical history",
    "hx": "history",
    "s/p": "status post",
    "p/t": "presenting to",
    "p/w": "presented with",
    "c/f": "concern for",
    "c/b": "complicated by",
    "c/t": "compared to",
    "c/s": "consult",
    "w/": "with",
    "w/o": "without",
    "wo": "without",
    "wwo": "with and without",
    "b/b": "bowel or bladder",
    "2/2": "secondary to",

    # Locations and facilities
    "ED": "emergency department",
    "OSH": "outside hospital",
    "CCH": "Cook County Hospital",
    "NSGY": "neurosurgery",

    # Imaging
    "CTA": "computed tomography angiography",
    "CTH": "CT head",
    "CAP": "chest abdomen pelvis",
    "MRI": "magnetic resonance imaging",
    "XR": "x-ray",

    # Procedures
    "ACDF": "anterior cervical discectomy and fusion",
    "lami": "laminectomy",
    "lamis": "laminectomies",

    # Conditions
    "tSAH": "traumatic subarachnoid hemorrhage",
    "SDH": "subdural hematoma",
    "aSDH": "acute subdural hematoma",
    "mets": "metastases",
    "AMS": "altered mental status",
    "AVN": "avascular necrosis",
    "fx": "fracture",
    "comp fx": "compression fracture",

    # Body parts and directions
    "TP": "transverse process",
    "SP": "spinous process",
    "BL": "bilateral",
    "Rt": "right",
    "Lt": "left",
    "L": "left",
    "R": "right",
    "RUE": "right upper extremity",
    "LUE": "left upper extremity",
    "RLE": "right lower extremity",
    "LLE": "left lower extremity",
    "BUE": "bilateral upper extremities",
    "BLE": "bilateral lower extremities",
    "RUL": "right upper lobe",
    "LBP": "low back pain",

    # Examination terms
    "DTR": "deep tendon reflex",
    "DTRs": "deep tendon reflexes",
    "TTP": "tenderness to palpation",
    "MAES": "moves all extremities spontaneously",
    "maes": "moves all extremities spontaneously",
    "EHL": "extensor hallucis longus",

    # Orientation
    "ox0": "not oriented to person, place, or time",
    "ox1": "oriented to person only",
    "ox2": "oriented to person and place or time",
    "ox3": "oriented to person, place, and time",
    "GCS": "Glasgow Coma Scale",

    # Movement responses
    "loc": "localizes to pain",
    "wd": "withdraws to pain",

    # Trauma
    "GLF": "ground level fall",
    "LOC": "loss of consciousness",
    "BHT": "blunt head trauma",
    "MVC": "motor vehicle collision",
    "MVA": "motor vehicle accident",

    # Patient state
    "ADLs": "activities of daily living",
    "ACAP": "anticoagulant or antiplatelet therapy",
    "PVR": "post-void residual",

    # Anatomical regions
    "T-spine": "thoracic spine",
    "L-spine": "lumbar spine",
    "C-spine": "cervical spine",

    # Time descriptors
    "dx": "diagnosed",
    "x3d": "for 3 days",
    "x2d": "for 2 days",
    "x4d": "for 4 days",
    "x10d": "for 10 days",
    "x1w": "for 1 week",
    "x1mo": "for 1 month",

    # Labs
    "Na": "sodium",
    "CBC": "complete blood count",
    "BMP": "basic metabolic panel",
    "Coags": "coagulation studies",
    "HH": "hemoglobin hematocrit",
    "PLT": "platelets",
    "PT": "prothrombin time",
    "INR": "international normalized ratio",
    "EtOH": "alcohol",

    # Status descriptors
    "neg": "negative",
    "nl": "normal",
    "wnl": "within normal limits",
    "N/V": "nausea and vomiting",
    "HA": "headache",
    "Pt": "patient",
    "pt": "patient",
    "s/s": "signs and symptoms",
    "AFO": "ankle-foot orthosis",
    "↓": "decreased",
    "↑": "increased",

    # Comorbidities
    "HTN": "hypertension",
    "HLD": "hyperlipidemia",
    "HF": "heart failure",
    "PE": "pulmonary embolism",
    "DVT": "deep vein thrombosis",
    "Ca": "cancer",

    # Treatments
    "PT": "physical therapy",
    "Tx": "treatment",
    "ASA": "aspirin",
    "tx": "treatment"
}

# === Enhanced Logging System ===
class ClinicalLogger:
    """Enhanced logging system for clinical processing"""
    def __init__(self):
        self.logger = logging.getLogger('synapse')
        self.logger.setLevel(getattr(logging, config.log_level))

        # Create console handler
        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)

        # Audit trail
        self.audit_trail = []

    def log_clinical_processing(self, input_text: str, findings: ClinicalFindings,
                              abbreviations_expanded: List[str]):
        """Log clinical processing details"""
        audit_entry = {
            'timestamp': datetime.now().isoformat(),
            'input_length': len(input_text),
            'abbreviations_expanded': abbreviations_expanded,
            'findings_count': len(findings.get_all_findings()),
            'processing_time_ms': findings.processing_time_ms,
            'confidence': findings.overall_confidence
        }
        self.audit_trail.append(audit_entry)

        self.logger.info(f"Processed clinical text: {len(findings.get_all_findings())} findings, "
                        f"{findings.processing_time_ms}ms, confidence: {findings.overall_confidence:.2f}")

    def log_error(self, component: str, error_code: str, error_message: str,
                  severity: str = "medium", stack_trace: str = None):
        """Log errors with structured format"""
        error_entry = {
            'timestamp': datetime.now().isoformat(),
            'component': component,
            'error_code': error_code,
            'error_message': error_message,
            'severity': severity,
            'stack_trace': stack_trace
        }
        self.audit_trail.append(error_entry)

        self.logger.error(f"[{component}] {error_code}: {error_message}")
        if stack_trace:
            self.logger.debug(stack_trace)

# Global logger instance
clinical_logger = ClinicalLogger()

# === Enhanced Abbreviation Expansion ===
def expand_abbreviations(blurb: str, dictionary: Dict[str, str]) -> Tuple[str, List[str]]:
    """
    Improved abbreviation expansion with context awareness and performance optimization
    """
    if not blurb:
        return blurb, []

    expanded_terms = set()
    blurb_original = blurb

    try:
        # First pass: complete word matches (with boundaries)
        for abbr, full in dictionary.items():
            # Skip very short abbreviations that might cause issues
            if len(abbr) <= 1:
                continue

            # Use word boundaries for complete word matches
            pattern = r'\b' + re.escape(abbr) + r'\b'
            matches = re.findall(pattern, blurb, re.IGNORECASE)

            if matches:
                expanded_terms.add(abbr)
                blurb = re.sub(pattern, full, blurb, flags=re.IGNORECASE)

        # Second pass: special patterns without boundaries
        special_patterns = {
            r's/p': 'status post',
            r'c/f': 'concern for',
            r'p/t': 'presenting to',
            r'c/t': 'compared to',
            r'c/b': 'complicated by',
            r'c/s': 'consult',
            r'w/': 'with',
            r'w/o': 'without',
            r'b/b': 'bowel or bladder',
            r'2/2': 'secondary to',
            r'p/w': 'presented with'
        }

        for pattern, replacement in special_patterns.items():
            matches = re.findall(re.escape(pattern), blurb, re.IGNORECASE)
            if matches:
                expanded_terms.add(pattern)
                blurb = re.sub(re.escape(pattern), replacement, blurb, flags=re.IGNORECASE)

        clinical_logger.logger.debug(f"Expanded {len(expanded_terms)} abbreviations: {', '.join(expanded_terms)}")
        return blurb, list(expanded_terms)

    except Exception as e:
        error_msg = f"Error in abbreviation expansion: {str(e)}"
        clinical_logger.log_error(
            component="abbreviation_expansion",
            error_code="EXPANSION_ERROR",
            error_message=error_msg,
            stack_trace=traceback.format_exc()
        )
        return blurb_original, []

# === Enhanced Negation Detection ===
def is_negated(term: str, blurb: str) -> bool:
    """
    Enhanced negation detection with improved context awareness
    """
    if not term or not blurb:
        return False

    term = term.strip().lower()
    blurb_lower = blurb.lower()

    # Special handling for exam findings - if something appears after "exam:" or "intact except", it's usually a positive finding
    exam_sections = re.split(r'exam\s*:', blurb_lower)
    if len(exam_sections) > 1:
        exam_part = exam_sections[-1]  # Get the part after "exam:"
        if term in exam_part:
            # Check if it's in an "intact except" context - these are positive findings
            if re.search(r'intact\s+except.*?' + re.escape(term), exam_part):
                return False
            # Check if it's directly after exam findings without negation
            if not re.search(r'(?:denies|no|negative|without|absent)\s+.*?' + re.escape(term), exam_part):
                return False

    # Parse the text into sentences to restrict negation scope
    sentences = re.split(r'[.!?]\s+', blurb_lower)

    # Check sentences containing the term
    for sentence in sentences:
        if term in sentence:
            # If the sentence contains "intact except", findings after that are positive
            if 'intact except' in sentence:
                intact_except_pos = sentence.find('intact except')
                term_pos = sentence.find(term)
                if term_pos > intact_except_pos:
                    return False  # Term appears after "intact except" so it's a positive finding

            # Check for denial patterns in the sentence
            denial_patterns = [
                r'denies', r'denied', r'no', r'not', r'negative',
                r'absence of', r'without', r'hasn\'t', r'doesn\'t',
                r'has not', r'does not', r'not (?:having|experiencing|showing)',
                r'hasn\'t endorsed', r'hasn\'t had', r'hasn\'t experienced'
            ]

            # Only consider negation if the denial word is close to the term (within 10 words)
            for pattern in denial_patterns:
                denial_matches = list(re.finditer(pattern, sentence))
                for denial_match in denial_matches:
                    denial_pos = denial_match.end()
                    term_pos = sentence.find(term)
                    if term_pos > denial_pos:
                        # Check if there are too many words between denial and term
                        words_between = len(sentence[denial_pos:term_pos].split())
                        if words_between <= 10:  # Only negate if close proximity
                            # Check for double negatives
                            if not re.search(r'not.+?(?:denied|negative|absent)', sentence):
                                return True

    return False

# === Enhanced Exam Components Extraction ===
def extract_enhanced_exam_components(blurb: str, blurb_lower: str) -> List[ClinicalFinding]:
    """
    Comprehensive exam component extraction with all clinical patterns
    """
    findings = []

    try:
        # Identify exam section
        exam_section = blurb_lower
        exam_section_patterns = [
            r'exam(?:\s*[-:]\s*)(.*?)(?:\n\n|\Z)',
            r'(?:physical|neuro(?:logical)?|motor)\s+exam(?:\s*[-:]\s*)(.*?)(?:\n\n|\Z)',
            r'(?:physical|neurological) findings(?:\s*[-:]\s*)(.*?)(?:\n\n|\Z)'
        ]

        for pattern in exam_section_patterns:
            match = re.search(pattern, blurb_lower, re.DOTALL)
            if match:
                exam_section = match.group(1).strip()
                break

        # === CONSCIOUSNESS AND ORIENTATION ASSESSMENT ===
        # Check for GCS score
        gcs_pattern = r'gcs\s*(?:of|is|at|score)?\s*(\d+)'
        gcs_match = re.search(gcs_pattern, exam_section)

        if gcs_match:
            gcs_score = int(gcs_match.group(1))
            if gcs_score == 15:
                findings.append(ClinicalFinding(
                    message="Full consciousness: GCS 15",
                    severity=Severity.NORMAL,
                    confidence=0.95,
                    category="consciousness"
                ))
            elif gcs_score >= 13:
                findings.append(ClinicalFinding(
                    message=f"Mild consciousness impairment: GCS {gcs_score}",
                    severity=Severity.MONITORING,
                    confidence=0.9,
                    category="consciousness"
                ))
            elif gcs_score >= 9:
                findings.append(ClinicalFinding(
                    message=f"Moderate consciousness impairment: GCS {gcs_score}",
                    severity=Severity.WARNING,
                    confidence=0.95,
                    category="consciousness"
                ))
            else:
                findings.append(ClinicalFinding(
                    message=f"Severe consciousness impairment: GCS {gcs_score}",
                    severity=Severity.CRITICAL,
                    confidence=0.98,
                    category="consciousness"
                ))

        # Check for arousal descriptions
        arousal_patterns = {
            r'(?:eyes?\s+open(?:ing)?)\s+(?:to|with)\s+(?:heavy\s+)?(?:stim|stimulation|pain)': (
                "Decreased arousal level—requires stimulation for eye opening", Severity.NEURO, 0.85),
            r'alert|awake|wide awake': (
                "Alert and awake", Severity.NORMAL, 0.9),
            r'drowsy|lethargic|somnolent': (
                "Decreased arousal—drowsy/lethargic", Severity.NEURO, 0.8)
        }

        for pattern, (message, severity, confidence) in arousal_patterns.items():
            match = re.search(pattern, exam_section)
            if match and not is_negated(match.group(0), blurb_lower):
                findings.append(ClinicalFinding(
                    message=message,
                    severity=severity,
                    confidence=confidence,
                    category="arousal"
                ))

        # Check for orientation status
        orientation_patterns = {
            r'\box0\b|not oriented|disoriented': (
                "Altered mental status: Ox0 or disoriented", Severity.NEURO, 0.9),
            r'\box1\b|oriented to person only': (
                "Partial orientation: Ox1", Severity.NEURO, 0.85),
            r'\box2\b|oriented to person and (place|time)': (
                "Oriented to person and place: Ox2", Severity.MONITORING, 0.8),
            r'\box3\b|oriented x3|fully oriented': (
                "Fully oriented: Ox3", Severity.NORMAL, 0.9)
        }

        for pattern, (message, severity, confidence) in orientation_patterns.items():
            match = re.search(pattern, exam_section)
            if match and not is_negated(match.group(0), blurb_lower):
                findings.append(ClinicalFinding(
                    message=message,
                    severity=severity,
                    confidence=confidence,
                    category="orientation"
                ))
                break  # Only add the most specific orientation finding

        # === MOTOR RESPONSES TO PAINFUL STIMULI ===
        motor_response_patterns = {
            r'bue\s+loc': (
                "Localizes to pain in bilateral upper extremities", Severity.NEURO, 0.85),
            r'ble\s+wd': (
                "Withdrawal response in bilateral lower extremities", Severity.NEURO, 0.85),
            r'bue\s+loc\s+ble\s+wd': (
                "Localizes in upper extremities, withdraws in lower extremities", Severity.NEURO, 0.9),
            r'(?:does not follow|unable to follow) commands': (
                "Decreased level of consciousness—unable to follow commands", Severity.NEURO, 0.9),
            r'follows? commands': (
                "Follows commands appropriately", Severity.NORMAL, 0.85)
        }

        for pattern, (message, severity, confidence) in motor_response_patterns.items():
            match = re.search(pattern, exam_section)
            if match and not is_negated(match.group(0), blurb_lower):
                findings.append(ClinicalFinding(
                    message=message,
                    severity=severity,
                    confidence=confidence,
                    category="motor_response"
                ))

        # === CRANIAL NERVE EXAMINATION ===
        cranial_nerve_patterns = {
            r'(?:facial|face)\s+(?:droop|weakness|asymmetry)': (
                "Facial droop—possible cranial nerve VII weakness", Severity.MOTOR, 0.85),
            r'(?:tongue|lingual)\s+(?:deviation|weakness)': (
                "Tongue deviation—possible cranial nerve XII weakness", Severity.MOTOR, 0.85),
            r'(?:slurred|dysarthric)\s+speech': (
                "Dysarthria/slurred speech—monitor for progression", Severity.NEURO, 0.8)
        }

        for pattern, (message, severity, confidence) in cranial_nerve_patterns.items():
            match = re.search(pattern, exam_section)
            if match and not is_negated(match.group(0), blurb_lower):
                findings.append(ClinicalFinding(
                    message=message,
                    severity=severity,
                    confidence=confidence,
                    category="cranial_nerves"
                ))

        # === COMPREHENSIVE MOTOR STRENGTH ASSESSMENT ===
        limb_map = {
            "rue": "RIGHT UPPER EXTREMITY",
            "lue": "LEFT UPPER EXTREMITY",
            "rle": "RIGHT LOWER EXTREMITY",
            "lle": "LEFT LOWER EXTREMITY",
            "bue": "BILATERAL UPPER EXTREMITIES",
            "ble": "BILATERAL LOWER EXTREMITIES",
            "right arm": "RIGHT UPPER EXTREMITY",
            "left arm": "LEFT UPPER EXTREMITY",
            "right upper extremity": "RIGHT UPPER EXTREMITY",
            "left upper extremity": "LEFT UPPER EXTREMITY",
            "right leg": "RIGHT LOWER EXTREMITY",
            "left leg": "LEFT LOWER EXTREMITY",
            "right lower extremity": "RIGHT LOWER EXTREMITY",
            "left lower extremity": "LEFT LOWER EXTREMITY"
        }

        muscle_groups = {
            "RIGHT UPPER EXTREMITY": ["shoulder abduction", "elbow flexion", "elbow extension", "wrist extension", "handgrip"],
            "LEFT UPPER EXTREMITY": ["shoulder abduction", "elbow flexion", "elbow extension", "wrist extension", "handgrip"],
            "RIGHT LOWER EXTREMITY": ["hip flexion", "knee extension", "ankle dorsiflexion", "EHL", "ankle plantarflexion"],
            "LEFT LOWER EXTREMITY": ["hip flexion", "knee extension", "ankle dorsiflexion", "EHL", "ankle plantarflexion"]
        }

        processed_extremities = set()
        exam_normalized = exam_section

        # Normalize strength notation
        exam_normalized = re.sub(r'(\d+[-+]?)\s+(\d+[-+]?)', r'\1/\2', exam_normalized)
        exam_normalized = re.sub(r'(\d+[-+]?)[,;](\d+[-+]?)', r'\1/\2', exam_normalized)
        exam_normalized = re.sub(r'(\d+)s', r'\1', exam_normalized)

        # Complex multiple muscle strength (e.g., RUE 5/5/5/4-/4)
        complex_pattern = r'(rue|lue|rle|lle|right\s+upper\s+extremity|left\s+upper\s+extremity|right\s+lower\s+extremity|left\s+lower\s+extremity)(?:\s*[:;-])?\s*(?:\()?(\d+[-+]?(?:[\/\s,.-]?\d+[-+]?){2,4})(?:\))?'
        complex_matches = re.finditer(complex_pattern, exam_normalized, re.IGNORECASE)

        for match in complex_matches:
            extremity = match.group(1).lower().replace(' ', ' ')
            if extremity in processed_extremities:
                continue

            extremity_name = limb_map.get(extremity, extremity.upper())
            strength_text = match.group(2)
            values = re.findall(r'(\d+[-+]?)', strength_text)

            weak_values = []
            for i, val in enumerate(values):
                if val.strip() not in ['5', '5+']:
                    if extremity_name in muscle_groups and i < len(muscle_groups[extremity_name]):
                        muscle_name = muscle_groups[extremity_name][i]
                    else:
                        muscle_name = f"muscle {i+1}"
                    weak_values.append(f"{muscle_name}: {val}")

            if weak_values:
                processed_extremities.add(extremity)
                weakness_details = ", ".join(weak_values)

                # Check severity
                has_severe_weakness = any(int(val.strip().replace('+', '').replace('-', '')) <= 2
                                        for val in values if val.strip().replace('+', '').replace('-', '').isdigit())
                severity = Severity.WARNING if has_severe_weakness else Severity.MOTOR
                confidence = 0.95 if has_severe_weakness else 0.85

                findings.append(ClinicalFinding(
                    message=f"Weakness in {extremity_name}: {weakness_details}",
                    severity=severity,
                    confidence=confidence,
                    category="motor"
                ))

        # Simple strength scores (e.g., LUE 4+/5)
        simple_pattern = r'(rue|lue|rle|lle|right arm|left arm|right leg|left leg)\s+(\d+[-+]?)/5'
        simple_matches = re.finditer(simple_pattern, exam_section, re.IGNORECASE)

        for match in simple_matches:
            extremity = match.group(1).lower()
            if extremity in processed_extremities:
                continue

            strength = match.group(2)
            if strength not in ['5', '5+']:
                processed_extremities.add(extremity)
                extremity_name = limb_map.get(extremity, extremity.upper())

                severity = Severity.WARNING if int(strength.replace('+', '').replace('-', '')) <= 2 else Severity.MOTOR
                confidence = 0.9

                findings.append(ClinicalFinding(
                    message=f"Weakness in {extremity_name}: {strength}/5",
                    severity=severity,
                    confidence=confidence,
                    category="motor"
                ))

        # === REFLEX ASSESSMENT ===
        reflex_patterns = {
            r'hoffmans?\b|hoffman\'?s?\s+sign': (
                "Hoffman's sign—UMN risk", Severity.WARNING, 0.85),
            r'babinski\b|babinski\'?s?\s+sign': (
                "Babinski reflex noted", Severity.WARNING, 0.85),
            r'(?:hyperreflexia|[\d][\+]?\s+dtrs?|dtrs?\s+[\d][\+]?)': (
                "Hyperreflexia—monitor for UMN lesion", Severity.WARNING, 0.8),
            r'(?:hyporeflexia|areflexia|0\s+dtrs?|dtrs?\s+0)': (
                "Hyporeflexia—possible LMN involvement", Severity.WARNING, 0.8),
            r'(?:normal\s+reflexes|2\+?\s+dtrs?|dtrs?\s+2\+?)': (
                "Reflexes within normal limits", Severity.NORMAL, 0.9)
        }

        for pattern, (message, severity, confidence) in reflex_patterns.items():
            matches = re.finditer(pattern, exam_section, re.IGNORECASE)
            for match in matches:
                if not is_negated(match.group(0), blurb_lower):
                    findings.append(ClinicalFinding(
                        message=message,
                        severity=severity,
                        confidence=confidence,
                        category="reflexes"
                    ))
                    break

        # === SENSORY ASSESSMENT ===
        sensory_patterns = {
            r'(?:decreased|diminished|reduced|impaired)\s+sensation': (
                "Sensory deficit present", Severity.NEURO, 0.8),
            r'numbness\b|paresthesia\b|dysesthesia\b': (
                "Sensory deficit present", Severity.NEURO, 0.8),
            r'sensation\s+(?:is\s+)?(?:intact|normal|preserved)|silt': (
                "Sensation intact", Severity.NORMAL, 0.9)
        }

        for pattern, (message, severity, confidence) in sensory_patterns.items():
            match = re.search(pattern, exam_section, re.IGNORECASE)
            if match and not is_negated(match.group(0), blurb_lower):
                if "deficit" in message and is_baseline_deficit_mentioned(blurb_lower):
                    findings.append(ClinicalFinding(
                        message="Sensory findings noted but consistent with baseline",
                        severity=Severity.NORMAL,
                        confidence=0.7,
                        category="sensory"
                    ))
                else:
                    findings.append(ClinicalFinding(
                        message=message,
                        severity=severity,
                        confidence=confidence,
                        category="sensory"
                    ))
                break

        # === CAUDA EQUINA ASSESSMENT ===
        # Check for bowel/bladder dysfunction
        bowel_bladder_pattern = r'(?:bowel|bladder)\s+(?:incontinence|retention|dysfunction)'
        bowel_bladder_match = re.search(bowel_bladder_pattern, blurb_lower, re.IGNORECASE)
        if bowel_bladder_match and not is_negated(bowel_bladder_match.group(0), blurb_lower):
            findings.append(ClinicalFinding(
                message="Bowel/bladder dysfunction—evaluate for cauda equina",
                severity=Severity.WARNING,
                confidence=0.9,
                category="cauda_equina"
            ))

        # Check for rectal tone
        rectal_tone_pattern = r'\+rectal\s+tone|normal\s+rectal\s+tone|rectal\s+tone\s+(?:present|intact)'
        rectal_tone_match = re.search(rectal_tone_pattern, exam_section, re.IGNORECASE)
        if rectal_tone_match and not is_negated(rectal_tone_match.group(0), blurb_lower):
            findings.append(ClinicalFinding(
                message="Rectal tone intact",
                severity=Severity.NORMAL,
                confidence=0.9,
                category="cauda_equina"
            ))

        return findings

    except Exception as e:
        clinical_logger.log_error(
            component="extract_enhanced_exam_components",
            error_code="ENHANCED_EXAM_ERROR",
            error_message=str(e),
            stack_trace=traceback.format_exc()
        )
        return []
def risk_flag(blurb: str, prior_findings: Optional[ClinicalFindings] = None) -> ClinicalFindings:
    """
    Enhanced universal risk flagging engine with structured output
    """
    if not blurb:
        return ClinicalFindings(
            critical_flags=[ClinicalFinding(
                message="Empty input—manual review advised",
                severity=Severity.WARNING,
                confidence=0.0,
                category="system"
            )]
        )

    start_time = time.time()
    findings = ClinicalFindings()

    try:
        # Expand abbreviations
        expanded_blurb, expanded_terms = expand_abbreviations(blurb, abbreviation_map)
        blurb_lower = expanded_blurb.lower()

        # === MOTOR STRENGTH ASSESSMENT ===
        limb_map = {
            "rue": "RIGHT UPPER EXTREMITY",
            "lue": "LEFT UPPER EXTREMITY",
            "rle": "RIGHT LOWER EXTREMITY",
            "lle": "LEFT LOWER EXTREMITY",
            "right upper extremity": "RIGHT UPPER EXTREMITY",
            "left upper extremity": "LEFT UPPER EXTREMITY",
            "right lower extremity": "RIGHT LOWER EXTREMITY",
            "left lower extremity": "LEFT LOWER EXTREMITY",
            "right arm": "RIGHT UPPER EXTREMITY",
            "left arm": "LEFT UPPER EXTREMITY",
            "right leg": "RIGHT LOWER EXTREMITY",
            "left leg": "LEFT LOWER EXTREMITY"
        }

        # Detailed strength patterns (e.g., LLE 2/5/4+/4+/5 OR left lower extremity 2/5/4+/4+/5)
        detailed_matches = patterns.detailed_strength_pattern.finditer(expanded_blurb)
        for match in detailed_matches:
            extremity = match.group(1).lower()
            extremity_name = limb_map.get(extremity, extremity.upper())
            strength_text = match.group(2)
            values = re.findall(r'(\d+[-+]?)', strength_text)

            # Analyze the strength values - any value < 5 indicates weakness
            weak_values = []
            muscle_groups = ["hip flexion", "knee extension", "ankle dorsiflexion", "EHL", "ankle plantarflexion"]

            for i, val in enumerate(values):
                val_clean = val.strip().replace('+', '').replace('-', '')
                if val_clean.isdigit() and int(val_clean) < 5:
                    muscle_name = muscle_groups[i] if i < len(muscle_groups) else f"muscle {i+1}"
                    weak_values.append(f"{muscle_name}: {val}")

            if weak_values:
                weakness_details = ", ".join(weak_values)
                # Flag severe weakness (any muscle 2/5 or less) as WARNING instead of just MOTOR
                has_severe_weakness = any(int(val.strip().replace('+', '').replace('-', '')) <= 2
                                        for val in values if val.strip().replace('+', '').replace('-', '').isdigit())
                severity = Severity.WARNING if has_severe_weakness else Severity.MOTOR
                confidence = 0.95 if has_severe_weakness else 0.85

                if severity == Severity.WARNING:
                    findings.warning_flags.append(ClinicalFinding(
                        message=f"Significant weakness in {extremity_name}: {weakness_details}",
                        severity=severity,
                        confidence=confidence,
                        category="motor"
                    ))
                else:
                    findings.motor_flags.append(ClinicalFinding(
                        message=f"Weakness in {extremity_name}: {weakness_details}",
                        severity=severity,
                        confidence=confidence,
                        category="motor"
                    ))

        # Generalized strength patterns (e.g., "LLE 4+ throughout")
        generalized_matches = patterns.generalized_strength_pattern.finditer(expanded_blurb)
        for match in generalized_matches:
            extremity = match.group(1).lower()
            extremity_name = limb_map.get(extremity, extremity.upper())
            strength_value = match.group(2)

            # Convert strength value to numeric for comparison
            strength_numeric = strength_value.replace('+', '').replace('-', '')
            if strength_numeric.isdigit() and int(strength_numeric) < 5:
                severity = Severity.WARNING if int(strength_numeric) <= 3 else Severity.MOTOR
                confidence = 0.9

                if severity == Severity.WARNING:
                    findings.warning_flags.append(ClinicalFinding(
                        message=f"Generalized weakness in {extremity_name}: {strength_value} throughout",
                        severity=severity,
                        confidence=confidence,
                        category="motor"
                    ))
                else:
                    findings.motor_flags.append(ClinicalFinding(
                        message=f"Weakness in {extremity_name}: {strength_value} throughout",
                        severity=severity,
                        confidence=confidence,
                        category="motor"
                    ))

        # === DIRECT SENSORY LEVEL DETECTION ===
        sensory_level_match = patterns.sensory_level_pattern.search(blurb_lower)
        if sensory_level_match and not is_negated(sensory_level_match.group(0), blurb_lower):
            level = sensory_level_match.group(1).upper()
            findings.warning_flags.append(ClinicalFinding(
                message=f"Sensory level at {level}—possible cord involvement",
                severity=Severity.WARNING,
                confidence=0.9,
                category="sensory"
            ))

        # === DIRECT CLONUS DETECTION ===
        clonus_match = patterns.clonus_pattern.search(blurb_lower)
        if clonus_match and not is_negated(clonus_match.group(0), blurb_lower):
            side = clonus_match.group(1) if clonus_match.group(1) else ""
            side_text = f" {side}" if side else ""
            findings.warning_flags.append(ClinicalFinding(
                message=f"Clonus detected{side_text}—UMN involvement",
                severity=Severity.WARNING,
                confidence=0.9,
                category="reflexes"
            ))

        # === BABINSKI DETECTION ===
        babinski_match = patterns.babinski_pattern.search(blurb_lower)
        if babinski_match and not is_negated(babinski_match.group(0), blurb_lower):
            side = babinski_match.group(1) if babinski_match.group(1) else ""
            side_text = f" {side}" if side else ""
            findings.warning_flags.append(ClinicalFinding(
                message=f"Babinski sign{side_text}—UMN involvement",
                severity=Severity.WARNING,
                confidence=0.9,
                category="reflexes"
            ))

        # === HYPERREFLEXIA DETECTION ===
        hyperreflexia_match = patterns.hyperreflexia_pattern.search(blurb_lower)
        if hyperreflexia_match and not is_negated(hyperreflexia_match.group(0), blurb_lower):
            extremity = hyperreflexia_match.group(1) if hyperreflexia_match.group(1) else ""
            reflex_grade = hyperreflexia_match.group(2)
            extremity_text = f" {extremity}" if extremity else ""
            findings.warning_flags.append(ClinicalFinding(
                message=f"Hyperreflexia{extremity_text} ({reflex_grade})—UMN involvement",
                severity=Severity.WARNING,
                confidence=0.85,
                category="reflexes"
            ))

        # === SADDLE ANESTHESIA DETECTION ===
        saddle_match = patterns.saddle_anesthesia_pattern.search(blurb_lower)
        if saddle_match and not is_negated(saddle_match.group(0), blurb_lower):
            findings.critical_flags.append(ClinicalFinding(
                message="Saddle anesthesia—urgent cauda equina evaluation needed",
                severity=Severity.CRITICAL,
                confidence=0.95,
                category="cauda_equina"
            ))

        # === SENSORY DEFICIT DETECTION ===
        sensory_deficit_match = patterns.sensory_deficit_pattern.search(blurb_lower)
        if sensory_deficit_match and not is_negated(sensory_deficit_match.group(0), blurb_lower):
            # Check if it's bilateral
            bilateral_match = patterns.bilateral_sensory_pattern.search(blurb_lower)
            if bilateral_match:
                findings.warning_flags.append(ClinicalFinding(
                    message="Bilateral sensory deficits—possible cord/cauda equina involvement",
                    severity=Severity.WARNING,
                    confidence=0.85,
                    category="sensory"
                ))
            else:
                findings.neuro_flags.append(ClinicalFinding(
                    message="Sensory deficit noted",
                    severity=Severity.NEURO,
                    confidence=0.8,
                    category="sensory"
                ))

        # === REFLEX ABNORMALITIES ===
        absent_reflexes_match = patterns.absent_reflexes_pattern.search(blurb_lower)
        if absent_reflexes_match and not is_negated(absent_reflexes_match.group(0), blurb_lower):
            findings.warning_flags.append(ClinicalFinding(
                message="Absent reflexes—possible lower motor neuron involvement",
                severity=Severity.WARNING,
                confidence=0.9,
                category="reflexes"
            ))

        # === ENHANCED PHYSICAL EXAM ANALYSIS ===
        exam_findings = extract_enhanced_exam_components(expanded_blurb, blurb_lower)
        for finding in exam_findings:
            if finding.severity == Severity.CRITICAL:
                findings.critical_flags.append(finding)
            elif finding.severity == Severity.WARNING:
                findings.warning_flags.append(finding)
            elif finding.severity == Severity.MOTOR:
                findings.motor_flags.append(finding)
            elif finding.severity == Severity.NEURO:
                findings.neuro_flags.append(finding)
            elif finding.severity == Severity.MONITORING:
                findings.monitoring_flags.append(finding)
            else:
                findings.normal_findings.append(finding)

        # === SPINAL TUBERCULOSIS CONTEXT ===
        if re.search(r'potts\s+disease|spinal\s+tb|tuberculosis.*spine', blurb_lower):
            findings.monitoring_flags.append(ClinicalFinding(
                message="Known spinal TB/Potts disease—monitor for neurological progression",
                severity=Severity.MONITORING,
                confidence=0.9,
                category="infection"
            ))

        # === CALCULATE OVERALL CONFIDENCE ===
        all_findings = findings.get_all_findings()
        if all_findings:
            findings.overall_confidence = sum(f.confidence for f in all_findings) / len(all_findings)
        else:
            findings.overall_confidence = 0.0

        # === DEDUPLICATION ===
        def deduplicate_findings(finding_list):
            """Remove duplicate findings based on clinical concepts"""
            unique_findings = []
            seen_concepts = set()

            for finding in finding_list:
                msg_lower = finding.message.lower()

                # Define concept keys for similar findings
                concept_key = None
                if 'babinski' in msg_lower:
                    concept_key = 'babinski'
                elif 'clonus' in msg_lower:
                    concept_key = 'clonus'
                elif 'sensory level' in msg_lower:
                    concept_key = 'sensory_level'
                elif 'saddle anesthesia' in msg_lower:
                    concept_key = 'saddle_anesthesia'
                elif 'hoffman' in msg_lower:
                    concept_key = 'hoffman'
                elif 'hyperreflexia' in msg_lower:
                    concept_key = 'hyperreflexia'
                elif 'weakness' in msg_lower and 'right upper extremity' in msg_lower:
                    concept_key = 'rue_weakness'
                elif 'weakness' in msg_lower and 'left upper extremity' in msg_lower:
                    concept_key = 'lue_weakness'
                elif 'weakness' in msg_lower and 'right lower extremity' in msg_lower:
                    concept_key = 'rle_weakness'
                elif 'weakness' in msg_lower and 'left lower extremity' in msg_lower:
                    concept_key = 'lle_weakness'
                else:
                    # For unique findings, use the full message
                    concept_key = msg_lower.strip()

                # Only add if we haven't seen this concept, or if this one is better
                if concept_key not in seen_concepts:
                    unique_findings.append(finding)
                    seen_concepts.add(concept_key)
                else:
                    # If duplicate, keep the one with higher confidence or more specific message
                    existing_idx = next(i for i, f in enumerate(unique_findings)
                                      if concept_key == 'babinski' and 'babinski' in f.message.lower() or
                                         concept_key == 'clonus' and 'clonus' in f.message.lower() or
                                         concept_key == 'hyperreflexia' and 'hyperreflexia' in f.message.lower() or
                                         concept_key in f.message.lower())

                    if 0 <= existing_idx < len(unique_findings):
                        existing_finding = unique_findings[existing_idx]
                        # Keep the more specific (longer) or higher confidence finding
                        if (finding.confidence > existing_finding.confidence or
                            len(finding.message) > len(existing_finding.message)):
                            unique_findings[existing_idx] = finding

            return unique_findings

        # Apply deduplication to all finding lists
        findings.critical_flags = deduplicate_findings(findings.critical_flags)
        findings.warning_flags = deduplicate_findings(findings.warning_flags)
        findings.motor_flags = deduplicate_findings(findings.motor_flags)
        findings.neuro_flags = deduplicate_findings(findings.neuro_flags)
        findings.monitoring_flags = deduplicate_findings(findings.monitoring_flags)
        findings.normal_findings = deduplicate_findings(findings.normal_findings)

        # Recalculate confidence after deduplication
        all_findings = findings.get_all_findings()
        if all_findings:
            findings.overall_confidence = sum(f.confidence for f in all_findings) / len(all_findings)
        else:
            findings.overall_confidence = 0.0

        # Calculate processing time
        findings.processing_time_ms = int((time.time() - start_time) * 1000)

        # Log processing
        clinical_logger.log_clinical_processing(blurb, findings, expanded_terms)

        return findings

    except Exception as e:
        error_msg = f"Error processing consult text: {str(e)}"
        clinical_logger.log_error(
            component="risk_flag_engine",
            error_code="PROCESSING_ERROR",
            error_message=error_msg,
            severity="high",
            stack_trace=traceback.format_exc()
        )

        return ClinicalFindings(
            critical_flags=[ClinicalFinding(
                message=f"Error in processing: {error_msg}",
                severity=Severity.CRITICAL,
                confidence=0.0,
                category="system"
            )]
        )

# === OpenAI Integration ===
def initialize_openai_client():
    """Initialize OpenAI client with error handling"""
    try:
        return OpenAI(api_key=config.openai_api_key)
    except Exception as e:
        clinical_logger.log_error(
            component="openai_integration",
            error_code="CLIENT_INIT_ERROR",
            error_message=str(e),
            severity="medium"
        )
        return None

# === GPT Summary Integration ===
def gpt_summarize(consult: str) -> str:
    """Enhanced GPT summary with error handling"""
    if not config.enable_gpt_summary:
        return "GPT summary disabled in configuration"

    client = initialize_openai_client()
    if not client:
        return "GPT client initialization failed"

    try:
        prompt = f"""
You are a clinical AI engine trained to summarize neurological consults.

Summarize the following consult note and its flagged neurological findings:

Consult:
{consult}

Output a clear, 2–3 sentence summary with emphasis on the neuro exam and clinical concern.
Focus on critical findings, motor/sensory deficits, and risk factors.
"""
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "You are a clinical summarizer for neurosurgical triage."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=200,
            temperature=0.3
        )
        return response.choices[0].message.content.strip()

    except Exception as e:
        clinical_logger.log_error(
            component="gpt_summary",
            error_code="SUMMARY_ERROR",
            error_message=str(e),
            severity="low"
        )
        return f"GPT summary failed: {str(e)}"

# === Enhanced CLI Interface ===
def format_findings_output(findings: ClinicalFindings) -> str:
    """Format clinical findings for display"""
    output_lines = []

    # Add critical findings first
    if findings.critical_flags:
        output_lines.append("🚨 CRITICAL FINDINGS:")
        for finding in findings.critical_flags:
            output_lines.append(f"  {finding.severity.value} {finding.message} (confidence: {finding.confidence:.2f})")
        output_lines.append("")

    # Warning findings
    if findings.warning_flags:
        output_lines.append("⚠️ WARNING FINDINGS:")
        for finding in findings.warning_flags:
            output_lines.append(f"  {finding.severity.value} {finding.message} (confidence: {finding.confidence:.2f})")
        output_lines.append("")

    # Motor findings
    if findings.motor_flags:
        output_lines.append("🦴 MOTOR FINDINGS:")
        for finding in findings.motor_flags:
            output_lines.append(f"  {finding.severity.value} {finding.message} (confidence: {finding.confidence:.2f})")
        output_lines.append("")

    # Neuro findings
    if findings.neuro_flags:
        output_lines.append("🧠 NEUROLOGICAL FINDINGS:")
        for finding in findings.neuro_flags:
            output_lines.append(f"  {finding.severity.value} {finding.message} (confidence: {finding.confidence:.2f})")
        output_lines.append("")

    # Monitoring findings
    if findings.monitoring_flags:
        output_lines.append("🔍 MONITORING:")
        for finding in findings.monitoring_flags:
            output_lines.append(f"  {finding.severity.value} {finding.message} (confidence: {finding.confidence:.2f})")
        output_lines.append("")

    # Normal findings
    if findings.normal_findings:
        output_lines.append("✅ NORMAL FINDINGS:")
        for finding in findings.normal_findings:
            output_lines.append(f"  {finding.severity.value} {finding.message} (confidence: {finding.confidence:.2f})")
        output_lines.append("")

    # Summary statistics
    total_findings = len(findings.get_all_findings())
    output_lines.append(f"📊 SUMMARY: {total_findings} findings, overall confidence: {findings.overall_confidence:.2f}, processing time: {findings.processing_time_ms}ms")

    return "\n".join(output_lines)

print("✅ Synapse v2.0 core engine loaded successfully!")

In [None]:
def is_baseline_deficit_mentioned(blurb_lower: str) -> bool:
    """Check if deficits are mentioned as baseline/chronic"""
    baseline_patterns = [
        r'baseline', r'chronic', r'known', r'prior', r'old',
        r'pre-existing', r'pre existing', r'longstanding', r'history of'
    ]
    for pattern in baseline_patterns:
        if re.search(pattern, blurb_lower):
            return True
    return False

class EnhancedClinicalPatterns:
    """Additional clinical patterns for v2.1"""

    def __init__(self):
        # Stroke/CVA patterns
        self.stroke_patterns = {
            'facial_droop': re.compile(r'facial\s+(?:droop|weakness|asymmetry)|(?:left|right|lt|rt)\s+facial', re.IGNORECASE),
            'dysarthria': re.compile(r'dysarth|slurred\s+speech|speech\s+difficulty', re.IGNORECASE),
            'aphasia': re.compile(r'aphas|word.?finding|speech\s+(?:difficulty|impairment)', re.IGNORECASE),
            'neglect': re.compile(r'(?:left|right)\s+(?:neglect|inattention)|hemineglect', re.IGNORECASE),
            'visual_field': re.compile(r'(?:visual\s+field|VF)\s+(?:cut|defect)|hemianop', re.IGNORECASE),
            'nihss': re.compile(r'nihss\s*(?:of|is|score)?\s*(\d+)', re.IGNORECASE)
        }

        # Spinal instability patterns
        self.spine_instability = {
            'mechanical_pain': re.compile(r'(?:mechanical|positional)\s+(?:pain|symptoms)|worse\s+with\s+(?:movement|position)', re.IGNORECASE),
            'step_off': re.compile(r'step.?off|palpable\s+(?:deformity|instability)', re.IGNORECASE),
            'midline_tenderness': re.compile(r'midline\s+(?:tender|ttp)|tender.*midline', re.IGNORECASE)
        }

        # Myelopathy patterns
        self.myelopathy = {
            'gait_disturbance': re.compile(r'(?:wide.?based|ataxic|unsteady|spastic)\s+gait|gait\s+(?:instability|difficulty)', re.IGNORECASE),
            'hand_clumsiness': re.compile(r'(?:hand|finger)\s+(?:clumsiness|dexterity)|(?:fine\s+motor|button)', re.IGNORECASE),
            'lhermitte': re.compile(r'lhermitte|electric\s+shock.*(?:neck|spine)', re.IGNORECASE)
        }

        # Peripheral nerve patterns
        self.peripheral_nerve = {
            'radiculopathy': re.compile(r'radiculopathy|(?:c[1-8]|t[1-12]|l[1-5]|s[1-5])\s+(?:distribution|dermatomal)', re.IGNORECASE),
            'foot_drop': re.compile(r'foot\s+drop|(?:ankle\s+)?dorsiflexion\s+(?:weakness|0\/5|1\/5|2\/5)', re.IGNORECASE),
            'wrist_drop': re.compile(r'wrist\s+drop|wrist\s+extension\s+(?:weakness|0\/5|1\/5|2\/5)', re.IGNORECASE)
        }

        # Critical timeline patterns
        self.timeline_critical = {
            'acute_onset': re.compile(r'(?:sudden|acute|abrupt)\s+onset|(?:symptoms?\s+)?(?:started|began)\s+(?:today|hours?\s+ago)', re.IGNORECASE),
            'progressive': re.compile(r'(?:rapid|quickly|rapidly)?\s*progress|worsen|deteriorat', re.IGNORECASE),
            'time_critical': re.compile(r'(\d+)\s*(?:hours?|hrs?|minutes?|mins?)\s+ago', re.IGNORECASE)
        }

# Add these assessment functions to enhance analysis

def assess_stroke_risk(blurb_lower: str, patterns: EnhancedClinicalPatterns) -> List[ClinicalFinding]:
    """Enhanced stroke assessment with FAST criteria"""
    findings = []
    stroke_score = 0

    # Check NIHSS score
    nihss_match = patterns.stroke_patterns['nihss'].search(blurb_lower)
    if nihss_match:
        score = int(nihss_match.group(1))
        if score >= 15:
            findings.append(ClinicalFinding(
                message=f"High NIHSS score ({score})—severe stroke",
                severity=Severity.CRITICAL,
                confidence=0.95,
                category="stroke"
            ))
        elif score >= 5:
            findings.append(ClinicalFinding(
                message=f"Moderate NIHSS score ({score})—significant stroke",
                severity=Severity.WARNING,
                confidence=0.9,
                category="stroke"
            ))

    # FAST criteria
    fast_components = []
    if patterns.stroke_patterns['facial_droop'].search(blurb_lower):
        fast_components.append("Facial droop")
        stroke_score += 1

    if patterns.stroke_patterns['dysarthria'].search(blurb_lower) or \
       patterns.stroke_patterns['aphasia'].search(blurb_lower):
        fast_components.append("Speech difficulty")
        stroke_score += 1

    # Check for arm weakness (already in motor patterns)
    if re.search(r'(?:rue|lue|arm)\s+(?:weakness|drift|0\/5|1\/5|2\/5)', blurb_lower):
        fast_components.append("Arm weakness")
        stroke_score += 1

    # Time-sensitive check
    time_match = patterns.timeline_critical['time_critical'].search(blurb_lower)
    if time_match:
        time_value = int(time_match.group(1))
        time_unit = time_match.group(2).lower()
        if 'hour' in time_unit and time_value <= 4:
            fast_components.append(f"Within thrombolysis window ({time_value} hours)")
            stroke_score += 1

    if stroke_score >= 2:
        findings.append(ClinicalFinding(
            message=f"Stroke alert: {', '.join(fast_components)}",
            severity=Severity.CRITICAL,
            confidence=0.9,
            category="stroke"
        ))

    return findings

def assess_spine_stability(blurb_lower: str, patterns: EnhancedClinicalPatterns) -> List[ClinicalFinding]:
    """Assess for spinal instability indicators"""
    findings = []
    instability_score = 0

    indicators = []

    if patterns.spine_instability['mechanical_pain'].search(blurb_lower):
        indicators.append("mechanical pain")
        instability_score += 1

    if patterns.spine_instability['step_off'].search(blurb_lower):
        indicators.append("palpable step-off")
        instability_score += 2  # Higher weight

    if patterns.spine_instability['midline_tenderness'].search(blurb_lower):
        indicators.append("midline tenderness")
        instability_score += 1

    # Check for fracture mentions
    if re.search(r'(?:unstable|burst|chance|compression)\s+(?:fracture|fx)', blurb_lower):
        indicators.append("unstable fracture pattern")
        instability_score += 2

    if instability_score >= 3:
        findings.append(ClinicalFinding(
            message=f"Spinal instability indicators: {', '.join(indicators)}",
            severity=Severity.CRITICAL,
            confidence=0.85,
            category="spine_stability"
        ))
    elif instability_score >= 1:
        findings.append(ClinicalFinding(
            message=f"Possible spinal instability: {', '.join(indicators)}",
            severity=Severity.WARNING,
            confidence=0.75,
            category="spine_stability"
        ))

    return findings

def calculate_myelopathy_score(blurb_lower: str, patterns: EnhancedClinicalPatterns) -> Tuple[int, List[str]]:
    """Calculate modified JOA myelopathy score components"""
    score = 0
    components = []

    # Gait disturbance
    if patterns.myelopathy['gait_disturbance'].search(blurb_lower):
        components.append("gait disturbance")
        score += 2

    # Hand dysfunction
    if patterns.myelopathy['hand_clumsiness'].search(blurb_lower):
        components.append("hand clumsiness")
        score += 2

    # Hyperreflexia (already detected)
    if re.search(r'hyperreflex|(?:3|4)\+\s*(?:dtr|reflex)', blurb_lower):
        components.append("hyperreflexia")
        score += 1

    # Pathological reflexes
    if re.search(r'babinski|hoffman|clonus', blurb_lower):
        components.append("pathological reflexes")
        score += 2

    # Sensory changes
    if re.search(r'(?:numbness|paresthesia|sensory).*(?:hands?|fingers?)', blurb_lower):
        components.append("sensory changes in hands")
        score += 1

    return score, components

original_risk_flag = risk_flag

def risk_flag_enhanced(blurb: str, prior_findings: Optional[ClinicalFindings] = None) -> ClinicalFindings:
    """Enhanced risk_flag that includes new v2.1 patterns"""

    # First, run your original risk_flag to get base findings
    findings = original_risk_flag(blurb, prior_findings)

    # Then add the enhanced assessments
    blurb_lower = blurb.lower()

    # Add stroke assessment
    stroke_findings = assess_stroke_risk(blurb_lower, enhanced_patterns)
    for finding in stroke_findings:
        if finding.severity == Severity.CRITICAL:
            findings.critical_flags.append(finding)
        elif finding.severity == Severity.WARNING:
            findings.warning_flags.append(finding)

    # Add spine stability assessment
    spine_findings = assess_spine_stability(blurb_lower, enhanced_patterns)
    for finding in spine_findings:
        if finding.severity == Severity.CRITICAL:
            findings.critical_flags.append(finding)
        elif finding.severity == Severity.WARNING:
            findings.warning_flags.append(finding)

    # Add myelopathy assessment
    myelo_score, myelo_components = calculate_myelopathy_score(blurb_lower, enhanced_patterns)
    if myelo_score >= 4:
        findings.warning_flags.append(ClinicalFinding(
            message=f"Myelopathy indicators (score {myelo_score}): {', '.join(myelo_components)}",
            severity=Severity.WARNING,
            confidence=0.85,
            category="myelopathy"
        ))
    elif myelo_score >= 2:
        findings.neuro_flags.append(ClinicalFinding(
            message=f"Possible myelopathy signs: {', '.join(myelo_components)}",
            severity=Severity.NEURO,
            confidence=0.75,
            category="myelopathy"
        ))

    # Check for peripheral nerve issues
    if enhanced_patterns.peripheral_nerve['foot_drop'].search(blurb_lower):
        findings.motor_flags.append(ClinicalFinding(
            message="Foot drop detected—evaluate for L5 radiculopathy or peroneal nerve injury",
            severity=Severity.MOTOR,
            confidence=0.85,
            category="peripheral_nerve"
        ))

    # Recalculate overall confidence after adding new findings
    all_findings = findings.get_all_findings()
    if all_findings:
        findings.overall_confidence = sum(f.confidence for f in all_findings) / len(all_findings)

    return findings

# Override the original function with enhanced version
risk_flag = risk_flag_enhanced

print("✅ Enhanced clinical patterns loaded successfully!")
print("🔍 New detection capabilities added:")
print("   • Stroke/CVA assessment (FAST criteria, NIHSS)")
print("   • Spinal instability indicators")
print("   • Myelopathy scoring")
print("   • Peripheral nerve patterns")
print("   • Timeline urgency detection")
print("\n📊 Ready for Cell 4: Clinical Scoring System")

In [None]:
# CELL 3.5: Fix Enhanced Patterns Issue
# Run this AFTER Cell 3 to ensure enhanced_patterns is properly created

# Check if EnhancedClinicalPatterns class exists
if 'EnhancedClinicalPatterns' in globals():
    print("✅ EnhancedClinicalPatterns class found")

    # Create the instance if it doesn't exist
    if 'enhanced_patterns' not in globals():
        enhanced_patterns = EnhancedClinicalPatterns()
        print("✅ Created enhanced_patterns instance")
    else:
        print("✅ enhanced_patterns already exists")
else:
    print("❌ EnhancedClinicalPatterns class not found - please re-run Cell 3")

# Test the patterns
print("\n🧪 Testing pattern detection...")
test_text = "patient with nihss 12 and facial droop"
if enhanced_patterns.stroke_patterns['nihss'].search(test_text):
    print("✅ NIHSS pattern detection working")
if enhanced_patterns.stroke_patterns['facial_droop'].search(test_text):
    print("✅ Facial droop pattern detection working")

# Now test the full analysis again
print("\n🧪 Testing full analysis with enhanced patterns...")
try:
    test_finding = risk_flag("Patient with LUE 2/5 weakness, facial droop, NIHSS 12")
    print(f"✅ Analysis successful! Found {len(test_finding.get_all_findings())} findings")

    # Check if stroke findings were added
    stroke_findings = [f for f in test_finding.get_all_findings() if 'stroke' in f.category]
    if stroke_findings:
        print(f"✅ Enhanced stroke detection working - found {len(stroke_findings)} stroke-related findings")

except Exception as e:
    print(f"❌ Error: {str(e)}")
    import traceback
    traceback.print_exc()

print("\n✨ Enhanced patterns should now be working! Try the dashboard again.")

In [None]:
# CELL 4: Clinical Scoring System for Synapse v2.1
# This adds standardized clinical scoring to quantify severity
# Run this AFTER Cell 3

from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np

@dataclass
class ClinicalScore:
    """Structured clinical scoring output"""
    score_name: str
    value: float
    max_value: float
    severity: str  # mild, moderate, severe, critical
    components: Dict[str, float]
    interpretation: str
    recommendations: List[str]

class ClinicalScoringEngine:
    """Calculate standardized clinical scores from findings"""

    def __init__(self):
        self.scores = {}

    def calculate_all_scores(self, findings: ClinicalFindings, blurb: str) -> Dict[str, ClinicalScore]:
        """Calculate all relevant clinical scores"""
        scores = {}

        # ASIA Impairment Scale approximation
        scores['asia'] = self.calculate_asia_score(findings, blurb)

        # Modified Rankin Scale for stroke
        scores['rankin'] = self.calculate_rankin_score(findings, blurb)

        # Nurick Grade for myelopathy
        scores['nurick'] = self.calculate_nurick_score(findings, blurb)

        # Custom Synapse Risk Score
        scores['synapse_risk'] = self.calculate_synapse_risk_score(findings)

        # Cauda Equina Risk Score
        scores['cauda_equina'] = self.calculate_cauda_equina_score(findings, blurb)

        return {k: v for k, v in scores.items() if v is not None}

    def calculate_asia_score(self, findings: ClinicalFindings, blurb: str) -> Optional[ClinicalScore]:
        """Approximate ASIA Impairment Scale from findings"""
        blurb_lower = blurb.lower()

        # Check for complete vs incomplete injury
        motor_findings = [f for f in findings.motor_flags + findings.warning_flags
                         if 'weakness' in f.message.lower()]
        sensory_findings = [f for f in findings.neuro_flags + findings.warning_flags
                           if 'sensory' in f.message.lower()]

        components = {
            'motor_preserved': 0,
            'sensory_preserved': 0,
            'sacral_sparing': 0
        }

        # Check for motor preservation
        if motor_findings:
            # Look for any motor function below injury
            if re.search(r'(?:some|partial|incomplete)\s+(?:motor|movement)', blurb_lower):
                components['motor_preserved'] = 1
            elif any('0/5' in f.message or '1/5' in f.message for f in motor_findings):
                components['motor_preserved'] = 0
            else:
                components['motor_preserved'] = 0.5

        # Check for sensory preservation
        if sensory_findings:
            if re.search(r'(?:intact|preserved|some)\s+sensation', blurb_lower):
                components['sensory_preserved'] = 1
            elif re.search(r'absent\s+sensation|complete\s+sensory\s+loss', blurb_lower):
                components['sensory_preserved'] = 0
            else:
                components['sensory_preserved'] = 0.5

        # Check for sacral sparing
        if re.search(r'(?:sacral\s+sparing|perianal\s+sensation|rectal\s+tone)', blurb_lower):
            components['sacral_sparing'] = 1

        # Calculate ASIA grade
        total_score = sum(components.values())

        if not (motor_findings or sensory_findings):
            return None

        if total_score == 0:
            grade = "A"
            severity = "critical"
            interpretation = "Complete injury - no motor or sensory function preserved"
        elif components['sensory_preserved'] > 0 and components['motor_preserved'] == 0:
            grade = "B"
            severity = "severe"
            interpretation = "Incomplete - sensory but no motor function preserved"
        elif components['motor_preserved'] > 0 and total_score < 2:
            grade = "C"
            severity = "moderate"
            interpretation = "Incomplete - motor function preserved, less than half key muscles"
        elif total_score >= 2:
            grade = "D"
            severity = "mild"
            interpretation = "Incomplete - motor function preserved, at least half key muscles"
        else:
            grade = "E"
            severity = "normal"
            interpretation = "Normal motor and sensory function"

        recommendations = []
        if grade in ["A", "B"]:
            recommendations.append("Urgent spinal stabilization assessment")
            recommendations.append("High-dose steroids within 8-hour window if acute")
            recommendations.append("Immediate MRI to assess cord compression")
        elif grade == "C":
            recommendations.append("Expedited MRI and neurosurgical consultation")
            recommendations.append("Monitor for neurological progression")

        return ClinicalScore(
            score_name="ASIA Impairment Scale",
            value=ord('E') - ord(grade),  # Convert to numeric (0-4)
            max_value=4,
            severity=severity,
            components=components,
            interpretation=f"Grade {grade}: {interpretation}",
            recommendations=recommendations
        )

    def calculate_synapse_risk_score(self, findings: ClinicalFindings) -> ClinicalScore:
        """Custom Synapse composite risk score"""

        components = {
            'critical_findings': len(findings.critical_flags) * 3,
            'warning_findings': len(findings.warning_flags) * 2,
            'motor_findings': len(findings.motor_flags) * 1,
            'neuro_findings': len(findings.neuro_flags) * 1,
            'confidence_factor': findings.overall_confidence
        }

        # Calculate weighted score
        base_score = (components['critical_findings'] +
                     components['warning_findings'] +
                     components['motor_findings'] +
                     components['neuro_findings'])

        # Apply confidence factor
        total_score = base_score * components['confidence_factor']

        # Determine severity
        if total_score >= 10:
            severity = "critical"
            interpretation = "Very high risk - immediate intervention required"
            recommendations = [
                "Immediate clinical evaluation",
                "Emergent imaging studies",
                "Activate stroke/spine team if applicable",
                "Continuous neurological monitoring"
            ]
        elif total_score >= 6:
            severity = "severe"
            interpretation = "High risk - urgent evaluation needed"
            recommendations = [
                "Urgent clinical assessment within 1 hour",
                "Expedited imaging studies",
                "Neurology/neurosurgery consultation"
            ]
        elif total_score >= 3:
            severity = "moderate"
            interpretation = "Moderate risk - expedited evaluation recommended"
            recommendations = [
                "Clinical evaluation within 4-6 hours",
                "Consider imaging based on clinical exam",
                "Serial neurological examinations"
            ]
        else:
            severity = "mild"
            interpretation = "Low risk - routine evaluation appropriate"
            recommendations = [
                "Routine clinical follow-up",
                "Document baseline neurological status"
            ]

        return ClinicalScore(
            score_name="Synapse Risk Score",
            value=round(total_score, 2),
            max_value=20.0,  # Theoretical maximum
            severity=severity,
            components=components,
            interpretation=interpretation,
            recommendations=recommendations
        )

    def calculate_cauda_equina_score(self, findings: ClinicalFindings, blurb: str) -> Optional[ClinicalScore]:
        """Calculate cauda equina syndrome risk score"""
        blurb_lower = blurb.lower()

        components = {
            'bilateral_weakness': 0,
            'saddle_anesthesia': 0,
            'bowel_bladder': 0,
            'sexual_dysfunction': 0,
            'progressive_symptoms': 0
        }

        # Check each component
        if any('saddle' in f.message.lower() for f in findings.critical_flags + findings.warning_flags):
            components['saddle_anesthesia'] = 1

        if re.search(r'(?:bowel|bladder)\s+(?:dysfunction|incontinence|retention)', blurb_lower):
            components['bowel_bladder'] = 1

        if re.search(r'(?:bilateral|ble)\s+(?:weakness|motor)', blurb_lower):
            components['bilateral_weakness'] = 1

        if re.search(r'(?:sexual|erectile)\s+dysfunction', blurb_lower):
            components['sexual_dysfunction'] = 1

        if re.search(r'(?:progressive|worsening|deteriorating)', blurb_lower):
            components['progressive_symptoms'] = 1

        total_score = sum(components.values())

        if total_score == 0:
            return None

        # Determine severity
        if total_score >= 3:
            severity = "critical"
            interpretation = "High suspicion for cauda equina syndrome"
            recommendations = [
                "EMERGENT MRI lumbar spine",
                "Immediate neurosurgical consultation",
                "Post-void residual measurement",
                "Consider emergent decompression"
            ]
        elif total_score == 2:
            severity = "severe"
            interpretation = "Concerning for incomplete cauda equina"
            recommendations = [
                "Urgent MRI within 4 hours",
                "Neurosurgical consultation",
                "Serial neurological exams",
                "Monitor bowel/bladder function"
            ]
        else:
            severity = "moderate"
            interpretation = "Some cauda equina features present"
            recommendations = [
                "MRI lumbar spine within 24 hours",
                "Document bowel/bladder function",
                "Consider neurosurgical referral"
            ]

        return ClinicalScore(
            score_name="Cauda Equina Risk Score",
            value=total_score,
            max_value=5,
            severity=severity,
            components=components,
            interpretation=interpretation,
            recommendations=recommendations
        )

    def calculate_nurick_score(self, findings: ClinicalFindings, blurb: str) -> Optional[ClinicalScore]:
        """Calculate Nurick grade for cervical myelopathy"""
        blurb_lower = blurb.lower()

        # Look for gait-related findings
        gait_patterns = [
            (r'normal\s+gait|ambulates?\s+(?:well|independently)', 0, "Grade 0: No gait impairment"),
            (r'(?:mild|slight)\s+(?:gait|walking)\s+(?:difficulty|impairment)', 1, "Grade 1: Mild gait disturbance"),
            (r'(?:moderate)\s+(?:gait|walking)\s+(?:difficulty|impairment)', 2, "Grade 2: Moderate gait impairment"),
            (r'(?:requires?\s+assistance|walker|cane)\s+(?:to\s+)?(?:walk|ambulate)', 3, "Grade 3: Requires assistance"),
            (r'(?:severe|significant)\s+(?:gait|walking)\s+(?:difficulty|impairment)', 4, "Grade 4: Severe impairment"),
            (r'(?:wheelchair|non.?ambulatory|unable\s+to\s+walk)', 5, "Grade 5: Wheelchair bound")
        ]

        grade = None
        interpretation = ""

        for pattern, score, desc in gait_patterns:
            if re.search(pattern, blurb_lower):
                grade = score
                interpretation = desc
                break

        # Check for myelopathy features
        myelopathy_features = 0
        if any('hyperreflexia' in f.message.lower() for f in findings.warning_flags):
            myelopathy_features += 1
        if re.search(r'(?:hand|finger)\s+(?:clumsiness|dexterity)', blurb_lower):
            myelopathy_features += 1
        if re.search(r'(?:wide.?based|spastic|ataxic)\s+gait', blurb_lower):
            myelopathy_features += 1
            if grade is None:
                grade = 2  # Default to moderate if gait abnormality mentioned

        if grade is None and myelopathy_features == 0:
            return None

        if grade is None:
            grade = 1  # Default to mild if features present but no specific gait description

        # Determine severity
        severity_map = {
            0: "normal",
            1: "mild",
            2: "mild",
            3: "moderate",
            4: "severe",
            5: "critical"
        }

        recommendations = []
        if grade >= 3:
            recommendations.extend([
                "Urgent MRI cervical spine",
                "Neurosurgical evaluation for decompression",
                "Fall risk assessment and precautions"
            ])
        elif grade >= 1:
            recommendations.extend([
                "MRI cervical spine",
                "Consider neurosurgical referral",
                "Physical therapy evaluation"
            ])

        return ClinicalScore(
            score_name="Nurick Grade",
            value=grade,
            max_value=5,
            severity=severity_map.get(grade, "unknown"),
            components={
                'gait_impairment': grade,
                'myelopathy_features': myelopathy_features
            },
            interpretation=interpretation or f"Grade {grade}: Myelopathy severity",
            recommendations=recommendations
        )

    def calculate_rankin_score(self, findings: ClinicalFindings, blurb: str) -> Optional[ClinicalScore]:
        """Estimate Modified Rankin Scale for stroke/disability"""
        blurb_lower = blurb.lower()

        # Look for functional status indicators
        if not any('stroke' in f.category or 'weakness' in f.message.lower()
                  for f in findings.get_all_findings()):
            return None

        # Estimate based on functional descriptions
        score = 0
        interpretation = ""

        if re.search(r'(?:independent|no\s+assistance|fully\s+functional)', blurb_lower):
            score = 0
            interpretation = "No symptoms"
        elif re.search(r'(?:minimal|mild)\s+(?:symptoms|deficit)', blurb_lower):
            score = 1
            interpretation = "No significant disability"
        elif re.search(r'(?:some\s+assistance|moderate\s+disability)', blurb_lower):
            score = 2
            interpretation = "Slight disability"
        elif re.search(r'(?:requires?\s+assistance|walker|cane)', blurb_lower):
            score = 3
            interpretation = "Moderate disability"
        elif re.search(r'(?:wheelchair|bedbound|dependent)', blurb_lower):
            score = 4
            interpretation = "Moderately severe disability"
        elif re.search(r'(?:total\s+care|complete\s+dependence)', blurb_lower):
            score = 5
            interpretation = "Severe disability"

        severity_map = {0: "normal", 1: "mild", 2: "mild",
                       3: "moderate", 4: "severe", 5: "critical"}

        return ClinicalScore(
            score_name="Modified Rankin Scale",
            value=score,
            max_value=6,
            severity=severity_map.get(score, "unknown"),
            components={'functional_status': score},
            interpretation=f"mRS {score}: {interpretation}",
            recommendations=["Monitor functional status", "Consider rehabilitation referral"]
        )

# Create global scoring engine instance
scoring_engine = ClinicalScoringEngine()

# Function to integrate scoring with findings
def add_clinical_scores(findings: ClinicalFindings, blurb: str) -> Dict[str, ClinicalScore]:
    """Add clinical scoring to findings"""
    return scoring_engine.calculate_all_scores(findings, blurb)

print("✅ Clinical Scoring System loaded successfully!")
print("📊 Available scores:")
print("   • ASIA Impairment Scale (spinal cord injury)")
print("   • Modified Rankin Scale (stroke/disability)")
print("   • Nurick Grade (cervical myelopathy)")
print("   • Synapse Risk Score (composite severity)")
print("   • Cauda Equina Risk Score")
print("\n🎯 Ready for Cell 5: Interactive Dashboard")

In [None]:
# CELL 5 FIXED: Complete Working Dashboard
# Replace your current Cell 5 with this version
# This is self-contained and will work properly

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
from datetime import datetime
import pandas as pd
import json
import numpy as np

# First, fix the stroke assessment function that's causing issues
def assess_stroke_risk_fixed(blurb_lower: str, patterns) -> List[ClinicalFinding]:
    """Fixed stroke assessment"""
    findings = []
    stroke_score = 0

    # Check NIHSS score
    nihss_match = patterns.stroke_patterns['nihss'].search(blurb_lower)
    if nihss_match:
        score = int(nihss_match.group(1))
        if score >= 15:
            findings.append(ClinicalFinding(
                message=f"High NIHSS score ({score})—severe stroke",
                severity=Severity.CRITICAL,
                confidence=0.95,
                category="stroke"
            ))
        elif score >= 5:
            findings.append(ClinicalFinding(
                message=f"Moderate NIHSS score ({score})—significant stroke",
                severity=Severity.WARNING,
                confidence=0.9,
                category="stroke"
            ))

    # FAST criteria
    fast_components = []
    if patterns.stroke_patterns['facial_droop'].search(blurb_lower):
        fast_components.append("Facial droop")
        stroke_score += 1

    if patterns.stroke_patterns['dysarthria'].search(blurb_lower) or \
       patterns.stroke_patterns['aphasia'].search(blurb_lower):
        fast_components.append("Speech difficulty")
        stroke_score += 1

    # Check for arm weakness
    if re.search(r'(?:rue|lue|arm)\s+(?:weakness|drift|0\/5|1\/5|2\/5)', blurb_lower):
        fast_components.append("Arm weakness")
        stroke_score += 1

    # Time-sensitive check - FIXED
    time_match = patterns.timeline_critical['time_critical'].search(blurb_lower)
    if time_match:
        try:
            time_value = int(time_match.group(1))
            full_match = time_match.group(0).lower()

            if 'hour' in full_match and time_value <= 4:
                fast_components.append(f"Within thrombolysis window ({time_value} hours)")
                stroke_score += 1
            elif 'min' in full_match and time_value <= 240:
                fast_components.append(f"Within thrombolysis window ({time_value} minutes)")
                stroke_score += 1
        except:
            pass

    if stroke_score >= 2:
        findings.append(ClinicalFinding(
            message=f"Stroke alert: {', '.join(fast_components)}",
            severity=Severity.CRITICAL,
            confidence=0.9,
            category="stroke"
        ))

    return findings

# Override the problematic function
assess_stroke_risk = assess_stroke_risk_fixed

# Create a simple analysis function that works reliably
def analyze_note_simple(text):
    """Simple reliable analysis function"""
    try:
        findings = risk_flag(text)
        scores = add_clinical_scores(findings, text)
        return findings, scores
    except Exception as e:
        print(f"Error in analysis: {str(e)}")
        return None, None

# Create the dashboard
class SimpleSynapseInterface:
    """Simplified but reliable interface"""

    def __init__(self):
        self.current_findings = None
        self.current_scores = None

    def create_interface(self):
        """Create the interface"""

        # Header
        display(HTML("""
        <div style='text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                    color: white; border-radius: 10px; margin-bottom: 20px;'>
            <h1 style='margin: 0;'>🧠 Synapse v3.0</h1>
            <p style='margin: 10px 0 0 0; font-size: 18px;'>Clinical Risk Analysis Dashboard</p>
        </div>
        """))

        # Text input
        self.text_area = widgets.Textarea(
            value='',
            placeholder='Paste any clinical note here and click Analyze...',
            layout=widgets.Layout(width='100%', height='200px')
        )

        # Analyze button
        self.analyze_btn = widgets.Button(
            description='🔍 Analyze Clinical Note',
            button_style='primary',
            layout=widgets.Layout(width='200px', height='40px')
        )

        # Clear button
        self.clear_btn = widgets.Button(
            description='🗑️ Clear',
            button_style='warning',
            layout=widgets.Layout(width='100px', height='40px')
        )

        # Output area
        self.output = widgets.Output()

        # Button handlers
        def on_analyze(b):
            with self.output:
                clear_output()

                if not self.text_area.value.strip():
                    print("❌ Please paste a clinical note first!")
                    return

                print("🔍 Analyzing...")
                print("=" * 80)

                # Run analysis
                findings, scores = analyze_note_simple(self.text_area.value)

                if findings:
                    # Display findings
                    print(format_findings_output(findings))

                    # Display scores
                    if scores:
                        print("\n\n📊 CLINICAL SCORES")
                        print("=" * 80)
                        for name, score in scores.items():
                            print(f"\n{score.score_name}: {score.value}/{score.max_value}")
                            print(f"Severity: {score.severity.upper()}")
                            print(f"{score.interpretation}")
                            if score.recommendations:
                                print("Recommendations:")
                                for rec in score.recommendations:
                                    print(f"  → {rec}")

                    # Risk summary
                    critical = len(findings.critical_flags)
                    warning = len(findings.warning_flags)

                    print("\n\n⚡ RISK ASSESSMENT")
                    print("=" * 80)
                    if critical > 0:
                        print("🚨 CRITICAL RISK - IMMEDIATE ACTION REQUIRED")
                    elif warning > 0:
                        print("⚠️ HIGH RISK - URGENT EVALUATION NEEDED")
                    else:
                        print("📋 MODERATE/LOW RISK - ROUTINE EVALUATION")

                    # GPT Summary
                    if config.enable_gpt_summary:
                        print("\n\n📋 GPT SUMMARY")
                        print("=" * 80)
                        try:
                            summary = gpt_summarize(self.text_area.value)
                            print(summary)
                        except:
                            print("GPT summary unavailable")

        def on_clear(b):
            self.text_area.value = ''
            with self.output:
                clear_output()

        self.analyze_btn.on_click(on_analyze)
        self.clear_btn.on_click(on_clear)

        # Display interface
        display(self.text_area)
        display(widgets.HBox([self.analyze_btn, self.clear_btn]))
        display(self.output)

        # Instructions
        display(HTML("""
        <div style='margin-top: 20px; padding: 15px; background: #f0f0f0; border-radius: 5px; color: #333;'>
            <h4 style='color: #333;'>📋 Instructions:</h4>
            <ol>
                <li>Paste any clinical/consult note in the text area above</li>
                <li>Click "Analyze Clinical Note" to process</li>
                <li>View findings, clinical scores, and risk assessment below</li>
            </ol>
            <p><b>Synapse detects:</b> Motor weakness, sensory deficits, reflexes, cauda equina signs,
            stroke symptoms, myelopathy, and more!</p>
        </div>
        """))

# Create and display the interface
print("🚀 Creating Synapse Dashboard...")
dashboard = SimpleSynapseInterface()
dashboard.create_interface()
print("\n✅ Dashboard ready! Paste any clinical note and click 'Analyze Clinical Note'")