# Test Notebook for AE Prevalence Analysis

This notebook allows you to:
1. Select a configuration from `configurations.py`
2. Run all steps from `ae_prevalence.py` interactively
3. Inspect intermediate outputs
4. Test different code options


In [None]:
# --- Imports and Setup ---
import math
import os
import re
import textwrap
from pathlib import Path

import duckdb
# To fix "Tcl_AsyncDelete: async handler deleted by the wrong thread" error.
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt

import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from rich.console import Console
from rich.panel import Panel
from rich.progress import track, Progress, TextColumn, BarColumn, TimeElapsedColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.table import Table
from rich.text import Text

from configurations import CONFIGURATIONS

plt.rcParams['font.sans-serif'] = 'Arial'

# Initialize pretty printing console
console = Console()

# --- Global Paths & Settings ---
ROOT_DIR = Path("//10.100.117.220/Research_Archive$/Archive/R01/R01-Ayush/txagent/")
RESULTS_DIR = ROOT_DIR / "results"
DATA_DIR = ROOT_DIR / "data"
GROUPS_DIR = ROOT_DIR / "code" / "groups"
WINDOWS = [30, 90, 365, 1825, None]

# --- Helper Functions ---
def slugify(text: str) -> str:
    """Converts text to a URL-friendly slug."""
    return "".join(ch.lower() if ch.isalnum() else '_' for ch in text).strip("_")

def clean_label(text: str) -> str:
    """Cleans a configuration string for display by lowercasing and replacing underscores."""
    return text.lower().replace("_", " ")

def _read_clean_groups(path: Path) -> pd.DataFrame:
    """
    Reads and cleans a group definition CSV.
    It also filters out any rows where the 'exclude' column is set to True.
    """
    try:
        df = pd.read_csv(path)
        if df.empty:
            console.print(f"  [yellow]WARNING:[/] Definition file is empty: {path.name}")
            return pd.DataFrame()
    except pd.errors.EmptyDataError:
        console.print(f"  [yellow]WARNING:[/] Could not parse columns from file (it may be empty): {path.name}")
        return pd.DataFrame()
    df.columns = [c.strip().lower() for c in df.columns]

    # Standardize all columns to stripped strings for consistent processing
    for c in df.columns:
        df[c] = df[c].astype(str).str.strip()

    # If an 'exclude' column exists, drop rows where its value is 'True' (case-insensitive)
    if "exclude" in df.columns:
        df = df[df["exclude"].str.lower() != 'true']

    if "name" in df.columns:
        df["name_key"] = df["name"].str.casefold()
        
    return df

def build_where_clause(rules: list) -> str:
    """Builds SQL OR-clauses from rule tuples."""
    parts = [f"({tc} = '{tv}' AND {c} LIKE '{pat}')" for c, tc, tv, pat in rules]
    return "(" + " OR ".join(parts) + ")"

def or_stats(a, b, c, d, add_halves=True):
    """Calculates odds ratio and 95% CI with Haldane-Anscombe correction."""
    if add_halves and (0 in (a, b, c, d)):
        a, b, c, d = a + 0.5, b + 0.5, c + 0.5, d + 0.5
    
    if b == 0 or c == 0:
        return np.nan, np.nan, np.nan, np.nan, np.nan

    OR = (a * d) / (b * c)
    logOR = math.log(OR)
    SE = math.sqrt(1/a + 1/b + 1/c + 1/d)
    lcl = math.exp(logOR - 1.96 * SE)
    ucl = math.exp(logOR + 1.96 * SE)
    
    return OR, logOR, SE, lcl, ucl

def ae_window_predicate(ae_col, idx_col, window_days):
    """Generates a SQL predicate for an AE occurring within a specified window."""
    if window_days is None:
        return f"{ae_col} > {idx_col}"
    return f"({ae_col} > {idx_col} AND {ae_col} <= {idx_col} + INTERVAL '{int(window_days)}' DAY)"

console.print("[bold green]✓[/bold green] All imports and helper functions loaded.")


## Step 1: Select Configuration

Choose which configuration to run. You can modify the index or select by criteria.


In [None]:
# Select configuration - modify this to test different configs
# Option 1: Select by index
config_index = 1  # Change this to select different configs

# Option 2: Select by criteria (uncomment to use)
# config_index = next(i for i, c in enumerate(CONFIGURATIONS) 
#                     if c.get("enabled", True) and c["disease"] == "hypertension")

config = CONFIGURATIONS[config_index]

# Display selected config
title = f"Selected Configuration: {clean_label(config['disease'])} + {clean_label(config['comorbidity'])}"
content = Text(no_wrap=True)
content.append("  Disease (D): ", style="green"); content.append(f"{clean_label(config['disease'])}\n")
content.append("  Comorbidity (D2): ", style="green"); content.append(f"{clean_label(config['comorbidity'])}\n")

if config.get("drug_group_name"):
    content.append("  Drug Group (C): ", style="green"); content.append(f"{config['drug_group_name']}\n")
    content.append("  Individual Drugs: ", style="green"); content.append(f"{', '.join(config['drugs'])}\n")
else:
    drug_label = "Drug (C)" if len(config['drugs']) == 1 else "Drugs (C)"
    content.append(f"  {drug_label}: ", style="green"); content.append(f"{', '.join(config['drugs'])}\n")

aes = config.get("aes", [])
positive_controls = config.get("positive_controls", [])
negative_controls = config.get("negative_controls", [])

if aes:
    content.append("  Adverse Events (E): ", style="green"); content.append(f"{', '.join(aes)}\n")
if positive_controls:
    content.append("  Positive Controls: ", style="cyan"); content.append(f"{', '.join(positive_controls)}\n")
if negative_controls:
    content.append("  Negative Controls: ", style="yellow"); content.append(f"{', '.join(negative_controls)}\n")

console.print(Panel(content, title=title, border_style="cyan", title_align="left"))


## Step 2: Load Definition Files and Population AE Data


In [None]:
# Load shared definition files
with console.status("[b]Loading shared definition files...[/b]", spinner="dots"):
    group_files = {
        "base": GROUPS_DIR / "base.csv",
        "comorbidity": GROUPS_DIR / "comorbidities.csv",
        "drugs": GROUPS_DIR / "drugs.csv",
        "ae": GROUPS_DIR / "adverse_effects.csv",
        "confounders": GROUPS_DIR / "confounders.csv"
    }
    group_dfs = {}
    for name, path in group_files.items():
        if path.exists():
            group_dfs[name] = _read_clean_groups(path)
        else:
            if name == 'confounders':
                console.print(f"[yellow]NOTE: '{path.name}' not found. Regression will only adjust for age, sex, and SES.[/yellow]")
                group_dfs[name] = pd.DataFrame()
            else:
                console.print(f"[bold red]ERROR:[/] Essential definition file not found: {path}. Exiting.")
                raise FileNotFoundError(f"Essential definition file not found: {path}")
console.print("✓ Shared definition files loaded.")

# Load population AE prevalence data
with console.status("[b]Loading population AE prevalence data...[/b]", spinner="dots"):
    pop_ae_file = ROOT_DIR / "data" / "codes" / "ae_patient_counts.csv"
    pop_ae_df = pd.read_csv(pop_ae_file)
    if pop_ae_df["prevalence_pct"].dropna().between(0, 1).all():
        pop_ae_df["prevalence_pct"] = pop_ae_df["prevalence_pct"] * 100.0
console.print("✓ Population AE prevalence data loaded.")

# Initialize database connection
con = duckdb.connect(database=":memory:")
console.print("✓ DuckDB connection established.")


## Step 3: Load and Validate Selections


In [None]:
console.print(Panel("Step 1: Load and Validate Selections", title_align="left", border_style="blue"))

disease, comorbidity, drugs = config["disease"], config["comorbidity"], config["drugs"]
# Combine all AEs for processing
aes = config.get("aes", [])
positive_controls = config.get("positive_controls", [])
negative_controls = config.get("negative_controls", [])
aes = aes + positive_controls + negative_controls

sel_disease, sel_comorb = [disease.strip().casefold()], [comorbidity.strip().casefold()]
sel_drugs, sel_aes = [d.strip().casefold() for d in drugs], [a.strip().casefold() for a in aes]
sel_confounders = [c.strip().casefold() for c in config.get("confounders", [])]

def _rules_from(df, selected_names):
    sub = df[df["name_key"].isin(selected_names)].drop_duplicates()
    return [(r.col, r.type_col, r.type_val, r.like_pattern) for r in sub.itertuples(index=False)]

with console.status("[b]Loading definitions and rules...[/b]", spinner="dots"):
    base_disease_rules = _rules_from(group_dfs["base"], sel_disease)
    comorbidity_rules = _rules_from(group_dfs["comorbidity"], sel_comorb)
    drug_codes = group_dfs["drugs"][group_dfs["drugs"]["name_key"].isin(sel_drugs)]["atc5_code"].dropna().unique().tolist()
    
    adverse_events = []
    for ae_name, grp in group_dfs["ae"][group_dfs["ae"]["name_key"].isin(sel_aes)].groupby("name_key", sort=True):
        rules = [(r.col, r.type_col, r.type_val, r.like_pattern) for r in grp.itertuples(index=False)]
        adverse_events.append({"name": grp['name'].iloc[0], "rules": rules})

    confounder_definitions = []
    if sel_confounders and "confounders" in group_dfs:
        for conf_name, grp in group_dfs["confounders"][group_dfs["confounders"]["name_key"].isin(sel_confounders)].groupby("name_key", sort=True):
            rules = [(r.col, r.type_col, r.type_val, r.like_pattern) for r in grp.itertuples(index=False)]
            confounder_definitions.append({"name": grp['name'].iloc[0], "rules": rules})

console.print(f"  [green]Disease rules found:[/green] {len(base_disease_rules)}")
console.print(f"  [green]Comorbidity rules found:[/green] {len(comorbidity_rules)}")
console.print(f"  [green]Drug ATC5 codes found:[/green] {len(drug_codes)}")
console.print(f"  [green]Adverse events found:[/green] {len(adverse_events)}")
console.print(f"  [green]Confounder definitions found:[/green] {len(confounder_definitions)}")

for label, obj in [("DISEASE", base_disease_rules), ("COMORBIDITY", comorbidity_rules), ("DRUGS", drug_codes), ("AEs", adverse_events)]:
    if not obj:
        console.print(f"  [bold red]ERROR:[/] No entries found for '{label}'. Check names in config and CSVs.")
        raise ValueError(f"No entries found for '{label}'")

console.print("  [green]✓ Selections validated successfully.[/green]")


## Step 4: Set Up Database Views


In [None]:
console.print(Panel("Step 2: Set Up Database Views", title_align="left", border_style="blue"))

cohort_dir = DATA_DIR / "cohorts" / disease
pop_file = cohort_dir / f"population_{disease}.parquet"
dx_file = cohort_dir / f"diagnoses_{disease}.parquet"
med_file = cohort_dir / f"meds_{disease}.parquet"

if not all([pop_file.exists(), dx_file.exists(), med_file.exists()]):
    console.print(f"  [bold red]ERROR:[/] Data files not found for disease '{disease}' in {cohort_dir}.")
    raise FileNotFoundError(f"Data files not found for disease '{disease}'")

# Define save directory
drug_slug = slugify(config.get("drug_group_name") or '_'.join(drugs))
save_dir_name = slugify(f"{disease}-{comorbidity}-{drug_slug}")
save_dir = RESULTS_DIR / "ae_prevalence" / slugify(disease) / save_dir_name
save_dir.mkdir(parents=True, exist_ok=True)

with console.status("[b]Creating database views from Parquet files...[/b]", spinner="dots"):
    con.execute(f"CREATE OR REPLACE VIEW dx AS SELECT patient_id, time_stamp::DATE AS diagnosis_date, TRIM(code1) AS code1, code_type1 FROM read_parquet('{dx_file.as_posix()}');")
    con.execute(f"CREATE OR REPLACE VIEW meds AS SELECT patient_id, time_stamp::DATE AS rx_start_date, code1 AS atc5_code, type FROM read_parquet('{med_file.as_posix()}') WHERE type = 'Medications purchase';")
    latest_date = con.execute("SELECT MAX(diagnosis_date) FROM dx;").fetchone()[0]
    admin_end_date = latest_date.strftime('%Y-%m-%d')

    # Load population data into a Pandas DataFrame for imputation
    pop_df = pd.read_parquet(pop_file)
    
    # --- Imputation of and categorization of socioeconomic_status ---
    console.print("  [yellow]Categorizing and imputing socioeconomic status...[/yellow]")
    n_missing_before = pop_df['socioeconomic_status'].isna().sum()

    # Create categorical SES levels based on population tertiles (~33% each).
    ses_cats = pd.qcut(
        pop_df['socioeconomic_status'],
        q=3,
        labels=['low', 'intermediate', 'high'],
        duplicates='drop'
    )

    # Impute missing SES values by assigning them to the 'intermediate' category.
    pop_df['socioeconomic_status'] = ses_cats.astype('object').fillna('intermediate')

    # Ensure the column has a consistent categorical type
    ses_order = ['low', 'intermediate', 'high']
    pop_df['socioeconomic_status'] = pop_df['socioeconomic_status'].astype(pd.CategoricalDtype(categories=ses_order, ordered=True))
    
    n_imputed = n_missing_before
    console.print(f"  [green]Categorized SES into 3 levels and imputed {n_imputed:,} missing values to 'intermediate'.[/green]")

    # Define observation end date for the view
    pop_df['observation_end_date'] = pd.to_datetime(pop_df['date_of_death']).fillna(pd.to_datetime(admin_end_date))
    
    # Register the imputed Pandas DataFrame as a temporary DuckDB table
    con.register('pop_imputed', pop_df)
    
    # Create the final 'pop' view, now including sex and the imputed socioeconomic_status
    con.execute("""
        CREATE OR REPLACE VIEW pop AS 
        SELECT 
            patient_id, 
            date_of_birth::DATE AS birth_date,
            observation_end_date::DATE AS observation_end_date,
            sex,
            socioeconomic_status
        FROM pop_imputed;
    """)

cohort_size = con.execute('SELECT COUNT(*) FROM pop;').fetchone()[0]
console.print(f"  [green]Overall cohort size for '{disease}':[/green] {cohort_size:,}")

# --- Demographic Summary ---
console.print(Panel("Demographic Summary of Population", title_align="left", border_style="blue"))

sex_counts = pop_df['sex'].value_counts()
ses_counts = pop_df['socioeconomic_status'].value_counts()

demographics_table = Table(title=f"Population Demographics for '{clean_label(disease)}'")
demographics_table.add_column("Statistic", justify="right", style="cyan", no_wrap=True)
demographics_table.add_column("Value", justify="left", style="magenta")

demographics_table.add_row("Total Patients", f"{len(pop_df):,}")
if 'F' in sex_counts:
    demographics_table.add_row("Females", f"{sex_counts['F']:,} ({sex_counts['F']/len(pop_df)*100:.1f}%)")
if 'M' in sex_counts:
    demographics_table.add_row("Males", f"{sex_counts['M']:,} ({sex_counts['M']/len(pop_df)*100:.1f}%)")

demographics_table.add_section()
demographics_table.add_row("[bold]SES Category[/bold]", "")
if 'low' in ses_counts:
    demographics_table.add_row("  Low", f"{ses_counts['low']:,} ({ses_counts['low']/len(pop_df)*100:.1f}%)")
if 'intermediate' in ses_counts:
    demographics_table.add_row("  Intermediate", f"{ses_counts['intermediate']:,} ({ses_counts['intermediate']/len(pop_df)*100:.1f}%)")
if 'high' in ses_counts:
    demographics_table.add_row("  High", f"{ses_counts['high']:,} ({ses_counts['high']/len(pop_df)*100:.1f}%)")
    
console.print(demographics_table)


## Step 5: Compute Index Dates and Define Cohorts


In [None]:
console.print(Panel("Step 3: Compute Index Dates and Define Cohorts", title_align="left", border_style="blue"))

with console.status("[b]Identifying first diagnosis/drug dates...[/b]", spinner="dots"):
    con.execute(f"CREATE OR REPLACE TEMP TABLE base_disease AS SELECT patient_id, MIN(diagnosis_date) AS base_date FROM dx WHERE {build_where_clause(base_disease_rules)} GROUP BY patient_id;")
    con.execute(f"CREATE OR REPLACE TEMP TABLE comorbidity AS SELECT patient_id, MIN(diagnosis_date) AS comorb_date FROM dx WHERE {build_where_clause(comorbidity_rules)} GROUP BY patient_id;")
    con.execute(f"CREATE OR REPLACE TEMP TABLE drug_exposure AS SELECT patient_id, MIN(rx_start_date) AS drug_date FROM meds WHERE atc5_code IN {tuple(drug_codes)} GROUP BY patient_id;")

with console.status("[b]Building master patient flag table...[/b]", spinner="dots"):
    con.execute("""
    CREATE OR REPLACE TEMP TABLE patient_flags AS
    SELECT p.patient_id,
        CASE WHEN bd.base_date IS NOT NULL THEN 1 ELSE 0 END AS base_disease, bd.base_date,
        CASE WHEN cd.comorb_date IS NOT NULL THEN 1 ELSE 0 END AS comorbidity, cd.comorb_date,
        CASE WHEN de.drug_date IS NOT NULL THEN 1 ELSE 0 END AS drug, de.drug_date
    FROM pop p
    LEFT JOIN base_disease bd USING (patient_id) LEFT JOIN comorbidity cd USING (patient_id) LEFT JOIN drug_exposure de USING (patient_id);
    """)

with console.status("[b]Calculating cohort sizes...[/b]", spinner="dots"):
    n_d, n_d2, n_c, n_d_d2, n_d_c, n_d_d2_c = con.execute("""
        SELECT SUM(base_disease), 
               SUM(comorbidity), 
               SUM(drug),
               SUM(CASE WHEN base_disease = 1 AND comorbidity = 1 THEN 1 ELSE 0 END),
               SUM(CASE WHEN base_disease = 1 AND drug = 1 THEN 1 ELSE 0 END),
               SUM(CASE WHEN base_disease = 1 AND comorbidity = 1 AND drug = 1 THEN 1 ELSE 0 END)
        FROM patient_flags;
    """).fetchone()
    con.execute("CREATE OR REPLACE TEMP TABLE cohort1 AS SELECT * FROM patient_flags WHERE base_disease = 1;")
    con.execute("CREATE OR REPLACE TEMP TABLE cohort2 AS SELECT * FROM patient_flags WHERE base_disease = 1 AND comorbidity = 1;")
    con.execute("CREATE OR REPLACE TEMP TABLE cohort3 AS SELECT * FROM patient_flags WHERE base_disease = 1 AND drug = 1;")
    con.execute("CREATE OR REPLACE TEMP TABLE cohort4 AS SELECT * FROM patient_flags WHERE base_disease = 1 AND comorbidity = 1 AND drug = 1;")

table = Table(title="Cohort Sizes")
table.add_column("Cohort Definition", justify="right", style="cyan", no_wrap=True)
table.add_column("Patient Count", justify="right", style="magenta")
table.add_row("N(D)", f"{n_d:,}")
table.add_row("N(D2)", f"{n_d2:,}")
table.add_row("N(C)", f"{n_c:,}")
table.add_row("N(D + D2) [c2]", f"{n_d_d2:,}")
table.add_row("N(D + C) [c3]", f"{n_d_c:,}")
table.add_row("N(D + D2 + C) [c4]", f"{n_d_d2_c:,}")
console.print(table)

console.print(f"\n[bold]Results will be saved to:[/bold] [blue underline]{save_dir.relative_to(RESULTS_DIR)}[/blue underline]\n")


## Step 6: Calculate Unadjusted Odds Ratios

You can inspect the results in the `unadjusted_or_results` variable.


In [None]:
console.print(Panel("Step 4a: Calculate Unadjusted Odds Ratios", title_align="left", border_style="blue"))

with console.status("[b]Defining analysis sets...[/b]", spinner="dots"):
    con.execute("CREATE OR REPLACE TEMP TABLE analysis_set AS SELECT * FROM patient_flags WHERE base_disease = 1 AND comorbidity = 1;")
    con.execute("CREATE OR REPLACE TEMP TABLE exposed_riskset AS SELECT patient_id, GREATEST(base_date, comorb_date, drug_date) AS index_date FROM analysis_set WHERE drug = 1;")
    con.execute("CREATE OR REPLACE TEMP TABLE unexposed_riskset AS SELECT patient_id, GREATEST(base_date, comorb_date) AS index_date FROM analysis_set WHERE drug = 0;")

total_exp = con.execute("SELECT COUNT(*) FROM exposed_riskset;").fetchone()[0]
total_unexp = con.execute("SELECT COUNT(*) FROM unexposed_riskset;").fetchone()[0]
console.print(f"  [green]Exposed starters (D+D2+C):[/green] {total_exp:,}")
console.print(f"  [green]Unexposed starters (D+D2, no C):[/green] {total_unexp:,}")

if total_exp == 0 or total_unexp == 0:
    console.print("  [bold yellow]WARNING:[/] Cannot calculate odds ratios with zero patients in an analysis group.")
    unadjusted_or_results = pd.DataFrame()
else:
    rows = []
    for ae in track(adverse_events, description="Computing unadjusted ORs for AEs..."):
        ae_name, ae_slug = ae["name"], slugify(ae["name"])
        ae_where = build_where_clause(ae["rules"])
        con.execute(f"CREATE OR REPLACE TEMP TABLE ae_{ae_slug} AS SELECT patient_id, MIN(diagnosis_date) AS ae_date FROM dx WHERE {ae_where} GROUP BY patient_id;")

        for W in WINDOWS:
            win_label = "any_after_index" if W is None else f"{W}d"
            pred_exp = ae_window_predicate("a.ae_date", "e.index_date", W)
            pred_unx = ae_window_predicate("a.ae_date", "u.index_date", W)

            a = con.execute(f"SELECT COUNT(DISTINCT e.patient_id) FROM exposed_riskset e JOIN ae_{ae_slug} a USING (patient_id) WHERE {pred_exp};").fetchone()[0]
            b = total_exp - a
            c = con.execute(f"SELECT COUNT(DISTINCT u.patient_id) FROM unexposed_riskset u JOIN ae_{ae_slug} a USING (patient_id) WHERE {pred_unx};").fetchone()[0]
            d = total_unexp - c
            
            OR, logOR, SE, lcl, ucl = or_stats(a, b, c, d)
            
            rows.append({
                "adverse_event": clean_label(ae_name), "window": win_label,
                "a_exposed_E": a, "b_exposed_noE": b, "c_unexposed_E": c, "d_unexposed_noE": d,
                "total_exposed": total_exp, "total_unexposed": total_unexp,
                "odds_ratio": OR, "log_or": logOR, "se_log_or": SE, "ci95_low": lcl, "ci95_high": ucl
            })
            
    unadjusted_or_results = pd.DataFrame(rows).sort_values(["adverse_event", "window"]).reset_index(drop=True)

# Display results
if not unadjusted_or_results.empty:
    console.print(f"\n[bold]Unadjusted OR Results:[/bold]")
    display(unadjusted_or_results)
else:
    console.print("[yellow]No unadjusted OR results to display.[/yellow]")


## Step 7: Calculate Adjusted Odds Ratios (Confounder-Adjusted)

You can inspect the results in the `adjusted_or_results` variable. This step is skipped if `run_regression` is False in the config.


In [None]:
if config.get("run_regression", True):
    console.print(Panel("Step 4b: Calculate Confounder-Adjusted Odds Ratios", title_align="left", border_style="blue"))
    
    with console.status("[b]Building base regression dataset...[/b]", spinner="dots"):
        con.execute("""
            CREATE OR REPLACE TEMP TABLE regression_base AS
            (SELECT patient_id, 1 AS exposed, GREATEST(base_date, comorb_date, drug_date) AS index_date FROM patient_flags WHERE base_disease = 1 AND comorbidity = 1 AND drug = 1)
            UNION ALL
            (SELECT patient_id, 0 AS exposed, GREATEST(base_date, comorb_date) AS index_date FROM patient_flags WHERE base_disease = 1 AND comorbidity = 1 AND drug = 0);
        """)

        base_query = """
            SELECT
                r.patient_id, r.exposed, r.index_date, p.sex, p.socioeconomic_status,
                DATE_DIFF('year', p.birth_date, r.index_date) AS age_at_index
            FROM regression_base r
            JOIN pop p ON r.patient_id = p.patient_id
        """
        con.execute(f"CREATE OR REPLACE TEMP TABLE regression_data AS ({base_query});")

    # Add specified confounders as binary flags
    confounder_slugs = []
    if confounder_definitions:
        for conf in track(confounder_definitions, description="Adding confounder flags..."):
            conf_name, conf_slug = conf["name"], slugify(conf["name"])
            conf_where = build_where_clause(conf["rules"])
            confounder_slugs.append(conf_slug)

            con.execute(f"CREATE OR REPLACE TEMP TABLE {conf_slug}_dates AS SELECT patient_id, MIN(diagnosis_date) as conf_date FROM dx WHERE {conf_where} GROUP BY patient_id;")
            con.execute(f"""
                CREATE OR REPLACE TEMP TABLE temp_regression_data AS
                SELECT
                    rd.*,
                    CASE WHEN c.conf_date IS NOT NULL AND c.conf_date < rd.index_date THEN 1 ELSE 0 END AS {conf_slug}
                FROM regression_data rd
                LEFT JOIN {conf_slug}_dates c ON rd.patient_id = c.patient_id;
            """)
            con.execute("DROP TABLE regression_data;")
            con.execute("ALTER TABLE temp_regression_data RENAME TO regression_data;")
            con.execute(f"DROP TABLE {conf_slug}_dates;")

    model_df = con.execute("SELECT * FROM regression_data;").fetchdf()

    if model_df.empty:
        console.print("  [bold yellow]WARNING:[/] Base data for regression is empty. Skipping analysis.")
        adjusted_or_results = pd.DataFrame()
        model_summaries = []
    else:
        # Build the formula string before the loop
        confounder_slugs = [slugify(c['name']) for c in confounder_definitions]
        base_formula = "outcome ~ exposed + age_at_index + C(sex) + C(socioeconomic_status)"
        confounders_str = " + ".join(confounder_slugs)
        formula = f"{base_formula} + {confounders_str}" if confounders_str else base_formula
        
        console.print(Panel(f"[cyan]Using formula for all models:[/cyan]\n{formula}", title="Regression Formula", border_style="yellow"))

        all_results = []
        model_summaries = []

        # Progress bar layout
        progress_columns = [
            TextColumn("[progress.description]{task.description}"),
            BarColumn(), MofNCompleteColumn(), TextColumn("•"),
            TimeElapsedColumn(), TextColumn("•"), TimeRemainingColumn(),
        ]
        
        with Progress(*progress_columns, console=console) as progress:
            # Create a single task for the outer loop (adverse events)
            ae_task = progress.add_task("[cyan]Running regressions...", total=len(adverse_events))

            for ae in adverse_events:
                ae_name, ae_slug = ae["name"], slugify(ae["name"])
                ae_where = build_where_clause(ae["rules"])

                # Update progress bar
                progress.update(ae_task, description=f"[cyan]AE: [bold]{clean_label(ae_name)}[/bold]")
                
                ae_dates_df = con.execute(f"SELECT patient_id, MIN(diagnosis_date) AS ae_date FROM dx WHERE {ae_where} GROUP BY patient_id;").fetchdf()
                
                merged_df = pd.merge(model_df, ae_dates_df, on='patient_id', how='left')
                merged_df['index_date'] = pd.to_datetime(merged_df['index_date'])
                merged_df['ae_date'] = pd.to_datetime(merged_df['ae_date'])

                for W in WINDOWS:
                    win_label = "any_after_index" if W is None else f"{W}d"
                    temp_df = merged_df.copy()
                    
                    if W is None:
                        temp_df['outcome'] = ((temp_df['ae_date'] > temp_df['index_date'])).astype(int)
                    else:
                        window_days = pd.to_timedelta(W, unit='d')
                        temp_df['outcome'] = ((temp_df['ae_date'] > temp_df['index_date']) & (temp_df['ae_date'] <= temp_df['index_date'] + window_days)).astype(int)

                    if temp_df['outcome'].sum() < 5:
                        progress.log(f"[yellow]Skipping model for '{ae_name}' ({win_label}): Insufficient outcome events ({temp_df['outcome'].sum()}).[/yellow]")
                        continue

                    # Build the formula string dynamically
                    base_formula = "outcome ~ exposed + age_at_index + C(sex) + C(socioeconomic_status)"
                    confounders_str = " + ".join(confounder_slugs)
                    formula = f"{base_formula} + {confounders_str}" if confounders_str else base_formula
                    
                    try:
                        model_vars = ['outcome', 'exposed', 'age_at_index', 'sex', 'socioeconomic_status'] + confounder_slugs
                        temp_df.dropna(subset=model_vars, inplace=True)

                        model = smf.logit(formula, data=temp_df).fit(maxiter=100, disp=0)
                        
                        params = model.params
                        conf = model.conf_int()
                        pvalues = model.pvalues
                        
                        adj_or = np.exp(params.get('exposed', np.nan))
                        ci_low = np.exp(conf.loc['exposed', 0])
                        ci_high = np.exp(conf.loc['exposed', 1])
                        p_value = pvalues.get('exposed', np.nan)
                        
                        all_results.append({
                            "adverse_event": clean_label(ae_name),
                            "window": win_label,
                            "adjusted_odds_ratio": adj_or,
                            "ci95_low": ci_low,
                            "ci95_high": ci_high,
                            "p_value": p_value,
                            "formula": formula
                        })
                        
                        summary_title = f"Model Summary: AE={clean_label(ae_name)}, Window={win_label}, Formula: {formula}\n"
                        model_summaries.append(summary_title + model.summary().as_csv())

                    except Exception as e:
                        progress.log(f"[bold red]ERROR:[/] Could not fit regression for '{ae_name}' ({win_label}). Reason: {e}")
                
                # Manually advance the progress bar after all windows for an AE are done
                progress.advance(ae_task)

        adjusted_or_results = pd.DataFrame(all_results).sort_values(["adverse_event", "window"]).reset_index(drop=True)
    
    # Display results
    if not adjusted_or_results.empty:
        console.print(f"\n[bold]Adjusted OR Results:[/bold]")
        display(adjusted_or_results)
    else:
        console.print("[yellow]No adjusted OR results to display.[/yellow]")
else:
    console.print("[yellow]Skipping adjusted OR calculation (run_regression=False in config).[/yellow]")
    adjusted_or_results = pd.DataFrame()
    model_summaries = []


## Step 8: Calculate Prevalence Statistics

You can inspect the results in the `prevalence_results` variable.


In [None]:
console.print(Panel("Step 5: Calculate Prevalence Statistics", title_align="left", border_style="blue"))

denom_c1 = con.execute("SELECT COUNT(*) FROM cohort1;").fetchone()[0]
denom_c2 = con.execute("SELECT COUNT(*) FROM cohort2;").fetchone()[0]
denom_c3 = con.execute("SELECT COUNT(*) FROM cohort3;").fetchone()[0]
denom_c4 = con.execute("SELECT COUNT(*) FROM cohort4;").fetchone()[0]
denom_total = con.execute("SELECT COUNT(*) FROM pop;").fetchone()[0]

rows = []
disease_label = clean_label(config['disease'])
comorbidity_label = clean_label(config['comorbidity'])

drug_display_name = config.get("drug_group_name", ', '.join([clean_label(d) for d in config['drugs']]))
if config.get("drug_group_name"):
    drug_display_name = clean_label(drug_display_name)

for ae in track(adverse_events, description="Computing prevalence for AEs..."):
    ae_name, ae_slug = ae["name"], slugify(ae["name"])
    ae_label = clean_label(ae_name)
    ae_where = build_where_clause(ae["rules"])
    con.execute(f"CREATE OR REPLACE TEMP TABLE ae_{ae_slug} AS SELECT patient_id, MIN(diagnosis_date) AS ae_date FROM dx WHERE {ae_where} GROUP BY patient_id;")

    n_total = con.execute(f"SELECT COUNT(*) FROM ae_{ae_slug};").fetchone()[0]
    # c1 = D
    n1 = con.execute(f"SELECT COUNT(*) FROM cohort1 c JOIN ae_{ae_slug} a USING (patient_id) WHERE a.ae_date > c.base_date;").fetchone()[0]
    # c2 = D + D2
    n2 = con.execute(f"SELECT COUNT(*) FROM cohort2 c JOIN ae_{ae_slug} a USING (patient_id) WHERE a.ae_date > GREATEST(c.base_date, c.comorb_date);").fetchone()[0]
    # c3 = D + C
    n3 = con.execute(f"SELECT COUNT(*) FROM cohort3 c JOIN ae_{ae_slug} a USING (patient_id) WHERE a.ae_date > GREATEST(c.base_date, c.drug_date);").fetchone()[0]
    # c4 = D + D2 + C
    n4 = con.execute(f"SELECT COUNT(*) FROM cohort4 c JOIN ae_{ae_slug} a USING (patient_id) WHERE a.ae_date > GREATEST(c.base_date, c.comorb_date, c.drug_date);").fetchone()[0]

    rows.extend([
        {"adverse_event": ae_label, "cohort": "total", "n_with_AE": n_total, "denominator": denom_total, "prevalence_pct": (n_total / denom_total * 100.0) if denom_total else 0.0},
        {"adverse_event": ae_label, "cohort": f"{disease_label}", "n_with_AE": n1, "denominator": denom_c1, "prevalence_pct": (n1 / denom_c1 * 100.0) if denom_c1 else 0.0},
        {"adverse_event": ae_label, "cohort": f"{disease_label} + {comorbidity_label}", "n_with_AE": n2, "denominator": denom_c2, "prevalence_pct": (n2 / denom_c2 * 100.0) if denom_c2 else 0.0},
        {"adverse_event": ae_label, "cohort": f"{disease_label} + {drug_display_name}", "n_with_AE": n3, "denominator": denom_c3, "prevalence_pct": (n3 / denom_c3 * 100.0) if denom_c3 else 0.0},
        {"adverse_event": ae_label, "cohort": f"{disease_label} + {comorbidity_label} + {drug_display_name}", "n_with_AE": n4, "denominator": denom_c4, "prevalence_pct": (n4 / denom_c4 * 100.0) if denom_c4 else 0.0}
    ])

prevalence_results = pd.DataFrame(rows)

# Display results
if not prevalence_results.empty:
    console.print(f"\n[bold]Prevalence Results:[/bold]")
    display(prevalence_results)
else:
    console.print("[yellow]No prevalence results to display.[/yellow]")


## Step 9: Generate Plots

This step generates bar plots for AE prevalence. The plots are saved to the results directory.


In [None]:
# Import plotting helper functions from ae_prevalence.py
def format_cohort_label(label, width=15):
    """Formats cohort labels for plot axes."""
    segs = [s for s in re.split(r"\s*\+\s*", str(label).strip()) if s]
    lines = []
    n = len(segs)
    for i, seg in enumerate(segs):
        is_last = (i == n - 1)
        wrap_w = width if is_last else max(1, width - 2)
        wrapped = textwrap.wrap(seg, width=wrap_w, break_long_words=False)
        if not is_last and wrapped:
            wrapped[-1] += " +"
        lines.extend(wrapped)
    return "\n".join(lines)

def add_value_labels(ax, bars, numerators, denominators, percentages):
    """Places text labels with counts and percentages above each bar."""
    for rect, num, den, pct in zip(bars, numerators, denominators, percentages):
        if pd.notna(pct) and np.isfinite(pct):
            label = f"{int(num):,}\nof {int(den):,}\n({pct:.2f}%)"
            ax.annotate(
                label,
                xy=(rect.get_x() + rect.get_width() / 2, rect.get_height()),
                xytext=(0, 3),
                textcoords="offset points",
                ha="center", va="bottom",
                fontsize=10
            )

if not prevalence_results.empty:
    console.print(Panel("Step 6: Generate and Save Plots", title_align="left", border_style="blue"))
    
    disease_label = clean_label(config['disease'])
    comorbidity_label = clean_label(config['comorbidity'])
    drug_display_name = config.get("drug_group_name", ', '.join([clean_label(d) for d in config['drugs']]))
    if config.get("drug_group_name"):
        drug_display_name = clean_label(drug_display_name)
    
    # --- Standardize cohort names for consistent plotting ---
    dynamic_c1 = disease_label
    dynamic_c2 = f"{disease_label} + {comorbidity_label}"
    dynamic_c3 = f"{disease_label} + {drug_display_name}"
    dynamic_c4 = f"{disease_label} + {comorbidity_label} + {drug_display_name}"

    cohort_map = {
        "population": "general population",
        dynamic_c1: "disease",
        dynamic_c2: "disease + comorbidity",
        dynamic_c3: "disease + drug",
        dynamic_c4: "disease + comorbidity + drug",
    }
    
    COHORT_ORDER = ["general population", "total", "disease", "disease + drug", "disease + comorbidity", "disease + comorbidity + drug"]
    
    available_AEs_pop = set(pop_ae_df["adverse_event"].unique())
    requested_AEs = set(prevalence_results["adverse_event"].unique())
    valid_AEs = requested_AEs & available_AEs_pop

    missing_aes = requested_AEs - valid_AEs
    for ae in missing_aes:
        console.print(f"  [bold yellow]WARNING:[/] Population prevalence data not available for '{ae}'. It will be excluded from plots.")
    
    res_use = prevalence_results[prevalence_results["adverse_event"].isin(valid_AEs)]
    pop_use = pop_ae_df[pop_ae_df["adverse_event"].isin(valid_AEs)]
    
    results_pop = pd.concat([res_use, pop_use], ignore_index=True)

    results_pop["cohort_standard"] = results_pop["cohort"].map(cohort_map)
    results_pop["cohort_standard"] = results_pop["cohort_standard"].fillna(results_pop["cohort"])
    results_pop["cohort_standard"] = pd.Categorical(results_pop["cohort_standard"], categories=COHORT_ORDER, ordered=True)
    results_pop = results_pop.sort_values(["adverse_event", "cohort_standard"]).reset_index(drop=True)
    
    plot_dir = save_dir / "barplots"
    plot_dir.mkdir(parents=True, exist_ok=True)

    plot_data = results_pop[~(results_pop["cohort_standard"] == "total")].copy()
    
    bar_colors = ["#e5e5e5", "#e9c46a", "#f4a261", "#f4a261", "#e76f51"]

    for ae, g in track(plot_data.groupby("adverse_event", sort=True), description="Generating plots..."):
        g = g.sort_values("cohort_standard")
        fig, ax = plt.subplots(figsize=(5, 4))
        x = np.arange(len(g))
        
        numerators = g["n_with_AE"].values
        denominators = g["denominator"].values
        percentages = g["prevalence_pct"].astype(float).values
        
        bars = ax.bar(x, percentages, zorder=3, color=bar_colors, edgecolor="black")
        
        add_value_labels(ax, bars, numerators, denominators, percentages)
        ax.set_xticks(x)
        ax.set_xticklabels([format_cohort_label(c) for c in g["cohort_standard"].astype(str)], ha="center")
        ax.set_ylabel("Prevalence (%)", fontsize=12)
        ax.set_xlabel("Cohort", fontsize=12)
        
        title_text = f"{str(ae).lower()}\nin {disease_label} + {comorbidity_label} + {drug_display_name}"
        ax.set_title(title_text, loc='left', pad=20, fontsize=12)
        
        ax.grid(axis='y', linestyle=":", linewidth=0.5, zorder=0)
        ax.margins(y=0.25)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        fig.tight_layout()
        
        plot_base_name = slugify(f"{ae}-prevalence-plot")
        fig.savefig(plot_dir / f"{plot_base_name}.png", dpi=300)
        fig.savefig(plot_dir / f"{plot_base_name}.pdf")
        plt.close(fig)
    
    console.print(f"  [bold green]Saved[/bold green] barplots to {plot_dir}.")
else:
    console.print("[yellow]Skipping plotting due to empty prevalence results.[/yellow]")


## Step 10: Testing Area - Compute Number of Visits Per Patient

This is where you can test code to compute the number of visits per patient. The database connection `con` and all intermediate tables are available for your experiments.


In [None]:
print(Panel("Step 10: Compute Outpatient Visits Per Year (Window-Specific)", title_align="left", border_style="blue"))

# Load events file as DuckDB view (memory-efficient, no pandas)
events_file = cohort_dir / f"events_{disease}.parquet"

if not events_file.exists():
    print(f"  [bold red]ERROR:[/] Events file not found: {events_file}")
    raise FileNotFoundError(f"Events file not found: {events_file}")

print(f"  [green]Creating DuckDB view from:[/green] {events_file.name}")
con.execute(f"CREATE OR REPLACE VIEW events AS SELECT * FROM read_parquet('{events_file.as_posix()}');")

# Get basic info without fetching large dataframes
events_info = con.execute("""
    SELECT 
        COUNT(*) as total_rows,
        COUNT(DISTINCT patient_id) as unique_patients
    FROM events
""").fetchdf()

print(f"  [green]Total events:[/green] {events_info['total_rows'].iloc[0]:,}")
print(f"  [green]Unique patients:[/green] {events_info['unique_patients'].iloc[0]:,}")

# Compute outpatient visits per year for each window
# Only count visits within the follow-up window: index_date to index_date + window
console.print("\n[bold]Computing outpatient visits per year for each analysis window:[/bold]")
console.print(f"  [yellow]Windows: {WINDOWS}[/yellow]")
console.print("  [yellow]Filtering for: type='visit' AND description != 'ER visit'[/yellow]")
console.print("  [yellow]Counting visits only within each window: index_date to index_date + window[/yellow]")

# Helper functions for window labels and column names
def get_window_label(W):
    return "any_after_index" if W is None else f"{W}d"

def get_visits_col(W):
    return f"visits_{get_window_label(W)}"

def get_visits_per_year_col(W):
    return f"visits_per_year_{get_window_label(W)}"

# Build SQL query dynamically for all windows using a loop
visits_select_parts = []
visits_per_year_select_parts = []

for W in WINDOWS:
    win_label = get_window_label(W)
    visits_col = get_visits_col(W)
    visits_per_year_col = get_visits_per_year_col(W)
    
    if W is None:
        # For None window, count all visits after index_date
        visits_select_parts.append(f"""
        -- Window {win_label}
        COALESCE((
            SELECT COUNT(DISTINCT e.time_stamp::DATE)
            FROM events e
            WHERE e.patient_id = r.patient_id
              AND e.type = 'visit'
              AND e.description != 'ER visit'
              AND e.time_stamp::DATE > r.index_date
        ), 0) AS {visits_col}""")
        
        # For "any_after_index", use full follow-up period
        visits_per_year_select_parts.append(f"""
        CASE WHEN follow_up_days > 0 THEN {visits_col} / (follow_up_days / 365.25) ELSE NULL END AS {visits_per_year_col}""")
    else:
        # For specific window, count visits within window
        visits_select_parts.append(f"""
        -- Window {win_label}
        COALESCE((
            SELECT COUNT(DISTINCT e.time_stamp::DATE)
            FROM events e
            WHERE e.patient_id = r.patient_id
              AND e.type = 'visit'
              AND e.description != 'ER visit'
              AND e.time_stamp::DATE > r.index_date
              AND e.time_stamp::DATE <= r.index_date + INTERVAL '{W}' DAY
        ), 0) AS {visits_col}""")
        
        # Calculate visits per year: visits_W / (W / 365.25)
        visits_per_year_select_parts.append(f"""
        CASE WHEN {W} > 0 THEN {visits_col} / ({W}.0 / 365.25) ELSE NULL END AS {visits_per_year_col}""")

# Create table with visit counts for all windows
# Use regression_data directly - it already has patient_id and index_date for the analysis cohort
# regression_data is created in Step 7 and contains:
#   - patient_id, exposed, index_date, sex, socioeconomic_status, age_at_index
#   - Plus any confounders that were added
# We only need patient_id and index_date, so we can use regression_data directly
# This ensures we compute visits for the exact same cohort used in regression
console.print("  [yellow]Note:[/yellow] Using existing 'regression_data' table (created in Step 7)")
console.print("  [yellow]This ensures we compute visits for the exact same cohort used in regression[/yellow]")

visits_query = f"""
    CREATE OR REPLACE TEMP TABLE outpatient_visits_per_year_by_window AS
    WITH pop_with_end_date AS (
        SELECT 
            rd.patient_id,
            rd.index_date,
            p.observation_end_date
        FROM regression_data rd
        JOIN pop p ON rd.patient_id = p.patient_id
    )
    SELECT 
        r.patient_id,
        r.index_date,
        p.observation_end_date,
        DATE_DIFF('day', r.index_date, p.observation_end_date) AS follow_up_days,
        {','.join(visits_select_parts)}
    FROM regression_data r
    JOIN pop_with_end_date p ON r.patient_id = p.patient_id
"""

con.execute(visits_query)
console.print("  [green]✓[/green] Visit counts computed for all windows")

# Calculate visits per year for each window
visits_per_year_query = f"""
    CREATE OR REPLACE TEMP TABLE outpatient_visits_per_year_final AS
    SELECT 
        patient_id,
        index_date,
        follow_up_days,
        {','.join([get_visits_col(W) for W in WINDOWS])},
        {','.join(visits_per_year_select_parts)}
    FROM outpatient_visits_per_year_by_window
"""

con.execute(visits_per_year_query)
console.print("  [green]✓[/green] Visits per year calculated for all windows")

# Show summary statistics without fetching large dataframes
console.print("\n[bold]Summary statistics (computed in DuckDB, not fetched):[/bold]")
for W in WINDOWS:
    win_label = get_window_label(W)
    col = get_visits_per_year_col(W)
    
    summary = con.execute(f"""
        SELECT 
            COUNT(*) as n,
            AVG({col}) as mean,
            PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY {col}) as median,
            MIN({col}) as min_val,
            MAX({col}) as max_val
        FROM outpatient_visits_per_year_final
        WHERE {col} IS NOT NULL
    """).fetchdf()
    
    if len(summary) > 0 and summary['n'].iloc[0] > 0:
        console.print(f"  [cyan]Window {win_label}:[/cyan]")
        console.print(f"    Patients: {summary['n'].iloc[0]:,}")
        console.print(f"    Mean visits/year: {summary['mean'].iloc[0]:.2f}")
        console.print(f"    Median visits/year: {summary['median'].iloc[0]:.2f}")

# List all available columns
available_cols = [get_visits_per_year_col(W) for W in WINDOWS]
console.print(f"\n[bold green]✓[/bold green] Outpatient visits per year computed for all windows!")
console.print(f"  - Table 'outpatient_visits_per_year_final' contains visits per year for each window")
console.print(f"  - Columns: {', '.join(available_cols)}")
console.print("  - Ready to join to regression_data for each window-specific analysis")



In [None]:
# Get column names and sample data
columns_info = con.execute("DESCRIBE SELECT * FROM events LIMIT 0;").fetchdf()
console.print(f"  [green]Columns:[/green] {', '.join(columns_info['column_name'].tolist())}")

# Get a small sample to explore structure
console.print("\n[bold]Sample of events data (first 10 rows):[/bold]")
events_sample = con.execute("SELECT * FROM events LIMIT 10;").fetchdf()
display(events_sample)

# Show all unique (type, description) combinations
console.print("\n[bold]All unique (type, description) combinations:[/bold]")
unique_type_desc = con.execute("""
    SELECT 
        type,
        description,
        COUNT(*) as count
    FROM events
    GROUP BY type, description
    ORDER BY type, description
""").fetchdf()

console.print(f"  [green]Total unique combinations:[/green] {len(unique_type_desc):,}")
console.print(f"  [green]Total events:[/green] {unique_type_desc['count'].sum():,}")

# Display the unique combinations
display(unique_type_desc)

# Also show summary by type
console.print("\n[bold]Summary by type:[/bold]")
type_summary = con.execute("""
    SELECT 
        type,
        COUNT(DISTINCT description) as num_descriptions,
        COUNT(*) as total_events
    FROM events
    GROUP BY type
    ORDER BY type
""").fetchdf()
display(type_summary)
