## Anchor Direction Visualization Framework: DeepSurv trained on the SUPPORT dataset (where the encoder has a Euclidean norm 1 constraint) -- using different numbers of clusters still with a mixture of von Mises-Fisher distributions

This notebook is a modified and slightly shortened version of the main SUPPORT DeepSurv (norm 1 constraint) notebook, where the main difference is just that we try multiple numbers of clusters. We still cluster using a mixture of von Mises-Fisher distributions.

*To save space, due to the supplemental file size constraint, this notebook needs to be re-run to see the code output.*

### Loading in data (including some outputs from the already trained neural survival analysis model)

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['font.family'] = ['sans-serif', 'Arial']
matplotlib.rcParams['text.usetex'] = False
matplotlib.rcParams['mathtext.fontset'] = 'stixsans'
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

import numpy as np
import pandas as pd
import scipy.stats

from visualization_utils import get_experiment_data, l2_normalize_rows, longest_common_prefix, \
    compute_median_survival_times

Some explanation of the variables below:

- `emb_direction`: these are embedding vectors for anchor direction estimation data
- `emb_vis`: these are embedding vectors for the visualization raw inputs
- `raw_direction`: raw inputs of the anchor direction estimation data
- `raw_vis`: visualization raw inputs
- `label_direction`: survival labels (2 columns: column 0 stores observed times and column 1 stores event indicators) of the anchor direction estimation data
- `label_vis`: survival labels of the visualization raw inputs
- `unique_train_times`: discretized time grid used for the predicted survival curves
- `predicted_surv_vis`: predicted survival curves of the visualization raw inputs (so each visualization raw input has a predicted survival curve that is specified for the time grid given by `unique_train_times`

In [None]:
emb_direction, emb_vis, raw_direction, raw_vis, label_direction, label_vis, \
    _, unique_train_times, predicted_surv_vis \
        = get_experiment_data('support', '../train_models/output_tabular_hypersphere')

For example, in this case, there are 665 anchor estimation data points and the embedding dimension is 10:

In [None]:
emb_direction.shape

### 2D PCA plot of the visualization data

In [None]:
median_surv_time_estimates = compute_median_survival_times(predicted_surv_vis, unique_train_times)

In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
emb_vis_pca_2d = pca.fit_transform(emb_vis)
plt.axis('equal')
plt.scatter(emb_vis_pca_2d[:, 0], emb_vis_pca_2d[:, 1], alpha=.6, c=median_surv_time_estimates, cmap='flare')
plt.colorbar()
plt.title('2D PCA plot of a DeepSurv embedding space (SUPPORT dataset)')
# plt.savefig('support-embedding-space-pca-hypersphere.pdf', bbox_inches='tight')

### 2D t-SNE plot of the visualization data

In [None]:
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, perplexity=50, random_state=3676767249)
emb_vis_tsne_2d = tsne.fit_transform(emb_vis)
plt.axis('equal')
plt.scatter(emb_vis_tsne_2d[:, 0], emb_vis_tsne_2d[:, 1], alpha=.6, c=median_surv_time_estimates, cmap='flare')
plt.colorbar()
plt.title('2D t-SNE plot of a DeepSurv embedding space (SUPPORT dataset)')
# plt.savefig('support-embedding-space-tsne-hypersphere.pdf', bbox_inches='tight')

### For different choices for the number of components, fit a mixture of von Mises-Fisher distributions

We use the Expectation-Maximization algorithm implementation by Minyoung Kim: https://github.com/minyoungkim21/vmf-lib

In [None]:
from lifelines.statistics import pairwise_logrank_test
import random
import torch

import models
import utils

seed = 1861600023

n_init = 100
n_clusters_to_try = list(range(2, 11))
logrank_p_values = []
cluster_assignments = []
cluster_models = []
samples = torch.tensor(emb_direction, dtype=torch.float32)
opts = {}
opts['max_iters'] = 100  # maximum number of EM iterations
opts['rll_tol'] = 1e-5  # tolerance of relative loglik improvement
for n_clusters in n_clusters_to_try:
    print('Trying %d clusters...' % n_clusters)

    best_ll = -np.inf
    best_cluster_assignment = None
    for repeat_idx in range(n_init):
        random.seed(seed + repeat_idx)
        np.random.seed(seed + repeat_idx)
        torch.manual_seed(seed + repeat_idx)
        torch.cuda.manual_seed(seed + repeat_idx)
        torch.cuda.manual_seed_all(seed + repeat_idx)

        # create a model
        mix = models.MixvMF(x_dim=emb_direction.shape[1], order=n_clusters)

        # EM learning
        ll_old = -np.inf
        with torch.no_grad():

            for steps in range(opts['max_iters']):

                # E-step
                logalpha, mus, kappas = mix.get_params()
                logliks, logpcs = mix(samples)
                ll = logliks.sum()
                jll = logalpha.unsqueeze(0) + logpcs
                qz = jll.log_softmax(1).exp()

                # tolerance check
                if steps > 0:
                    rll = (ll-ll_old).abs() / (ll_old.abs()+utils.realmin)
                    if rll < opts['rll_tol']:
                        break

                ll_old = ll

                # M-step
                qzx = ( qz.unsqueeze(2) * samples.unsqueeze(1) ).sum(0)
                qzx_norms = utils.norm(qzx, dim=1)
                mus_new = qzx / qzx_norms
                Rs = qzx_norms[:,0] / (qz.sum(0) + utils.realmin)
                kappas_new = (mix.x_dim*Rs - Rs**3) / (1 - Rs**2 + 1e-6)
                alpha_new = qz.sum(0) / samples.shape[0]

                # assign new params
                mix.set_params(alpha_new, mus_new, kappas_new)

            logliks, logpcs = mix(samples)
            ll = logliks.sum()
            if ll > best_ll:
                best_ll = ll
                best_cluster_assignment = np.argmax(logpcs.numpy(), axis=1)

    result = pairwise_logrank_test(label_direction[:, 0],
                                   best_cluster_assignment,
                                   label_direction[:, 1])
    logrank_p_values.append(result.p_value)
    cluster_assignments.append(best_cluster_assignment)

Below, we make a violin plot for helping us select the number of clusters to use.

In [None]:
plt.figure(figsize=(6, 2.25))
plt.violinplot(logrank_p_values, n_clusters_to_try)
plt.xlabel('Number of clusters $k$')
plt.ylabel('Log-rank test p-value')
plt.yticks(np.linspace(0, 1, 11))
# plt.savefig('support-logrank-pvalue-vs-nclusters-hypersphere.pdf', bbox_inches='tight')

### Using different clusters' anchor directions, make various visualizations

In [None]:
# code that discretizes the all raw features (discrete and continuous) for raw feature probability heatmap
# visualizations (this code is specific to the SUPPORT dataset but can be modified for other tabular data)

from sklearn.preprocessing import KBinsDiscretizer, OneHotEncoder

def transform(raw_features, continuous_n_bins=5):
    feature_names = ['age', 'female', 'race', 'number of comorbidities',
                     'diabetes', 'dementia', 'cancer', 'mean arterial blood pressure',
                     'heart rate', 'respiration rate', 'temperature', 'white blood count',
                     'serum sodium', 'serum creatinine']
    binary_indices = [1, 4, 5]
    continuous_indices = [0, 3, 7, 8, 9, 10, 11, 12, 13]
    discretized_features = []
    discretized_feature_names = []
    all_n_bins_to_use = []
    for idx in continuous_indices:
        n_bins_to_use = continuous_n_bins
        discretizer = KBinsDiscretizer(n_bins=n_bins_to_use,
                                       strategy='quantile',
                                       encode='onehot-dense')
        new_features = discretizer.fit_transform(raw_features[:, idx].reshape(-1, 1).astype(float))
        if discretizer.n_bins_[0] != n_bins_to_use:
            n_bins_to_use = discretizer.n_bins_[0]

        if n_bins_to_use > 1:
            discretized_features.append(new_features)
            for bin_idx in range(n_bins_to_use):
                if bin_idx == 0:
                    discretized_feature_names.append(feature_names[idx] + ' bin#1(-inf,%.2f)' % discretizer.bin_edges_[0][bin_idx+1])
                elif bin_idx == n_bins_to_use - 1:
                    discretized_feature_names.append(feature_names[idx] + ' bin#%d[%.2f,inf)' % (n_bins_to_use, discretizer.bin_edges_[0][bin_idx]))
                else:
                    # print(discretizer.bin_edges_[0][bin_idx:bin_idx+2])
                    discretized_feature_names.append(feature_names[idx] + ' bin#%d[%.2f,%.2f)' % tuple([bin_idx + 1] + list(discretizer.bin_edges_[0][bin_idx:bin_idx+2])))
        all_n_bins_to_use.append(n_bins_to_use)
    for idx in binary_indices:
        discretized_features.append(raw_features[:, idx].reshape(-1, 1).astype(float))
        discretized_feature_names.append(feature_names[idx])
        all_n_bins_to_use.append(1)

    # race
    discretizer = OneHotEncoder(sparse=False, categories=[[0, 1, 2, 3, 4, 5]])
    discretized_features.append(discretizer.fit_transform(raw_features[:, 2].reshape(-1, 1).astype(float)))
    discretized_feature_names.extend(['race cat#1(unspecified)',
                                      'race cat#2(asian)',
                                      'race cat#3(black)',
                                      'race cat#4(hispanic)',
                                      'race cat#5(other)',
                                      'race cat#6(white)'])
    all_n_bins_to_use.append(6)

    # cancer
    discretizer = OneHotEncoder(sparse=False, categories=[[0, 1, 2]])
    discretized_features.append(discretizer.fit_transform(raw_features[:, 6].reshape(-1, 1).astype(float)))
    discretized_feature_names.extend(['cancer cat#1(no)',
                                      'cancer cat#2(yes)',
                                      'cancer cat#3(metastatic)'])
    all_n_bins_to_use.append(3)
    return np.hstack(discretized_features), discretized_feature_names, all_n_bins_to_use

In [None]:
raw_vis_discretized, discretized_feature_names, all_n_bins_to_use = transform(raw_vis)

In [None]:
alpha = .1

for final_n_clusters in [3, 4, 5, 6, 7]:
    print('[Number of clusters: %d]' % final_n_clusters)
    final_cluster_assignment = cluster_assignments[n_clusters_to_try.index(final_n_clusters)]
    center_of_mass = emb_direction.mean(axis=0)
    anchor_directions = np.array([emb_direction[final_cluster_assignment == cluster_idx].mean(axis=0) - center_of_mass
                                  for cluster_idx in range(final_n_clusters)])

    all_raw_feature_probability_heatmaps = []
    all_bin_edges = []

    n_bins = 7
    emb_vis_normalized = l2_normalize_rows(emb_vis - center_of_mass[np.newaxis, :])
    anchor_directions_normalized = l2_normalize_rows(anchor_directions)
    for cluster_idx in range(final_n_clusters):
        projections = np.dot(emb_vis_normalized, anchor_directions_normalized[cluster_idx])

        bin_counts, bin_edges = np.histogram(projections, bins=n_bins)
        bin_edges_copy_with_inf_right_edge = bin_edges.copy()
        bin_edges_copy_with_inf_right_edge[-1] = np.inf

        bin_assignments = np.digitize(projections, bin_edges_copy_with_inf_right_edge) - 1

        heatmap = np.zeros((len(discretized_feature_names), n_bins))
        for discretized_feature_idx in range(len(discretized_feature_names)):
            for projection_bin_idx in range(n_bins):
                heatmap[discretized_feature_idx, projection_bin_idx] = raw_vis_discretized[bin_assignments == projection_bin_idx][:, discretized_feature_idx].mean()

        all_raw_feature_probability_heatmaps.append(heatmap)
        all_bin_edges.append(bin_edges)

        # compute ranking table
        projection_bin_counts = np.array([(bin_assignments == bin_idx).sum() for bin_idx in range(n_bins)])
        print('[Cluster %d]' % (cluster_idx + 1))
        current_row = 0
        variable_pval_pairs = []
        for variable_idx in range(len(all_n_bins_to_use)):
            n_bins_to_use = all_n_bins_to_use[variable_idx]
            if n_bins_to_use >= 2:
                prefix = \
                    longest_common_prefix([discretized_feature_names[idx]
                                           for idx in range(current_row,
                                                            current_row + n_bins_to_use)])
                if prefix.endswith('(') or prefix.endswith('['):
                    prefix = prefix[:-1]
                if prefix.endswith(' cat#') or prefix.endswith(' bin#'):
                    prefix = prefix[:-5]
                prefix = prefix.strip()
                res = scipy.stats.chi2_contingency(
                    heatmap[current_row:(current_row + n_bins_to_use), :]
                    * projection_bin_counts[np.newaxis, :])
                variable_pval_pairs.append((prefix, res[1]))
            else:
                indicator_row = heatmap[current_row:(current_row + n_bins_to_use), :]
                res = scipy.stats.chi2_contingency(
                    np.array([[indicator_row, 1. - indicator_row]])
                    * projection_bin_counts[np.newaxis, :])
                # print(discretized_feature_names[current_row], res[1])
                variable_pval_pairs.append((discretized_feature_names[current_row], res[1]))
            current_row += n_bins_to_use
        for idx, (variable_name, pval) in enumerate(sorted(variable_pval_pairs, key=lambda x: x[1])):
            print(idx + 1, '-', variable_name, '-', pval)
            # print('%d &' % (idx + 1), variable_name, (('& $%.2E' % pval).replace('E-', '\\times 10^{-') + '}$ \\\\').replace('^{-0', '^{-'))
        print()

    # plot all raw feature probability heatmaps together

    n_bins = 7
    # fig, axn = plt.subplots(1, final_n_clusters, sharex=False, sharey=True, figsize=(12.5, 15))
    fig, axn = plt.subplots(1, final_n_clusters, sharex=False, sharey=True, figsize=(15, 15))
    cbar_ax = fig.add_axes([.91, .3, .03, .4])

    for cluster_idx in range(final_n_clusters):
        ax = axn.flat[cluster_idx]
        current_row = 0
        for idx, count in enumerate(all_n_bins_to_use[:-1]):
            current_row += count
            ax.plot([0, n_bins], [current_row, current_row], 'black')
        sns.heatmap(pd.DataFrame(all_raw_feature_probability_heatmaps[cluster_idx],
                                 index=discretized_feature_names,
                                 columns=['%.2f' % x for x in (all_bin_edges[cluster_idx][:-1]
                                                               + all_bin_edges[cluster_idx][1:])/2]),
                    ax=ax,
                    cmap=sns.light_palette("#4a72ae", reverse=False, as_cmap=True),
                    vmin=0, vmax=1,
                    cbar=(cluster_idx == 0),
                    cbar_ax=None if (cluster_idx != 0) else cbar_ax)
        ax.set_xlabel('Projection onto\nanchor direction\nfor cluster %d' % (cluster_idx + 1))
    fig.tight_layout(rect=[0, 0, .9, 1], pad=1.)
    # plt.savefig('support-raw-feature-prob-heatmaps-hypersphere.pdf', bbox_inches='tight')

    # compute survival probability heatmaps

    from scipy import interpolate

    n_rows = 10
    discrete_time_grid = np.linspace(unique_train_times.min(), unique_train_times.max(), n_rows)

    all_survival_probability_heatmaps = []
    all_projection_bin_surv_curves = []
    for cluster_idx in range(final_n_clusters):
        projections = np.dot(emb_vis_normalized, anchor_directions_normalized[cluster_idx])

        bin_counts, bin_edges = np.histogram(projections, bins=n_bins)
        bin_edges_copy_with_inf_right_edge = bin_edges.copy()
        bin_edges_copy_with_inf_right_edge[-1] = np.inf

        bin_assignments = np.digitize(projections, bin_edges_copy_with_inf_right_edge) - 1

        heatmap = np.zeros((n_rows, n_bins))
        projection_bin_surv_curves = []
        for projection_bin_idx in range(n_bins):
            projection_bin_surv = interpolate.interp1d(
                unique_train_times,
                predicted_surv_vis[bin_assignments == projection_bin_idx].mean(axis=0))
            projection_bin_surv_curves.append(
                predicted_surv_vis[bin_assignments == projection_bin_idx].mean(axis=0))
            heatmap[:, projection_bin_idx] = projection_bin_surv(discrete_time_grid)[::-1]

        all_survival_probability_heatmaps.append(heatmap)
        all_projection_bin_surv_curves.append(projection_bin_surv_curves)

    # plot all survival probability heatmaps

    # fig, axn = plt.subplots(1, final_n_clusters, sharex=False, sharey=True, figsize=(12.5, 3.5))
    fig, axn = plt.subplots(1, final_n_clusters, sharex=False, sharey=True, figsize=(15, 3.5))
    cbar_ax = fig.add_axes([.91, .3, .03, .4])

    for cluster_idx in range(final_n_clusters):
        ax = axn.flat[cluster_idx]
        sns.heatmap(
            pd.DataFrame(all_survival_probability_heatmaps[cluster_idx],
                         index=['%.1f' % x for x in discrete_time_grid[::-1]],
                         columns=['%.2f' % x for x in (all_bin_edges[cluster_idx][:-1]
                                                       + all_bin_edges[cluster_idx][1:])/2]),
            cmap=sns.light_palette("#4a72ae", reverse=False, as_cmap=True),
            vmin=0, vmax=1, ax=ax,
            cbar=(cluster_idx == 0),
            cbar_ax=None if (cluster_idx != 0) else cbar_ax)
        ax.set_xlabel('Projection onto\nanchor direction\nfor cluster %d' % (cluster_idx + 1))
        if cluster_idx == 0:
            ax.set_ylabel('Survival time (days)')
    fig.tight_layout(rect=[0, 0, .9, 1], pad=1.5)
    # plt.savefig('support-surv-prob-heatmaps-hypersphere.pdf', bbox_inches='tight')

    # ranking anchor directions based on median survival time

    # estimate a survival curve for the top alpha fraction of visualization data points per cluster/anchor direction
    anchor_direction_median_survival_time_pairs = []
    for cluster_idx in range(final_n_clusters):
        projections = np.dot(emb_vis_normalized, anchor_directions_normalized[cluster_idx])
        q_alpha = np.sort(projections)[int(np.ceil((1-alpha)*len(projections)))]

        surv_curv_alpha = predicted_surv_vis[projections >= q_alpha].mean(axis=0)
        median_surv_time_estimate = compute_median_survival_times(surv_curv_alpha, unique_train_times)
        anchor_direction_median_survival_time_pairs.append((cluster_idx, median_surv_time_estimate))

    # sort anchor directions by median survival time estimates
    sorted_anchor_direction_median_survival_time_pairs = \
        sorted(anchor_direction_median_survival_time_pairs, key=lambda x: x[1])
    for cluster_idx, median_survival_time in sorted_anchor_direction_median_survival_time_pairs:
        print('Cluster', cluster_idx + 1, ': median survival time estimate', median_survival_time)
    print()
    print()