In [153]:
import os
import warnings
import pandas as pd
import numpy as np
from joblib import Parallel, delayed, parallel_backend
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored
import matplotlib.pyplot as plt
import seaborn as sns

# =====================
# Config / Options
# =====================
pd.options.mode.chained_assignment = None  # suppress SettingWithCopy warnings

# Paths
PROJECT_PATH = '/data/gusev/USERS/jpconnor/clinical_text_project/'
DATA_PATH = os.path.join(PROJECT_PATH, 'data/')
SURV_PATH = os.path.join(DATA_PATH, 'survival_data/')
RESULTS_PATH = os.path.join(SURV_PATH, 'results/')
NOTES_PATH = os.path.join(DATA_PATH, 'batched_datasets/VTE_data/processed_datasets/')
STAGE_PATH = '/data/gusev/PROFILE/CLINICAL/OncDRS/DERIVED_FROM_CLINICAL_TEXTS_2024_03/derived_files/cancer_stage/'
OUTPUT_PATH = os.path.join(RESULTS_PATH, 'phecode_model_comps')
HELD_OUT_PRED_PATH = os.path.join(RESULTS_PATH, 'phecode_held_out_preds')
FIGURE_PATH = os.path.join(PROJECT_PATH, 'figures/')
os.makedirs(OUTPUT_PATH, exist_ok=True)

events = [file.split('_')[0] for file in os.listdir(HELD_OUT_PRED_PATH)]

def evaluate_risk_model_cv(event, data_path=HELD_OUT_PRED_PATH, n_splits=5, n_permutations=5, random_state=42):
    df = pd.read_csv(os.path.join(data_path, event + '_cv_preds.csv'))
    
    duration_col = f"tt_{event}"
    event_col = event
    features = ["text_risk_score", "stage_type_risk_score", "ehr_risk_score", "genomics_risk_score"]

    # Drop NA rows
    subset = df.dropna(subset=[duration_col, event_col] + features)

    X = subset[features].reset_index(drop=True)
    y = np.array(
        [(bool(e), t) for e, t in zip(subset[event_col], subset[duration_col])],
        dtype=[('event', '?'), ('duration', 'f8')]
    )

    kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    cindices = []
    coefs = []
    perm_drops = {feat: [] for feat in features}

    for train_idx, test_idx in kf.split(X):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)

        model = CoxPHSurvivalAnalysis()
        model.fit(X_train_scaled, y_train)

        preds = model.predict(X_test_scaled)
        base_cindex = concordance_index_censored(y_test["event"], y_test["duration"], preds)[0]
        cindices.append(base_cindex)
        coefs.append(model.coef_)

        # Permutation importance
        for feat in features:
            drops = []
            for _ in range(n_permutations):
                X_test_perm = X_test.copy()
                X_test_perm[feat] = np.random.permutation(X_test_perm[feat].values)
                X_test_perm_scaled = scaler.transform(X_test_perm)
                perm_preds = model.predict(X_test_perm_scaled)
                perm_cindex = concordance_index_censored(y_test["event"], y_test["duration"], perm_preds)[0]
                drops.append(base_cindex - perm_cindex)
            perm_drops[feat].append(np.mean(drops))

    # Aggregate results
    mean_cindex = np.mean(cindices)
    mean_coefs = np.mean(np.array(coefs), axis=0)
    mean_perm_importance = {feat: np.mean(perm_drops[feat]) for feat in features}

    # Clean feature names by removing "_risk_score"
    clean_features = [feat.replace("_risk_score", "") for feat in features]

    # Map permutation importances to clean names
    mean_perm_importance_clean = {
        clean_feat: mean_perm_importance[orig_feat]
        for clean_feat, orig_feat in zip(clean_features, features)
    }

    # Determine most important clean feature
    most_important_feat = max(mean_perm_importance_clean, key=mean_perm_importance_clean.get)

    result = {
        "event": event,
        "cindex": mean_cindex,
        **{f"coef_{clean_feat}": mean_coefs[i] for i, clean_feat in enumerate(clean_features)},
        **{f"perm_{clean_feat}": mean_perm_importance_clean[clean_feat] for clean_feat in clean_features},
        "most_important_feature": most_important_feat
    }

    return result

In [154]:
results = Parallel(n_jobs=-1, backend="loky")(  # uses all CPUs
    delayed(evaluate_risk_model_cv)(event) for event in events
)
final_results_df = pd.DataFrame(results)
final_results_df = final_results_df.loc[final_results_df['event'] != '8.0']

  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(*arrays, *other_args, **kwargs)
  return f(

In [155]:
final_results_df['most_important_feature'].value_counts()

most_important_feature
text          100
stage_type      6
ehr             6
genomics        2
Name: count, dtype: int64

In [156]:
# Set 'event' as index and keep only importance columns
importance_cols = ['perm_text', 'perm_ehr', 'perm_genomics', 'perm_stage_type']
heatmap_data = final_results_df.set_index('event')[importance_cols]

# Optional: Normalize each row (event) to better visualize relative importance
# heatmap_data = heatmap_data.div(heatmap_data.sum(axis=1), axis=0)

# Clustered heatmap
g = sns.clustermap(
    heatmap_data,
    method='average',
    metric='euclidean',
    cmap='viridis',
    annot=False,
    figsize=(12, max(8, len(heatmap_data) *0.1)),  # Dynamically size based on number of events
    cbar_kws={"label": "Feature Importance"},
    linewidths=0
)
g.ax_heatmap.set_yticklabels([])
g.ax_heatmap.yaxis.set_ticks([])
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=16)
g.ax_heatmap.set_ylabel('')
plt.suptitle("Clustered Heatmap of Risk Score Importances Across All Events", fontsize=20, y=1.02)

g.fig.savefig(os.path.join(FIGURE_PATH, 'risk_score_models/all_risk_score_feature_importances_clustermap.png'), dpi=300, bbox_inches='tight')
plt.close()

top40_df = final_results_df.sort_values(by='cindex', ascending=False).head(40)

icd_to_phecode_map = pd.read_csv(os.path.join(DATA_PATH, 'code_data/icd_to_phecode_map.csv'))
phecode_descr_dict = dict(zip(icd_to_phecode_map['PHECODE'].astype(str), icd_to_phecode_map['PHECODE_DESCR'].astype(str)))

top40_df['event_descr'] = top40_df['event'].apply(lambda x : phecode_descr_dict[x] if x in phecode_descr_dict.keys() else x)

heatmap_data = top40_df.set_index('event_descr')[importance_cols]

# Optional: Normalize each row (event) to better visualize relative importance
# heatmap_data = heatmap_data.div(heatmap_data.sum(axis=1), axis=0)

# Clustered heatmap
g = sns.clustermap(
    heatmap_data,
    method='average',
    metric='euclidean',
    cmap='viridis',
    annot=True,
    fmt='.2f',
    figsize=(18, max(8, len(heatmap_data) *0.4)),  # Dynamically size based on number of events
    cbar_kws={"label": "Feature Importance"},
    linewidths=0
)
g.fig.suptitle("Clustered Heatmap of Risk Score Importances Across Top 40 Events by C-Index", fontsize=20, y=1.02)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=16)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize=16)

g.fig.savefig(os.path.join(FIGURE_PATH, 'risk_score_models/top_40_risk_score_feature_importances_clustermap.png'), dpi=300, bbox_inches='tight')
plt.close()

In [157]:
# Set 'event' as index and keep only importance columns
importance_cols = ['coef_text', 'coef_ehr', 'coef_genomics', 'coef_stage_type']
heatmap_data = final_results_df.set_index('event')[importance_cols]

# Optional: Normalize each row (event) to better visualize relative importance
# heatmap_data = heatmap_data.div(heatmap_data.sum(axis=1), axis=0)

# Clustered heatmap
g = sns.clustermap(
    heatmap_data,
    method='average',
    metric='euclidean',
    cmap='viridis',
    annot=False,
    figsize=(12, max(8, len(heatmap_data) *0.1)),  # Dynamically size based on number of events
    cbar_kws={"label": "Coefficient"},
    linewidths=0
)
g.ax_heatmap.set_yticklabels([])
g.ax_heatmap.yaxis.set_ticks([])
g.ax_heatmap.set_ylabel('')
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=16)
plt.suptitle("Clustered Heatmap of Risk Score Coefficients Across All Events", fontsize=20, y=1.02)

g.fig.savefig(os.path.join(FIGURE_PATH, 'risk_score_models/all_risk_score_coefficients_clustermap.png'), dpi=300, bbox_inches='tight')
plt.close()

In [158]:
top40_df = final_results_df.sort_values(by='cindex', ascending=False).head(40)
top40_df['event_descr'] = top40_df['event'].apply(lambda x : phecode_descr_dict[x] if x in phecode_descr_dict.keys() else x)

heatmap_data = top40_df.set_index('event_descr')[importance_cols]

# Optional: Normalize each row (event) to better visualize relative importance
# heatmap_data = heatmap_data.div(heatmap_data.sum(axis=1), axis=0)

# Clustered heatmap
g = sns.clustermap(
    heatmap_data,
    method='average',
    metric='euclidean',
    cmap='viridis',
    annot=True,
    fmt='.2f',
    figsize=(18, max(8, len(heatmap_data) *0.4)),  # Dynamically size based on number of events
    cbar_kws={"label": "Coefficient"},
    linewidths=0
)
g.fig.suptitle("Clustered Heatmap of Risk Score Coefficients Across Top 40 Events by C-Index", fontsize=20, y=1.02)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize=16)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize=16)

g.fig.savefig(os.path.join(FIGURE_PATH, 'risk_score_models/top_40_risk_score_coefficients_clustermap.png'), dpi=300, bbox_inches='tight')
plt.close()

In [174]:
phecode_descr_dict['783.0']

'Fever of unknown origin'