In [4]:
import os
import warnings
import random
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test
from lifelines.exceptions import StatisticalWarning

random.seed(42)  # set seed for reproducibility


def compute_common_support_and_iptw(df, ps_col="IO_prediction", treat_col="PX_on_IO",
                                   trim_common_support=True, trunc_pct=(1, 99), eps=1e-6):
    """
    Mirrors your logic:
      - optional common support trimming on propensity
      - stabilized ATE IPTW
      - truncation at given percentiles
    Returns: df copy with 'IPTW' column.
    """
    out = df.copy()

    ps_raw = out[ps_col].clip(eps, 1 - eps)

    if trim_common_support:
        ps_t = ps_raw[out[treat_col] == 1]
        ps_c = ps_raw[out[treat_col] == 0]
        lower = max(ps_t.min(), ps_c.min())
        upper = min(ps_t.max(), ps_c.max())
        out = out[(ps_raw >= lower) & (ps_raw <= upper)].copy()
        ps_raw = out[ps_col].clip(eps, 1 - eps)

    p_treated = out[treat_col].mean()
    p_control = 1 - p_treated

    w = np.where(out[treat_col] == 1, p_treated / ps_raw, p_control / (1 - ps_raw))
    lo, hi = np.percentile(w, trunc_pct)
    out["IPTW"] = np.clip(w, lo, hi)

    return out


def combined_km_in_io_two_panel(df, marker, duration_col="tt_death", event_col="death",
                                treat_col="PX_on_IO", weight_col="IPTW",
                                marker_pos_label="Marker+", marker_neg_label="Marker-",
                                title=None,
                                # adjusted (bootstrap) options
                                renormalize_weights=True, n_boot=250, grid_points=200, seed=42,
                                show_adjusted_ci=True, adjusted_ci_alpha=0.20,
                                show_adjusted_naive_line=False,
                                # unadjusted options
                                show_unadjusted_ci=True, unadjusted_ci_alpha=0.15,
                                show_unadjusted_naive_points=False,
                                savefig=False, output_path=None):
    """
    One PNG per marker with TWO PANELS (same figure):
      - Left: Adjusted (IPTW + bootstrap pointwise 95% CI) KM curves in IO-treated cohort
      - Right: Unadjusted (standard KM + Greenwood CI) KM curves in IO-treated cohort

    Curves are split by marker (0/1). All fits are within treat_col==1.

    Returns: fig, (ax_adj, ax_unadj), info_dict
    """
    d = df.loc[df[treat_col] == 1, [duration_col, event_col, marker, weight_col]].dropna().copy()
    d_pos = d[d[marker] == 1].copy()
    d_neg = d[d[marker] == 0].copy()

    # p-values
    lr_w = logrank_test(
        d_pos[duration_col], d_neg[duration_col],
        event_observed_A=d_pos[event_col],
        event_observed_B=d_neg[event_col],
        weights_A=d_pos[weight_col],
        weights_B=d_neg[weight_col],
    )
    lr_u = logrank_test(
        d_pos[duration_col], d_neg[duration_col],
        event_observed_A=d_pos[event_col],
        event_observed_B=d_neg[event_col],
    )

    # figure with two subplots
    fig, (ax_adj, ax_unadj) = plt.subplots(1, 2, figsize=(14, 5.5), sharey=True)

    # -------------------------
    # Unadjusted panel (right)
    # -------------------------
    km_u = KaplanMeierFitter()

    km_u.fit(
        d_pos[duration_col], d_pos[event_col],
        label=f"{marker_pos_label} (n={len(d_pos)})"
    )
    km_u.plot_survival_function(
        ax=ax_unadj, ci_show=show_unadjusted_ci, ci_alpha=unadjusted_ci_alpha, at_risk_counts=False
    )
    if show_unadjusted_naive_points:
        ax_unadj.scatter(
            km_u.survival_function_.index.values,
            km_u.survival_function_[km_u._label].values,
            s=8
        )

    km_u.fit(
        d_neg[duration_col], d_neg[event_col],
        label=f"{marker_neg_label} (n={len(d_neg)})"
    )
    km_u.plot_survival_function(
        ax=ax_unadj, ci_show=show_unadjusted_ci, ci_alpha=unadjusted_ci_alpha, at_risk_counts=False
    )
    if show_unadjusted_naive_points:
        ax_unadj.scatter(
            km_u.survival_function_.index.values,
            km_u.survival_function_[km_u._label].values,
            s=8
        )

    ax_unadj.set_title(f"Unadjusted KM\nlog-rank p={lr_u.p_value:.3g}")
    ax_unadj.set_xlabel("Time")
    ax_unadj.legend()

    # -------------------------
    # Adjusted panel (left)
    # -------------------------
    if renormalize_weights:
        for sub in (d_pos, d_neg):
            s = sub[weight_col].sum()
            if s > 0:
                sub[weight_col] = sub[weight_col] * (len(sub) / s)

    tmax = float(np.nanmax(d[duration_col].values)) if len(d) else 0.0
    grid = np.linspace(0, tmax, grid_points) if tmax > 0 else np.array([0.0])

    def _bootstrap_group(df_g, seed_offset):
        rng = np.random.default_rng(seed + seed_offset)
        n = len(df_g)
        if n == 0:
            raise ValueError("Empty subgroup for bootstrap KM.")

        km = KaplanMeierFitter()

        # naive single-fit weighted curve (optional overlay)
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=StatisticalWarning)
            km.fit(df_g[duration_col], df_g[event_col], weights=df_g[weight_col])
            naive = km.survival_function_at_times(grid).values

        surv_mat = np.empty((n_boot, len(grid)), dtype=float)
        for b in range(n_boot):
            idx = rng.integers(0, n, size=n)
            bs = df_g.iloc[idx]
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=StatisticalWarning)
                km.fit(bs[duration_col], bs[event_col], weights=bs[weight_col])
            surv_mat[b, :] = km.survival_function_at_times(grid).values

        med = np.median(surv_mat, axis=0)
        lo = np.quantile(surv_mat, 0.025, axis=0)
        hi = np.quantile(surv_mat, 0.975, axis=0)
        return naive, med, lo, hi

    naive_pos, med_pos, lo_pos, hi_pos = _bootstrap_group(d_pos, seed_offset=1)
    naive_neg, med_neg, lo_neg, hi_neg = _bootstrap_group(d_neg, seed_offset=2)

    ax_adj.plot(grid, med_pos, linewidth=2, label=f"{marker_pos_label} (IPTW)")
    if show_adjusted_ci:
        ax_adj.fill_between(grid, lo_pos, hi_pos, alpha=adjusted_ci_alpha)

    ax_adj.plot(grid, med_neg, linewidth=2, label=f"{marker_neg_label} (IPTW)")
    if show_adjusted_ci:
        ax_adj.fill_between(grid, lo_neg, hi_neg, alpha=adjusted_ci_alpha)

    if show_adjusted_naive_line:
        ax_adj.plot(grid, naive_pos, linestyle=":", linewidth=1, label=f"{marker_pos_label} naive")
        ax_adj.plot(grid, naive_neg, linestyle=":", linewidth=1, label=f"{marker_neg_label} naive")

    ax_adj.set_title(f"IPTW-adjusted KM (bootstrap CI)\nweighted log-rank p={lr_w.p_value:.3g}")
    ax_adj.set_xlabel("Time")
    ax_adj.set_ylabel("Survival probability")
    ax_adj.legend()

    # overall title
    if title is None:
        title = f"{marker}: Adjusted vs Unadjusted KM (IO cohort)"
    fig.suptitle(title, y=1.02)

    fig.tight_layout()

    if savefig and (output_path is not None):
        fig.savefig(output_path, bbox_inches="tight")
        plt.close(fig)

    info = {
        "marker": marker,
        "n_io_pos": int(len(d_pos)),
        "n_io_neg": int(len(d_neg)),
        "p_weighted_logrank": float(lr_w.p_value),
        "p_unweighted_logrank": float(lr_u.p_value),
        "n_boot": int(n_boot),
        "grid_points": int(grid_points),
    }
    return fig, (ax_adj, ax_unadj), info


# Paths
PROJ_PATH = '/data/gusev/USERS/jpconnor/clinical_text_project/'
DATA_PATH = os.path.join(PROJ_PATH, 'data/')
NOTES_PATH = os.path.join(DATA_PATH, 'batched_datasets/processed_datasets/')
MARKER_PATH = os.path.join(DATA_PATH, 'biomarker_analysis/')
IPTW_RESULTS_PATH = os.path.join(MARKER_PATH, 'IPTW_runs/')
FIGURE_PATH = os.path.join(PROJ_PATH, 'figures/')
MARKER_FIG_PATH = os.path.join(FIGURE_PATH, 'biomarker_analysis/')
IPTW_FIG_PATH = os.path.join(MARKER_FIG_PATH, 'IPTW_figures/')
KM_FIG_PATH = os.path.join(IPTW_FIG_PATH, 'KM_curves/')

interaction_IO_df = pd.read_csv(os.path.join(MARKER_PATH, 'IPTW_IO_interaction_runs_df.csv'))

cancer_types = ['pan_cancer', 'LUNG', 'SKIN']

for cancer_type in cancer_types:
    print(f'Starting cancer type {cancer_type}')

    TYPE_PATH = os.path.join(KM_FIG_PATH, cancer_type)
    HARM_PATH = os.path.join(TYPE_PATH, 'predicted_IO_harm/')
    BENEFIT_PATH = os.path.join(TYPE_PATH, 'predicted_IO_benefit/')
    os.makedirs(HARM_PATH, exist_ok=True)
    os.makedirs(BENEFIT_PATH, exist_ok=True)

    marker_df = pd.read_csv(os.path.join(IPTW_RESULTS_PATH, f'{cancer_type}_IPTW_IO_predictive_markers.csv'))

    IO_pred_marker_df = (marker_df
                         .query("significant_predictive")
                         .sort_values("beta_marker_IO", ascending=True))
    markers_w_IO_harm = IO_pred_marker_df.loc[IO_pred_marker_df['beta_marker_IO'] > 0.5]
    markers_w_IO_benefit = IO_pred_marker_df.loc[IO_pred_marker_df['beta_marker_IO'] < -0.5]

    # compute IPTW exactly like your pipeline (common support + stabilized + truncation)
    if cancer_type == 'pan_cancer':
        df_ct_w = compute_common_support_and_iptw(
            interaction_IO_df,
            ps_col="IO_prediction", treat_col="PX_on_IO",
            trim_common_support=True, trunc_pct=(1, 99)
        )
    else:
        df_ct_w = compute_common_support_and_iptw(
            interaction_IO_df.loc[interaction_IO_df[f'CANCER_TYPE_{cancer_type}']],
            ps_col="IO_prediction", treat_col="PX_on_IO",
            trim_common_support=True, trunc_pct=(1, 99)
        )

    # --- HARM markers ---
    if len(markers_w_IO_harm) > 0:
        for m in tqdm(markers_w_IO_harm['marker'].tolist()):
            fig, axes, info = combined_km_in_io_two_panel(
                df_ct_w, marker=m,
                duration_col="tt_death", event_col="death",
                treat_col="PX_on_IO", weight_col="IPTW",
                title=f"{cancer_type}: {m}",
                savefig=True,
                output_path=os.path.join(HARM_PATH, f'{m}_combined_KM_curve.png')
            )

    # --- BENEFIT markers ---
    if len(markers_w_IO_benefit) > 0:
        for m in tqdm(markers_w_IO_benefit['marker'].tolist()):
            fig, axes, info = combined_km_in_io_two_panel(
                df_ct_w, marker=m,
                duration_col="tt_death", event_col="death",
                treat_col="PX_on_IO", weight_col="IPTW",
                title=f"{cancer_type}: {m}",
                savefig=True,
                output_path=os.path.join(BENEFIT_PATH, f'{m}_combined_KM_curve.png')
            )


Starting cancer type pan_cancer


100%|██████████| 35/35 [02:28<00:00,  4.23s/it]
100%|██████████| 14/14 [00:58<00:00,  4.18s/it]


Starting cancer type LUNG


100%|██████████| 25/25 [01:44<00:00,  4.18s/it]
100%|██████████| 41/41 [02:58<00:00,  4.36s/it]


Starting cancer type SKIN


100%|██████████| 32/32 [02:09<00:00,  4.06s/it]
100%|██████████| 59/59 [03:59<00:00,  4.07s/it]
