# Imports

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
from functools import partial
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest
from survlime.survlime_explainer import SurvLimeExplainer
from survlime.load_datasets import RandomSurvivalData
from xgbse import XGBSEKaplanNeighbors
import pandas as pd
from pycox.models import DeepHitSingle, CoxPH
import torchtuples as tt

# Generate data

In [None]:
# Generate data
n_points = 500
true_coef = [1, 1]
r = 1
center = [0, 0]
prob_event = 0.9
lambda_weibull = 10**(-6)
v_weibull = 2
n_features = len(true_coef)

rsd = RandomSurvivalData(
    center=center,
    radius=r,
    coefficients=true_coef,
    prob_event=prob_event,
    lambda_weibull=lambda_weibull,
    v_weibull=v_weibull,
    time_cap=None,
    random_seed=90,
)

# Train
X, time_to_event, delta = rsd.random_survival_data(num_points=n_points)
z = [(d, t) for d, t in zip(delta, time_to_event)]
y = np.array(z, dtype=[("delta", np.bool_), ("time_to_event", np.float32)])
total_row_train = X.shape[0]
print('total_row_train:', total_row_train)
unique_times = np.sort(np.unique(time_to_event))


In [None]:
# Point to explain
x_new = np.array([0.1, 0.1])

# Cox

In [None]:
# Fit a Cox model
cox = CoxPHSurvivalAnalysis()
cox.fit(X, y)
print(cox.coef_)

In [None]:
# SurvLime for COX
explainer = SurvLimeExplainer(
    training_features=X,
    traininig_events=[tp[0] for tp in y],
    training_times=[tp[1] for tp in y],
    model_output_times=cox.event_times_,
    sample_around_instance=True,
    random_state=10,
)

b = explainer.explain_instance(
    data_row=x_new,
    predict_fn=partial(cox.predict_cumulative_hazard_function, return_array=True),
    num_samples=1000,
    verbose=False,
)

print('b:', b)

In [None]:
# SurvLime for COX
explainer = SurvLimeExplainer(
    training_features=X,
    traininig_events=[tp[0] for tp in y],
    training_times=[tp[1] for tp in y],
    model_output_times=cox.event_times_,
    sample_around_instance=True,
    random_state=10,
)

b = explainer.explain_instance(
    data_row=x_new,
    predict_fn=cox.predict_cumulative_hazard_function,
    num_samples=1000,
    verbose=False,
)

print('b:', b)

# Random Survival Forest

In [None]:
rsf = RandomSurvivalForest().fit(X, y)

In [None]:
# SurvLime for RSF
explainer = SurvLimeExplainer(
    training_features=X,
    traininig_events=[tp[0] for tp in y],
    training_times=[tp[1] for tp in y],
    model_output_times=rsf.event_times_,
    sample_around_instance=True,
    random_state=10,
)

b = explainer.explain_instance(
    data_row=x_new,
    predict_fn=partial(rsf.predict_cumulative_hazard_function, return_array=True),
    num_samples=1000,
    verbose=False,
)

print('b:', b)

In [None]:
explainer = SurvLimeExplainer(
    training_features=X,
    traininig_events=[tp[0] for tp in y],
    training_times=[tp[1] for tp in y],
    model_output_times=rsf.event_times_,
    sample_around_instance=True,
    random_state=10,
)

b = explainer.explain_instance(
    data_row=x_new,
    predict_fn=rsf.predict_cumulative_hazard_function,
    num_samples=1000,
    verbose=False,
)

print('b:', b)

# xgbse

In [None]:
X_df = pd.DataFrame(X, columns = ['A', 'B'])

In [None]:
xgbse = XGBSEKaplanNeighbors(n_neighbors=50)
xgbse.fit(X_df, y)

In [None]:
explainer = SurvLimeExplainer(
    training_features=X_df,
    traininig_events=[tp[0] for tp in y],
    training_times=[tp[1] for tp in y],
    model_output_times=xgbse.time_bins,
    sample_around_instance=True,
    random_state=10,
)

b = explainer.explain_instance(
    data_row=x_new,
    predict_fn=xgbse.predict,
    num_samples=1000,
    verbose=False
)

print('b:', b)

# DeepHit

In [None]:
X_transformed = X.astype('float32')

In [None]:
in_features = X.shape[1]
num_nodes = [32, 32]
batch_norm = True
dropout = 0.1
output_bias = False
batch_size = 256
epochs = 512
get_target = lambda df: (df['duration'].values, df['event'].values)
verbose = True
num_durations = 50
labtrans = DeepHitSingle.label_transform(num_durations)
y_transformed = labtrans.fit_transform(time_to_event[:, 0], delta)

In [None]:
net_deep_hit = tt.practical.MLPVanilla(
    in_features,
    num_nodes,
    labtrans.out_features,
    batch_norm,
    dropout,
    output_bias=output_bias
)
deep_hit = DeepHitSingle(net_deep_hit, tt.optim.Adam, alpha=0.2, sigma=0.1, duration_index=labtrans.cuts)
deep_hit.optimizer.set_lr(0.001)
# Train!
log = deep_hit.fit(
    input=X_transformed,
    target=y_transformed,
    batch_size=batch_size,
    epochs=epochs,
    verbose=False
)

In [None]:
explainer = SurvLimeExplainer(
    training_features=X_transformed,
    traininig_events=[tp[0] for tp in y],
    training_times=[tp[1] for tp in y],
    model_output_times=deep_hit.duration_index,
    sample_around_instance=True,
    random_state=10,
)

b = explainer.explain_instance(
    data_row=x_new,
    predict_fn=deep_hit.predict_surv,
    type_fn = "survival",
    num_samples=1000,
    verbose=False,
)

print('b:', b)

# DeepSurv

In [None]:
y_df = pd.DataFrame(data={'duration': time_to_event[:, 0], 'event': delta})
y_deepsurv = get_target(y_df)

In [None]:
net_deep_surv = tt.practical.MLPVanilla(in_features, num_nodes, 1, batch_norm,dropout, output_bias=output_bias)
deep_surv = CoxPH(net_deep_surv, tt.optim.Adam())
deep_surv.optimizer.set_lr(0.001)
log = deep_surv.fit(
    input=X_transformed,
    target=y_deepsurv,
    batch_size=batch_size,
    epochs=epochs,
    verbose=False
)

In [None]:
deep_surv.compute_baseline_hazards()

In [None]:
explainer = SurvLimeExplainer(
    training_features=X_transformed,
    traininig_events=[tp[0] for tp in y],
    training_times=[tp[1] for tp in y],
    model_output_times=unique_times,
    sample_around_instance=True,
    random_state=10,
)

b = explainer.explain_instance(
    data_row=x_new,
    predict_fn=deep_surv.predict_surv,
    num_samples=1000,
    verbose=False,
)

print('b:', b)