# Preliminary

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('dcs')
%aimport dcs 

import pandas as pd
import numpy as np

ImportError: cannot import name 'DcsModel' from partially initialized module 'dcs.models' (most likely due to a circular import) (/home/pfuhlert/src/discrete-calibrated-survival/dcs/models/__init__.py)

In [None]:
dataset_name = 'support'
test_size = .2
random_seed = 40

# Dataset

In [None]:
dataset = dcs.datasets.get_dataset(dataset_name)
pipeline = dcs.pipelines.get_pipeline(dataset_name)
display(dataset.head())

train_X, train_y, test_X, test_y = dcs.preprocessing.train_test_split_X_y(
    dataset,
    random_state=random_seed,
    test_size=test_size)

train_X_t = pipeline.fit_transform(train_X)
test_X_t = pipeline.transform(test_X)

display(train_X_t.head())

# Models

In [None]:
predictions = {}

## CoxPH

In [None]:
cox = dcs.models.CoxPH()
cox.fit(train_X_t, train_y)
predictions['CoxPH'] = cox.predict(test_X_t)

## DeepSurv

In [None]:
deepsurv = dcs.models.DeepSurv()
deepsurv.fit(train_X_t, train_y)
predictions['DeepSurv'] = deepsurv.predict(test_X_t)

## CoxTime

In [None]:
coxtime = dcs.models.CoxTime()
coxtime.fit(train_X_t, train_y)
predictions['CoxTime'] = coxtime.predict(test_X_t)

## DRSA

In [None]:
train_max_months = int(np.ceil(train_y['event_days'].max() * 12 / 365))
nns_epochs = 100

drsa = dcs.models.Drsa(
    epochs=nns_epochs, 
    use_early_stopping=True,
    early_stopping_patience=10,
    batch_size=50,
    output_grid_num_nodes=train_max_months,
    learning_rate = 1e-4,
    validation_size=.1
    )
drsa.fit(train_X_t, train_y)
predictions['DRSA'] = drsa.predict(test_X_t)

## Kamran

In [None]:
kamran = dcs.models.Kamran(
    epochs=nns_epochs, 
    use_early_stopping=True,
    early_stopping_patience=10,
    batch_size=50,
    learning_rate = 1e-4,
    sigma=0.7,
    lambda_=1,
    output_grid_num_nodes=train_max_months,
    validation_size=.1)

kamran.fit(train_X_t, train_y)
kamran.plot_history()

predictions['Kamran'] = kamran.predict(test_X_t)

## DCS

In [None]:
dcs_model = dcs.models.DcsModel(
    epochs=nns_epochs, 
    use_early_stopping=True,
    early_stopping_patience=10,
    batch_size=50,
    learning_rate = 1e-4,
    sigma=0.7,
    lambda_=1,
    output_grid_type='quantile',
    output_grid_num_nodes=train_max_months,
    validation_size=.1)

dcs_model.fit(train_X_t, train_y)
dcs_model.plot_history()

predictions['DCS-Model'] = dcs_model.predict(test_X_t)

# Evaluation

## Qualitative Survival Curves

In [None]:
import matplotlib.pyplot as plt

sample_idx = test_X.sample(5).index

nrows=int(np.ceil(len(predictions)/3))

fig, axs = plt.subplots(
    figsize=(16, 4*nrows), dpi=100, 
    ncols=3, nrows=nrows,
    sharex=True, sharey=True)
axs = axs.reshape(-1)

for i, (model_name, prediction) in enumerate(predictions.items()):
    ax = axs[i]
    prediction.loc[sample_idx].T.plot(ax=ax)
    ax.set_title(model_name)


## Quantitative

In [None]:
results_df = pd.DataFrame()
for model_name, prediction in predictions.items():

    results_model = pd.Series({
        "c-index-td": dcs.evaluation.concordance_index_td(test_y, prediction),
        "cdauc": dcs.evaluation.cdauc(test_y, prediction),
        "ddc": dcs.evaluation.ddc(test_y, prediction),
    }, name=model_name)
    
    results_df = results_df.append(results_model)


display(results_df.T)