# Code Implementation

## 0. Package Import

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib import patches
from lifelines import KaplanMeierFitter
from tqdm import tqdm
import ast
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import matplotlib
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from lifelines import CoxPHFitter
from numpy import trapz
from collections import Counter
import statsmodels.api as sm
from scipy.stats import pearsonr, spearmanr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import os 

## 1. Mean Cumulative Count Calculation

In [None]:
#!/usr/bin/env python3
import argparse
import numpy as np
import pandas as pd
from typing import List, Tuple, Optional
import os
from tqdm import tqdm
from multiprocessing import Pool, cpu_count

# ---------------- Configuration ----------------
COND_LIST = ['Anx','Dep','SMI','Ast','COPD','Diab','Hyp','CHD','StroTIA','AF','HF','PAD','CKD','Dem','Park','Ost','RA','Can']
AGE_BANDS = [[20,29],[30,39],[40,49],[50,59],[60,69],[70,79]]

# ---------------- Core MCC on grid ----------------
def compute_mcc_curve_yearscale(events_times, enter_age, exit_age, t_grid):
    events_times = np.asarray(events_times, float)
    enter_age = np.asarray(enter_age, float)
    exit_age  = np.asarray(exit_age,  float)
    t_grid    = np.asarray(t_grid,    float)

    if events_times.size == 0:
        return pd.DataFrame({"t": t_grid, "mcc": np.zeros_like(t_grid, float)})

    dj_series = pd.Series(events_times).value_counts().sort_index()
    ev_times = dj_series.index.values.astype(float)
    dj = dj_series.values.astype(float)

    enter_sorted = np.sort(enter_age)
    exit_sorted  = np.sort(exit_age)
    entered_before = np.searchsorted(enter_sorted, ev_times, side="right")
    exited_before  = np.searchsorted(exit_sorted,  ev_times, side="left")
    Yj = entered_before - exited_before

    inc = np.zeros_like(ev_times, float)
    valid = Yj > 0
    inc[valid] = dj[valid] / Yj[valid]

    mcc_at_ev = np.cumsum(inc)
    idx = np.searchsorted(ev_times, t_grid, side="right") - 1
    mcc_grid = np.where(idx >= 0, mcc_at_ev[np.clip(idx, 0, len(mcc_at_ev)-1)], 0.0)

    return pd.DataFrame({"t": t_grid, "mcc": mcc_grid})

# ---------------- Helpers ----------------
def assign_age_band(df, bands, index_age_col="1st_age", band_label="1st_age_band"):
    def as_band(a):
        if pd.isna(a): return np.nan
        for (lo, hi) in bands:
            if lo <= a <= hi: return (lo, hi)
        return np.nan
    out = df.copy()
    out[band_label] = out[index_age_col].apply(as_band)
    return out

def riskset_over_grid(enter_age, exit_age, t_grid):
    enter_sorted = np.sort(np.asarray(enter_age, float))
    exit_sorted  = np.sort(np.asarray(exit_age,  float))
    entered_before = np.searchsorted(enter_sorted, t_grid, side="right")
    exited_before  = np.searchsorted(exit_sorted,  t_grid, side="left")
    return entered_before - exited_before

# ---------------- Landmark per person ----------------
def landmark_per_person(sub_people, time_scale, index_age_col, band_col="1st_age_band"):
    return pd.Series(np.zeros(len(sub_people), dtype=float), index=sub_people.index)

# ---------------- Events long table ----------------
def build_events_long(df, cond_cols, id_col, time_scale, landmark_col_name, index_cond=None):
    ev = df.melt(id_vars=[id_col, landmark_col_name], value_vars=cond_cols, var_name="condition", value_name="onset_age").dropna(subset=["onset_age"]).copy()
    if time_scale == "attained_age":
        ev["event_time"] = ev["onset_age"].astype(float)
    else:
        raise ValueError("time_scale must be 'attained_age'.")
    return ev[[id_col, "condition", "event_time"]]

# ---------------- One bootstrap replicate ----------------
def bootstrap_replicate(args):
    sub_people, cond_cols, t_grid, id_col, exit_col, time_scale, index_age_col, seed = args
    rng = np.random.default_rng(seed)

    # Resample IDs
    ids = sub_people[id_col].unique()
    samp_ids = rng.choice(ids, size=len(ids), replace=True)
    samp = sub_people[sub_people[id_col].isin(samp_ids)].copy()

    # Landmark
    lm = landmark_per_person(samp, time_scale, index_age_col, "1st_age_band")
    samp["__lm__"] = lm.values

    ev_b = build_events_long(samp, cond_cols, id_col=id_col, time_scale=time_scale, landmark_col_name="__lm__", index_cond=sub_people['1st_cond'].values[0])
    enter_b = np.zeros(len(samp))
    exit_b = samp[exit_col].astype(float).values

    return compute_mcc_curve_yearscale(ev_b["event_time"].values, enter_b, exit_b, t_grid)["mcc"].values

# ---------------- MCC for one stratum with CI ----------------
def mcc_for_stratum(sub_people, cond_cols, t_grid, id_col, exit_col, time_scale, index_age_col, min_alive=1000, n_boot=100, n_jobs=None):
    lm = landmark_per_person(sub_people, time_scale, index_age_col, "1st_age_band")
    tmp = sub_people.copy()
    tmp["__lm__"] = lm.values

    ev_long = build_events_long(tmp, cond_cols, id_col=id_col, time_scale=time_scale, landmark_col_name="__lm__", index_cond=sub_people['1st_cond'].values[0])
    enter = np.zeros(len(sub_people), dtype=float)
    exit_ = sub_people[exit_col].astype(float).values

    # Truncate where Y > min_alive
    Y = riskset_over_grid(enter, exit_, t_grid)
    valid_idx = np.where(Y >= min_alive)[0]
    if valid_idx.size == 0:
        return None
    tau_star = t_grid[valid_idx[-1]]
    grid_trunc = t_grid[t_grid <= tau_star]

    point_df = compute_mcc_curve_yearscale(ev_long["event_time"].values, enter, exit_, grid_trunc)
    point = point_df["mcc"].values

    # Bootstrap with multiprocessing
    if n_boot > 0:
        seeds = np.random.SeedSequence().spawn(n_boot)
        args = [(sub_people, cond_cols, grid_trunc, id_col, exit_col, time_scale, index_age_col, int(s.generate_state(1)[0])) for s in seeds]
        n_jobs = n_jobs or cpu_count()
        with Pool(processes=n_jobs) as pool:
            B = list(tqdm(pool.imap(bootstrap_replicate, args), total=n_boot, desc="Bootstrapping"))
        B = np.vstack(B)
        lo = np.percentile(B, 2.5, axis=0)
        hi = np.percentile(B, 97.5, axis=0)
    else:
        lo = np.full_like(point, np.nan)
        hi = np.full_like(point, np.nan)
    t = grid_trunc
    return pd.DataFrame({"t": t, "mcc": point, "mcc_lo": lo, "mcc_hi": hi})

# ---------------- Loop over strata ----------------
def mcc_by_strata(df, cond_cols, t_grid, strata, id_col, exit_col, time_scale, index_age_col, min_alive=1000, n_boot=100, n_jobs=None):
    out_rows = []
    for keys in strata:
        for key_vals, sub in tqdm(df.groupby(list(keys), dropna=False, sort=True)):
            print(key_vals, sub.shape[0])
            if not isinstance(key_vals, tuple):
                key_vals = (key_vals,)
            res = mcc_for_stratum(
                sub_people=sub, cond_cols=cond_cols, t_grid=t_grid,
                id_col=id_col, exit_col=exit_col,
                time_scale=time_scale, index_age_col=index_age_col,
                min_alive=min_alive, n_boot=n_boot, n_jobs=n_jobs
            )
            if res is None: 
                continue
            for k, v in zip(keys, key_vals):
                res[k] = str(v) if isinstance(v, (tuple, list)) else v
            out_rows.append(res)
    return pd.concat(out_rows, axis=0, ignore_index=True)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="MCC with CI (multiprocessing, landmark-corrected)")
    parser.add_argument("--input",  default="./data/df_patient_20_80.csv", help='Input file path')
    parser.add_argument("--output", default="./results/", help='Output directory path')
    parser.add_argument("--id_col", default="patid", help='Patient ID column name')
    parser.add_argument("--index_age_col", default="1st_age", help='Index age column name')
    parser.add_argument("--time_scale", default="attained_age", help='Time scale for analysis')
    parser.add_argument("--min_alive", type=int, default=1000, help='Minimum number of alive people for bootstrap')
    parser.add_argument("--n_boot", type=int, default=100, help='Number of bootstrap replicates')
    parser.add_argument("--max_t", type=int, default=100, help='Maximum age for analysis')
    parser.add_argument("--n_jobs", type=int, default=None, help="Processes for multiprocessing (default: all cores)")
    args = parser.parse_args()
    
    df = pd.read_csv(args.input)
    cond_cols = COND_LIST
    df.loc[:, cond_cols] = (df[cond_cols].apply(pd.to_datetime, errors="coerce").apply(lambda x: x.dt.year).sub(df["yob"], axis=0))  # Convert condition date to age at event

    df_ext = assign_age_band(df, AGE_BANDS, index_age_col=args.index_age_col, band_label="1st_age_band") # Determine the index age band of each patient
    t_grid = np.arange(0, args.max_t + 1, dtype=float)
    strata = [("gender","1st_cond","1st_age_band")]
    df_ext = df_ext[df_ext['1st_cond']!='Health'] # Remove healthy people by end of 2019
    df_mcc = mcc_by_strata(
        df=df_ext, cond_cols=cond_cols, t_grid=t_grid, strata=strata,
        id_col=args.id_col, exit_col="exit_age",
        time_scale=args.time_scale, index_age_col=args.index_age_col,
        min_alive=args.min_alive, n_boot=args.n_boot, n_jobs=args.n_jobs
    ) # Calculate MCC and confidence interval
    df_mcc.to_csv(args.output+'mcc_'+args.time_scale+'_ci.csv', index=False)
    print(f"Saved!")

## 2. Power-Law Model Fitting

In [None]:
def power_law(t, alpha, beta):
    return alpha * (t) ** beta
    
# --------------------------
# f1: Fit power law for one stratum (with bootstrap)
# --------------------------
def fit_power_law_stratum(df_stratum, band_str, n_boot=100, least_fit_pts=5, random_state=42):
    """
    Fit f(t) = alpha * (t)^beta for one stratum.
    Adds bootstrap confidence intervals for alpha, beta.
    """
    rng = np.random.default_rng(random_state)
    band = eval(band_str)
    t0 = float(band[1])
    # Subset points AFTER band upper bound
    df_fit = df_stratum[df_stratum["t"] > t0].copy()
    if df_fit.empty or len(df_fit) < least_fit_pts:
        return None
    x = df_fit["t"].values - df_fit["t"].values[0] 
    y = df_fit["mcc"].values - df_fit["mcc"].values[0]
    x = x[1:]
    y = y[1:]
    # Initial guesses
    p0 = [y[0], 1.25]
    try:
        popt, pcov = curve_fit(lambda t, alpha, beta: power_law(t, alpha, beta), x, y, p0=p0, maxfev=20000)
        alpha, beta = popt
        perr = np.sqrt(np.diag(pcov))  # standard errors
    except RuntimeError:
        return None
    # Predictions and error metrics
    y_pred = power_law(x, alpha, beta)
    metrics = evaluate_fit(y, y_pred)
    # --------------------------
    # Bootstrap for CI
    # --------------------------
    alpha_boot, beta_boot = [], []
    for _ in range(n_boot):
        idx = rng.choice(len(x), size=len(x), replace=True)
        xb, yb = x[idx], y[idx]
        try:
            popt_b, _ = curve_fit(lambda t, a, b: power_law(t, a, b), xb, yb, p0=[alpha, beta], maxfev=10000)
            alpha_boot.append(popt_b[0])
            beta_boot.append(popt_b[1])
        except RuntimeError:
            continue
    if alpha_boot and beta_boot:
        alpha_ci = (np.percentile(alpha_boot, 2.5), np.percentile(alpha_boot, 97.5))
        beta_ci  = (np.percentile(beta_boot, 2.5), np.percentile(beta_boot, 97.5))
    else:
        alpha_ci = (np.nan, np.nan)
        beta_ci  = (np.nan, np.nan)
    return {"alpha": alpha, "beta": beta, "alpha_se": perr[0], "beta_se": perr[1], "alpha_ci_low": alpha_ci[0], "alpha_ci_high": alpha_ci[1],
            "beta_ci_low": beta_ci[0], "beta_ci_high": beta_ci[1], "n_points": len(x), **metrics}

# --------------------------
# f2: Error metrics
# --------------------------
def evaluate_fit(y_true, y_pred):
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae  = mean_absolute_error(y_true, y_pred)
    r2   = r2_score(y_true, y_pred)
    return {"rmse": rmse, "mae": mae, "r2": r2}

# --------------------------
# Fit across all strata
# --------------------------
def fit_all_strata(df_plot, n_boot=100, imd=None, least_fit_pts=5):
    results = []
    if imd is not None:
        for (g, cond, band, imd_i), sub in df_plot.groupby(["gender","1st_cond","1st_age_band"] + ['imd']):
            if band in age_bands:
                res = fit_power_law_stratum(sub, band, n_boot, least_fit_pts)
                if res is not None:
                    res["gender"] = g
                    res["condition"] = cond
                    res["band"] = band
                    res['imd'] = imd_i
                    results.append(res)
        return pd.DataFrame(results)
    else:
        for (g, cond, band), sub in df_plot.groupby(["gender","1st_cond","1st_age_band"]):
            res = fit_power_law_stratum(sub, band, n_boot, least_fit_pts)
            if res is not None and band in age_bands:
                res["gender"] = g
                res["condition"] = cond
                res["band"] = band
                results.append(res)
        return pd.DataFrame(results)

dir_out = './results/'
time_scale = 'attained_age'
n_plot_points = 5
df_plot = pd.read_csv(dir_out+'mcc_'+args.time_scale+'_ci.csv')
df_plot = df_plot[df_plot['1st_cond'].isin(cond_date_list)]
df_results = fit_all_strata(df_plot, n_boot=100, imd=None, least_fit_pts=n_plot_points)
df_results = df_results[['alpha', 'alpha_ci_low', 'alpha_ci_high', 'beta', 'beta_ci_low', 'beta_ci_high', 'rmse', 'mae', 'r2', 'gender', 'condition', 'band']]
df_results = df_results.rename(columns={'condition': '1st_cond', 'band': '1st_age_band'})
df_results.to_csv(dir_out+'mcc_fit_power_law.csv')

## 3. Analysis of Socioeconomic Deprivation

In [None]:
from itertools import combinations
from scipy.stats import norm

df = pd.read_csv('./results/mcc_fit_power_law_imd.csv')
pairs = []
cond_sel_list = ['Anx', 'Dep', 'Ast', 'Diab', 'Hyp', 'CHD', 'Can']
age_bands = ['(20, 29)', '(30, 39)', '(40, 49)', '(50, 59)', '(60, 69)', '(70, 79)']
for g in ['M', 'F']:
    for metric_i in ['alpha', 'beta']:
        for cond_i in cond_sel_list:
            for ag_i in age_bands:
                df_beta = df[(df['gender']==g) & (df['condition']==cond_i) & (df['band']==ag_i)]
                if 1 in df_beta['imd'].unique() and 5 in df_beta['imd'].unique():
                    for i, j in combinations(sorted([1, 5]), 2):
                        bi, bj = df_beta.loc[df_beta['imd']==i, metric_i].values[0], df_beta.loc[df_beta['imd']==j, metric_i].values[0]
                        sei, sej = df_beta.loc[df_beta['imd']==i, metric_i+'_se'].values[0], df_beta.loc[df_beta['imd']==j, metric_i+'_se'].values[0]
                        z = (bi - bj) / np.sqrt(sei**2 + sej**2)
                        p = 2 * (1 - norm.cdf(abs(z)))
                        pairs.append((g, metric_i, cond_i, ag_i, i, j, bi, bj, z, p))
df_pval = pd.DataFrame(pairs, columns=['gender', 'metric', 'cond', 'age_band', 'IMD_i','IMD_j','v_i','v_j','z','p'])