### Imports

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
import statsmodels.formula.api as smf
import statsmodels.api as sm
import seaborn as sns
import ast
import re
import warnings
from google.colab import drive
from ast import literal_eval
from statsmodels.regression.mixed_linear_model import MixedLM
from collections import Counter
from collections import defaultdict
from scipy.interpolate import UnivariateSpline
from scipy import stats
from scipy.stats import chi2
from scipy.stats import t
from datetime import date
from patsy import dmatrix

warnings.filterwarnings('ignore')
drive.mount('/content/drive')
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

Mounted at /content/drive


### Load and Prepare Data


In [None]:
# --- Load Data ---
summary_symptom_flares = pd.read_csv('/content/drive/My Drive/coreway_ml/Thesis - Mika/summary_symptom_flares.csv')
summary_regular_flares = pd.read_csv('/content/drive/My Drive/coreway_ml/Thesis - Mika/summary_regular_flares.csv')
subjective_flare_annotated = pd.read_csv('/content/drive/My Drive/coreway_ml/Thesis - Mika/subjective_flare_annotated.csv')
objective = pd.read_csv('/content/drive/My Drive/coreway_ml/Thesis - Mika/objective.csv')

# --- Keep relevant columns ---
summary_symptom_flares = summary_symptom_flares[['user_id', 'date_flare_onset', 'date_flare_end']]
summary_regular_flares = summary_regular_flares[['user_id', 'date_flare_onset', 'date_flare_end']]
subjective_flare_annotated = subjective_flare_annotated[["user_id", "date", "gender", "age", "diagnosis", "symptom_flare", "flare"]]
objective = objective[["user_id", "date", "provider", "REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"]]

# --- Merge datasets ---
merged = pd.merge(subjective_flare_annotated, objective, on=["user_id", "date"], how="inner")

### Association Analysis

In [None]:
def compute_emm(mixed_model, df, predictor):

    levels = df[predictor].unique()
    emms = []
    ses = []

    for lvl in levels:

        df_ref = df.copy()
        df_ref[predictor] = lvl

        # Build design matrix for the full dataset at this level
        exog_ref = dmatrix(mixed_model.model.data.design_info.builder, df_ref, return_type="dataframe")

        # Predicted values
        yhat = exog_ref @ mixed_model.fe_params.values

        # Average across population
        emm = yhat.mean()
        emms.append(emm)

        # Variance of the mean prediction
        cov_fe = mixed_model.cov_params()
        cov_fe = cov_fe.loc[exog_ref.columns, exog_ref.columns]
        mean_design = exog_ref.mean(axis=0).values.reshape(1, -1)
        var_emm = mean_design @ cov_fe.values @ mean_design.T
        ses.append(np.sqrt(var_emm.item()))

    # Compute confidence intervals
    tval = t.ppf(0.975, mixed_model.df_resid)
    ci_lower = np.array(emms) - tval * np.array(ses)
    ci_upper = np.array(emms) + tval * np.array(ses)

    return pd.DataFrame({
        predictor: levels,
        "EMM": emms,
        "SE": ses,
        "CI_lower": ci_lower,
        "CI_upper": ci_upper
    })

# --- Define metrics & predictors ---
sleep_metrics = ["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"]
flare_predictors = ["symptom_flare", "flare"]

# --- Run models ---
results = {}

for metric in sleep_metrics:

    for predictor in flare_predictors:

        formula = f"{metric} ~ gender + age + {predictor} + provider"
        df_model = merged[["user_id", metric, "gender", "age", predictor, "provider"]].dropna()

        # --- Homogeneous variance model ---
        m1 = smf.mixedlm(formula, data=df_model, groups=df_model["user_id"])
        res1 = m1.fit(reml=False)

        # --- Heterogeneous variance model ---
        m2 = smf.mixedlm(formula, data=df_model, groups=df_model["user_id"], vc_formula={predictor: f"C({predictor})"})
        res2 = m2.fit(reml=False)

        # --- Likelihood ratio test ---
        lr_stat = 2 * (res2.llf - res1.llf)
        df_diff = 1
        p_value = chi2.sf(lr_stat, df_diff)

        # Select final model
        if (p_value < 0.05) or (res2.aic < res1.aic):
            final_model = res2
            model_type = "heterogeneous"
        else:
            final_model = res1
            model_type = "homogeneous"

        marginal_means = compute_emm(final_model, df_model, predictor)

        # Store in dict
        key = f"{metric} ~ {predictor}"
        results[key] = {
            "final_model_type": model_type,
            "AIC_homogeneous": res1.aic,
            "AIC_heterogeneous": res2.aic,
            "LRT_stat": lr_stat,
            "LRT_df": df_diff,
            "LRT_p": p_value,
            "summary": final_model.summary(),
            "marginal_means": marginal_means,
            "final_model": final_model
        }

# --- Print results ---
for key, res in results.items():
    print("="*100)
    print("\n")
    print(f"Model: {key}")
    print(f"Final model chosen: {res['final_model_type']}")
    print(f"AIC (homogeneous): {res['AIC_homogeneous']:.2f}, "
          f"AIC (heterogeneous): {res['AIC_heterogeneous']:.2f}")
    print(f"LRT: stat={res['LRT_stat']:.3f}, df={res['LRT_df']}, p={res['LRT_p']:.4f}")
    print(res["summary"])
    print("\n--- Marginal means (95% CI) ---")
    print(res["marginal_means"])
    print("\n")

#  --- EMM summary table ---
summary_rows = []

for key, res in results.items():
    metric, predictor = key.split(" ~ ")

    # Extract marginal means
    mm = res["marginal_means"]
    mm_dict = {f"{row[predictor]}": f"{row['EMM']:.3f} ({row['SE']:.3f})"
               for _, row in mm.iterrows()}

    # Extract p-value for predictor (fix!)
    predictor_terms = [name for name in res["final_model"].pvalues.index
                       if name.startswith(predictor)]

    if len(predictor_terms) == 1:
        pval = res["final_model"].pvalues[predictor_terms[0]]
    elif len(predictor_terms) > 1:
        # placeholder: choose smallest p-value across levels
        pval = res["final_model"].pvalues[predictor_terms].max()
    else:
        pval = None

    row = {
        "metric": metric,
        "predictor": predictor,
        **mm_dict,
        "p_value": pval
    }
    summary_rows.append(row)

summary_df = pd.DataFrame(summary_rows)

summary_df

In [None]:
def parse_mean_se(cell):
    """
    Parse strings like '20.109 (0.503)' -> (20.109, 0.503)
    """
    if pd.isna(cell):
        return np.nan, np.nan
    m = re.match(r"\s*([-+]?\d*\.?\d+)\s*\(\s*([-+]?\d*\.?\d+)\s*\)\s*", str(cell))
    if m:
        return float(m.group(1)), float(m.group(2))
    # if already numeric, try to handle gracefully
    try:
        return float(cell), np.nan
    except Exception:
        return np.nan, np.nan

def stars_from_p(p):
    if p < 0.01:
        return "**"
    if p < 0.05:
        return "*"
    return ""

CLR_FLARE = "#e41a1c"
CLR_REMIS = "#f4a3b4"

# Titles and x-axis labels per metric
metric_titles = {
    "REM_pct": "Rem",
    "deep_pct": "Deep",
    "light_pct": "Light",
    "sleep_eff": "Sleep Efficiency",
    "dur_asleep": "Time Asleep",
}

xlabels = {
    "REM_pct": "% of Sleep (M ± SEM)",
    "deep_pct": "% of Sleep (M ± SEM)",
    "light_pct": "% of Sleep (M ± SEM)",
    "sleep_eff": "Sleep Measure (M ± SEM)",
    "dur_asleep": "Sleep Measure (M ± SEM)",
}

# Ensure ordering of columns
metric_order = ["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"]

# ----------------------------
# Prepare plotting data
# ----------------------------
# Split the table by predictor for quick lookup
tbl = summary_df.copy()

# Parse the "False" and "True" columns into mean & se
parsed = []
for _, row in tbl.iterrows():
    m_false, se_false = parse_mean_se(row["False"])
    m_true,  se_true  = parse_mean_se(row["True"])
    parsed.append({
        "metric": row["metric"],
        "predictor": row["predictor"],
        "mean_false": m_false,
        "se_false": se_false,
        "mean_true": m_true,
        "se_true": se_true,
        "p_value": float(row["p_value"]),
    })
tbl_parsed = pd.DataFrame(parsed)

# Convenience accessors for each predictor
def row_for(metric, predictor):
    return tbl_parsed[(tbl_parsed.metric == metric) & (tbl_parsed.predictor == predictor)].iloc[0]

# ----------------------------
# Plot
# ----------------------------
fig, axes = plt.subplots(2, 5, figsize=(16, 3.2), sharex='col', gridspec_kw=dict(wspace=0.1, hspace=0.35))

row_info = [("flare", "IBD Inflammatory Flare"), ("symptom_flare", "IBD Symptom Flare")]

for col_idx, metric in enumerate(metric_order):
    # Compute shared x-lims across both rows for this metric
    mins, maxs = [], []
    for predictor, _ylabel in row_info:
        r = row_for(metric, predictor)
        vals = [
            r["mean_false"] - r["se_false"],
            r["mean_true"]  - r["se_true"],
            r["mean_false"] + r["se_false"],
            r["mean_true"]  + r["se_true"],
        ]
        mins.append(np.nanmin(vals))
        maxs.append(np.nanmax(vals))
    xmin = np.nanmin(mins)
    xmax = np.nanmax(maxs)
    xrng = xmax - xmin
    pad = 0.08 * xrng if xrng > 0 else 0.5
    xmin -= pad
    xmax += pad

    for row_idx, (predictor, ylab) in enumerate(row_info):
        ax = axes[row_idx, col_idx]
        r = row_for(metric, predictor)

        # Two slightly offset y positions so the two dots don't overlap
        y_flare   = 0.58
        y_remis   = 0.42

        # Remission (False)
        ax.errorbar(
            r["mean_false"], y_remis,
            xerr=r["se_false"],
            fmt='o', ms=6, capsize=6, elinewidth=2, linewidth=2,
            color=CLR_REMIS, mec=CLR_REMIS
        )
        # Flare (True)
        ax.errorbar(
            r["mean_true"], y_flare,
            xerr=r["se_true"],
            fmt='o', ms=6, capsize=6, elinewidth=2, linewidth=2,
            color=CLR_FLARE, mec=CLR_FLARE
        )

        # Aesthetic tweaks
        ax.set_ylim(0.2, 0.8)
        ax.set_yticks([])  # hide per-panel y ticks; we add row labels on left only
        ax.set_xlim(xmin, xmax)
        ax.grid(axis='x', alpha=0.25)

        # Add significance stars (top-right)
        sig = stars_from_p(r["p_value"])
        if sig:
            xmin, xmax = ax.get_xlim()
            ymin, ymax = ax.get_ylim()

            x_star = xmax - 0.05 * (xmax - xmin)
            y_star = ymax - 0.10 * (ymax - ymin)

            ax.text(x_star, y_star, sig, ha='right', va='top', fontsize=12, color='k')

        # Titles on top row only
        if row_idx == 0:
            ax.set_title(metric_titles.get(metric, metric), fontsize=14, pad=6)

        # Bottom row x-labels
        if row_idx == 1:
            ax.set_xlabel(xlabels[metric], fontsize=10)

# Row labels on the left-most column
axes[0, 0].set_ylabel("rate_as_flare", fontsize=10, labelpad=10)
axes[1, 0].set_ylabel("symptom_deg", fontsize=10, labelpad=10)

# Tight layout with a little extra left margin for row labels
plt.subplots_adjust(left=0.10, right=0.98, top=0.88, bottom=0.20, wspace=0.35)

plt.show()

#### Physiological Metrics During Periods of Flares Compared with Periods of Remission (Restricted to User's that Contribute Both) - Welsh t-Test

In [None]:
# Setup
physiological_metrics = ["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"]
groups = ["CD", "UC", "IBD"]
flare_predictors = ["symptom_flare", "flare"]

all_results = []

# Compute differences & statistical tests
for flare_col in flare_predictors:
    results = []

    for feature in physiological_metrics:
        for diag in groups:
            # Subset depending on group
            if diag == "IBD":
                relevant = merged[merged["diagnosis"].isin(["CD", "UC"])]
            else:
                relevant = merged[merged["diagnosis"] == diag]

            #  Only keep users with both flare and remission data
            users_with_both = relevant.groupby("user_id")[flare_col].nunique()
            users_with_both = users_with_both[users_with_both == 2].index

            df_flare = relevant[(relevant["user_id"].isin(users_with_both)) & (relevant[flare_col] == True)]
            df_remission = relevant[(relevant["user_id"].isin(users_with_both)) & (relevant[flare_col] == False)]

            # Skip if no overlap
            if df_flare.empty or df_remission.empty:
                continue

            # Summary stats
            flare_mean, flare_std, flare_n = (df_flare[feature].mean(), df_flare[feature].std(), df_flare[feature].count())
            rem_mean, rem_std, rem_n = (df_remission[feature].mean(), df_remission[feature].std(), df_remission[feature].count())

            # Difference
            diff = flare_mean - rem_mean

            # Welch’s t-test
            t_stat, p_val = stats.ttest_ind(df_flare[feature].dropna(), df_remission[feature].dropna(), equal_var=False)

            # Confidence interval
            cm = sm.stats.CompareMeans.from_data(df_flare[feature].dropna(), df_remission[feature].dropna())
            ci_low, ci_high = cm.tconfint_diff(usevar='unequal')

            results.append({
                "flare_col": flare_col,
                "feature": feature,
                "diagnosis": diag,
                "flare_mean": flare_mean,
                "flare_std": flare_std,
                "flare_n": flare_n,
                "rem_mean": rem_mean,
                "rem_std": rem_std,
                "rem_n": rem_n,
                "diff": diff,
                "ci_low": ci_low,
                "ci_high": ci_high,
                "p_value": p_val
            })

    results_df = pd.DataFrame(results)

    # Significance stars
    def significance_stars(p):
        if p < 0.001:
            return "***"
        elif p < 0.01:
            return "**"
        elif p < 0.05:
            return "*"
        else:
            return ""

    results_df["stars"] = results_df["p_value"].apply(significance_stars)

    all_results.append(results_df)

# Combine results from both flare predictors
final_results_df = pd.concat(all_results, ignore_index=True)

# Display results
pd.set_option("display.float_format", "{:.3f}".format)
display_df = final_results_df[["flare_col", "feature", "diagnosis", "flare_mean", "rem_mean",  "flare_std", "flare_n", "rem_n", "rem_std", "p_value", "stars"]]

display_df

#### Physiological Metrics During Periods of Flares Compared with Periods of Remission (Restricted to User's that Contribute Both) - Paired t-Test

In [None]:
physiological_metrics = ["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"]
groups = ["CD", "UC", "IBD"]
flare_predictors = ["symptom_flare", "flare"]

all_results = []

for flare_col in flare_predictors:
    results = []

    for feature in physiological_metrics:
        for diag in groups:
            # Subset relevant users depending on group
            if diag == "IBD":
                relevant = merged[merged["diagnosis"].isin(["CD", "UC"])]
            else:
                relevant = merged[merged["diagnosis"] == diag]

            # If no rows at all for this group/feature, skip entirely
            if relevant.empty:
                continue

            # Keep only users with at least one flare/remission record
            # (we'll handle pairing below)
            df_user = (
                relevant.groupby(["user_id", flare_col])[feature]
                .mean()
                .unstack()
            )

            # Extract available states safely
            flare_vals_all = df_user[True] if True in df_user.columns else pd.Series([], dtype=float)
            rem_vals_all   = df_user[False] if False in df_user.columns else pd.Series([], dtype=float)

            # Summary stats using all available values (even if unpaired)
            flare_mean = flare_vals_all.mean() if not flare_vals_all.empty else np.nan
            flare_std  = flare_vals_all.std(ddof=1) if len(flare_vals_all) > 1 else np.nan
            flare_n    = int(flare_vals_all.count()) if not flare_vals_all.empty else 0

            rem_mean   = rem_vals_all.mean() if not rem_vals_all.empty else np.nan
            rem_std    = rem_vals_all.std(ddof=1) if len(rem_vals_all) > 1 else np.nan
            rem_n      = int(rem_vals_all.count()) if not rem_vals_all.empty else 0

            # Paired users = those with both states present (drop NaNs across both columns)
            if {True, False}.issubset(df_user.columns):
                df_pairs = df_user.dropna(subset=[True, False])
            else:
                df_pairs = pd.DataFrame(columns=[True, False])

            paired_n = len(df_pairs)

            # Differences & tests only from paired users
            if paired_n >= 1:
                diffs = df_pairs[True] - df_pairs[False]
                diff = diffs.mean()
            else:
                diffs = pd.Series([], dtype=float)
                diff = np.nan

            # Paired t-test & 95% CI only if we have at least 2 pairs
            if paired_n >= 2:
                t_stat, p_val = stats.ttest_rel(df_pairs[True], df_pairs[False], nan_policy="omit")

                # CI (uses t distribution with df = n-1)
                mean_diff = diffs.mean()
                se_diff = diffs.std(ddof=1) / np.sqrt(paired_n)
                t_crit = stats.t.ppf(0.975, paired_n - 1)
                ci_low = mean_diff - t_crit * se_diff
                ci_high = mean_diff + t_crit * se_diff
            else:
                p_val = np.nan
                ci_low = np.nan
                ci_high = np.nan

            results.append({
                "flare_col": flare_col,
                "feature": feature,
                "diagnosis": diag,
                "flare_mean": flare_mean,
                "flare_std": flare_std,
                "flare_n": flare_n,
                "rem_mean": rem_mean,
                "rem_std": rem_std,
                "rem_n": rem_n,
                "diff": diff,
                "ci_low": ci_low,
                "ci_high": ci_high,
                "p_value": p_val,
                "paired_n": paired_n
            })

    results_df = pd.DataFrame(results)

    # Significance stars (handle NaN p-values)
    def significance_stars(p):
        if pd.isna(p):
            return ""
        if p < 0.001:
            return "***"
        elif p < 0.01:
            return "**"
        elif p < 0.05:
            return "*"
        else:
            return ""

    results_df["stars"] = results_df["p_value"].apply(significance_stars)
    all_results.append(results_df)

# Combine results from both flare predictors
final_results_df = pd.concat(all_results, ignore_index=True)

# Display results
pd.set_option("display.float_format", "{:.3f}".format)
display_df = final_results_df[[
    "flare_col", "feature", "diagnosis", "paired_n",
    "flare_mean", "rem_mean",
    "flare_std", "rem_std", "p_value", "stars"
]]

display_df

### Causal Inference Analysis

#### Trajectories of Sleep Metrics Around Flares

##### Aggregated

In [None]:
def plot_metrics_around_flares(
    flare_df,
    objective,
    metrics,
    pre_days=45,
    post_days=45,
    ncols=5,
    poly_degree=3,
    min_days_threshold=7,
    method="AND",
    raw_scatter_jitter=0.12,
    raw_scatter_size=12,
    raw_scatter_alpha=0.75,
    rng_seed=42
):

    # ---------- Display names + font sizes ----------
    metric_display = {
        "REM_pct": "REM Sleep (%)",
        "deep_pct": "Deep Sleep (%)",
        "light_pct": "Light Sleep (%)",
        "sleep_eff": "Sleep Efficiency (%)",
        "dur_asleep": "Total Time Asleep (h)",
    }
    title_fontsize = 18
    label_fontsize = 15
    # ------------------------------------------------

    # Copy + ensure datetime
    flare = flare_df.copy()
    obj_daily = objective.copy()
    flare["date_flare_onset"] = pd.to_datetime(flare["date_flare_onset"])
    flare["date_flare_end"] = pd.to_datetime(flare["date_flare_end"])
    obj_daily["date"] = pd.to_datetime(obj_daily["date"])

    # relative index for x-axis
    rel_index = list(range(-pre_days, 0)) + [0] + list(range(1, post_days + 1))

    # storage
    metric_values = {m: defaultdict(list) for m in metrics}
    contributing_flares = {m: set() for m in metrics}
    flare_lengths = {m: [] for m in metrics}
    datapoints_within_flare = {m: [] for m in metrics}
    raw_points_pre = {m: 0 for m in metrics}
    raw_points_post = {m: 0 for m in metrics}

    # iterate over flares
    for flare_idx, row in flare.iterrows():
        uid, onset, end = row["user_id"], row["date_flare_onset"], row["date_flare_end"]

        # temp buffers for this flare (only merged if flare passes thresholds)
        this_flare_values = {m: defaultdict(list) for m in metrics}
        contributed_this_flare = {m: False for m in metrics}
        points_within = {m: 0 for m in metrics}
        pre_points_count, post_points_count = 0, 0

        # --- compute subject's mean during flare for normalization ---
        mask = (
            (obj_daily["user_id"] == uid)
            & (obj_daily["date"] >= onset)
            & (obj_daily["date"] <= end)
        )
        flare_vals = obj_daily.loc[mask, metrics]
        flare_means = flare_vals.mean().to_dict()

        # Rule 1: skip if 0 datapoints during flare
        if flare_vals.empty:
            continue

        # pre days (buffer only)
        for d in range(-pre_days, 0):
            qdate = onset + pd.Timedelta(days=d)
            vals = obj_daily[(obj_daily["user_id"] == uid) & (obj_daily["date"] == qdate)]
            if not vals.empty:
                pre_points_count += 1
                for m in metrics:
                    v = vals[m].mean()
                    if pd.notna(v) and pd.notna(flare_means[m]):
                        v_norm = v - flare_means[m]
                        this_flare_values[m][d].append(v_norm)
                        contributed_this_flare[m] = True

        # post days (buffer only)
        for d in range(1, post_days + 1):
            qdate = end + pd.Timedelta(days=d)
            vals = obj_daily[(obj_daily["user_id"] == uid) & (obj_daily["date"] == qdate)]
            if not vals.empty:
                post_points_count += 1
                for m in metrics:
                    v = vals[m].mean()
                    if pd.notna(v) and pd.notna(flare_means[m]):
                        v_norm = v - flare_means[m]
                        this_flare_values[m][d].append(v_norm)
                        contributed_this_flare[m] = True

        # Rule 2: flexible datapoint threshold (decide BEFORE merging buffered data)
        if method.upper() == "AND":
            if (pre_points_count < min_days_threshold) or (post_points_count < min_days_threshold):
                continue
        elif method.upper() == "OR":
            if (pre_points_count < min_days_threshold) and (post_points_count < min_days_threshold):
                continue
        else:
            raise ValueError("method must be 'AND' or 'OR'")

        # Flare passed thresholds → merge buffered data and add day 0
        # collapse flare to 0 (normalized to its own mean)
        if mask.any():
            for m in metrics:
                v = flare_means[m]
                if pd.notna(v):
                    v_norm0 = 0.0
                    metric_values[m][0].append(v_norm0)
                    contributed_this_flare[m] = True
                    points_within[m] = int(mask.sum())

        # merge buffered pre/post values; update raw counters for INCLUDED flares only
        for m in metrics:
            # merge values
            for d, lst in this_flare_values[m].items():
                metric_values[m][d].extend(lst)

            # update raw counts
            pre_added = sum(len(this_flare_values[m][d]) for d in range(-pre_days, 0))
            post_added = sum(len(this_flare_values[m][d]) for d in range(1, post_days + 1))
            raw_points_pre[m] += pre_added
            raw_points_post[m] += post_added

        # update contributing sets and flare stats (included flares only)
        flare_len = (end - onset).days + 1
        for m in metrics:
            if contributed_this_flare[m]:
                contributing_flares[m].add(flare_idx)
                flare_lengths[m].append(flare_len)
                if points_within[m] > 0:
                    datapoints_within_flare[m].append(points_within[m])

    # setup subplot grid
    nrows = int(np.ceil(len(metrics) / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 4 * nrows), squeeze=False)

    rng = np.random.default_rng(rng_seed)
    results = {}

    for i, m in enumerate(metrics):
        ax = axes[i // ncols, i % ncols]

        # aggregate per relative day (for mean curve + counts)
        x_vals, means = [], []
        for d in rel_index:
            vals = metric_values[m].get(d, [])
            if len(vals) > 0:
                mean = np.mean(vals)
            else:
                mean = np.nan
            x_vals.append(d)
            means.append(mean)

        x = np.array(x_vals)
        y = np.array(means)

        # ---------- RAW POINTS: scatter all contributing values ----------
        # We add a tiny horizontal jitter so overlapping points are visible.
        raw_x, raw_y = [], []
        for d in rel_index:
            vals = metric_values[m].get(d, [])
            if len(vals) == 0:
                continue
            # jitter per point
            jitter = rng.uniform(-raw_scatter_jitter, raw_scatter_jitter, size=len(vals))
            raw_x.extend(d + jitter)
            raw_y.extend(vals)
        if len(raw_x) > 0:
            ax.scatter(
                raw_x,
                raw_y,
                s=raw_scatter_size,
                c="lightgrey",
                alpha=raw_scatter_alpha,
                edgecolors="none"
            )
        # -----------------------------------------------------------------

        # polynomial regression fit (to per-day means)
        valid = ~np.isnan(y)
        coeffs = np.polyfit(x[valid], y[valid], deg=poly_degree) if valid.sum() >= poly_degree + 1 else None

        if coeffs is not None:
            xsmooth = np.linspace(x.min(), x.max(), 400)
            ysmooth = np.polyval(coeffs, xsmooth)
            ax.plot(xsmooth, ysmooth, color="blue", linewidth=2)

        ax.axvline(0, linestyle="--", color="red", linewidth=0.8)

        # ---------- Titles & Labels ----------
        display_name = metric_display.get(m, m)
        ax.set_title(display_name, fontsize=title_fontsize)
        ax.set_xlabel("Days relative to flare (0 = flare mean)", fontsize=label_fontsize)
        ax.set_ylabel("Δ " + display_name, fontsize=label_fontsize)
        # ------------------------------------

        ax.grid(True, linestyle=":", linewidth=0.5)

        # counts (raw, not collapsed) — INCLUDED flares only
        pre_points = raw_points_pre[m]
        flare_points = int(np.sum(datapoints_within_flare[m])) if datapoints_within_flare[m] else 0
        post_points = raw_points_post[m]
        total_points = pre_points + flare_points + post_points
        num_flares = len(contributing_flares[m])

        avg_len = np.mean(flare_lengths[m]) if flare_lengths[m] else np.nan
        avg_points_within = np.mean(datapoints_within_flare[m]) if datapoints_within_flare[m] else np.nan

        print(
            f"[{m}] {num_flares} contributing flares | "
            f"avg length: {avg_len:.2f} days | "
            f"avg datapoints within flare: {avg_points_within:.2f} | "
            f"pre: {pre_points} pts, flare: {flare_points} pts, post: {post_points} pts "
            f"(total={total_points})"
        )

        results[m] = {
            "coeffs": coeffs,
            "num_flares": num_flares,
            "avg_flare_length": avg_len,
            "avg_points_within_flare": avg_points_within,
            "raw_counts": {"pre": pre_points, "flare": flare_points, "post": post_points},
        }

    # hide unused axes
    nrows = int(np.ceil(len(metrics) / ncols))
    for j in range(len(metrics), nrows * ncols):
        axes[j // ncols, j % ncols].axis("off")

    fig.tight_layout()
    plt.show()

    return results

In [None]:
results = plot_metrics_around_flares(summary_regular_flares, objective, metrics=["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"], min_days_threshold=10, method="AND")

In [None]:
results = plot_metrics_around_flares(summary_symptom_flares, objective, metrics=["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"], min_days_threshold=10, method="AND")

##### Individual

In [None]:
def list_plottable_flares(
    flare_df,
    objective,
    metrics,
    pre_days=45,
    post_days=45,
    min_days_threshold=7,
    method="OR",
):

    # normalize inputs
    if isinstance(metrics, str):
        metrics = [metrics]

    flare = flare_df.copy()
    obj = objective.copy()

    flare["date_flare_onset"] = pd.to_datetime(flare["date_flare_onset"])
    flare["date_flare_end"] = pd.to_datetime(flare["date_flare_end"])
    obj["date"] = pd.to_datetime(obj["date"])

    # Pre-aggregate daily means per metric to match plotting logic
    obj_daily_all = {}
    for metric in metrics:
        obj_metric = obj.copy()
        obj_metric[metric] = pd.to_numeric(obj_metric[metric], errors="coerce")
        obj_daily_all[metric] = (
            obj_metric.groupby(["user_id", "date"], as_index=False)[metric].mean()
        )

    rel_index = list(range(-pre_days, 0)) + [0] + list(range(1, post_days + 1))
    plottable_indices = []

    # Check each flare once
    for flare_idx, row in flare.iterrows():
        uid, onset, end = row["user_id"], row["date_flare_onset"], row["date_flare_end"]

        # if ANY metric qualifies, we keep this flare
        any_metric_ok = False

        for metric in metrics:
            obj_daily = obj_daily_all[metric]

            # values during flare (for normalization)
            mask = (obj_daily["user_id"] == uid) & (obj_daily["date"] >= onset) & (obj_daily["date"] <= end)
            flare_vals = obj_daily.loc[mask, metric]

            # Rule 1: must have some data during flare
            if flare_vals.empty:
                continue

            flare_mean = flare_vals.mean()
            if pd.isna(flare_mean):
                continue

            # Count pre/post days with any data
            pre_points_count = sum(
                not obj_daily[(obj_daily["user_id"] == uid) &
                              (obj_daily["date"] == (onset + pd.Timedelta(days=d)))].empty
                for d in range(-pre_days, 0)
            )
            post_points_count = sum(
                not obj_daily[(obj_daily["user_id"] == uid) &
                              (obj_daily["date"] == (end + pd.Timedelta(days=d)))].empty
                for d in range(1, post_days + 1)
            )

            # Rule 2: threshold logic
            if method.upper() == "AND":
                if (pre_points_count < min_days_threshold) or (post_points_count < min_days_threshold):
                    continue
            elif method.upper() == "OR":
                if (pre_points_count < min_days_threshold) and (post_points_count < min_days_threshold):
                    continue
            else:
                raise ValueError("method must be 'AND' or 'OR'")

            # Build the highlighted trajectory (same as plotting) — we only need to know
            # if there is at least one valid, normalized point to plot.
            has_any_point = False
            for d in rel_index:
                if d < 0:
                    qdate = onset + pd.Timedelta(days=d)
                    vals = obj_daily[(obj_daily["user_id"] == uid) & (obj_daily["date"] == qdate)][metric]
                elif d == 0:
                    vals = flare_vals
                else:
                    qdate = end + pd.Timedelta(days=d)
                    vals = obj_daily[(obj_daily["user_id"] == uid) & (obj_daily["date"] == qdate)][metric]

                if not vals.empty:
                    v = vals.mean()
                    if pd.notna(v):
                        has_any_point = True
                        break

            if has_any_point:
                any_metric_ok = True
                break

        if any_metric_ok:
            plottable_indices.append(flare_idx)

    return plottable_indices


def plot_metrics_with_spaghetti_highlight(
    flare_df,
    objective,
    metrics,
    pre_days=45,
    post_days=45,
    poly_degree=3,
    highlight=None,
    highlight_color="orange",
    sharey=False,
    min_days_threshold=7,
    method="OR",
    raw_scatter_jitter=0.12,
    raw_scatter_size=12,
    raw_scatter_alpha=0.75,
    highlight_scatter_size=20,
    rng_seed=42
):

    # ---------- Display names + font sizes ----------
    metric_display = {
        "REM_pct": "REM Sleep (%)",
        "deep_pct": "Deep Sleep (%)",
        "light_pct": "Light Sleep (%)",
        "sleep_eff": "Sleep Efficiency (%)",
        "dur_asleep": "Total Time Asleep (h)",
    }
    title_fontsize = 18
    label_fontsize = 15
    # ------------------------------------------------

    # Ensure metrics is a list
    if isinstance(metrics, str):
        metrics = [metrics]

    flare = flare_df.copy()
    obj = objective.copy()
    flare["date_flare_onset"] = pd.to_datetime(flare["date_flare_onset"])
    flare["date_flare_end"] = pd.to_datetime(flare["date_flare_end"])
    obj["date"] = pd.to_datetime(obj["date"])

    rel_index = list(range(-pre_days, 0)) + [0] + list(range(1, post_days + 1))
    rng = np.random.default_rng(rng_seed)

    # Build figure; if nothing to highlight, we'll close it.
    fig, axes = plt.subplots(
        1, len(metrics), figsize=(7 * len(metrics), 5), sharey=sharey
    )
    if len(metrics) == 1:
        axes = [axes]

    has_highlight_any_metric = False

    # Iterate over metrics
    for m_idx, (ax, metric) in enumerate(zip(axes, metrics)):
        # aggregate daily per user (keep metric numeric only)
        obj[metric] = pd.to_numeric(obj[metric], errors="coerce")
        obj_daily = obj.groupby(["user_id", "date"], as_index=False)[metric].mean()

        # store values per day for group aggregation
        metric_values = {d: [] for d in rel_index}
        # highlighted trajectory for this metric (only if same flare index)
        highlight_traj = None  # tuple: (uid, flare_idx, days_np, values_np)

        # iterate over flares
        for flare_idx, row in flare.iterrows():
            uid, onset, end = row["user_id"], row["date_flare_onset"], row["date_flare_end"]

            # --- compute subject's mean during flare for normalization ---
            mask = (obj_daily["user_id"] == uid) & (obj_daily["date"] >= onset) & (obj_daily["date"] <= end)
            flare_vals = obj_daily.loc[mask, metric]

            # Rule 1: skip if 0 datapoints during flare
            if flare_vals.empty:
                continue

            flare_mean = flare_vals.mean()
            if pd.isna(flare_mean):
                continue

            # --- check pre and post datapoints ---
            pre_points_count, post_points_count = 0, 0
            for d in range(-pre_days, 0):
                qdate = onset + pd.Timedelta(days=d)
                if not obj_daily[(obj_daily["user_id"] == uid) & (obj_daily["date"] == qdate)].empty:
                    pre_points_count += 1
            for d in range(1, post_days + 1):
                qdate = end + pd.Timedelta(days=d)
                if not obj_daily[(obj_daily["user_id"] == uid) & (obj_daily["date"] == qdate)].empty:
                    post_points_count += 1

            # Rule 2: flexible datapoint threshold
            if method.upper() == "AND":
                if (pre_points_count < min_days_threshold) or (post_points_count < min_days_threshold):
                    continue
            elif method.upper() == "OR":
                if (pre_points_count < min_days_threshold) and (post_points_count < min_days_threshold):
                    continue
            else:
                raise ValueError("method must be 'AND' or 'OR'")

            # --- build trajectory (per-day normalized values for this flare) ---
            traj, days = [], []
            for d in rel_index:
                if d < 0:
                    qdate = onset + pd.Timedelta(days=d)
                    vals = obj_daily[(obj_daily["user_id"] == uid) & (obj_daily["date"] == qdate)][metric]
                elif d == 0:
                    vals = flare_vals
                else:  # post days
                    qdate = end + pd.Timedelta(days=d)
                    vals = obj_daily[(obj_daily["user_id"] == uid) & (obj_daily["date"] == qdate)][metric]

                if not vals.empty:
                    v = vals.mean()
                    if pd.notna(v):
                        v_norm = v - flare_mean
                        traj.append(v_norm)
                        metric_values[d].append(v_norm)
                    else:
                        traj.append(np.nan)
                else:
                    traj.append(np.nan)
                days.append(d)

            # If this is the selected flare, keep its trajectory for *this* metric
            if highlight is not None and flare_idx == highlight:
                if np.any(pd.notna(traj)):
                    highlight_traj = (uid, flare_idx, np.array(days, dtype=float), np.array(traj, dtype=float))

        # ------------------------------------------------------------------
        # 1) RAW POINTS: Plot ALL contributing datapoints as light grey dots
        # ------------------------------------------------------------------
        raw_x, raw_y = [], []
        for d in rel_index:
            vals = metric_values[d]
            if len(vals) == 0:
                continue
            jitter = rng.uniform(-raw_scatter_jitter, raw_scatter_jitter, size=len(vals))
            raw_x.extend(d + jitter)
            raw_y.extend(vals)
        if len(raw_x) > 0:
            ax.scatter(
                raw_x, raw_y,
                s=raw_scatter_size,
                c="lightgrey",
                alpha=raw_scatter_alpha,
                edgecolors="none",
                label="Raw datapoints"
            )

        # ------------------------------------------------------------------
        # 2) GROUP MEAN POLYNOMIAL (blue) — fit to per-day means (with intercept)
        # ------------------------------------------------------------------
        x_vals, means = [], []
        for d in rel_index:
            vals = metric_values[d]
            mean = np.mean(vals) if len(vals) > 0 else np.nan
            x_vals.append(d)
            means.append(mean)

        x = np.array(x_vals, dtype=float)
        y = np.array(means, dtype=float)
        valid = ~np.isnan(y)

        if valid.sum() >= poly_degree + 1:
            coeffs = np.polyfit(x[valid], y[valid], deg=poly_degree)
            xsmooth = np.linspace(x.min(), x.max(), 400)
            ysmooth = np.polyval(coeffs, xsmooth)
            ax.plot(xsmooth, ysmooth, color="blue", linewidth=2, label=f"Group poly{poly_degree}")

        # ------------------------------------------------------------------
        # 3) HIGHLIGHTED TRAJECTORY: orange points + ORIGIN-ANCHORED polynomial
        # ------------------------------------------------------------------
        if highlight_traj is not None:
            uid, fidx, days, values = highlight_traj
            hv_valid = ~np.isnan(values)

            if hv_valid.any():
                has_highlight_any_metric = True

                # orange scatter points (jittered)
                hjitter = rng.uniform(-raw_scatter_jitter, raw_scatter_jitter, size=hv_valid.sum())
                uid_short = str(uid)[:7] + "..."
                ax.scatter(
                    days[hv_valid] + hjitter,
                    values[hv_valid],
                    s=highlight_scatter_size,
                    c=highlight_color,
                    alpha=0.95,
                    label=f"Highlight points (uid={uid_short}, flare={fidx})"
                )

                # Fit polynomial THROUGH ORIGIN (0,0)
                xh = days[hv_valid].astype(float)
                yh = values[hv_valid].astype(float)
                # Require at least one nonzero x to avoid singularity
                nonzero = xh != 0
                if nonzero.any():
                    deg_anchor = max(1, min(poly_degree, nonzero.sum()))
                    X = np.column_stack([xh[nonzero]**k for k in range(1, deg_anchor + 1)])
                    try:
                        coefs, *_ = np.linalg.lstsq(X, yh[nonzero], rcond=None)
                        xs = np.linspace(xh.min(), xh.max(), 300)
                        Xs = np.column_stack([xs**k for k in range(1, deg_anchor + 1)])
                        ys = Xs @ coefs
                        ax.plot(xs, ys, color=highlight_color, linewidth=2.5,
                                label=f"Highlight poly{deg_anchor} (uid={uid_short})")
                    except np.linalg.LinAlgError:
                        pass

        # cosmetics
        ax.axvline(0, linestyle="--", color="red", linewidth=0.8)
        display_name = metric_display.get(metric, metric)
        ax.set_title(display_name, fontsize=title_fontsize)
        ax.set_xlabel("Days relative to flare (0 = flare mean)", fontsize=label_fontsize)
        ax.set_ylabel("Δ " + display_name, fontsize=label_fontsize)
        ax.grid(True, linestyle=":", linewidth=0.5)

        # only put legend on the first axis
        if m_idx == 0:
            ax.legend()

    # If no metric had a valid highlighted trajectory, close and skip
    if (highlight is not None) and (not has_highlight_any_metric):
        plt.close(fig)
        return False

    plt.tight_layout()
    plt.show()
    return True

In [None]:
plottable = list_plottable_flares(summary_symptom_flares, objective, metrics=["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"], pre_days=45, post_days=45, min_days_threshold=10, method="AND")
for idx in plottable:
     plot_metrics_with_spaghetti_highlight(summary_symptom_flares, objective, metrics=["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"], pre_days=45, post_days=45, min_days_threshold=10, method="AND", highlight=idx)

In [None]:
plottable = list_plottable_flares(summary_regular_flares, objective, metrics=["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"], pre_days=45, post_days=45, min_days_threshold=10, method="AND")
for idx in plottable:
     plot_metrics_with_spaghetti_highlight(summary_regular_flares, objective, metrics=["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"], pre_days=45, post_days=45, min_days_threshold=10, method="AND", highlight=idx)

#### Rate of Change in Sleep Metrics Preceding and Following Flares

In [None]:
# ==========================================================
# CONFIG
# ==========================================================
METRICS = ["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"]
PRE_DAYS  = 45
POST_DAYS = 45

# ==========================================================
# HELPERS
# ==========================================================
def _sig_marker(p):
    if not np.isfinite(p): return "ns"
    if p < 0.01: return "**"
    if p < 0.05: return "*"
    if p < 0.10: return "+"
    return "ns"

def _sem(x):
    x = np.asarray(x, dtype=float)
    x = x[~np.isnan(x)]
    if x.size <= 1:
        return np.nan
    return stats.sem(x)

def _fit_slope_through_origin(days, values):
    days = np.asarray(days, dtype=float)
    values = np.asarray(values, dtype=float)
    msk = ~np.isnan(days) & ~np.isnan(values)
    days, values = days[msk], values[msk]
    if days.size < 2:
        return np.nan
    return np.sum(days * values) / np.sum(days**2)

def _fmt_slope(m, se):
    if not np.isfinite(m) or not np.isfinite(se):
        return "–"
    return f"{m:.3f} (SEM: {se:.3f})"

def _fmt_p(p):
    if not np.isfinite(p):
        return "–"
    if p < 0.001:
        return "< .001"
    s = f"{p:.2f}"
    if s.startswith("0"):
        s = s[1:]
    return f"= {s}"

def _span_correction(days_array, ref_span_days):
    d = np.asarray(days_array, dtype=float)
    if d.size == 0 or not np.isfinite(d).any():
        return 1.0
    span = np.nanmax(d) - np.nanmin(d)
    if not np.isfinite(span) or span <= 0:
        return 1.0
    if not np.isfinite(ref_span_days) or ref_span_days <= 0:
        return 1.0
    return float(span) / float(ref_span_days)

def compute_prepost_slopes_raw(flare_df, objective,
                               metrics=METRICS,
                               pre_days=PRE_DAYS,
                               post_days=POST_DAYS,
                               min_days_threshold=7,
                               method="OR"):
    df = objective.copy()
    df["date"] = pd.to_datetime(df["date"])
    fl = flare_df.copy()
    fl["date_flare_onset"] = pd.to_datetime(fl["date_flare_onset"])
    fl["date_flare_end"]   = pd.to_datetime(fl["date_flare_end"])

    out_rows = []

    for flare_id, row in fl.reset_index(drop=True).iterrows():
        uid   = row["user_id"]
        start = row["date_flare_onset"]
        end   = row["date_flare_end"]

        u = df[df["user_id"] == uid].copy()
        if u.empty:
            continue

        # windows
        pre   = u[(u["date"] >= start - pd.Timedelta(days=pre_days)) & (u["date"] <  start)]
        post  = u[(u["date"] >  end) & (u["date"] <= end + pd.Timedelta(days=post_days))]
        flare = u[(u["date"] >= start) & (u["date"] <= end)]

        # Rule 1: skip if 0 datapoints during flare
        if flare.empty:
            continue

        # Rule 2: flexible datapoint threshold requirement (applies to pre/post only)
        n_pre, n_post = len(pre), len(post)
        if method.upper() == "AND":
            if (n_pre < min_days_threshold) or (n_post < min_days_threshold):
                continue
        elif method.upper() == "OR":
            if (n_pre < min_days_threshold) and (n_post < min_days_threshold):
                continue
        else:
            raise ValueError("method must be 'AND' or 'OR'")

        # normalization base (per flare, per metric): mean during the flare
        flare_means = flare[metrics].mean()

        for metric in metrics:
            base = flare_means.get(metric, np.nan)
            if not np.isfinite(base):
                continue

            # PRE slope (days negative; approach to flare start) + span normalization
            if len(pre) >= min_days_threshold:
                d = (pre["date"] - start).dt.days.values
                y = pre[metric].values.astype(float) - base
                slope_day = _fit_slope_through_origin(d, y)
                slope_day *= _span_correction(d, pre_days)
                out_rows.append(dict(flare_id=flare_id, user_id=uid, phase="pre",
                                     metric=metric, slope_per_week=slope_day * 7.0))

            # FLARE slope (within-flare, anchored at flare start) — no span normalization
            if len(flare) >= 2:
                d = (flare["date"] - start).dt.days.values
                y = flare[metric].values.astype(float) - base
                slope_day = _fit_slope_through_origin(d, y)
                out_rows.append(dict(flare_id=flare_id, user_id=uid, phase="flare",
                                     metric=metric, slope_per_week=slope_day * 7.0))

            # POST slope (after flare end) + span normalization
            if len(post) >= min_days_threshold:
                d = (post["date"] - end).dt.days.values
                y = post[metric].values.astype(float) - base
                slope_day = _fit_slope_through_origin(d, y)
                slope_day *= _span_correction(d, post_days)
                out_rows.append(dict(flare_id=flare_id, user_id=uid, phase="post",
                                     metric=metric, slope_per_week=slope_day * 7.0))

    slopes = pd.DataFrame(out_rows)
    return slopes

def summarize_and_test_all(slopes, metrics=METRICS):

    # summary by phase
    summary = (slopes
               .groupby(["metric", "phase"])["slope_per_week"]
               .agg(M="mean", SEM=_sem, N="count")
               .reset_index())

    # paired tests (+ paired N)
    tests = []
    for metric in metrics:
        wide = (slopes[slopes["metric"] == metric]
                .pivot_table(index=["flare_id", "user_id"],
                             columns="phase",
                             values="slope_per_week"))

        def _paired_p_and_n(col_a, col_b):
            if (col_a in wide.columns) and (col_b in wide.columns):
                paired = wide.dropna(subset=[col_a, col_b])
                n = int(paired.shape[0])
                if n >= 2:
                    _, p = stats.ttest_rel(paired[col_a], paired[col_b])
                    return float(p), n
                else:
                    return np.nan, n
            return np.nan, 0

        p_pp, n_pp   = _paired_p_and_n("pre", "post")
        p_pf, n_pf   = _paired_p_and_n("pre", "flare")
        p_fp, n_fp   = _paired_p_and_n("flare", "post")

        tests.append(dict(metric=metric,
                          comparison="Preflare vs. Postflare",
                          p=p_pp, N=n_pp))
        tests.append(dict(metric=metric,
                          comparison="Preflare vs. Flare",
                          p=p_pf, N=n_pf))
        tests.append(dict(metric=metric,
                          comparison="Flare vs. Postflare",
                          p=p_fp, N=n_fp))

    tests_all = pd.DataFrame(tests)
    return summary, tests_all

def build_prepost_table(label, summary, tests_all, metrics=METRICS, metric_name_map=None):

    rows = []
    metric_name_map = metric_name_map or {}

    # helper to fetch slope/sem for a (metric, phase)
    def _get(summary, m, ph):
        sm = summary[(summary["metric"] == m) & (summary["phase"] == ph)]
        if sm.empty:
            return np.nan, np.nan
        return float(sm["M"]), float(sm["SEM"])

    for m in metrics:
        m_pre,  se_pre  = _get(summary, m, "pre")
        m_flr,  se_flr  = _get(summary, m, "flare")
        m_post, se_post = _get(summary, m, "post")

        for comp in ["Preflare vs. Postflare", "Preflare vs. Flare", "Flare vs. Postflare"]:
            p_row = tests_all[(tests_all["metric"] == m) & (tests_all["comparison"] == comp)]
            p = float(p_row["p"]) if not p_row.empty else np.nan
            n = int(p_row["N"]) if (not p_row.empty and np.isfinite(p_row["N"]).all()) else 0

            rows.append(dict(
                Section=label,
                Measurement=metric_name_map.get(m, m).replace("_", " ").title(),
                **{
                    "n (flares)": f"n = {n}" if n > 0 else "n = 0",
                    "Preflare Slope":  _fmt_slope(m_pre,  se_pre),
                    "Flare Slope":     _fmt_slope(m_flr,  se_flr),
                    "Postflare Slope": _fmt_slope(m_post, se_post),
                    "Comparison": comp,
                    "P-value": _fmt_p(p)
                }
            ))

    return pd.DataFrame(rows, columns=[
        "Section", "Measurement", "n (flares)", "Preflare Slope", "Flare Slope",
        "Postflare Slope", "Comparison", "P-value"
    ])

def plot_change_per_week_shared(groups, metrics=METRICS, metric_name_map=None,
                                p_light=0.10, alpha_sig=1.0, alpha_ns=0.35):

    metric_name_map = metric_name_map or {
        "REM_pct": "REM Sleep (%)",
        "deep_pct": "Deep Sleep (%)",
        "light_pct": "Light Sleep (%)",
        "sleep_eff": "Sleep Efficiency (%)",
        "dur_asleep": "Total Time Asleep (h)",
    }

    colors = dict(pre="#6A1B9A", post="#FF00B3")
    label_map = {"pre": "Before Flare", "post": "After Flare"}

    G = len(groups)
    M = len(metrics)
    if G == 0 or M == 0:
        return

    fig, axes = plt.subplots(G, M, figsize=(3*M, 1.7*G), squeeze=False)

    # ---- compute shared x-lims per metric across all groups (PRE/POST ONLY)
    shared_xlim = {}
    for metric in metrics:
        max_extent = []
        for label, summary, tests in groups:
            sm = summary[(summary["metric"] == metric) &
                         (summary["phase"].isin(["pre", "post"]))]
            if sm.empty:
                continue
            # extent = max(|mean| + SEM) over the plotted phases
            Mvals = np.abs(sm["M"].to_numpy())
            SEvals = sm["SEM"].fillna(0).to_numpy()
            ext = np.nanmax(Mvals + SEvals) if Mvals.size else np.nan
            if np.isfinite(ext):
                max_extent.append(ext)

        xmax = np.nanmax(max_extent) if max_extent else 1.0
        if not np.isfinite(xmax) or xmax == 0:
            xmax = 1.0
        shared_xlim[metric] = (-1.2 * xmax, 1.2 * xmax)  # small padding

    def _get_p_pre_post(tests_df, metric):
        row = tests_df[(tests_df["metric"] == metric) &
                       (tests_df["comparison"] == "Preflare vs. Postflare")]
        return float(row["p"]) if not row.empty else np.nan

    for i, (label, summary, tests) in enumerate(groups):
        for j, metric in enumerate(metrics):
            ax = axes[i, j]
            sm = summary[summary["metric"] == metric]
            if sm.empty:
                ax.set_axis_off()
                continue

            p = _get_p_pre_post(tests, metric)
            is_sig = np.isfinite(p) and (p < p_light)
            alpha = alpha_sig if is_sig else alpha_ns

            # plot pre/post only
            for idx, phase in enumerate(["pre", "post"]):
                row = sm[sm["phase"] == phase]
                if row.empty:
                    continue
                m  = float(row["M"])
                se = float(row["SEM"]) if np.isfinite(row["SEM"]).all() else np.nan
                y = idx
                ax.errorbar(m, y, xerr=se, fmt='o', capsize=4,
                            color=colors[phase], alpha=alpha, markersize=8, lw=2,
                            label=label_map[phase] if (i == 0 and j == 0) else None)

            # significance marker positioned using pre/post only
            sm_pp = sm[sm["phase"].isin(["pre", "post"])]
            if np.isfinite(p):
                sig = _sig_marker(p)
                if sig != "ns" and not sm_pp.empty:
                    right_x = np.nanmax(sm_pp["M"].to_numpy() +
                                        sm_pp["SEM"].fillna(0).to_numpy())
                    ax.text(right_x + 0.08, 0.5, sig, va="center", ha="left", fontsize=11)

            ax.axvline(0, color="k", lw=2)
            ax.set_yticks([0, 1]); ax.set_yticklabels([])
            ax.set_ylim(-1, 2)
            ax.set_xlim(shared_xlim[metric])
            if i == 0:
                ax.set_title(metric_name_map.get(metric, metric.replace("_", " ")))
            if j == 0:
                ax.set_ylabel(label)
            ax.grid(True, axis="x", linestyle=":", alpha=0.6)

    handles, labels = axes[0,0].get_legend_handles_labels()
    if handles:
        fig.legend(handles, labels, loc="lower left", ncol=2, frameon=False)
    fig.text(0.5, 0.02, "Change per week (M ± SEM)", ha="center", va="center")
    plt.tight_layout(rect=[0, 0.08, 1, 1])
    plt.show()

# USAGE
# 1) summary_regular_flares
slopes_infl = compute_prepost_slopes_raw(summary_regular_flares, objective, metrics=METRICS, min_days_threshold=10, method="AND")
summary_infl, tests_infl_all = summarize_and_test_all(slopes_infl, metrics=METRICS)
table_infl = build_prepost_table("Inflammatory Flares", summary_infl, tests_infl_all, metrics=METRICS)

# 2) summary_symptom_flares
slopes_symp = compute_prepost_slopes_raw(summary_symptom_flares, objective, metrics=METRICS, min_days_threshold=10, method="AND")
summary_symp, tests_symp_all = summarize_and_test_all(slopes_symp, metrics=METRICS)
table_symp = build_prepost_table("Symptomatic Flares", summary_symp, tests_symp_all, metrics=METRICS)

# 3) Combined pretty table
final_table = pd.concat([table_infl, table_symp], ignore_index=True)

final_table

In [None]:
plot_change_per_week_shared([("symptom_deg",  summary_symp, tests_symp_all), ("rate_as_flare", summary_infl, tests_infl_all)], metrics=METRICS, p_light=0.10)

##### Aggregated Slopes

In [None]:
# helper: slope through origin
def _fit_slope_through_origin(days, values):
    days = np.asarray(days, dtype=float)
    values = np.asarray(values, dtype=float)
    msk = ~np.isnan(days) & ~np.isnan(values)
    days, values = days[msk], values[msk]
    if days.size < 2:
        return np.nan
    return np.sum(days * values) / np.sum(days**2)

# display mapping and font sizes (adapted)
metric_display = {
    "REM_pct": "REM Sleep (%)",
    "deep_pct": "Deep Sleep (%)",
    "light_pct": "Light Sleep (%)",
    "sleep_eff": "Sleep Efficiency (%)",
    "dur_asleep": "Total Time Asleep (h)",
}
title_fontsize = 14
label_fontsize = 12

def plot_overlay_all_metrics(
    flare_df,
    objective,
    metrics,
    pre_days=45,
    post_days=45,
    alpha=0.15,
    point_size=10,
    min_days_threshold=10,
    method="OR",
):
    df = objective.copy()
    df["date"] = pd.to_datetime(df["date"])
    fl = flare_df.copy()
    fl["date_flare_onset"] = pd.to_datetime(fl["date_flare_onset"])
    fl["date_flare_end"]   = pd.to_datetime(fl["date_flare_end"])

    n = len(metrics)
    fig, axes = plt.subplots(1, n, figsize=(4*n, 3), sharey=False)

    if n == 1:
        axes = [axes]

    for ax, metric in zip(axes, metrics):
        all_pre_days, all_pre_vals = [], []
        all_post_days, all_post_vals = [], []

        for _, row in fl.iterrows():
            uid   = row["user_id"]
            start = row["date_flare_onset"]
            end   = row["date_flare_end"]

            u = df[df["user_id"] == uid].copy()
            if u.empty:
                continue

            pre  = u[(u["date"] >= start - pd.Timedelta(days=pre_days)) & (u["date"] <  start)]
            post = u[(u["date"] >  end) & (u["date"] <= end + pd.Timedelta(days=post_days))]
            flare= u[(u["date"] >= start) & (u["date"] <= end)]

            if flare.empty:
                continue

            # --- FLEXIBLE SKIPPING LOGIC ---
            n_pre, n_post = len(pre), len(post)
            if method.upper() == "AND":
                # require BOTH sides to meet threshold
                if (n_pre < min_days_threshold) or (n_post < min_days_threshold):
                    continue
            elif method.upper() == "OR":
                # require AT LEAST ONE side to meet threshold
                if (n_pre < min_days_threshold) and (n_post < min_days_threshold):
                    continue
            else:
                raise ValueError("method must be 'AND' or 'OR'")
            # --------------------------------

            base = flare[metric].mean()
            if not np.isfinite(base):
                continue

            if not pre.empty:
                d = (pre["date"] - start).dt.days.values
                y = pre[metric].values.astype(float) - base
                all_pre_days.extend(d)
                all_pre_vals.extend(y)

            if not post.empty:
                d = (post["date"] - end).dt.days.values
                y = post[metric].values.astype(float) - base
                all_post_days.extend(d)
                all_post_vals.extend(y)

        all_pre_days, all_pre_vals = np.array(all_pre_days), np.array(all_pre_vals)
        all_post_days, all_post_vals = np.array(all_post_days), np.array(all_post_vals)

        # scatter raw points
        ax.scatter(
            all_pre_days, all_pre_vals, color="#6A1B9A", alpha=alpha, s=point_size,
            label="Pre (raw)" if metric == metrics[0] else ""
        )
        ax.scatter(
            all_post_days, all_post_vals, color="#FF00B3", alpha=alpha, s=point_size,
            label="Post (raw)" if metric == metrics[0] else ""
        )

        # regression lines (forced through origin)
        m_pre  = _fit_slope_through_origin(all_pre_days, all_pre_vals)
        m_post = _fit_slope_through_origin(all_post_days, all_post_vals)

        if np.isfinite(m_pre):
            x_pre = np.linspace(all_pre_days.min(), 0, 100)
            ax.plot(x_pre, m_pre * x_pre, color="#6A1B9A", lw=2,
                    label="Pre (fit)" if metric == metrics[0] else "")
        if np.isfinite(m_post):
            x_post = np.linspace(0, all_post_days.max(), 100)
            ax.plot(x_post, m_post * x_post, color="#FF00B3", lw=2,
                    label="Post (fit)" if metric == metrics[0] else "")

        # print slope values per week
        print(f"{metric}: pre slope = {m_pre*7:.4f} per week, post slope = {m_post*7:.4f} per week")

        ax.axvline(0, color="k", lw=1)
        ax.axhline(0, color="k", lw=1)
        ax.set_title(metric_display.get(metric, metric.replace('_', ' ')), fontsize=title_fontsize)
        ax.grid(True, linestyle=":", alpha=0.6)

        if metric == metrics[0]:
            ax.set_ylabel("Normalized value", fontsize=label_fontsize)
        ax.set_xlabel("Days relative to flare", fontsize=label_fontsize)

    # Legend adapted to use font sizes
    fig.legend(
        loc="upper center",
        ncol=4,
        frameon=False,
        bbox_to_anchor=(0.5, 1.15),
        fontsize=title_fontsize
    )
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

# Example call
plot_overlay_all_metrics(summary_regular_flares,objective, metrics=["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"], min_days_threshold=10, method="AND")

In [None]:
plot_overlay_all_metrics(summary_symptom_flares, objective, metrics=["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"], min_days_threshold=10, method="AND")

##### Individual Slopes

In [None]:
def _span_correction(days_array, ref_span_days):

    d = np.asarray(days_array, dtype=float)
    if d.size == 0 or not np.isfinite(d).any():
        return 1.0
    span = np.nanmax(d) - np.nanmin(d)
    if not np.isfinite(span) or span <= 0:
        return 1.0
    if not np.isfinite(ref_span_days) or ref_span_days <= 0:
        return 1.0
    return float(span) / float(ref_span_days)

def plot_per_flare(
    flare_df,
    objective,
    metrics,
    pre_days=45,
    post_days=45,
    alpha=0.15,
    point_size=10,
    min_days_threshold=10,
    method="OR",
    normalize_spans=True,
):
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt

    # ---------- Display names + font sizes ----------
    metric_display = {
        "REM_pct": "REM Sleep (%)",
        "deep_pct": "Deep Sleep (%)",
        "light_pct": "Light Sleep (%)",
        "sleep_eff": "Sleep Efficiency (%)",
        "dur_asleep": "Total Time Asleep (h)",
        "sleep": "Subj. Sleep Score"}
    title_fontsize = 14
    label_fontsize = 12
    # ------------------------------------------------

    # Prepare data
    df = objective.copy()
    df["date"] = pd.to_datetime(df["date"])
    fl = flare_df.copy()
    fl["date_flare_onset"] = pd.to_datetime(fl["date_flare_onset"])
    fl["date_flare_end"]   = pd.to_datetime(fl["date_flare_end"])

    # Helper to check global threshold
    def _threshold_ok(n_pre, n_post):
        m = method.upper()
        if m == "AND":
            return (n_pre >= min_days_threshold) and (n_post >= min_days_threshold)
        elif m == "OR":
            return (n_pre >= min_days_threshold) or (n_post >= min_days_threshold)
        else:
            raise ValueError("method must be 'AND' or 'OR'")

    # First pass: decide which flares to include (rows)
    rows = []  # each: dict with flare_id, uid, start, end, pre, post, flare
    for flare_id, row in fl.reset_index(drop=True).iterrows():
        uid   = row["user_id"]
        start = row["date_flare_onset"]
        end   = row["date_flare_end"]

        u = df[df["user_id"] == uid].copy()
        if u.empty:
            continue

        pre  = u[(u["date"] >= start - pd.Timedelta(days=pre_days)) & (u["date"] <  start)]
        post = u[(u["date"] >  end) & (u["date"] <= end + pd.Timedelta(days=post_days))]
        flare= u[(u["date"] >= start) & (u["date"] <= end)]

        if flare.empty:
            continue

        n_pre, n_post = len(pre), len(post)
        if not _threshold_ok(n_pre, n_post):
            continue

        # does ANY metric have usable finite pre or post values?
        any_metric = False
        for metric in metrics:
            base = np.nanmean(flare[metric].values.astype(float))
            if not np.isfinite(base):
                continue
            pre_vals  = pre[metric].values.astype(float)  if not pre.empty  else np.array([])
            post_vals = post[metric].values.astype(float) if not post.empty else np.array([])
            if (pre_vals.size and np.isfinite(pre_vals).any()) or (post_vals.size and np.isfinite(post_vals).any()):
                any_metric = True
                break

        if any_metric:
            rows.append(dict(
                flare_id=flare_id,
                uid=uid, start=start, end=end,
                pre=pre, post=post, flare=flare
            ))

    n_rows = len(rows)
    n_cols = len(metrics)

    # If nothing qualifies, show one empty row with notes
    if n_rows == 0:
        fig, axes = plt.subplots(1, n_cols, figsize=(4*n_cols, 3), sharey=False)
        if n_cols == 1:
            axes = [axes]
        for j, (ax, metric) in enumerate(zip(axes, metrics)):
            ax.axvline(0, color="k", lw=1)
            ax.axhline(0, color="k", lw=1)
            ax.grid(True, linestyle=":", alpha=0.6)
            ax.set_title(metric_display.get(metric, metric.replace('_', ' ')), fontsize=title_fontsize)
            ax.set_xlabel("Days relative to flare", fontsize=label_fontsize)
            if j == 0:
                ax.set_ylabel("Normalized value", fontsize=label_fontsize)
            ax.text(0.5, 0.5, "No flares with usable data", transform=ax.transAxes,
                    ha="center", va="center", alpha=0.6)
        plt.tight_layout()
        plt.show()
        return

    # Create one big grid: rows=flares, cols=metrics
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 3*n_rows), sharey=False)
    # Normalize axes to 2D array
    import numpy as _np  # alias to avoid shadowing above np, if any
    if n_rows == 1 and n_cols == 1:
        axes = _np.array([[axes]])
    elif n_rows == 1:
        axes = _np.array([axes])
    elif n_cols == 1:
        axes = _np.array([[ax] for ax in axes])

    # Decorator for empty cells
    def _decorate_empty(ax, metric, first_col, msg="No data in window"):
        ax.axvline(0, color="k", lw=1)
        ax.axhline(0, color="k", lw=1)
        ax.grid(True, linestyle=":", alpha=0.6)
        ax.set_title(metric_display.get(metric, metric.replace('_', ' ')), fontsize=title_fontsize)
        ax.set_xlabel("Days relative to flare", fontsize=label_fontsize)
        if first_col:
            ax.set_ylabel("Normalized value", fontsize=label_fontsize)
        ax.text(0.5, 0.5, msg, transform=ax.transAxes, ha="center", va="center", alpha=0.6)

    # --- Plot each row/flare ---
    for i, entry in enumerate(rows):
        uid   = entry["uid"]
        start = entry["start"]
        end   = entry["end"]
        pre   = entry["pre"]
        post  = entry["post"]
        flare = entry["flare"]

        print(f"\n=== Flare {entry['flare_id']}, user {uid} ===")

        for j, metric in enumerate(metrics):
            ax = axes[i, j]
            first_col = (j == 0)

            base = np.nanmean(flare[metric].values.astype(float))
            if not np.isfinite(base):
                _decorate_empty(ax, metric, first_col, msg="No valid values for metric")
                continue

            all_pre_days, all_pre_vals = [], []
            all_post_days, all_post_vals = [], []

            if not pre.empty:
                d = (pre["date"] - start).dt.days.values
                y = pre[metric].values.astype(float) - base
                mask = np.isfinite(y)
                all_pre_days.extend(d[mask])
                all_pre_vals.extend(y[mask])

            if not post.empty:
                d = (post["date"] - end).dt.days.values
                y = post[metric].values.astype(float) - base
                mask = np.isfinite(y)
                all_post_days.extend(d[mask])
                all_post_vals.extend(y[mask])

            all_pre_days, all_pre_vals   = np.array(all_pre_days),  np.array(all_pre_vals)
            all_post_days, all_post_vals = np.array(all_post_days), np.array(all_post_vals)

            # If this metric has zero usable points, render empty axis
            if all_pre_days.size == 0 and all_post_days.size == 0:
                _decorate_empty(ax, metric, first_col, msg="No finite pre/post values")
                continue

            # scatter (raw)
            if all_pre_days.size > 0:
                ax.scatter(all_pre_days, all_pre_vals, color="#6A1B9A", alpha=alpha, s=point_size, label="Pre (raw)")
            if all_post_days.size > 0:
                ax.scatter(all_post_days, all_post_vals, color="#FF00B3", alpha=alpha, s=point_size, label="Post (raw)")

            # slopes (through origin)
            m_pre_raw  = _fit_slope_through_origin(all_pre_days,  all_pre_vals)  if all_pre_days.size  > 1 else np.nan
            m_post_raw = _fit_slope_through_origin(all_post_days, all_post_vals) if all_post_days.size > 1 else np.nan

            # normalize to fixed span
            m_pre = m_pre_raw
            m_post = m_post_raw
            if normalize_spans and np.isfinite(m_pre_raw):
                m_pre = m_pre_raw * _span_correction(all_pre_days, pre_days)
            if normalize_spans and np.isfinite(m_post_raw):
                m_post = m_post_raw * _span_correction(all_post_days, post_days)

            # lines
            if np.isfinite(m_pre) and all_pre_days.size > 0:
                x_pre = np.linspace(all_pre_days.min(), 0, 100)
                ax.plot(x_pre, m_pre * x_pre, color="#6A1B9A", lw=2, label="Pre (fit)")
            if np.isfinite(m_post) and all_post_days.size > 0:
                x_post = np.linspace(0, all_post_days.max(), 100)
                ax.plot(x_post, m_post * x_post, color="#FF00B3", lw=2, label="Post (fit)")

            # print slope values per week (normalized if enabled)
            if np.isfinite(m_pre) or np.isfinite(m_post):
                note = " (normalized to ref span)" if normalize_spans else ""
                print(f"{metric}: pre slope = {m_pre*7:.4f} per week, post slope = {m_post*7:.4f} per week{note}")

            # aesthetics + labels
            ax.axvline(0, color="k", lw=1)
            ax.axhline(0, color="k", lw=1)
            ax.grid(True, linestyle=":", alpha=0.6)
            ax.set_title(metric_display.get(metric, metric.replace('_', ' ')), fontsize=title_fontsize)
            ax.set_xlabel("Days relative to flare", fontsize=label_fontsize)
            if first_col:
                ax.set_ylabel("Normalized value", fontsize=label_fontsize)

    plt.tight_layout()
    plt.show()

In [None]:
plot_per_flare(
    summary_regular_flares,
    objective,
    metrics=["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"],
    min_days_threshold=10,
    method="AND",
    pre_days=45,
    post_days=45,
    normalize_spans=True
)

In [None]:
plot_per_flare(
    summary_symptom_flares,
    objective,
    metrics=["REM_pct", "deep_pct", "light_pct", "sleep_eff", "dur_asleep"],
    min_days_threshold=10,
    method="AND",
    pre_days=45,
    post_days=45,
    normalize_spans=True
)