In [1]:
import os
import icd10
from tqdm import tqdm
from functools import reduce
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr  
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

# Paths
PROJ_PATH = '/data/gusev/USERS/jpconnor/clinical_text_project/'
FIGURE_PATH = os.path.join(PROJ_PATH, 'figures/model_metrics/')
DATA_PATH = os.path.join(PROJ_PATH, 'data/')
FEATURE_PATH = os.path.join(DATA_PATH, 'clinical_and_genomic_features/')
SURV_PATH = os.path.join(DATA_PATH, 'survival_data/')
RESULTS_PATH = os.path.join(SURV_PATH, 'results/icd_results/')
LANDMARK_PATH = os.path.join(RESULTS_PATH, 'landmark_results/')

FIGURE_PATH = os.path.join(PROJ_PATH, 'figures/landmark_analysis/')
FEATURE_FIG_PATH = os.path.join(FIGURE_PATH, 'feature_based_clusters/')
UNSCALED_FEATURE_FIG_PATH = os.path.join(FEATURE_FIG_PATH, 'unscaled/')
ADJUSTED_FEATURE_FIG_PATH = os.path.join(FEATURE_FIG_PATH, 'adjusted/')
SCALED_FEATURE_FIG_PATH = os.path.join(FEATURE_FIG_PATH, 'scaled/')

def zscore(ts):
    return (ts - ts.mean()) / (ts.std() + 1e-8)

def gen_clusters(feature_df, feature_cols, clusters_to_test=[i+2 for i in range(18)]):
    inertias = []
    silhouette_scores = []
    cluster_labels = {'event' : feature_df['event']}
    for n_clust in tqdm(clusters_to_test):
        km = KMeans(n_clusters=n_clust, random_state=0).fit(feature_df[feature_cols])
        clusters = km.predict(feature_df[feature_cols])
        inertias.append(km.inertia_)
        cluster_labels[f'label_w_{n_clust}_clusters'] = clusters

    cluster_label_df = pd.DataFrame(cluster_labels)
    cluster_label_df[[col for col in cluster_label_df.columns if col != 'event']] += 1
    return cluster_label_df, inertias

def gen_mean_risk_traj_plot(long_traj_df, output_path, x='time', y='c_index', title = 'Mean C-Index Trajectories by Cluster', figsize=(10,6)):
    plt.figure(figsize=figsize)
    sns.lineplot(
        data=long_traj_df,
        x=x,
        y=y,
        hue='cluster_label',
        estimator='mean',
        errorbar='sd',
        linewidth=3)
    
    plt.title(title)
    plt.ylabel(y)
    plt.xlabel(x)
    plt.legend(title='Cluster')
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()
    
def gen_spaghetti_plot(long_traj_df, output_path, x='time', y='c_index', title='C-Index Trajectories by Cluster'):
    g=sns.relplot(
        data=long_traj_df,
        x=x,
        y=y,
        col='cluster_label',
        kind='line',
        units='event',
        estimator=None,
        alpha=0.2,
        linewidth=1,
        col_wrap=2,
        height=3,
        aspect=1.2)
    
    g.set_axis_labels(x, y)
    g.fig.suptitle(title, y=1.05)
    plt.savefig(output_path)
    plt.close()
    
def gen_trajectory_heatmap(long_traj_df, output_path, x='time', y='c_index', title='C-Index Heatmap Ordered by Cluster', figsize=(10,8)):
    heatmap_df = (long_traj_df
                  .pivot_table(index='event',
                               columns=x,
                               values=y))
    
    order = (long_traj_df
             .groupby('event')
             .agg(cluster=('cluster_label', 'first'),
                  mean_risk=(y, 'mean'))
             .sort_values(['cluster', 'mean_risk'])
             .index)
    
    ordered_clusters = (
        long_traj_df
        .groupby('event')['cluster_label']
        .first()
        .loc[order])
    
    cluster_sizes = ordered_clusters.value_counts(sort=False)
    cluster_starts = np.cumsum([0] + cluster_sizes.tolist()[:-1])
    cluster_mids = cluster_starts + cluster_sizes.values /2
    
    plt.figure(figsize=figsize)
    ax = sns.heatmap(
        heatmap_df.loc[order],
        cmap='viridis',
        center=0,
        cbar_kws={'label' : y},
        yticklabels=False)
    
    ax.set_yticks(cluster_mids)
    ax.set_yticklabels([f'Cluster {c}' for c in cluster_sizes.index], rotation=0)
    
    for y in cluster_starts[1:]:
        ax.hlines(y, *ax.get_xlim(), colors='black', linewidth=0.5)
        
    plt.title(title)
    plt.xlabel(x)
    plt.ylabel(y)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

In [2]:
months_to_eval = [0, 3, 6, 9, 12]
events = set.intersection(*[set(os.listdir(LANDMARK_PATH + f'plus_{month}_months/')) for month in months_to_eval])

event_descr_data = []
for event in events:
    if icd10.exists(event):
        event_code = icd10.find(event)
        event_descr = event_code.description
        try:
            event_chapter = event_code.chapter
            event_block_descr = event_code.block_description
        except:
            event_chapter = None; event_block_descr = None;
    else:
        event_descr = None; event_chapter = None; event_block_descr = None;
    event_descr_data.append([event, event_descr, event_chapter, event_block_descr])

event_descr_df = pd.DataFrame(event_descr_data, columns=['event', 'event_descr', 'event_chapter', 'event_block_descr'])                

In [3]:
c_index_metric_dfs = []; auc_metric_dfs = [];
for month in months_to_eval:
    month_path = os.path.join(LANDMARK_PATH, f'plus_{month}_months/')
    
    c_index_data = []; mean_auc_data = [];
    for event in events:
        event_path = os.path.join(month_path, event)
        
        full_text_metrics = pd.read_csv(os.path.join(event_path, 'text_test.csv'))
        full_base_metrics = pd.read_csv(os.path.join(event_path, 'type_model_metrics.csv'))
        full_base_test_metrics = full_base_metrics.loc[full_base_metrics['eval_data'] == 'test_data']
        
        full_base_c_index, full_base_mean_auc = full_base_test_metrics[['mean_c_index', 'mean_auc(t)']].values[0]
        full_text_c_index, full_text_mean_auc = full_text_metrics[['mean_c_index', 'mean_auc(t)']].values[0]

        c_index_data.append([event, full_base_c_index, full_text_c_index])
        mean_auc_data.append([event, full_base_mean_auc, full_text_mean_auc])

    c_index_metric_dfs.append(pd.DataFrame(c_index_data, columns=['event', f'{month}_months_base_c_index', f'{month}_months_text_c_index']))
    auc_metric_dfs.append(pd.DataFrame(mean_auc_data, columns=['event', f'{month}_months_base_mean_auc(t)', f'{month}_months_text_mean_auc(t)']))

complete_c_index_metric_df = reduce(lambda left, right : pd.merge(left, right, on='event', how='inner'), c_index_metric_dfs)
complete_auc_metric_df = reduce(lambda left, right : pd.merge(left, right, on='event', how='inner'), auc_metric_dfs)

# complete_c_index_metric_df.to_csv(os.path.join(RESULTS_PATH, 'landmark_analysis_c_index_df.csv'), index=False)
# complete_auc_metric_df.to_csv(os.path.join(RESULTS_PATH, 'landmark_analysis_mean_auc_df.csv'), index=False)

# Feature-Based Clustering

## Unscaled

In [4]:
text_metric_cols = [c for c in complete_c_index_metric_df.columns if c.endswith("_months_text_c_index")]

landmark_long_df = (complete_c_index_metric_df
                   .melt(id_vars='event',
                        value_vars=text_metric_cols,
                        var_name='time',
                        value_name='c_index')
                    .dropna())
landmark_long_df['time'] = landmark_long_df['time'].apply(lambda x : int(x.split('_')[0]))
landmark_long_df = landmark_long_df.sort_values(by=['event', 'time'])

In [5]:
features = (landmark_long_df
            .groupby('event')
            .apply(lambda x : pd.Series({
                'baseline' : x.sort_values('time').iloc[0]['c_index'],
                'slope' : np.polyfit(x['time'], x['c_index'], 1)[0],
                'auc' : np.trapezoid(x['c_index'], x['time']),
                'max' : x['c_index'].max()}))
            .reset_index())
cluster_label_df, inertias = gen_clusters(features, [col for col in features.columns if col != 'event'])

  'slope' : np.polyfit(x['time'], x['c_index'], 1)[0],
  .apply(lambda x : pd.Series({
100%|██████████| 18/18 [00:00<00:00, 64.92it/s]


In [6]:
chosen_clust_num=4
plt.plot(range(len(inertias)), inertias)
plt.axvline(x=chosen_clust_num, color='red', label=f'{chosen_clust_num} clusters')
plt.legend()
plt.title('Raw Landmark Results Elbow Plot')
plt.savefig(os.path.join(UNSCALED_FEATURE_FIG_PATH, 'elbow_plot.png'))
plt.close()

In [7]:
trajs_to_plot = (landmark_long_df
                 .merge(cluster_label_df[['event', f'label_w_{chosen_clust_num}_clusters']], on='event')
                 .rename(columns={f'label_w_{chosen_clust_num}_clusters' : 'cluster_label'})
                 .sort_values(by=['cluster_label', 'event', 'time']))
gen_mean_risk_traj_plot(trajs_to_plot, output_path=os.path.join(UNSCALED_FEATURE_FIG_PATH, 'mean_c_index_trajectories_by_cluster.png'))
gen_spaghetti_plot(trajs_to_plot, output_path=os.path.join(UNSCALED_FEATURE_FIG_PATH, 'spaghetti_plot_by_cluster.png'))
gen_trajectory_heatmap(trajs_to_plot, output_path=os.path.join(UNSCALED_FEATURE_FIG_PATH, 'heatmap_by_cluster.png'))

## Adjusted

In [8]:
adjusted_c_index_metric_df = complete_c_index_metric_df.copy()
adjusted_c_index_metric_df[text_metric_cols] = adjusted_c_index_metric_df[text_metric_cols] - adjusted_c_index_metric_df['0_months_text_c_index'].values.reshape(-1,1)

adjusted_landmark_long_df = (adjusted_c_index_metric_df
                             .melt(id_vars='event', 
                                   value_vars=text_metric_cols, 
                                   var_name='time', 
                                   value_name='c_index').dropna())

adjusted_landmark_long_df['time'] = adjusted_landmark_long_df['time'].apply(lambda x : int(x.split('_')[0]))
adjusted_landmark_long_df = adjusted_landmark_long_df.sort_values(by=['event', 'time'])

In [9]:
adjusted_features = (adjusted_landmark_long_df
                     .groupby('event')
                     .apply(lambda x : pd.Series({'baseline' : x.sort_values('time').iloc[0]['c_index'], 
                                                  'slope' : np.polyfit(x['time'], x['c_index'], 1)[0], 
                                                  'auc' : np.trapezoid(x['c_index'], x['time']), 
                                                  'max' : x['c_index'].max()}))
                     .reset_index())
adjusted_cluster_label_df, inertias = gen_clusters(adjusted_features, [col for col in features.columns if col != 'event'])

  .apply(lambda x : pd.Series({'baseline' : x.sort_values('time').iloc[0]['c_index'],
100%|██████████| 18/18 [00:00<00:00, 122.06it/s]


In [10]:
chosen_clust_num=4
plt.plot(range(len(inertias)), inertias)
plt.axvline(x=chosen_clust_num, color='red', label=f'{chosen_clust_num} clusters')
plt.legend()
plt.title('Adjusted Landmark Results Elbow Plot')
plt.savefig(os.path.join(ADJUSTED_FEATURE_FIG_PATH, 'elbow_plot.png'))
plt.close()

In [11]:
trajs_to_plot = (adjusted_landmark_long_df
                 .merge(adjusted_cluster_label_df[['event', f'label_w_{chosen_clust_num}_clusters']], on='event')
                 .rename(columns={f'label_w_{chosen_clust_num}_clusters' : 'cluster_label'})
                 .sort_values(by=['cluster_label', 'event', 'time']))
gen_mean_risk_traj_plot(trajs_to_plot, output_path=os.path.join(ADJUSTED_FEATURE_FIG_PATH, 'mean_c_index_trajectories_by_cluster.png'))
gen_spaghetti_plot(trajs_to_plot, output_path=os.path.join(ADJUSTED_FEATURE_FIG_PATH, 'spaghetti_plot_by_cluster.png'))
gen_trajectory_heatmap(trajs_to_plot, output_path=os.path.join(ADJUSTED_FEATURE_FIG_PATH, 'heatmap_by_cluster.png'))