In [160]:
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr  
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

decay_param=0.01

# Paths
FIGURE_PATH = '/data/gusev/USERS/jpconnor/figures/clinical_text_embedding_project/model_metrics/'
DATA_PATH = '/data/gusev/USERS/jpconnor/data/clinical_text_embedding_project/'
FEATURE_PATH = os.path.join(DATA_PATH, 'clinical_and_genomic_features/')
SURV_PATH = os.path.join(DATA_PATH, 'time-to-event_analysis/')
RESULTS_PATH = os.path.join(SURV_PATH, 'results/level_3_ICD_results/')
TRAJECTORY_PATH = os.path.join(RESULTS_PATH, 'mortality_trajectories/')

FIGURE_PATH = '/data/gusev/USERS/jpconnor/figures/clinical_text_embedding_project/mortality_trajectories/'
# FEATURE_FIG_PATH = os.path.join(FIGURE_PATH, 'feature_based_clusters/')
FEATURE_FIG_PATH = os.path.join(FIGURE_PATH, f'decay_param_{decay_param}')
STAGE_VS_TRAJECTORY_FIG_PATH = os.path.join(FEATURE_FIG_PATH, 'stage_vs_trajectory/')

os.makedirs(STAGE_VS_TRAJECTORY_FIG_PATH, exist_ok=True)

def gen_clusters(feature_df, feature_cols, clusters_to_test=[i+2 for i in range(18)]):
    inertias = []
    silhouette_scores = []
    cluster_labels = {'DFCI_MRN' : feature_df['DFCI_MRN']}
    for n_clust in tqdm(clusters_to_test):
        km = KMeans(n_clusters=n_clust, random_state=0).fit(feature_df[feature_cols])
        clusters = km.predict(feature_df[feature_cols])
        inertias.append(km.inertia_)
        cluster_labels[f'label_w_{n_clust}_clusters'] = clusters

    cluster_label_df = pd.DataFrame(cluster_labels)
    cluster_label_df[[col for col in cluster_label_df.columns if col != 'DFCI_MRN']] += 1
    return cluster_label_df, inertias

In [161]:
# Load datasets
stage_df = pd.read_csv(os.path.join(FEATURE_PATH, 'cancer_stage_df.csv'))
type_df = pd.read_csv(os.path.join(FEATURE_PATH, 'cancer_type_df.csv'))

stage_cols = ["CANCER_STAGE_2.0", "CANCER_STAGE_3.0", "CANCER_STAGE_4.0"]
stage_df["STAGE"] = (stage_df[stage_cols]
                     .mul([2, 3, 4])
                     .sum(axis=1)
                     .replace(0, 1))
stage_df.drop(columns=stage_cols, inplace=True)

surv_traj = pd.read_csv(os.path.join(TRAJECTORY_PATH, f'survival_trajectories_w_decay_param_{decay_param}.csv'))

In [162]:
traj_long = (surv_traj.loc[~surv_traj['plus_0_months_data'].isna() & ~surv_traj['plus_3_months_data'].isna()]
             .melt(id_vars='DFCI_MRN',
                   value_vars = [c for c in surv_traj.columns if c.startswith('plus_')],
                   var_name='time',
                   value_name='risk_score')
             .dropna())
traj_long['time'] = traj_long['time'].apply(lambda x : int(x.split('_')[1]))
traj_long = traj_long.sort_values(by=['DFCI_MRN', 'time'])

features = (
    traj_long
    .groupby("DFCI_MRN")
    .apply(lambda x: pd.Series({
        "baseline": x.sort_values("time").iloc[0]["risk_score"],
        "slope": np.polyfit(x["time"], x["risk_score"], 1)[0],
        "auc": np.trapz(x["risk_score"], x["time"]),
        "max": x["risk_score"].max()}))
    .reset_index())

cluster_label_df, inertias = gen_clusters(features, [col for col in features.columns if col != 'DFCI_MRN'])

chosen_clust_num=4

  "auc": np.trapz(x["risk_score"], x["time"]),
  .apply(lambda x: pd.Series({
100%|██████████| 18/18 [00:03<00:00,  5.75it/s]


In [163]:
baseline_traj_label_df, _ = gen_clusters(features, ['baseline'])
full_traj_label_df, _ = gen_clusters(features, [col for col in features.columns if col != 'DFCI_MRN'])

cluster_col = f'label_w_{chosen_clust_num}_clusters'
baseline_traj_label_df = baseline_traj_label_df[['DFCI_MRN', cluster_col]].rename(columns={cluster_col : 'baseline_cluster_label'})
full_traj_label_df = full_traj_label_df[['DFCI_MRN', cluster_col]].rename(columns={cluster_col : 'full_trajectory_cluster_label'})

trajectory_cluster_df = baseline_traj_label_df.merge(full_traj_label_df, on='DFCI_MRN')

100%|██████████| 18/18 [00:02<00:00,  7.10it/s]
100%|██████████| 18/18 [00:02<00:00,  6.14it/s]


In [164]:
traj_plus_clust_label_df = surv_traj.merge(trajectory_cluster_df, on='DFCI_MRN')
data_cols = [col for col in traj_plus_clust_label_df.columns if 'plus_' in col]

mean_risk_by_full_traj = (traj_plus_clust_label_df[['full_trajectory_cluster_label'] + data_cols]
                          .groupby(by='full_trajectory_cluster_label')
                          .mean().mean(axis=1)
                          .sort_values(ascending=True)
                          .reset_index())

mean_risk_by_baseline = (traj_plus_clust_label_df[['baseline_cluster_label', 'plus_0_months_data']]
                         .groupby(by='baseline_cluster_label')
                         .mean()
                         .sort_values(by='plus_0_months_data', ascending=True)
                         .reset_index())

In [165]:
full_traj_update_cluster_dict = dict(zip(mean_risk_by_full_traj['full_trajectory_cluster_label'],
                                         mean_risk_by_full_traj.index.tolist()))
baseline_update_cluster_dict = dict(zip(mean_risk_by_baseline['baseline_cluster_label'],
                                        mean_risk_by_baseline.index.tolist()))

In [166]:
trajectory_cluster_df['baseline_cluster_label'] = trajectory_cluster_df['baseline_cluster_label'].map(baseline_update_cluster_dict) + 1
trajectory_cluster_df['full_trajectory_cluster_label'] = trajectory_cluster_df['full_trajectory_cluster_label'].map(full_traj_update_cluster_dict) + 1

In [167]:
cluster_stage_df = trajectory_cluster_df.merge(stage_df, on='DFCI_MRN')

In [168]:
cluster_stage_df

Unnamed: 0,DFCI_MRN,baseline_cluster_label,full_trajectory_cluster_label,STAGE
0,107014,3,3,3
1,124902,4,3,4
2,125735,1,2,3
3,132040,2,2,2
4,133178,1,2,3
...,...,...,...,...
6646,1178152,2,3,4
6647,1180002,4,3,4
6648,1180205,4,3,4
6649,1180349,2,3,1


## Stage vs. Cluster Distributions

In [169]:
baseline_rho, baseline_p = spearmanr(cluster_stage_df['STAGE'],
                                     cluster_stage_df['baseline_cluster_label'])

full_rho, full_p = spearmanr(cluster_stage_df['STAGE'], 
                             cluster_stage_df['full_trajectory_cluster_label'])

cross_rho, cross_p = spearmanr(cluster_stage_df['baseline_cluster_label'],
                               cluster_stage_df['full_trajectory_cluster_label'])

print(f'Spearman\'s rho between stage and baseline cluster = {baseline_rho : 0.2f} (p = {baseline_p : 0.3f})\n')
print(f'Spearman\'s rho between stage and full trajectory cluster = {full_rho : 0.2f} (p = {full_p : 0.3f})\n')
print(f'Spearman\'s rho between baseline and full trajectory cluster = {cross_rho : 0.2f} (p = {cross_rho : 0.3f})')

Spearman's rho between stage and baseline cluster =  0.22 (p =  0.000)

Spearman's rho between stage and full trajectory cluster =  0.15 (p =  0.000)

Spearman's rho between baseline and full trajectory cluster =  0.46 (p =  0.457)


In [170]:
stage_1_df = cluster_stage_df[cluster_stage_df['STAGE'] == 1]
stage_4_df = cluster_stage_df[cluster_stage_df['STAGE'] == 4]

baseline_1_df = cluster_stage_df[cluster_stage_df['baseline_cluster_label'] == 1]
baseline_4_df = cluster_stage_df[cluster_stage_df['baseline_cluster_label'] == 4]

full_1_df = cluster_stage_df[cluster_stage_df['full_trajectory_cluster_label'] == 1]
full_4_df = cluster_stage_df[cluster_stage_df['full_trajectory_cluster_label'] == 4]

In [171]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

def plot_proportion_bar(df, x_col, order, ax, title, xlabel, ylabel=None):
    s = pd.to_numeric(df[x_col], errors="coerce").dropna().astype(int)

    vc = s.value_counts(normalize=True).reindex(order, fill_value=0)

    prop_df = pd.DataFrame({
        x_col: pd.Categorical(vc.index, categories=order, ordered=True),
        "proportion": vc.values
    })

    sns.barplot(
        data=prop_df,
        x=x_col,
        y="proportion",
        ax=ax
    )

    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel("Proportion" if ylabel is None else ylabel)
    ax.set_ylim(0, 1)

    # Force correct tick labels
    ax.set_xticks(range(len(order)))
    ax.set_xticklabels(order)

In [172]:
fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharey=False)

plot_proportion_bar(
    stage_1_df,
    x_col="baseline_cluster_label",
    order=[1, 2, 3, 4],
    ax=axes[0,0],
    title="Baseline Cluster Dist. in Stage 1",
    xlabel="Baseline Cluster",
    ylabel="Proportion of Stage 1")

plot_proportion_bar(
    stage_4_df,
    x_col="baseline_cluster_label",
    order=[1, 2, 3, 4],
    ax=axes[0,1],
    title="Baseline Cluster Dist. in Stage 4",
    xlabel="Baseline Cluster",
    ylabel='Proportion of Stage 4')

plot_proportion_bar(
    stage_1_df,
    x_col="full_trajectory_cluster_label",
    order=[1, 2, 3, 4],
    ax=axes[1,0],
    title="Full Trajectory Cluster Dist. in Stage 1",
    xlabel="Full Trajectory Cluster",
    ylabel="Proportion of Stage 1")

plot_proportion_bar(
    stage_4_df,
    x_col="full_trajectory_cluster_label",
    order=[1, 2, 3, 4],
    ax=axes[1,1],
    title="Full Trajectory Cluster Dist. in Stage 4",
    xlabel="Full Trajectory Cluster",
    ylabel='Proportion of Stage 4')

plt.tight_layout()
plt.savefig(os.path.join(STAGE_VS_TRAJECTORY_FIG_PATH, 'clusters_stratified_by_stage_barplot.png'))
plt.close()

In [173]:
fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharey=False)

plot_proportion_bar(
    baseline_1_df,
    x_col="STAGE",
    order=[1, 2, 3, 4],
    ax=axes[0,0],
    title="Stage Dist. in Baseline Cluster 1",
    xlabel="Stage",
    ylabel="Proportion of Baseline Cluster 1")

plot_proportion_bar(
    baseline_4_df,
    x_col="STAGE",
    order=[1, 2, 3, 4],
    ax=axes[0,1],
    title="Stage Dist. in Baseline Cluster 4",
    xlabel="Stage",
    ylabel='Proportion of Baseline Cluster 4')

plot_proportion_bar(
    full_1_df,
    x_col="STAGE",
    order=[1, 2, 3, 4],
    ax=axes[1,0],
    title="Stage Dist. in Full Trajectory Cluster 1",
    xlabel="Stage",
    ylabel="Proportion of Full Trajectory Cluster 1")

plot_proportion_bar(
    full_4_df,
    x_col="STAGE",
    order=[1, 2, 3, 4],
    ax=axes[1,1],
    title="Stage Dist. in Full Trajectory Cluster 4",
    xlabel="Stage",
    ylabel='Proportion of Full Trajectory Cluster 4')

plt.tight_layout()
plt.savefig(os.path.join(STAGE_VS_TRAJECTORY_FIG_PATH, 'stage_stratified_by_clusters_barplot.png'))
plt.close()

## Stage vs. Trajectory KM Curves

In [180]:
surv_df = pd.read_csv(os.path.join(SURV_PATH, 'level_3_ICD_surv_df.csv'), usecols=['DFCI_MRN', 'death', 'tt_death'])
tt_death_w_clusters_df = cluster_stage_df.merge(surv_df, on='DFCI_MRN')



In [207]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from lifelines import KaplanMeierFitter
from lifelines.statistics import multivariate_logrank_test
from lifelines.utils import concordance_index

LABEL_ORDER = [1, 2, 3, 4]
HORIZON_MONTHS = 60
HORIZON_DAYS = int(round(HORIZON_MONTHS * (365.25 / 12)))

# -----------------------
# Helpers
# -----------------------
def _prep_surv_df(df, label_col, time_col="tt_death", event_col="death"):
    d = df[[label_col, time_col, event_col]].copy()
    d[label_col] = pd.to_numeric(d[label_col], errors="coerce")
    d[time_col] = pd.to_numeric(d[time_col], errors="coerce")
    d[event_col] = pd.to_numeric(d[event_col], errors="coerce")
    d = d.dropna(subset=[label_col, time_col, event_col])
    d[label_col] = d[label_col].astype(int)
    d[event_col] = d[event_col].astype(int)
    return d

def logrank_pvalue(df, label_col, time_col="tt_death", event_col="death"):
    d = _prep_surv_df(df, label_col, time_col, event_col)
    # Need at least 2 groups with events/observations
    if d[label_col].nunique() < 2:
        return np.nan
    res = multivariate_logrank_test(
        event_durations=d[time_col],
        groups=d[label_col],
        event_observed=d[event_col]
    )
    return float(res.p_value)

def cindex_for_label(df, label_col, time_col="tt_death", event_col="death"):
    """
    Computes c-index using the label as an ordinal risk score.
    Since "higher label = higher risk" might not be guaranteed,
    we compute both directions and return the better one.

    Returns: (cindex, direction_str)
      direction_str is either "+label" or "-label" indicating which direction performed better.
    """
    d = _prep_surv_df(df, label_col, time_col, event_col)
    if len(d) == 0:
        return np.nan, "NA"
    # lifelines' concordance_index expects higher score -> higher risk (shorter survival)
    c_pos = concordance_index(d[time_col], d[label_col].astype(float), d[event_col])
    c_neg = concordance_index(d[time_col], -d[label_col].astype(float), d[event_col])

    if np.isnan(c_pos) and np.isnan(c_neg):
        return np.nan, "NA"

    if c_pos >= c_neg:
        return float(c_pos), "+label"
    else:
        return float(c_neg), "-label"

def km_plot_by_groups(df, label_col, ax, title, time_col="tt_death", event_col="death", order=LABEL_ORDER, max_time=None):
    d = _prep_surv_df(df, label_col, time_col, event_col)

    kmf = KaplanMeierFitter()
    for g in order:
        dd = d[d[label_col] == g]
        if len(dd) == 0:
            continue
        kmf.fit(dd[time_col], event_observed=dd[event_col], label=f"{g}")
        kmf.plot(ax=ax, ci_show=True)

    ax.set_title(title)
    ax.set_xlabel("Time (days)")
    ax.set_ylabel("Survival probability")
    if max_time is not None:
        ax.set_xlim(0, max_time)
    ax.legend(title=label_col, frameon=False)

def format_p(p):
    if p is None or np.isnan(p):
        return "p=NA"
    if p < 1e-4:
        return "p<1e-4"
    return f"p={p:.3g}"

# -----------------------
# Time-0 evaluation (full follow-up)
#   - Baseline cluster vs Stage
# -----------------------
p_base_stage1 = logrank_pvalue(tt_death_w_clusters_df, "baseline_cluster_label", "tt_death", "death")
c_base, dir_base = cindex_for_label(tt_death_w_clusters_df, "baseline_cluster_label", "tt_death", "death")

p_stage_full = logrank_pvalue(tt_death_w_clusters_df, "STAGE", "tt_death", "death")
c_stage0, dir_stage0 = cindex_for_label(tt_death_w_clusters_df, "STAGE", "tt_death", "death")

fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

km_plot_by_groups(
    tt_death_w_clusters_df, "baseline_cluster_label", axes[0],
    title=f"Baseline clusters @ time 0 ({format_p(p_base_stage1)}, c={c_base:.3f})",
    time_col="tt_death", event_col="death", order=LABEL_ORDER
)

km_plot_by_groups(
    tt_death_w_clusters_df, "STAGE", axes[1],
    title=f"Stage @ time 0 ({format_p(p_stage_full)}, c={c_stage0:.3f})",
    time_col="tt_death", event_col="death", order=LABEL_ORDER
)

plt.tight_layout()
plt.savefig(os.path.join(STAGE_VS_TRAJECTORY_FIG_PATH, 'baseline_clusters_vs_stage_KM_curves.png'))
plt.close()

In [208]:
tt_full_traj_df = tt_death_w_clusters_df.copy()
tt_full_traj_df['tt_death_full_traj'] = tt_full_traj_df['tt_death'] - HORIZON_DAYS
tt_full_traj_df = tt_full_traj_df.loc[tt_full_traj_df['tt_death_full_traj'] > 0]

p_fulltraj = logrank_pvalue(tt_full_traj_df, "full_trajectory_cluster_label", "tt_death_full_traj", "death")
c_fulltraj, dir_fulltraj = cindex_for_label(tt_full_traj_df, "full_trajectory_cluster_label", "tt_death_full_traj", "death")

p_stage_60 = logrank_pvalue(tt_full_traj_df, "STAGE", "tt_death_full_traj", "death")
c_stage60, dir_stage60 = cindex_for_label(tt_full_traj_df, "STAGE", "tt_death_full_traj", "death")

fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

km_plot_by_groups(
    tt_full_traj_df, "full_trajectory_cluster_label", axes[0],
    title=f"Full-trajectory clusters (≤{HORIZON_MONTHS}mo) ({format_p(p_fulltraj)}, c={c_fulltraj:.3f})",
    time_col="tt_death_full_traj", event_col="death", order=LABEL_ORDER, max_time=HORIZON_DAYS
)

km_plot_by_groups(
    tt_full_traj_df, "STAGE", axes[1],
    title=f"Stage (≤{HORIZON_MONTHS}mo) ({format_p(p_stage_60)}, c={c_stage60:.3f})",
    time_col="tt_death_full_traj", event_col="death", order=LABEL_ORDER, max_time=HORIZON_DAYS
)

plt.tight_layout()
plt.savefig(os.path.join(STAGE_VS_TRAJECTORY_FIG_PATH, 'full_trajectory_clusters_vs_stage_KM_curves.png'))
plt.close()

In [209]:
tt_stage_4_df = tt_death_w_clusters_df.copy()
tt_stage_4_df = tt_stage_4_df.loc[tt_stage_4_df['STAGE'] == 4]
tt_stage_4_df['tt_death_full_traj'] = tt_stage_4_df['tt_death'] - HORIZON_DAYS
tt_stage_4_full_traj_df = tt_stage_4_df.loc[tt_stage_4_df['tt_death_full_traj'] > 0]

p_baseline = logrank_pvalue(tt_stage_4_df, 'baseline_cluster_label', 'tt_death', 'death')
c_baseline, dir_baseline = cindex_for_label(tt_stage_4_df, 'baseline_cluster_label', 'tt_death', 'death')

p_fulltraj = logrank_pvalue(tt_stage_4_df, "full_trajectory_cluster_label", "tt_death_full_traj", "death")
c_fulltraj, dir_fulltraj = cindex_for_label(tt_stage_4_df, "full_trajectory_cluster_label", "tt_death_full_traj", "death")

fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

km_plot_by_groups(
    tt_stage_4_df, "baseline_cluster_label", axes[0],
    title=f"Baseline clusters (Just Stage 4 Px's) ({format_p(p_baseline)}, c={c_baseline:.3f})",
    time_col="tt_death", event_col="death", order=LABEL_ORDER
)

km_plot_by_groups(
    tt_stage_4_full_traj_df, "full_trajectory_cluster_label", axes[1],
    title=f"Full trajectory clusters (Just Stage 4 Px's) ({format_p(p_fulltraj)}, c={c_fulltraj:.3f})",
    time_col="tt_death_full_traj", event_col="death", order=LABEL_ORDER, max_time=HORIZON_DAYS
)

plt.tight_layout()
plt.savefig(os.path.join(STAGE_VS_TRAJECTORY_FIG_PATH, 'baseline_vs_full_trajectory_stage_4_KM_curves.png'))
plt.close()

In [210]:
tt_stage_1_df = tt_death_w_clusters_df.copy()
tt_stage_1_df = tt_stage_1_df.loc[tt_stage_1_df['STAGE'] == 1]
tt_stage_1_df['tt_death_full_traj'] = tt_stage_1_df['tt_death'] - HORIZON_DAYS
tt_stage_1_full_traj_df = tt_stage_1_df.loc[tt_stage_1_df['tt_death_full_traj'] > 0]

p_baseline = logrank_pvalue(tt_stage_1_df, 'baseline_cluster_label', 'tt_death', 'death')
c_baseline, dir_baseline = cindex_for_label(tt_stage_1_df, 'baseline_cluster_label', 'tt_death', 'death')

p_fulltraj = logrank_pvalue(tt_stage_1_df, "full_trajectory_cluster_label", "tt_death_full_traj", "death")
c_fulltraj, dir_fulltraj = cindex_for_label(tt_stage_1_df, "full_trajectory_cluster_label", "tt_death_full_traj", "death")

fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

km_plot_by_groups(
    tt_stage_1_df, "baseline_cluster_label", axes[0],
    title=f"Baseline clusters (Just Stage 1 Px's) ({format_p(p_baseline)}, c={c_baseline:.3f})",
    time_col="tt_death", event_col="death", order=LABEL_ORDER
)

km_plot_by_groups(
    tt_stage_1_full_traj_df, "full_trajectory_cluster_label", axes[1],
    title=f"Full trajectory clusters (Just Stage 1 Px's) ({format_p(p_fulltraj)}, c={c_fulltraj:.3f})",
    time_col="tt_death_full_traj", event_col="death", order=LABEL_ORDER, max_time=HORIZON_DAYS
)

plt.tight_layout()
plt.savefig(os.path.join(STAGE_VS_TRAJECTORY_FIG_PATH, 'baseline_vs_full_trajectory_stage_1_KM_curves.png'))
plt.close()