## Stroke Work
<br>Author: Daniel Maina Nderitu<br>
Project: MADIVA<br>
Purpose: Publication-quality outputs<br>
Notes:   To rerun this whenever reviewers ask for tweaks.

#### Bootstrap cell

In [3]:
# =================== BOOTSTRAP CELL ===================
# Standard setup for all notebooks
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd().parents[0]  # assumes notebooks are in a subfolder
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

# ========================================================
import os
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from src.config.variables import COVARIATES

# ========================================================
# Optional for warnings and nicer plots
import warnings
warnings.filterwarnings("ignore")
sns.set(style="whitegrid")

import sys
from pathlib import Path

# ========================================================
# 1️⃣ Ensure project root is in Python path
# Adjust this if your notebooks are nested deeper
PROJECT_ROOT = Path.cwd().parents[0]  # assumes notebooks are in a subfolder
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

# ========================================================
# 2️⃣ Import helper to load paths
from src.utils.helpers import load_paths

# ========================================================
# 3️⃣ Load paths from config.yaml (works regardless of notebook location)
paths = load_paths()

# ========================================================
# 4️⃣ Optionally, print paths to confirm
for key, value in paths.items():
    print(f"{key}: {value}")

# ========================================================
# 5️⃣ Now you can use these paths in your notebook:
# Example:
DATA_DIR = paths['DATA_DIR']
OUT_DIR = paths['OUT_DIR']
FIG_DIR = paths['FIG_DIR']

# ========================================================

BASE_DIR: D:\APHRC\GoogleDrive_ii\stata_do_files\madiva\stroke_work
DATA_DIR: D:\APHRC\GoogleDrive_ii\stata_do_files\madiva\stroke_work\data
OUT_DIR: D:\APHRC\GoogleDrive_ii\stata_do_files\madiva\stroke_work\model_output
FIG_DIR: D:\APHRC\GoogleDrive_ii\stata_do_files\madiva\stroke_work\visualization


### Import data - from previous step

In [None]:
# data saved as pickle:
df = pd.read_pickle(OUT_DIR / "df_step03_processed.pkl")

#### Forest plots

In [None]:
import matplotlib.pyplot as plt

# =================================================================================  
# Ensure sorting by study start
# =================================================================================  
study_periods_sorted = study_periods.sort_values("study_start")

plt.figure(figsize=(12, 6))

for i, row in study_periods_sorted.iterrows():
    plt.plot([row['study_start'], row['study_end']], 
             [i, i], linewidth=6)
    plt.text(row['study_start'], i+0.1, str(row['source']), fontsize=10)

plt.yticks([])
plt.xlabel("Calendar Time")
plt.title("Study Periods by Project (source)")
plt.tight_layout()

# =================================================================================  

# =================================================================================  
plt.savefig(FIG_DIR / "study_periods_graph_main.png", dpi=300, bbox_inches='tight')

# =================================================================================  

# =================================================================================  
plt.show()
# plt.savefig("figure_name.png", dpi=300, bbox_inches='tight')
# plt.savefig("figure_name.pdf", bbox_inches='tight')   # optional
# plt.savefig("figure_name.svg", bbox_inches='tight')   # optional

#### Timeline plots

In [None]:
# Get unique IDs
unique_ids = df['individual_id'].drop_duplicates()

# Adjust sample size to the number of available individuals
sample_size = min(12, len(unique_ids))

# Take sample safely
sample_ids = unique_ids.sample(sample_size, random_state=42)

# Filter dataset for those individuals
df_sample = df[df['individual_id'].isin(sample_ids)]

# ---------------------------------------------------------------------------
# Create timeline visualization
fig = px.timeline(
    df_sample,
    x_start="start_date",
    x_end="end_date",
    y="individual_id",
    color="event",
    color_discrete_map={0: "skyblue", 1: "red"},
    title="Observation Timelines with Stroke Events",
)

# Update layout
fig.update_yaxes(title="Individual ID", categoryorder="total ascending")
fig.update_xaxes(title="Date")
fig.update_layout(
    legend_title_text="Stroke Event (1=Yes, 0=No)",
    template="plotly_white",
    height=600,
)

# Save (Plotly)
fig.write_image(
    FIG_DIR / "observation_timelines_Agincourt_Nairobi.png",
    scale=3  # improves resolution for presentations
)
fig.show()

#### Annotated figures

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# --- prepare a base variable order (use Poisson result as reference) ---
varlist = results_pois.loc[results_pois["Variable"] != "const", "Variable"].tolist()
n_vars = len(varlist)
base_y = np.arange(n_vars)

# vertical offsets so the 3 model points don't overlap
y_offsets = {
    'Poisson': -0.18,
    'Robust Poisson': 0.0,
    'NegBinomial': 0.18
}

colors = {
    'Poisson': '#1f77b4',
    'Robust Poisson': '#ff7f0e',
    'NegBinomial': '#2ca02c'
}

plt.figure(figsize=(10, 6))

# loop over models and plot
for model_name, df_irr in [
    ('Poisson', results_pois),
    ('Robust Poisson', results_robust),
    ('NegBinomial', results_nb)
]:
    # align to varlist so positions are consistent
    df_aligned = df_irr.set_index('Variable').reindex(varlist).reset_index()

    for j, row in df_aligned.iterrows():
        # skip variables with missing estimates
        if pd.isna(row['IRR']):
            continue

        y = base_y[j] + y_offsets[model_name]

        # compute left/right error for errorbar
        left = row['IRR'] - row['IRR_CI_lower']
        right = row['IRR_CI_upper'] - row['IRR']

        # Plot errorbar and point
        plt.errorbar(
            x=row['IRR'],
            y=y,
            xerr=np.array([[left], [right]]),
            fmt='o',
            capsize=4,
            color=colors[model_name],
            label=None  # legend handled below
        )

        # Add IRR text + significance star to the right of the point
        text_offset = 0.03 * (plt.xlim()[1] - plt.xlim()[0]) if plt.xlim()[1] > plt.xlim()[0] else 0.05
        # to avoid referencing plt.xlim() before any points are drawn, use a small constant fallback:
        if text_offset == 0:
            text_offset = 0.05
        plt.text(row['IRR'] + text_offset, y, f'{row["IRR"]:.2f}{row.get("sig","")}',
                 fontsize=9, color=colors[model_name], va='center')

# Add a manual legend (one marker per model)
for mn, color in colors.items():
    plt.plot([], [], 'o', color=color, label=mn)
plt.legend(title='Model')

# y-ticks and labels
plt.yticks(base_y, varlist)
plt.axvline(1, color='red', linestyle='--')
plt.xlabel("Incidence Rate Ratio (IRR)")
plt.title("Stroke Incidence – Model Comparison (IRR with 95% CI)")
plt.tight_layout()

# Save and show
plt.savefig(FIG_DIR / 'stroke_model_comparison_annotated_fixed_main.png', dpi=300)

plt.show()
print("Saved: stroke_model_comparison_annotated_fixed_main.png")

#### Export tables

In [None]:
# =================================================================================  
# Number of times each person is represented
# =================================================================================  
counts = df['individual_id'].value_counts()          # number of rows per person (index = individual_id)
counts_summary = counts.describe()                   # mean, min, max, median, etc.

print("\nRecords per individual (summary):")
print(counts_summary)

# =================================================================================  
# How many individuals have only one record (single-visit)?
# =================================================================================  
n_single = (counts == 1).sum()
pct_single = (n_single / counts.shape[0]) * 100

print(f"\nIndividuals with a single record: {n_single} ({pct_single:.1f}%)")

# =================================================================================  
# Show frequency distribution (top few)
# =================================================================================  
print("\nTop frequency counts (number of Agincourt persons with X records):")
freq_table = counts.value_counts().sort_index()      # index = number of records, value = # persons
print(freq_table.head(20))  # show first 20 rows; increase if needed

# # If you want the full distribution dataframe:
# freq_df = freq_table.reset_index().rename(columns={'index': 'n_records', 'individual_id': 'n_persons'})

# =================================================================================  
# Create a clean dataframe version
# =================================================================================  
freq_df = freq_table.reset_index(name="n_persons")
freq_df.columns = ['n_records', 'n_persons']  # rename safely

# =================================================================================  
# (Optional) If you intend to use months offsets in models:
#     prepare X,y, offset using offset_months
# =================================================================================  
# Example:
# offset_for_model = df['offset_months']
# y_for_model = df['event']
# X_for_model = df[covariates_present]  # after your usual pre-processing

#### Grouped IRR plots

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Drop rows with missing confidence intervals
plot_df = plot_df.dropna(subset=["IRR", "IRR_CI_lower", "IRR_CI_upper"]).copy()

# Define conceptual categories
category_map = {
    'sex_binary': 'Demographics',
    'alcohol_use': 'Lifestyle',
    'tobacco_use': 'Lifestyle',
    'bmi_category_Normal': 'Lifestyle', 
    'bmi_category_Overweight': 'Lifestyle', 
    'bmi_category_Obese': 'Lifestyle',
    'obese_status_derived': 'Comorbidities',
    'hpt_status_derived': 'Comorbidities',
    'diab_status_derived': 'Comorbidities',
    'hiv_status_derived': 'Comorbidities',
    'tb_status_derived': 'Comorbidities',
    'site_Nairobi': 'Site Effect'
}

# Define category colors
category_colors = {
    'Demographics': '#1f77b4',   # blue
    'Lifestyle': '#ff7f0e',      # orange
    'Comorbidities': '#2ca02c',  # green
    'Site Effect': '#9467bd'     # purple
}

# Assign category and color
plot_df['Category'] = plot_df['Variable'].map(category_map)
plot_df['Color'] = plot_df['Category'].map(category_colors).fillna('#999999')

# Sort by category and variable
plot_df = plot_df.sort_values(by=['Category', 'Variable'], ascending=True).reset_index(drop=True)

# Compute symmetric error bars
xerr = np.array([
    plot_df["IRR"] - plot_df["IRR_CI_lower"],
    plot_df["IRR_CI_upper"] - plot_df["IRR"]
])

# --- PLOT ---
plt.figure(figsize=(9, 6))

# Plot CIs
plt.errorbar(
    x=plot_df["IRR"],
    y=np.arange(len(plot_df)),
    xerr=xerr,
    fmt='o',
    capsize=4,
    color='gray',
    ecolor='lightgray',
    elinewidth=1.2,
    zorder=1
)

# Plot colored points + annotations
for idx, (_, row) in enumerate(plot_df.iterrows()):
    plt.scatter(row["IRR"], idx, color=row["Color"], s=60, zorder=2)
    irr_text = f'{row["IRR"]:.2f}'
    if "sig" in row and isinstance(row["sig"], str):
        irr_text += row["sig"]
    plt.text(row["IRR"] * 1.05, idx, irr_text, va='center', fontsize=9, color='black')

# Reference line
plt.axvline(1, color='red', linestyle='--', linewidth=1)

# Axis and labels
plt.yticks(np.arange(len(plot_df)), plot_df["Variable"])
plt.xlabel("Incidence Rate Ratio (IRR)")
plt.title("Negative Binomial Model – Stroke Incidence (Grouped by Category)")

# --- Legend in top-right corner ---
handles = [
    plt.Line2D([0], [0], marker='o', color='w', label=cat,
               markerfacecolor=color, markersize=8)
    for cat, color in category_colors.items()
]
plt.legend(handles=handles, title="Category", loc='upper right',
           bbox_to_anchor=(1, 1), frameon=False)
# plt.legend(title="Effect Direction", loc="upper right", frameon=True)

# Layout and aesthetics
plt.grid(axis='x', linestyle='--', alpha=0.3)
plt.tight_layout()

# Save and show
output_path = (
    FIG_DIR / 'stroke_model_nb_grouped_main.png'
)
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"✅ Saved: {output_path}")

#### End

In [42]:
# Saved as pickle (faster for large data, preserves types)
df.to_pickle(OUT_DIR / "df_step04_processed.pkl")