<p align=center>
<img src="assets/cphbanner.png" width=1280>
</p>

In [None]:
%reload_ext autoreload
%autoreload 2

import sys
import os

path_to_project = os.path.abspath(os.path.join(os.getcwd(), '../'))    
sys.path.insert(1, os.path.join(path_to_project))

# **Project 1: Survival Analysis and Prediction [30 points]**

Many clinical trials and observational studies involve following patients for a long time. The primary event of
interest in those studies may include death, relapse, or the onset of a new disease. The follow-up time for a trial
or a study may range from few weeks to many years. To analyze this data, we typically conduct time-to-event
analysis and build predictive models that learn time-to-event distributions. The goal of this project is to test
your ability to conduct basic survival analyses as well as develop ML models for survival prediction.

**Please submit your report and code by <u> Tuesday 2/4 11:59 PST </u>.**

## Task 1.1: Nonparametric Survival Analysis in Heart Failure [7 pts]

Nonparametric models of survival data do not make parametric assumptions on the distribution of timeto-event outcomes. They are widely used in clinical studies to derive descriptive statistics of survival in a population. In this task, we will apply standard nonparametric estimators to analyze survival of heart failure patients in a recent, widely-recognized study [1].

####  Setup and Dataset

The dataset we will use in this task was extracted from the electronic health records (EHRs) of 299 heart failure patients from the Faisalabad Institute of Cardiology and at the Allied Hospital in Faisalabad (Punjab, Pakistan), during April–December 2015. The cohort included 105 women and 194 men, and their ages range between 40 and 95 years old. All 299 patients had left ventricular systolic dysfunction and had previous heart failures (HF) that put them in classes III or IV of New York Heart Association (NYHA) classification of the stages of heart failure. The dataset contains 13 features, which report clinical, body, and lifestyle information. The patients were followed up for 130 days on average (maximum follow-up period was 285 days). The event of interest was death during the follow-up period.

The dataset is publicly accessible and was shared with the class through UCSF Box. You can load the dataset in the directory "./data" and inspect all the features/outcomes using pandas as follows:

In [None]:
import os
import pandas as pd
from src.directory import csv_paths

In [None]:
# read data
dataset = pd.read_csv(csv_paths['faisalabad'])

In [None]:
dataset

In [None]:
# descriptive statistics
dataset.describe()

In [None]:
# check for NaN
print(f'NA Count by Variable : {dataset.isna().sum(axis=0)[dataset.isna().sum(axis=0) > 0]}')

## Solution

### Task 1.1.1

In [None]:
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from lifelines import KaplanMeierFitter
from src.estimators import kaplan_meier
from src.directory import csv_paths
import numpy as np

In [None]:
time_col = 'time'
event_col = 'DEATH_EVENT'
dataset = pd.read_csv(csv_paths['faisalabad'])

# set alpha for confidence intervals
alpha = 0.05

In [None]:
# get KM estimate from scratch
km_df = kaplan_meier(dataset, time_col, event_col, alpha=alpha)

In [None]:
# get KM estimate from lifelines
kmf = KaplanMeierFitter(alpha=alpha)
kmf.fit(dataset['time'], event_observed=dataset['DEATH_EVENT'])

In [None]:
# plot KM estimates from scratch and with lifelines
fig, axs = plt.subplots()

# plot from scratch KM
sns.lineplot(data=km_df, x='time', y='survival_prob', 
             drawstyle='steps-pre', ax=axs, label='From scratch', legend=True)
plt.fill_between(km_df['time'], km_df['ci_lower'], km_df['ci_upper'], alpha=0.3)

# plot lifelines KM
kmf.plot_survival_function(label='Lifelines')

plt.ylim(.45, 1.05)
plt.title(fr'Kaplan-Meier Estimates, CI $\alpha$ ={alpha}')
plt.show()

# check survival function equality
assert all(np.isclose(km_df['survival_prob'].values.ravel(),
                      kmf.survival_function_.reindex(range(km_df['time'].max() +1)).ffill().values.ravel())), 'Survival functions don\'t match!'

### Task 1.1.2

In [None]:
from scipy.optimize import curve_fit

In [None]:
eps = 1e-8

def exp_model(x, a, b, c):
    return a * np.exp(-b * x) + c

popt, pcov = curve_fit(f=exp_model, xdata=km_df.index, ydata=km_df['survival_prob'])
S_t = exp_model(km_df.index, *popt)

In [None]:
# plot KM estimates from scratch and with lifelines
fig2, axs2 = plt.subplots()

# plot from scratch KM
sns.lineplot(data=km_df, x='time', y='survival_prob', 
             drawstyle='steps-pre', ax=axs2, label='Empirical curve', legend=True)
plt.fill_between(km_df['time'], km_df['ci_lower'], km_df['ci_upper'], alpha=0.3)

axs2.plot(S_t, label='Exponential Curve')
plt.legend()
plt.show()

# Limitation: no guaranteed y==0 at x==0

### Task 1.1.3

In [None]:
from src.data_dict import feature_config
from src.estimators import nearest_neighbor_km
from src.metrics import evaluate_c_index

dataset_name = 'faisalabad'
patient_features = feature_config[dataset_name]
time_col = 'time'
event_col = 'DEATH_EVENT'
max_time = dataset[time_col].max()

patient_km_fits = nearest_neighbor_km(dataset, patient_features, time_col, event_col, n_neighbors=20)

c_index = evaluate_c_index(dataset, patient_km_fits, time_col, event_col)
print(f"C-index: {c_index}")

#Example of getting survival probabilites for a specific patient
fig, axs = plt.subplots()
for patient_index in range(5):
    time_points = np.arange(max_time)
    patient_survival_prob = patient_km_fits[patient_index].survival_function_at_times(time_points)
    patient_survival_prob.index.name = time_col
    
    sns.lineplot(data=patient_survival_prob.reset_index(), x='time', y='KM_estimate', 
                 drawstyle='steps-pre', ax=axs, label=f'Patient {patient_index}', legend=True)

## Task 1.2: Survival Prediction in HF patients using the Cox Model [7 pts]

### Task 1.2.1

In [None]:
import numpy as np
import pandas as pd
from src.data_dict import feature_config
from src.lightning import CoxRiskLightning, get_trainer, get_checkpoint_callback, get_log_dir_path
from src.dataset import SurvivalDataModule

In [None]:
time_col = 'time'
event_col = 'DEATH_EVENT'
dataset_name = 'faisalabad'

# set up datamodule
datamodule = SurvivalDataModule(
    dataset_name=dataset_name,
    input_features=feature_config[dataset_name],
    time_col=time_col,
    event_col=event_col
)

# set up model 
model = CoxRiskLightning(
    dataset_name=dataset_name,
    clinical_features=feature_config[dataset_name],
    time_col=time_col,
    event_col=event_col
)

# set log dir
model_name = 'cox'
log_dir_path = get_log_dir_path(model_name)

# get checkpoint callback
checkpoint_callback = get_checkpoint_callback(model_name, log_dir_path)

# get trainer
trainer = get_trainer(model_name, checkpoint_callback)

print("Training model")
trainer.fit(model, datamodule)

In [None]:
model_coefficients = model.model.risk.weight.squeeze().detach().numpy()

feature_of_interest = 'age'
idx = model.feature_names.index(feature_of_interest)
effect_of_increment = np.exp(model_coefficients[idx])
print(f'\nHazard Ratio of 1-year increment of {feature_of_interest} on risk:', effect_of_increment)

In [None]:
coefficient_df = pd.DataFrame(model_coefficients, index=feature_config[dataset_name], columns=['Model coefficients'])
latex_tab = coefficient_df.to_latex(index=True, 
                            float_format="%.3f",
                            label=f'tab:cox_coefficients',
                            caption=f'Cox model coefficients',
                            sparsify=True)
latex_tab = latex_tab.replace('_', ' ')
print(latex_tab)

### Task 1.2.2

In [None]:
# run train set through model
predict_trainer = get_trainer(model_name, checkpoint_callback)
_ = predict_trainer.predict(model, datamodule)
metrics = model.metric_dict
print('C-index on train set:', metrics['predict']['predict_cindex'])
print('AUC-ROC on train set:', metrics['predict']['predict_auc'])

### Task 1.2.3

In [None]:
from src.dataset import merge_batches
from src.data_dict import feature_config
from src.lightning import CoxRiskLightning, get_trainer, get_checkpoint_callback, get_log_dir_path
from src.dataset import SurvivalDataModule

In [None]:
time_col = 'time'
event_col = 'DEATH_EVENT'
dataset_name = 'faisalabad'

# set up datamodule
datamodule = SurvivalDataModule(
    dataset_name=dataset_name,
    input_features=feature_config[dataset_name],
    time_col=time_col,
    event_col=event_col
)

# set up model 
model = CoxRiskLightning(
    dataset_name=dataset_name,
    clinical_features=feature_config[dataset_name],
    interaction_features=[('age', 'sex')],
    time_col=time_col,
    event_col=event_col
)

# get log dir
model_name = 'cox_age_sex'
log_dir_path = get_log_dir_path(model_name)

# get checkpoint callback
checkpoint_callback = get_checkpoint_callback(model_name, log_dir_path)

# get trainer
trainer = get_trainer(model_name, checkpoint_callback)

print("Training model")
trainer.fit(model, datamodule)

In [None]:
X = merge_batches(datamodule.train_dataloader())
X, _, _ = model.get_xtc(X)

In [None]:
# get model coefficient pvalues
cox_pvals = model.get_coefficient_pvals(X)
pvals_df = pd.DataFrame.from_dict(cox_pvals, columns=['Coefficient p-values'], orient='index')

latex_tab = pvals_df.to_latex(index=True, 
                            float_format="%.3f",
                            label=f'tab:cox_pvals',
                            caption=f'Cox model coefficient p-values',
                            sparsify=True)
latex_tab = latex_tab.replace('pval_', '').replace('_', ' ')
print(latex_tab)

## Task 1.3: Deep Survival Prediction for Heart Transplantation [8 pts]

####  Setup and Dataset

For this task, we will use data collected by the United Network for Organ Sharing (UNOS) [2], a non-profit organization that administers the only Organ Procurement and Transplantation Network (OPTN) in the US. UNOS is involved in many aspects of the organ transplant and donation process in the US, including data collection and maintenance, providing assitance to patients and care takers, and informing policy makers on the best use of the limited supply of organs and give all patients a fair chance at receiving the organ they need. UNOS manages the heart transplant waiting list, i.e., the list of terminally-ill patients waiting for donor heart. In order to determine the order of priority for receipt of a donor heart, individuals are classified by degrees of severity for a donor heart, blood type, body weight, and geographic location.

This Task will focus on the cohort of terminally-ill patients who are enrolled in the wait-list for heart transplantation. In this setup, our goal is to predict the patients who are less likely to survive in order to prioritize them for receiving donated organs. The UNOS data covers 30 years of heart transplantation data in the US, spanning the years from 1985 to 2015. We will use data for patients who were on the wait-list for heart transplantation in the US from 1985 to 2010 (27,926 patients) to train an ML-based model for predicting individual-level survival. A held-out test set of 8,403 patients enrolled in the wait-list between 2010 and 2015 will be used by the instructor to evaluate your model. You can load the UNOS data in pandas as follows.

In [None]:
import pandas as pd
from src.directory import csv_paths

In [None]:
UNOS_data = pd.read_csv(csv_paths['unos'])

In [None]:
UNOS_data

In [None]:
UNOS_data.describe()

In [None]:
print(f'NA Count by Variable : {UNOS_data.isna().sum(axis=0)[UNOS_data.isna().sum(axis=0) > 0]}')

#### Feature Dictionary

Each patient's record in the UNOS database is associated with the following variables:

In [None]:
patient_variables   = ["init_age", "gender", "hgt_cm_tcr", "wgt_kg_tcr", "diab", "ventilator_tcr",
                       "ecmo_tcr", "most_rcnt_creat", "abo_A", "abo_B", "abo_O", "vad_while_listed",
                       "days_stat1", "days_stat1a", "days_stat2", "days_stat1b", "iabp_tcr",
                       "init_bmi_calc", "tah", "inotropic", "Censor (Censor = 1)", "Survival Time"]

The interpretation of each variable is provided below:

- "init_age": Patient's age at time of enrolling in the wait-list
- "gender": Patient's biological sex
- "hgt_cm_tcr": Patient's height in cm
- "wgt_kg_tcr": Patient's weight in kgs
- "diab": Indication on whether or not the patient is diabetic
- "abo_A": Indication on whether patient's blood type is A
- "abo_B": Indication on whether patient's blood type is B
- "abo_O": Indication on whether patient's blood type is O
- "ventilator_tcr": Indication on whether the patient was dependent on a ventilator at time of enrollment in the wait-list
- "ecmo_tcr": Indication on whether the patient was treated with ECMO (extracorporeal membrane oxygenation) by the time they where enrolled in the wait-list. ECMO is an artificial life support that continuously pumps blood out of the patient's body and sends it through a series of devices that add oxygen and remove carbon dioxide, pumping the blood back to the patient. It is used for a patient whose heart and lungs are not functioning properly.  
- "most_rcnt_creat": Creatinine level in the patient's most recent blood test before enrolling in wait-list.
- "vad_while_listed": Whether the patient was on ventricular assist device (VAD) support when listed for a heart transplant. VAD is a mechanical pump used to restore cardiac function by pumping blood from the lower chambers of the heart to the rest of the body.
- "iabp_tcr": Whether the patient was on Intra-Aortic Balloon Pump (IABP) Therapy. This is a therapeutic device used to improve blood flow when the heart is unable to pump enough blood for your body.
- "init_bmi_calc": Patient's Body Mass Index at time of enrollment in the wait-list.
- "tah": Whether the patient underwent a total artificial heart (TAH) surgery. This is a mechanical pump that replaces the heart when it is not working as it should.
- "inotropic": Whether the patient was on an Inotropic drug at time of enrollment in wait-list. These are medicines that change the force of the heart's contractions.
- "days_stat1", "days_stat1a", "days_stat1b", "days_stat2": UNOS has an internal system for classifying the priority of patients for receiving a heart transplant. Individuals classified as Status 1A have the highest priority on the heart transplant waiting list. Status 1A are individuals who must stay in the hospital as in-patients and require high doses of intravenous drugs, require a VAD for survival, are dependent on a ventilator or have a life expectancy of a week or less without a transplant. Individuals classified as Status 1B are generally not required to stay in the hospital as in-patients. All other candidates for the transplant are listed under Status 2. These variables indicate the number of days a patient spends in each status during the time between their enrollment in the wait-list and death or reception of a transplant.
- "Censor (Censor = 1)": Indication of censoring
- "Survival Time": Time between enrollment in wait-list and death

### Task 1.3.2

In [None]:
import numpy as np
import pandas as pd
from src.data_dict import feature_config
from src.directory import csv_paths, deep_survival_model_path
from src._torch import DeepSurvival
from lifelines import KaplanMeierFitter
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch

In [None]:
event_col = "Censor (Censor = 1)"
time_col = "Survival Time"

UNOS_data = pd.read_csv(csv_paths['unos'])
X = UNOS_data[feature_config['unos']].to_numpy()
T = UNOS_data[time_col].to_numpy()
C = UNOS_data[event_col].to_numpy()
n_features = X.shape[1]

X_train, X_test, T_train, T_test, C_train, C_test = train_test_split(X, T, C, test_size=.2, random_state=40)

In [None]:
torch.manual_seed(40)
model = DeepSurvival(save_path=deep_survival_model_path)
model.fit(X_train, T_train, C_train, lr=3e-2)

In [None]:
# check test c_index
cindex = model.get_cindex(X_test, T_test, C_test)
print('C-index on test set:', cindex.item())

In [None]:
predictions = model.predict(X, max_years=10)

In [None]:
# plot mean survival curve
plt.plot(predictions.mean(axis=0), label='DeepSurvival')

# plot KM survival curve
kmf = KaplanMeierFitter()
kmf.fit(UNOS_data[time_col]/365, event_observed=UNOS_data[event_col])
kmf.plot_survival_function(label='Lifelines KM')

# adjust plot
plt.xlim((-0.5, 10))
plt.ylim((0,1.05))
plt.xlabel("Years")
plt.ylabel("Survival Probability")
plt.legend()
plt.show()

## Task 1.4: Handling Informative Censoring via Domain Adaptation [8 pts]

### Task 1.4.1

In [None]:
from pycox import datasets
import seaborn as sns
import matplotlib.pyplot as plt
from src.informative_censoring import check_informative_censoring, generate_semi_synthetic_dataset

In [None]:
# Note: "['Unnamed: 0'] not found in axis" error received reading in flchain & nwtco:
# To fix error: replace the following lines pycox/datasets/from_rdatasets.py:
# line 78: .drop(['chapter', 'Unnamed: 0', 'rownames'], axis=1, errors='ignore')
# line 146: .drop(['Unnamed: 0', 'seqno', 'instit', 'histol', 'study', 'rownames'], axis=1, errors='ignore'))

dataset_dict = dict(zip(['flchain', 'gbsg', 'metabric', 'nwtco'], 
                         [datasets.flchain.read_df(), 
                          datasets.gbsg.read_df(), 
                          datasets.metabric.read_df(), 
                          datasets.nwtco.read_df()]))

for name, dataset in dataset_dict.items():
    dataset.name = name
    

time_col_dict = dict(zip(dataset_dict.keys(),
                         ['futime',
                          'duration',
                          'duration',
                          'edrel']))

event_col_dict = dict(zip(dataset_dict.keys(),
                         ['death',
                          'event',
                          'event',
                          'rel']))

In [None]:
# check for informative censoring
results = {}
for name, dataset in dataset_dict.items():
    time_col, event_col = time_col_dict[name], event_col_dict[name]
    results[name] = check_informative_censoring(dataset, time_col, event_col)

# print results
for name, r in results.items():
    print(name, '\n', r.loc[['informative_censoring']].T, '\n')

In [None]:
# format df of p-values
informative_censoring_df = []
for name, r in results.items():
    temp_df = pd.DataFrame.from_dict(r).loc[['p_value']]
    temp_df = temp_df.reset_index().drop(columns=['spearmanr_results', 'feature']).T
    temp_df['dataset'] = name
    # temp_df = temp_df.rename
    temp_df = temp_df.reset_index().set_index(['dataset', 'index'])
    informative_censoring_df.append(temp_df)
informative_censoring_df = pd.concat(informative_censoring_df)
informative_censoring_df.index = informative_censoring_df.index.rename({'index': 'feature'})
informative_censoring_df.columns = ['p-value']

# get latex table
latex_tab = informative_censoring_df.to_latex(index=True, 
                                            float_format="%.3f",
                                            label=f'tab:spearman_pvals',
                                            caption=f'Informative censoring p-values',
                                            sparsify=True)
latex_tab = latex_tab.replace('_', ' ')
print(latex_tab)

In [None]:
# visually inspect time-dependent censoring
fig, axs = plt.subplots(2,2)
axs = axs.ravel()
for i, (name, dataset) in enumerate(dataset_dict.items()):
    time_col, event_col = time_col_dict[name], event_col_dict[name]
    sns.violinplot(dataset, x=time_col, hue=event_col, ax=axs[i], split=True, cut=0)
    axs[i].set_title(name)

plt.tight_layout()

In [None]:
# generate semi synthetic data
synthetic_data = {}
for name, dataset in dataset_dict.items():
    time_col, event_col = time_col_dict[name], event_col_dict[name]
    synthetic_data[name] = generate_semi_synthetic_dataset(dataset, time_col, event_col, max_loops=10)
    synthetic_data[name].name = name

In [None]:
# visually inspect time-dependent censoring
fig, axs = plt.subplots(2,2)
axs = axs.ravel()
for i, (name, dataset) in enumerate(synthetic_data.items()):
    time_col, event_col = time_col_dict[name], event_col_dict[name]
    sns.violinplot(dataset, x=time_col, hue=event_col, ax=axs[i], split=True, cut=0)
    axs[i].set_title(f'{name}_synth')

plt.tight_layout()

In [None]:
# check for informative censoring
results = {}
for name, dataset in synthetic_data.items():
    time_col, event_col = time_col_dict[name], event_col_dict[name]
    results[name] = check_informative_censoring(dataset, time_col, event_col)

# print results
for name, r in results.items():
    time_col = time_col_dict[name]
    print(f'{name}_synth', '\n', r.loc[['informative_censoring']].T, '\n')

### Task 1.4.2

In [None]:
import torch
from src._torch import DeepSurvival
from src.directory import deep_survival_model_path
from sklearn.model_selection import train_test_split

In [None]:
multindex = pd.MultiIndex.from_product(iterables=[synthetic_data.keys(), ['vanilla', 'importance sampling']], names=['dataset', 'method'])
c_index_df = pd.DataFrame(index=multindex, columns=['c-index'])

In [None]:
# modified modal with importance weighting ERM
for name, dataset in synthetic_data.items():

    # set up data
    time_col, event_col = time_col_dict[name], event_col_dict[name]
    feature_cols = list(set(dataset.columns) - set([time_col, event_col]))

    X = dataset[feature_cols].to_numpy()
    T = dataset[time_col].to_numpy()
    C = dataset[event_col].to_numpy()
    n_features = X.shape[1]

    # split data
    X_train, X_test, T_train, T_test, C_train, C_test = train_test_split(X, T, C, test_size=.2, random_state=40)

    # get save path
    save_path = os.path.join(os.path.dirname(deep_survival_model_path), f'best_model_DA_{name}.pth')

    # init and train model
    torch.manual_seed(40)
    model = DeepSurvival(dataset_name=name,
                         save_path=save_path,
                         clinical_features=feature_cols,
                         importance_weighting=True)
    model.fit(X_train, T_train, C_train)

    # check test c_index
    cindex = model.get_cindex(X_test, T_test, C_test).item()
    c_index_df.loc[(name, 'importance sampling')] = cindex
    print(f'C-index on {name}_synth test set:', cindex, '\n')

In [None]:
# original model
for name, dataset in synthetic_data.items():

    # set up data
    time_col, event_col = time_col_dict[name], event_col_dict[name]
    feature_cols = list(set(dataset.columns) - set([time_col, event_col]))

    X = dataset[feature_cols].to_numpy()
    T = dataset[time_col].to_numpy()
    C = dataset[event_col].to_numpy()
    n_features = X.shape[1]

    # split data
    X_train, X_test, T_train, T_test, C_train, C_test = train_test_split(X, T, C, test_size=.2, random_state=40)

    # get save path
    save_path = os.path.join(os.path.dirname(deep_survival_model_path), f'best_model_{name}.pth')

    # init and train model
    torch.manual_seed(40)
    model = DeepSurvival(dataset_name=name,
                         save_path=save_path,
                         clinical_features=feature_cols,
                         importance_weighting=False)
    model.fit(X_train, T_train, C_train)

    # check test c_index
    cindex = model.get_cindex(X_test, T_test, C_test).item()
    c_index_df.loc[(name, 'vanilla')] = cindex
    print(f'C-index on {name}_synth test set:', cindex, '\n')

In [None]:
# get latex table
latex_tab = c_index_df.to_latex(index=True, 
                            float_format="%.3f",
                            label=f'tab:informative censoring c-index',
                            caption=f'C-index on benchmark dataset test sets',
                            sparsify=True)
print(latex_tab)

## References

[1] Chicco, Davide, and Giuseppe Jurman. “Machine learning can predict survival of patients with heart failure from serum creatinine and ejection fraction alone.” BMC Medical Informatics and Decision Making, vol.
20, no. 1 (2020): 1-16.

[2] Weiss, Eric S., Lois U. Nwakanma, Stuart B. Russell, John V. Conte, and Ashish S. Shah. “Outcomes in
bicaval versus biatrial techniques in heart transplantation: an analysis of the UNOS database.” The Journal
of heart and lung transplantation, vol. 27, no. 2 (2008): 178-183.