In [None]:
# --- 1. Imports and Script Configuration ---
import math
import os
import re
import textwrap
from pathlib import Path

import duckdb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from rich.console import Console
from rich.panel import Panel
from rich.progress import track
from rich.table import Table
from rich.text import Text

# Initialize pretty printing console
console = Console()

# --- Copied from your script ---
CONFIGURATIONS = [
    {
        "disease": "diabetes",
        "comorbidity": "chronic kidney disease",
        "drugs": ["dulaglutide"],
        "aes": ["pancreatitis", "hypoglycemia"],
        "enabled": False,
    },
    {
        "disease": "hypertension",
        "comorbidity": "chronic kidney disease",
        "drugs": ["losartan", "valsartan", "irbesartan", "candesartan", "olmesartan"],
        "drug_group_name": "ARB",
        "aes": ["hyperkalemia", "systemic lupus erythematosus"],
        "enabled": False,
    },
    {
        "disease": "hypertension",
        "comorbidity": "gout",
        "drugs": ["atenolol", "metoprolol", "propranolol", "bisoprolol", "carvedilol", "labetalol", "timolol", "oxprenolol", "pindolol"],
        "drug_group_name": "beta-blocker",
        "aes": ["acute kidney failure", "unspecified acute kidney failure", "hyperkalemia", "cardiac dysrhythmia"],
        "enabled": True,
    },
    # ... other configurations from your script
]

# --- Global Paths & Settings ---
ROOT_DIR = Path("//10.100.117.220/Research_Archive$/Archive/R01/R01-Ayush/")
RESULTS_DIR = ROOT_DIR / "results"
DATA_DIR = ROOT_DIR / "data"
GROUPS_DIR = ROOT_DIR / "code" / "groups"
WINDOWS = [30, 90, None]

# --- Helper Functions ---
def slugify(text: str) -> str:
    return "".join(ch.lower() if ch.isalnum() else '_' for ch in text).strip("_")

def clean_label(text: str) -> str:
    return text.lower().replace("_", " ")

def _read_clean_groups(path: Path) -> pd.DataFrame:
    df = pd.read_csv(path)
    df.columns = [c.strip().lower() for c in df.columns]
    for c in df.columns:
        df[c] = df[c].astype(str).str.strip()
    if "exclude" in df.columns:
        df = df[df["exclude"].str.lower() != 'true']
    if "name" in df.columns:
        df["name_key"] = df["name"].str.casefold()
    return df

def build_where_clause(rules: list) -> str:
    parts = [f"({tc} = '{tv}' AND {c} LIKE '{pat}')" for c, tc, tv, pat in rules]
    return "(" + " OR ".join(parts) + ")"

# --- 2. Main Setup Execution ---
console.print(Panel("[bold magenta]Adverse Event Analysis Setup for Jupyter[/bold magenta]", subtitle="Loading data...", expand=False))

# Load shared definition files
with console.status("[b]Loading shared definition files...[/b]", spinner="dots"):
    group_files = {
        "base": GROUPS_DIR / "base.csv",
        "comorbidity": GROUPS_DIR / "comorbidities.csv",
        "drugs": GROUPS_DIR / "drugs.csv",
        "ae": GROUPS_DIR / "adverse_effects.csv",
        "confounders": GROUPS_DIR / "confounders.csv"
    }
    group_dfs = {}
    for name, path in group_files.items():
        if path.exists():
            group_dfs[name] = _read_clean_groups(path)
        else:
            if name == 'confounders':
                console.print("[yellow]NOTE: 'confounders.csv' not found. Regression will only adjust for age.[/yellow]")
                group_dfs[name] = pd.DataFrame(columns=['name', 'name_key', 'col', 'type_col', 'type_val', 'like_pattern'])
            else:
                raise FileNotFoundError(f"Essential definition file not found: {path}")
console.print("Shared definition files loaded.")

# Connect to database
con = duckdb.connect(database=":memory:")
console.print("In-memory DuckDB connection established.")

# Select the first enabled configuration to run the setup
try:
    config = next(c for c in CONFIGURATIONS if c.get("enabled", True))
except StopIteration:
    raise ValueError("No enabled configuration found in the CONFIGURATIONS list.")

# --- Execute Steps 1-3 from process_configuration() ---

# Display configuration info
title = f"Setting up for: {clean_label(config['disease'])} + {clean_label(config['comorbidity'])}"
content = Text(no_wrap=True)
content.append("  Disease (D): ", style="green"); content.append(f"{clean_label(config['disease'])}\n")
content.append("  Comorbidity (D2): ", style="green"); content.append(f"{clean_label(config['comorbidity'])}\n")
if config.get("drug_group_name"):
    content.append("  Drug Group (C): ", style="green"); content.append(f"{config['drug_group_name']}\n")
else:
    content.append("  Drugs (C): ", style="green"); content.append(f"{', '.join(config['drugs'])}\n")
content.append("  Adverse Events (E): ", style="green"); content.append(f"{', '.join(config['aes'])}")
console.print(Panel(content, title=title, border_style="cyan", title_align="left"))

# Step 1: Load and Validate Selections
console.print(Panel("Step 1: Load and Validate Selections", title_align="left", border_style="blue"))
disease, comorbidity, drugs, aes = config["disease"], config["comorbidity"], config["drugs"], config["aes"]
sel_disease, sel_comorb = [disease.strip().casefold()], [comorbidity.strip().casefold()]
sel_drugs, sel_aes = [d.strip().casefold() for d in drugs], [a.strip().casefold() for a in aes]

def _rules_from(df, selected_names):
    sub = df[df["name_key"].isin(selected_names)].drop_duplicates()
    return [(r.col, r.type_col, r.type_val, r.like_pattern) for r in sub.itertuples(index=False)]

base_disease_rules = _rules_from(group_dfs["base"], sel_disease)
comorbidity_rules = _rules_from(group_dfs["comorbidity"], sel_comorb)
drug_codes = group_dfs["drugs"][group_dfs["drugs"]["name_key"].isin(sel_drugs)]["atc5_code"].dropna().unique().tolist()

adverse_events = []
for ae_name, grp in group_dfs["ae"][group_dfs["ae"]["name_key"].isin(sel_aes)].groupby("name_key", sort=True):
    rules = [(r.col, r.type_col, r.type_val, r.like_pattern) for r in grp.itertuples(index=False)]
    adverse_events.append({"name": grp['name'].iloc[0], "rules": rules})

confounders = []
if "confounders" in group_dfs and not group_dfs["confounders"].empty:
    for conf_name, grp in group_dfs["confounders"].groupby("name_key", sort=True):
        rules = [(r.col, r.type_col, r.type_val, r.like_pattern) for r in grp.itertuples(index=False)]
        confounders.append({"name": grp['name'].iloc[0], "rules": rules})

console.print(f"  [green]Selections validated successfully.[/green]")

# Step 2: Set Up Database Views
console.print(Panel("Step 2: Set Up Database Views", title_align="left", border_style="blue"))
cohort_dir = DATA_DIR / "cohorts" / disease
pop_file = cohort_dir / f"population_{disease}.parquet"
dx_file = cohort_dir / f"diagnoses_{disease}.parquet"
med_file = cohort_dir / f"meds_{disease}.parquet"

with console.status("[b]Creating database views from Parquet files...[/b]", spinner="dots"):
    con.execute(f"CREATE OR REPLACE VIEW dx AS SELECT patient_id, time_stamp::DATE AS diagnosis_date, TRIM(code1) AS code1, code_type1 FROM read_parquet('{dx_file.as_posix()}');")
    con.execute(f"CREATE OR REPLACE VIEW meds AS SELECT patient_id, time_stamp::DATE AS rx_start_date, code1 AS atc5_code, type FROM read_parquet('{med_file.as_posix()}') WHERE type = 'Medications purchase';")
    latest_date = con.execute("SELECT MAX(diagnosis_date) FROM dx;").fetchone()[0]
    admin_end_date = latest_date.strftime('%Y-%m-%d')
    # con.execute(f"CREATE OR REPLACE VIEW pop AS SELECT patient_id, date_of_birth::DATE AS birth_date, COALESCE(date_of_death::DATE, DATE '{admin_end_date}') AS observation_end_date FROM read_parquet('{pop_file.as_posix()}');")
    
    # Load population data into a Pandas DataFrame for imputation
    pop_df = pd.read_parquet(pop_file)
    
    # --- Imputation of and categorization of socioeconomic_status ---
    console.print("  [yellow]Categorizing and imputing socioeconomic_status...[/yellow]")
    n_missing_before = pop_df['socioeconomic_status'].isna().sum()

    # Create categorical SES levels based on population tertiles (~33% each).
    # NaNs in the original column will result in NaNs here.
    ses_cats = pd.qcut(
        pop_df['socioeconomic_status'],
        q=3,
        labels=['low', 'intermediate', 'high'],
        duplicates='drop'  # Handles non-unique bin edges
    )

    # Impute missing SES values by assigning them to the 'intermediate' category.
    # We use .astype('object') to allow filling with a string not in the original categories.
    pop_df['socioeconomic_status'] = ses_cats.astype('object').fillna('intermediate')

    # Ensure the column has a consistent categorical type for statsmodels
    ses_order = ['low', 'intermediate', 'high']
    pop_df['socioeconomic_status'] = pop_df['socioeconomic_status'].astype(pd.CategoricalDtype(categories=ses_order, ordered=True))
    
    n_imputed = n_missing_before
    console.print(f"  [green]Categorized SES into 3 levels and imputed {n_imputed:,} missing values to 'intermediate'.[/green]")

    # Define observation end date for the view
    pop_df['observation_end_date'] = pd.to_datetime(pop_df['date_of_death']).fillna(pd.to_datetime(admin_end_date))
    
    # Register the imputed Pandas DataFrame as a temporary DuckDB table
    con.register('pop_imputed', pop_df)
    
    # Create the final 'pop' view, now including sex and the imputed socioeconomic_status
    con.execute("""
        CREATE OR REPLACE VIEW pop AS 
        SELECT 
            patient_id, 
            date_of_birth::DATE AS birth_date,
            observation_end_date::DATE AS observation_end_date,
            sex,
            socioeconomic_status
        FROM pop_imputed;
    """)

cohort_size = con.execute('SELECT COUNT(*) FROM pop;').fetchone()[0]
console.print(f"  [green]Overall cohort size for '{disease}':[/green] {cohort_size:,}")

# --- Demographic Summary ---
console.print(Panel("Demographic Summary of Population", title_align="left", border_style="blue"))

sex_counts = pop_df['sex'].value_counts()
ses_counts = pop_df['socioeconomic_status'].value_counts()

demographics_table = Table(title=f"Population Demographics for '{clean_label(disease)}'")
demographics_table.add_column("Statistic", justify="right", style="cyan", no_wrap=True)
demographics_table.add_column("Value", justify="left", style="magenta")

demographics_table.add_row("Total Patients", f"{len(pop_df):,}")
if 'F' in sex_counts:
    demographics_table.add_row("Females", f"{sex_counts['F']:,} ({sex_counts['F']/len(pop_df)*100:.1f}%)")
if 'M' in sex_counts:
    demographics_table.add_row("Males", f"{sex_counts['M']:,} ({sex_counts['M']/len(pop_df)*100:.1f}%)")

demographics_table.add_section()
demographics_table.add_row("[bold]SES Category[/bold]", "")
if 'low' in ses_counts:
    demographics_table.add_row("  Low", f"{ses_counts['low']:,} ({ses_counts['low']/len(pop_df)*100:.1f}%)")
if 'intermediate' in ses_counts:
    demographics_table.add_row("  Intermediate", f"{ses_counts['intermediate']:,} ({ses_counts['intermediate']/len(pop_df)*100:.1f}%)")
if 'high' in ses_counts:
    demographics_table.add_row("  High", f"{ses_counts['high']:,} ({ses_counts['high']/len(pop_df)*100:.1f}%)")
    
console.print(demographics_table)

# Step 3: Compute Index Dates and Define Cohorts
console.print(Panel("Step 3: Compute Index Dates and Define Cohorts", title_align="left", border_style="blue"))
with console.status("[b]Identifying first diagnosis/drug dates...[/b]", spinner="dots"):
    con.execute(f"CREATE OR REPLACE TEMP TABLE base_disease AS SELECT patient_id, MIN(diagnosis_date) AS base_date FROM dx WHERE {build_where_clause(base_disease_rules)} GROUP BY patient_id;")
    con.execute(f"CREATE OR REPLACE TEMP TABLE comorbidity AS SELECT patient_id, MIN(diagnosis_date) AS comorb_date FROM dx WHERE {build_where_clause(comorbidity_rules)} GROUP BY patient_id;")
    con.execute(f"CREATE OR REPLACE TEMP TABLE drug_exposure AS SELECT patient_id, MIN(rx_start_date) AS drug_date FROM meds WHERE atc5_code IN {tuple(drug_codes)} GROUP BY patient_id;")

with console.status("[b]Building master patient flag table...[/b]", spinner="dots"):
    con.execute("""
    CREATE OR REPLACE TEMP TABLE patient_flags AS
    SELECT p.patient_id,
        CASE WHEN bd.base_date IS NOT NULL THEN 1 ELSE 0 END AS base_disease, bd.base_date,
        CASE WHEN cd.comorb_date IS NOT NULL THEN 1 ELSE 0 END AS comorbidity, cd.comorb_date,
        CASE WHEN de.drug_date IS NOT NULL THEN 1 ELSE 0 END AS drug, de.drug_date
    FROM pop p
    LEFT JOIN base_disease bd USING (patient_id) LEFT JOIN comorbidity cd USING (patient_id) LEFT JOIN drug_exposure de USING (patient_id);
    """)

with console.status("[b]Calculating cohort sizes...[/b]", spinner="dots"):
    n_d, n_d2, n_c, n_d_d2, n_d_d2_c = con.execute("""
        SELECT SUM(base_disease), SUM(comorbidity), SUM(drug),
               SUM(CASE WHEN base_disease = 1 AND comorbidity = 1 THEN 1 ELSE 0 END),
               SUM(CASE WHEN base_disease = 1 AND comorbidity = 1 AND drug = 1 THEN 1 ELSE 0 END)
        FROM patient_flags;
    """).fetchone()

table = Table(title="Cohort Sizes")
table.add_column("Cohort Definition", justify="right", style="cyan", no_wrap=True)
table.add_column("Patient Count", justify="right", style="magenta")
table.add_row("N(D)", f"{n_d:,}")
table.add_row("N(D2)", f"{n_d2:,}")
table.add_row("N(C)", f"{n_c:,}")
table.add_row("N(D + D2)", f"{n_d_d2:,}")
table.add_row("N(D + D2 + C)", f"{n_d_d2_c:,}")
console.print(table)

# --- 4. Expose Variables for Testing ---
console.print("\n[bold magenta]SETUP COMPLETE.[/bold magenta] The following variables are now available for testing in subsequent cells:")
console.print(f"- `con`: The DuckDB connection object containing `pop`, `dx`, `meds`, and `patient_flags`.")
console.print(f"- `config`: The configuration dictionary for the current analysis.")
console.print(f"- `adverse_events`: A list of adverse event definition dictionaries.")
console.print(f"- `confounders`: A list of confounder definition dictionaries.")
console.print(f"- `WINDOWS`: The list of time windows: {WINDOWS}.")

In [None]:
con.execute("""
CREATE OR REPLACE TEMP TABLE regression_base AS
(SELECT patient_id, 1 AS exposed, drug_date AS index_date FROM patient_flags WHERE base_disease = 1 AND comorbidity = 1 AND drug = 1)
UNION ALL
(SELECT patient_id, 0 AS exposed, GREATEST(base_date, comorb_date) AS index_date FROM patient_flags WHERE base_disease = 1 AND comorbidity = 1 AND drug = 0);
""")

In [None]:
con.execute("""
CREATE OR REPLACE TEMP TABLE regression_data AS
SELECT
    r.patient_id,
    r.exposed,
    r.index_date,
    p.sex,
    p.socioeconomic_status,
    DATE_DIFF('year', p.birth_date, r.index_date) AS age_at_index
FROM regression_base r
JOIN pop p ON r.patient_id = p.patient_id;
""")

In [None]:
model_df = con.execute("SELECT * FROM regression_data;").fetchdf()

# if model_df.empty:
#     console.print("  [bold yellow]WARNING:[/] Base data for regression is empty. Skipping analysis.")
#     return pd.DataFrame()

In [None]:
all_results = []
# 4. Iterate through each AE and run the regression model.
for ae in track(adverse_events, description="Running regression models for AEs..."):
    ae_name, ae_slug = ae["name"], slugify(ae["name"])
    ae_where = build_where_clause(ae["rules"])
    
    ae_dates_df = con.execute(f"SELECT patient_id, MIN(diagnosis_date) AS ae_date FROM dx WHERE {ae_where} GROUP BY patient_id;").fetchdf()
    
    merged_df = pd.merge(model_df, ae_dates_df, on='patient_id', how='left')
    merged_df['index_date'] = pd.to_datetime(merged_df['index_date'])
    merged_df['ae_date'] = pd.to_datetime(merged_df['ae_date'])

    for W in [None]: # WINDOWS:
        win_label = "any_after_index" if W is None else f"{W}d"
        console.print(f"Processing model for: ([bold cyan]{clean_label(ae_name)}[/bold cyan], [yellow]{win_label}[/yellow])")
        
        temp_df = merged_df.copy()
        
        if W is None:
            temp_df['outcome'] = ((temp_df['ae_date'] > temp_df['index_date'])).astype(int)
        else:
            window_days = pd.to_timedelta(W, unit='d')
            temp_df['outcome'] = ((temp_df['ae_date'] > temp_df['index_date']) & (temp_df['ae_date'] <= temp_df['index_date'] + window_days)).astype(int)

        if temp_df['outcome'].sum() < 5: # Min events needed for a stable model
            continue

        # 5. Build the formula and fit the logistic regression model.
        additional_confounders_str = ""
        formula = f"outcome ~ exposed + age_at_index + C(sex) + C(socioeconomic_status){additional_confounders_str}"
        
        try:
            model_vars = ['outcome', 'exposed', 'age_at_index', 'sex', 'socioeconomic_status']
            temp_df.dropna(subset=model_vars, inplace=True)

            model = smf.logit(formula, data=temp_df).fit(maxiter=100, disp=1)
            
            params = model.params
            conf = model.conf_int()
            pvalues = model.pvalues
            
            adj_or = np.exp(params.get('exposed', np.nan))
            ci_low = np.exp(conf.loc['exposed', 0])
            ci_high = np.exp(conf.loc['exposed', 1])
            p_value = pvalues.get('exposed', np.nan)
            
            all_results.append({
                "adverse_event": clean_label(ae_name),
                "window": win_label,
                "adjusted_odds_ratio": adj_or,
                "ci95_low": ci_low,
                "ci95_high": ci_high,
                "p_value": p_value,
                "formula": formula
            })

        except Exception as e:
            console.print(f"  [bold red]ERROR:[/] Could not fit regression for '{ae_name}' ({win_label}). Reason: {e}")

# return pd.DataFrame(all_results)