In [None]:
import os
import re
import time
from typing import List, Tuple, Dict, Any, Optional
from datetime import datetime, timedelta

from sqlalchemy import create_engine, text, inspect

from pymilvus import connections, utility, FieldSchema, CollectionSchema, DataType, Collection
from sentence_transformers import SentenceTransformer

import google.generativeai as genai
from google.api_core.exceptions import TooManyRequests, ServiceUnavailable

try:
    from zoneinfo import ZoneInfo
    _DHAKA_TZ = ZoneInfo("Asia/Dhaka")
except Exception:
    _DHAKA_TZ = None


def build_engine() -> Any:
    real_host = os.getenv("REAL_HOST", "")
    real_port = os.getenv("REAL_PORT", "")
    password = os.getenv("PG_PASSWORD", "")
    db_name = os.getenv("PG_DB", "hospital")
    user = os.getenv("PG_USER", "postgres")

    if not real_host or not real_port or not password:
        raise ValueError("Missing Postgres env vars REAL_HOST, REAL_PORT, PG_PASSWORD")

    database_url = f"postgresql://{user}:{password}@{real_host}:{real_port}/{db_name}"
    return create_engine(
        database_url,
        pool_pre_ping=True,
        connect_args={"connect_timeout": 10},
    )


def build_milvus_collection(collection_name: str, embed_dim: int) -> Collection:
    uri = os.getenv("MILVUS_URI", "")
    token = os.getenv("MILVUS_TOKEN", "")

    if not uri or not token:
        raise ValueError("Missing Milvus env vars MILVUS_URI, MILVUS_TOKEN")

    connections.connect(alias="default", uri=uri, token=token)

    expected_fields = [
        "id", "embedding", "doc_type", "table_name",
        "contains_phi", "hospital_scoped", "schema_version", "text"
    ]

    if utility.has_collection(collection_name):
        col = Collection(collection_name)
        existing_fields = [f.name for f in col.schema.fields]
        if existing_fields != expected_fields:
            utility.drop_collection(collection_name)
        else:
            col.load()
            return col

    fields = [
        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=embed_dim),
        FieldSchema(name="doc_type", dtype=DataType.VARCHAR, max_length=32),
        FieldSchema(name="table_name", dtype=DataType.VARCHAR, max_length=128),
        FieldSchema(name="contains_phi", dtype=DataType.BOOL),
        FieldSchema(name="hospital_scoped", dtype=DataType.BOOL),
        FieldSchema(name="schema_version", dtype=DataType.VARCHAR, max_length=64),
        FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=8192),
    ]
    schema = CollectionSchema(fields, description="Schema docs for SQL generation")
    col = Collection(collection_name, schema=schema)

    index_params = {
        "index_type": "HNSW",
        "metric_type": "COSINE",
        "params": {"M": 16, "efConstruction": 200},
    }
    col.create_index(field_name="embedding", index_params=index_params)
    col.load()
    return col


def build_embedder() -> Any:
    return SentenceTransformer("intfloat/e5-base-v2")


def embed_text(embedder: Any, s: str) -> List[float]:
    vec = embedder.encode([s], normalize_embeddings=True)[0]
    return vec.tolist()


def _common_joins_for_table(table_name: str) -> List[str]:
    join_map = {
        "admissions": [
            "admissions.patient_id -> patients.patient_id",
            "admissions.department_id -> departments.department_id",
            "admissions.ward_id -> wards.ward_id",
            "admissions.room_id -> rooms.room_id",
            "admissions.bed_id -> beds.bed_id",
            "admissions.attending_doctor_id -> staff.staff_id",
        ],
        "discharges": ["discharges.admission_id -> admissions.admission_id"],
        "diagnoses": ["diagnoses.admission_id -> admissions.admission_id"],
        "patient_conditions": ["patient_conditions.admission_id -> admissions.admission_id"],
        "patient_vitals": ["patient_vitals.admission_id -> admissions.admission_id"],
        "prescriptions": [
            "prescriptions.admission_id -> admissions.admission_id",
            "prescriptions.prescribed_by_staff_id -> staff.staff_id",
        ],
        "prescription_items": [
            "prescription_items.prescription_id -> prescriptions.prescription_id",
            "prescription_items.medication_id -> medications.medication_id",
        ],
        "nurse_assignments": [
            "nurse_assignments.admission_id -> admissions.admission_id",
            "nurse_assignments.nurse_staff_id -> staff.staff_id",
        ],
        "births": [
            "births.mother_patient_id -> patients.patient_id",
            "births.admission_id -> admissions.admission_id",
            "births.doctor_staff_id -> staff.staff_id",
        ],
        "newborns": ["newborns.birth_id -> births.birth_id"],
        "abortion_cases": [
            "abortion_cases.patient_id -> patients.patient_id",
            "abortion_cases.admission_id -> admissions.admission_id",
            "abortion_cases.performed_by_staff_id -> staff.staff_id",
        ],
        "donations": [
            "donations.donor_id -> donors.donor_id",
            "donations.recipient_patient_id -> patients.patient_id",
        ],
        "organ_donation_items": [
            "organ_donation_items.donation_id -> donations.donation_id",
            "organ_donation_items.organ_id -> organs.organ_id",
        ],
        "surgeries": ["surgeries.admission_id -> admissions.admission_id"],
        "surgery_team": [
            "surgery_team.surgery_id -> surgeries.surgery_id",
            "surgery_team.staff_id -> staff.staff_id",
        ],
        "blood_bank_inventory": [
            "blood_bank_inventory.blood_group_id -> blood_groups.blood_group_id"
        ],
        "staff_department": [
            "staff_department.staff_id -> staff.staff_id",
            "staff_department.department_id -> departments.department_id",
        ],
        "staff_shifts": [
            "staff_shifts.staff_id -> staff.staff_id",
            "staff_shifts.department_id -> departments.department_id",
        ],
    }
    return join_map.get(table_name, [])


def build_table_doc(inspector, table_name: str) -> str:
    cols = inspector.get_columns(table_name)
    pk = inspector.get_pk_constraint(table_name).get("constrained_columns", [])
    fks = inspector.get_foreign_keys(table_name)

    lines = []
    lines.append("doc_type: table")
    lines.append(f"table: {table_name}")
    lines.append(f"primary_key: {', '.join(pk) if pk else 'none'}")
    lines.append("columns:")
    for c in cols:
        nn = "not null" if not c.get("nullable", True) else "nullable"
        lines.append(f"  {c['name']} {str(c['type'])} {nn}")
    lines.append("foreign_keys:")
    if not fks:
        lines.append("  none")
    else:
        for fk in fks:
            src_cols = ", ".join(fk.get("constrained_columns", []))
            ref_table = fk.get("referred_table")
            ref_cols = ", ".join(fk.get("referred_columns", []))
            lines.append(f"  {table_name}.{src_cols} -> {ref_table}.{ref_cols}")

    joins = _common_joins_for_table(table_name)
    if joins:
        lines.append("common_joins:")
        for j in joins:
            lines.append(f"  {j}")

    return "\n".join(lines)


def table_contains_phi(table_name: str) -> bool:
    return table_name in {
        "patients",
        "admissions",
        "discharges",
        "diagnoses",
        "patient_conditions",
        "patient_vitals",
        "prescriptions",
        "prescription_items",
        "nurse_assignments",
        "births",
        "newborns",
        "abortion_cases",
        "donations",
        "donors",
        "surgeries",
        "surgery_team",
        "organ_donation_items",
    }


def index_schema(engine, schema_col: Collection, embedder, schema_version: str = "v1") -> int:
    inspector = inspect(engine)
    tables = inspector.get_table_names()

    schema_col.delete(expr="id >= 0")
    schema_col.flush()

    embeddings = []
    doc_type = []
    table_name = []
    contains_phi = []
    hospital_scoped = []
    schema_versions = []
    text_docs = []

    for t in tables:
        doc = build_table_doc(inspector, t)
        emb = embed_text(embedder, doc)

        embeddings.append(emb)
        doc_type.append("table")
        table_name.append(t)
        contains_phi.append(bool(table_contains_phi(t)))
        hospital_scoped.append("hospital_id" in [c["name"] for c in inspector.get_columns(t)])
        schema_versions.append(schema_version)
        text_docs.append(doc)

    schema_col.insert([embeddings, doc_type, table_name, contains_phi, hospital_scoped, schema_versions, text_docs])
    schema_col.flush()
    return len(tables)


def retrieve_schema(schema_col: Collection, embedder, question: str, top_k: int = 12) -> List[Tuple[str, str]]:
    q_emb = embed_text(embedder, question)
    res = schema_col.search(
        data=[q_emb],
        anns_field="embedding",
        param={"params": {"ef": 128}},
        limit=top_k,
        output_fields=["table_name", "text"],
    )
    return [(h.entity.get("table_name"), h.entity.get("text")) for h in res[0]]


def _ensure_tables(schema_col: Collection, embedder, schema_hits: List[Tuple[str, str]], must_have: List[str]) -> List[Tuple[str, str]]:
    have = {t for t, _ in schema_hits}
    out = list(schema_hits)
    for t in must_have:
        if t in have:
            continue
        extra = retrieve_schema(schema_col, embedder, f"table: {t}", top_k=3)
        for et, edoc in extra:
            if et == t and et not in have and edoc:
                out.append((et, edoc))
                have.add(et)
                break
    return out


def build_schema_context(schema_hits: List[Tuple[str, str]]) -> str:
    return "\n\n".join([doc for _, doc in schema_hits if doc])


def safe_generate(model, prompt: str, max_retries: int = 8) -> str:
    last_err = None
    for attempt in range(max_retries):
        try:
            resp = model.generate_content(prompt, request_options={"timeout": 120})
            return (resp.text or "").strip()
        except (TooManyRequests, ServiceUnavailable) as e:
            last_err = e
            time.sleep(min(60, (2 ** attempt) + 0.25))
        except Exception as e:
            last_err = e
            msg = str(e)
            if ("RemoteDisconnected" in msg) or ("Connection aborted" in msg) or ("EOF" in msg):
                time.sleep(min(60, (2 ** attempt) + 0.5))
                continue
            raise
    raise RuntimeError(f"LLM failed after retries. Last error was: {last_err}")


SQL_CACHE: Dict[str, str] = {}
_LAST_LLM_CALL_TS = 0.0

FORBIDDEN = {
    "insert", "update", "delete", "drop", "alter", "truncate", "create",
    "grant", "revoke", "vacuum", "copy"
}


def sanitize_sql(sql: str) -> str:
    s = (sql or "").strip()
    s = re.sub(r"^```[a-zA-Z]*\s*", "", s).strip()
    s = re.sub(r"\s*```$", "", s).strip()
    s = re.sub(r";+\s*$", "", s).strip()
    return s


def extract_cte_names(sql: str) -> set:
    s = re.sub(r"\s+", " ", (sql or "").strip()).strip().lower()
    if not s.startswith("with "):
        return set()

    ctes = set()
    for m in re.finditer(r"\bwith\s+([a-zA-Z_]\w*)\s+as\s*\(|\)\s*,\s*([a-zA-Z_]\w*)\s+as\s*\(", s):
        if m.group(1):
            ctes.add(m.group(1))
        if m.group(2):
            ctes.add(m.group(2))
    return ctes


def extract_table_names(sql: str) -> set:
    s = re.sub(r"\s+", " ", (sql or "").strip().lower())
    ctes = extract_cte_names(sql)

    tables = set()
    for m in re.finditer(r"\b(from|join)\s+([a-zA-Z_][\w]*)", s):
        t = m.group(2)
        t = t.split(".")[-1].strip('"')
        if t and (t not in ctes):
            tables.add(t)
    return tables


def validate_sql(sql: str, allowed_tables: List[str]):
    s = (sql or "").strip()
    if not s:
        raise ValueError("Empty SQL")
    if ";" in s:
        raise ValueError("Semicolons not allowed")
    low = s.lower()
    if not (low.startswith("select") or low.startswith("with")):
        raise ValueError("Only SELECT or WITH allowed")
    for kw in FORBIDDEN:
        if re.search(rf"\b{kw}\b", low):
            raise ValueError("Forbidden SQL keyword found")
    if ":hospital_id" not in s:
        raise ValueError("Missing required :hospital_id parameter")

    used = extract_table_names(s)
    unknown = used.difference(set(allowed_tables))
    if unknown:
        raise ValueError(f"SQL references tables not in retrieved schema: {sorted(list(unknown))}")


def extract_alias_map(sql: str) -> Dict[str, str]:
    s = re.sub(r"\s+", " ", sql.strip())
    alias_map = {}
    pattern = re.compile(r"\b(from|join)\s+([a-zA-Z_][\w]*)\s*(?:as\s+)?([a-zA-Z_][\w]*)?\b", re.IGNORECASE)
    for m in pattern.finditer(s):
        table = m.group(2)
        alias = m.group(3) or table
        alias_map[alias] = table
    return alias_map


def get_table_columns(engine, table_name: str) -> set:
    inspector = inspect(engine)
    return {c["name"] for c in inspector.get_columns(table_name)}


def _parse_qualified_identifiers(sql: str) -> List[Tuple[str, str]]:
    s = re.sub(r"\s+", " ", (sql or "").strip())
    alias_map = extract_alias_map(s)
    pairs = []
    for m in re.finditer(r"\b([a-zA-Z_]\w*)\.([a-zA-Z_]\w*)\b", s):
        alias = m.group(1)
        col = m.group(2)
        table = alias_map.get(alias)
        if table:
            pairs.append((table, col))
    return pairs


def validate_columns(engine, sql: str) -> None:
    pairs = _parse_qualified_identifiers(sql)
    cache: Dict[str, set] = {}
    for table, col in pairs:
        if table not in cache:
            cache[table] = get_table_columns(engine, table)
        if col not in cache[table]:
            raise ValueError(f"Unknown column, {table}.{col}")


def remove_invalid_hospital_filters(engine, sql: str) -> str:
    alias_map = extract_alias_map(sql)
    bad_aliases = []
    for alias, table in alias_map.items():
        cols = get_table_columns(engine, table)
        if "hospital_id" not in cols:
            bad_aliases.append(alias)

    if not bad_aliases:
        return sql

    fixed = sql
    for a in bad_aliases:
        fixed = re.sub(rf"\s+AND\s+{a}\.hospital_id\s*=\s*:hospital_id\b", "", fixed, flags=re.IGNORECASE)
        fixed = re.sub(rf"\b{a}\.hospital_id\s*=\s*:hospital_id\s+AND\s+", "", fixed, flags=re.IGNORECASE)
        fixed = re.sub(rf"\b{a}\.hospital_id\s*=\s*:hospital_id\b", "TRUE", fixed, flags=re.IGNORECASE)

    fixed = re.sub(r"WHERE\s+TRUE\s+AND\s+", "WHERE ", fixed, flags=re.IGNORECASE)
    fixed = re.sub(r"WHERE\s+TRUE\s*$", "", fixed, flags=re.IGNORECASE)
    return fixed


def execute_sql(engine, sql: str, hospital_id: int, extra_params: Optional[Dict[str, Any]] = None, max_rows: int = 200) -> List[Dict[str, Any]]:
    wrapped = f"WITH q AS ({sql}) SELECT * FROM q LIMIT :_limit"
    params: Dict[str, Any] = {"hospital_id": hospital_id, "_limit": max_rows}
    if extra_params:
        params.update(extra_params)

    with engine.connect() as conn:
        rows = conn.execute(text(wrapped), params).mappings().all()
    return [dict(r) for r in rows]


def summarize_rows(rows: List[Dict[str, Any]]) -> str:
    if not rows:
        return "No matching records found."

    cols = list(rows[0].keys())

    if len(rows) <= 10:
        return "\n".join([", ".join([f"{k}: {r.get(k)}" for k in cols]) for r in rows])

    head = f"Summary, {len(rows)} rows returned. Showing first 10."
    body = "\n".join([", ".join([f"{k}: {r.get(k)}" for k in cols]) for r in rows[:10]])
    return head + "\n" + body


def build_gemini_model() -> Any:
    api_key = os.getenv("GEMINI_API_KEY", "")
    if not api_key:
        raise ValueError("Missing GEMINI_API_KEY env var")
    genai.configure(api_key=api_key)
    model_name = os.getenv("GEMINI_SQL_MODEL", "models/gemini-2.0-flash-lite-001")
    return genai.GenerativeModel(model_name)


_INTENTS = [
    ("ADMISSIONS_BY_DEPARTMENT", [r"\badmit", r"\bdepartment\b"], ["admissions", "departments"]),
    ("ADMISSIONS_DAILY_SERIES", [r"\badmit", r"last\s+\d+\s+days|last\s+7\s+days|daily"], ["admissions"]),
    ("ADMISSIONS_TODAY", [r"\badmit", r"\btoday\b"], ["admissions"]),
    ("CURRENT_ADMITTED_LIST", [r"\bcurrently\b|\bcurrent\b", r"\badmitted\b"], ["admissions", "patients", "departments", "wards", "rooms", "beds"]),
    ("DISCHARGES", [r"\bdischarg"], ["discharges", "admissions", "patients"]),
    ("CRITICAL_PATIENTS", [r"\bcritical\b"], ["patient_conditions", "admissions", "patients", "departments"]),
    ("VITALS", [r"\bvitals?\b"], ["patient_vitals", "admissions", "patients"]),
    ("BIRTHS", [r"\bbirths?\b|\bbabies\b"], ["births", "newborns", "patients"]),
    ("ABORTION", [r"\babortion\b"], ["abortion_cases", "patients", "staff"]),
    ("STAFF_COUNTS", [r"\bstaff\b", r"\brole\b|by\s+role|distribution"], ["staff", "staff_roles", "staff_department", "departments"]),
    ("SHIFTS", [r"\bshift\b|\bavailability\b|on\s+shift"], ["staff_shifts", "staff", "staff_roles", "departments"]),
    ("PRESCRIPTIONS", [r"\bprescription|\bmedication"], ["prescriptions", "prescription_items", "medications", "admissions", "patients", "staff"]),
    ("NURSE_ASSIGNMENTS", [r"\bnurse\b", r"\bassign"], ["nurse_assignments", "admissions", "patients", "staff", "wards", "rooms", "beds"]),
    ("BLOOD_BANK", [r"\bblood\b"], ["blood_bank_inventory", "blood_groups", "donors", "donations", "patients"]),
    ("ORGAN_DONATION", [r"\borgan\b"], ["donations", "organ_donation_items", "organs", "donors"]),
    ("SURGERY", [r"\bsurger"], ["surgeries", "surgery_team", "admissions", "patients", "staff"]),
]


def detect_intent(question: str) -> Tuple[str, List[str]]:
    q = (question or "").lower()
    for name, patterns, must_tables in _INTENTS:
        ok = True
        for p in patterns:
            if not re.search(p, q):
                ok = False
                break
        if ok:
            return name, must_tables
    return "GENERAL", []


def _parse_iso_dt(s: str) -> Optional[datetime]:
    s2 = (s or "").strip()
    s2 = s2.replace(" ", "T")
    try:
        dt = datetime.fromisoformat(s2)
        return dt
    except Exception:
        return None


def get_db_now(engine) -> Optional[datetime]:
    try:
        with engine.connect() as conn:
            dt = conn.execute(text("SELECT max(admitted_ts) FROM admissions")).scalar()
            return dt
    except Exception:
        return None


def resolve_time_window(question: str, engine) -> Dict[str, Any]:
    q = (question or "").lower()

    db_now = get_db_now(engine)
    if db_now is not None:
        now = db_now
    else:
        now = datetime.now(_DHAKA_TZ) if _DHAKA_TZ else datetime.now()

    m_iso = re.search(r"(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}(?:\+\d{2}:\d{2})?)", question or "")
    if m_iso:
        dt = _parse_iso_dt(m_iso.group(1))
        if dt:
            return {"at_ts": dt}

    if re.search(r"\btoday\b", q):
        start = now.replace(hour=0, minute=0, second=0, microsecond=0)
        end = start + timedelta(days=1)
        return {"start_ts": start, "end_ts": end}

    if re.search(r"\bthis\s+month\b", q):
        start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
        if start.month == 12:
            end = start.replace(year=start.year + 1, month=1)
        else:
            end = start.replace(month=start.month + 1)
        return {"start_ts": start, "end_ts": end}

    m = re.search(r"last\s+(\d+)\s+days", q)
    if m:
        n = int(m.group(1))
        end = now
        start = now - timedelta(days=n)
        return {"start_ts": start, "end_ts": end}

    if re.search(r"last\s+7\s+days", q):
        end = now
        start = now - timedelta(days=7)
        return {"start_ts": start, "end_ts": end}

    if re.search(r"last\s+30\s+days", q):
        end = now
        start = now - timedelta(days=30)
        return {"start_ts": start, "end_ts": end}

    return {}


def generate_sql(model, question: str, schema_hits: List[Tuple[str, str]], intent: str, time_info: Dict[str, Any], error_hint: str = "") -> str:
    schema_context = build_schema_context(schema_hits)

    time_bucket = "time" if time_info else "notime"
    cache_key = f"{intent}:{hash(schema_context)}:{time_bucket}"

    if cache_key in SQL_CACHE and not error_hint:
        return SQL_CACHE[cache_key]

    time_rules = []
    if "start_ts" in time_info and "end_ts" in time_info:
        time_rules.append("If the question is time based, use :start_ts and :end_ts for filtering on the correct timestamp column.")
    if "at_ts" in time_info:
        time_rules.append("If the question asks availability at a time, use :at_ts and compare it between shift start and end timestamps.")
    time_rule_text = "\n".join(time_rules) if time_rules else "If the question is time based, you may use NOW() but prefer provided parameters when available."

    prompt = f"""
You generate PostgreSQL SQL to be executed with SQLAlchemy.

Rules
1 Output exactly one SQL query
2 Only SELECT or WITH is allowed
3 Use only tables and columns present in the schema context
4 Apply hospital filter only on tables that actually contain a hospital_id column
5 Add LIMIT 200 for list queries
6 Do not output semicolons or code fences
7 Use :hospital_id placeholder, not a literal number
8 {time_rule_text}
9 Prefer join paths shown in common_joins if present
10 When querying patient_conditions for current status, select the latest record per admission using max(updated_ts)

Intent
{intent}

Schema context
{schema_context}

Question
{question}

{error_hint}

Output only SQL
""".strip()

    sql = sanitize_sql(safe_generate(model, prompt))
    if not error_hint:
        SQL_CACHE[cache_key] = sql
    return sql


_ENGINE = None
_SCHEMA_COL = None
_EMBEDDER = None
_GEMINI = None

COLLECTION_NAME = os.getenv("MILVUS_COLLECTION", "db_schema_docs")
EMBED_DIM = int(os.getenv("EMBED_DIM", "768"))


def init_once(index: bool = False) -> None:
    global _ENGINE, _SCHEMA_COL, _EMBEDDER, _GEMINI

    if _ENGINE is None:
        _ENGINE = build_engine()
    if _SCHEMA_COL is None:
        _SCHEMA_COL = build_milvus_collection(COLLECTION_NAME, EMBED_DIM)
    if _EMBEDDER is None:
        _EMBEDDER = build_embedder()
    if _GEMINI is None:
        _GEMINI = build_gemini_model()

    if index:
        index_schema(_ENGINE, _SCHEMA_COL, _EMBEDDER, schema_version=os.getenv("SCHEMA_VERSION", "v1"))


def chat(question: str, hospital_id: int) -> str:
    init_once(index=False)

    intent, must_tables = detect_intent(question)
    time_info = resolve_time_window(question, _ENGINE)

    schema_hits = retrieve_schema(_SCHEMA_COL, _EMBEDDER, question, top_k=12)
    if must_tables:
        schema_hits = _ensure_tables(_SCHEMA_COL, _EMBEDDER, schema_hits, must_tables)

    allowed_tables = [t for t, _ in schema_hits]

    extra_params: Dict[str, Any] = {}
    if "start_ts" in time_info and "end_ts" in time_info:
        extra_params["start_ts"] = time_info["start_ts"]
        extra_params["end_ts"] = time_info["end_ts"]
    if "at_ts" in time_info:
        extra_params["at_ts"] = time_info["at_ts"]

    global _LAST_LLM_CALL_TS
    now_ts = time.time()
    gap = now_ts - _LAST_LLM_CALL_TS
    if gap < 1.2:
        time.sleep(1.2 - gap)
    _LAST_LLM_CALL_TS = time.time()

    sql = generate_sql(_GEMINI, question, schema_hits, intent=intent, time_info=time_info)
    sql = remove_invalid_hospital_filters(_ENGINE, sql)

    try:
        validate_sql(sql, allowed_tables=allowed_tables)
        validate_columns(_ENGINE, sql)
    except Exception as e:
        hint = f"Previous SQL failed validation because: {str(e)}. Regenerate SQL using only valid identifiers from schema context."
        sql = generate_sql(_GEMINI, question, schema_hits, intent=intent, time_info=time_info, error_hint=hint)
        sql = remove_invalid_hospital_filters(_ENGINE, sql)
        validate_sql(sql, allowed_tables=allowed_tables)
        validate_columns(_ENGINE, sql)

    rows = execute_sql(_ENGINE, sql, hospital_id=hospital_id, extra_params=extra_params, max_rows=200)
    return summarize_rows(rows)