In [None]:
import pandas as pd
import os
import re
import warnings
import matplotlib.pyplot as plt
from IPython.display import display
from sqlalchemy import create_engine

# 📦 Project path setup
import sys
from pathlib import Path

# Append the repo root to sys.path (NAS version)
sys.path.append(str(Path.home() / "work" / "EMS_QI_Projects" / "ahaems-2025-submission"))

from project_paths import (
    BASE_DIR, FALLOUTS_DIR, DATA_RAW_DIR, DATA_CLEANED_DIR, REPORTS_DIR, print_project_paths
)

warnings.filterwarnings("ignore", message="Could not infer format.*", category=UserWarning)

# === DB Connection ===
engine = create_engine("postgresql://jtaft:GunnersMate2003!@100.118.151.104:5432/datalake")

# === Constants ===
ARREST_EXCLUDE_CODES = {"3001003", "3001005"}
STEMI_IMPRESSION_STRING = "chest pain - stemi (i21.3)"
TRANSPORT_KEYWORD = "transport by this ems unit"
RESPONSE_CODE = "2205001"
AGE_THRESHOLD = 18

# === Load data from PostgreSQL ===
print("📥 Loading data from PostgreSQL (ahaems_cleaned)...")
df = pd.read_sql("SELECT * FROM ahaems_cleaned", con=engine)

# === Rename for internal use ===
df = df.rename(columns={
    "UniqueIncidentKey": "incident_id",
    "Patient Age (ePatient.15)": "age",
    "Patient Age Units (ePatient.16)": "age_units",
    "Primary Impression": "primary_impression",
    "Secondary Impression": "secondary_impression",
    "Transport Disposition": "transport_disposition",
    "Response Type Of Service Requested With Code (eResponse.05)": "response_type",
    "Vitals Signs Taken Date Time (eVitals.01)": "vitals_time",
    "Cardiac Arrest During EMS Event With Code (eArrest.01)": "cardiac_arrest",
    "Aspirin Given": "aspirin_given"
})

# === Extract arrest code ===
def extract_arrest_code(text):
    if isinstance(text, str):
        match = re.search(r"(\\d+)", text)
        if match:
            return match.group(1)
    return None

df["cardiac_arrest_code"] = df["cardiac_arrest"].apply(extract_arrest_code)

# === Detect STEMI impression ===
df["impression_valid"] = df["primary_impression"].str.contains(STEMI_IMPRESSION_STRING, case=False, na=False, regex=False) | \
                           df["secondary_impression"].str.contains(STEMI_IMPRESSION_STRING, case=False, na=False, regex=False)

# === Parse vitals time ===
df["vitals_time"] = pd.to_datetime(df["vitals_time"], errors="coerce")

# === Aggregate by incident ===
grouped = df.groupby("incident_id").agg({
    "age": "first",
    "age_units": "first",
    "impression_valid": "max",
    "transport_disposition": "first",
    "response_type": "first",
    "cardiac_arrest_code": "first",
    "vitals_time": "min"
}).reset_index()

# === Validation ===
def validate_grouped_data(df):
    return pd.Series({
        "missing_age": df["age"].isna().sum(),
        "invalid_age": (df["age"] < 0).sum(),
        "missing_vitals_time": df["vitals_time"].isna().sum(),
        "missing_transport_disposition": df["transport_disposition"].isna().sum(),
        "missing_response_type": df["response_type"].isna().sum(),
        "missing_cardiac_arrest_code": df["cardiac_arrest_code"].isna().sum(),
    }, name="Validation Summary")

validation_summary = validate_grouped_data(grouped)

# === Denominator Logic ===
grouped["age"] = pd.to_numeric(grouped["age"], errors="coerce")
age_valid = grouped["age"] >= AGE_THRESHOLD
transport_valid = grouped["transport_disposition"].str.contains(TRANSPORT_KEYWORD, case=False, na=False)
response_valid = grouped["response_type"].astype(str).str.contains(RESPONSE_CODE, na=False)
arrest_exclude = grouped["cardiac_arrest_code"].isin(ARREST_EXCLUDE_CODES)

grouped["in_denominator"] = age_valid & grouped["impression_valid"] & transport_valid & response_valid & ~arrest_exclude

# === Numerator Logic ===
aspirin_given_any = df[df["aspirin_given"]].groupby("incident_id").size().rename("aspirin_count").reset_index()
grouped = grouped.merge(aspirin_given_any, on="incident_id", how="left")
grouped["in_numerator"] = grouped["in_denominator"] & (grouped["aspirin_count"].fillna(0) > 0)

# === Summarize by Quarter ===
grouped["quarter"] = grouped["vitals_time"].dt.to_period("Q")
summary = grouped[grouped["in_denominator"]].groupby("quarter").agg(
    AHAEMS6_Denominator=("in_denominator", "sum"),
    AHAEMS6_Numerator=("in_numerator", "sum")
).reset_index()
summary["AHAEMS6_Percentage"] = (summary["AHAEMS6_Numerator"] / summary["AHAEMS6_Denominator"] * 100).round(2)

# === Display ===
display(summary)

# === Plot ===
def plot_measure_trends(summary_df, measure_name):
    plt.figure(figsize=(10, 6))
    plt.plot(summary_df["quarter"].astype(str), summary_df[f"{measure_name}_Percentage"], marker='o')
    plt.axhline(90, color='red', linestyle='--', label='Target (90%)')
    plt.title(f"{measure_name} Trends Over Time")
    plt.xlabel("Quarter")
    plt.ylabel("Percentage")
    plt.grid(True)
    plt.legend()
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

plot_measure_trends(summary, "AHAEMS6")

# === Validation Printout ===
print("\n🩺 Post-Aggregation Data Validation:")
print(validation_summary.to_string())

# === Export Fallout CSV ===
fallouts = grouped[grouped["in_denominator"] & ~grouped["in_numerator"]]
fallout_path = FALLOUTS_DIR / "ahaems1_fallouts.csv"

# Ensure directory exists before writing
os.makedirs(fallout_path.parent, exist_ok=True)

# Write fallout file
fallouts.to_csv(fallout_path, index=False)