In [None]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest
from survLime.survlime_explainer import SurvLimeExplainer
from survLime.datasets.load_datasets import RandomSurvivalData
from shap import KernelExplainer
from functools import partial

In [None]:
# Generate data
rsd = RandomSurvivalData(
    center=[0, 0, 0],
    radius=1,
    coefficients=[1, 1, 1],
    prob_event=0.8,
    lambda_weibull=10**(-5),
    v_weibull=2,
    random_seed=99
 )
X, time_to_event, delta = rsd.random_survival_data(num_points=500)
z = [(d, int(t)) for d, t in zip(delta, time_to_event)]
y = np.array(z, dtype=[("delta", np.bool_), ("time_to_event", np.float32)])
print(X[:5, ])


In [None]:
# Fit a Cox model
model = CoxPHSurvivalAnalysis()
model.fit(X, y)
print(model.coef_)

In [None]:
# Fit a random forest
rf = RandomSurvivalForest()
rf.fit(X, y)

In [None]:
# Point to explain
x_new_list = [0, 0, 0]
x_new = np.array(x_new_list)
x_new_shap = np.array([x_new_list])


In [None]:
# SurvLime for COX
explainer = SurvLimeExplainer(
    training_data=X,
    target_data=y,
    model_output_times=model.event_times_,
    sample_around_instance=True,
)

b, result = explainer.explain_instance(
    data_row=x_new,
    predict_fn=partial(model.predict_cumulative_hazard_function, return_array=True),
    num_samples=100,
)
print(b)

In [None]:
# Function predict for Cox
pred = model.predict(x_new_shap)
print(np.dot(x_new_shap, model.coef_))

In [None]:
# SHAP for Cox
explainer = KernelExplainer(model.predict, X)
shap_values = explainer.explain(x_new_shap)
print('shap_values:', shap_values)
print('sum(shap):', np.sum(shap_values))
print('prediction:', model.predict(x_new_shap))
print('prediction - mean_predicted_value:', model.predict(x_new_shap) - np.mean(model.predict(X)))

In [None]:
# SurvLime for rsf
explainer_rf = SurvLimeExplainer(
    training_data=X,
    target_data=y,
    model_output_times=rf.event_times_,
    sample_around_instance=True,
)

b_rf, result_rf = explainer_rf.explain_instance(
    data_row=x_new,
    predict_fn=partial(rf.predict_cumulative_hazard_function, return_array=True),
    num_samples=100,
)
print(b_rf)

In [None]:
# Predict for RSF
rf.predict(x_new_shap)

In [None]:
# SHAP for RSF
explainer_rsf = KernelExplainer(rf.predict, X)
shap_values_rsf = explainer_rsf.explain(x_new_shap)
print('shap_values:', shap_values_rsf)
print('sum(shap):', np.sum(shap_values_rsf))
print('prediction:', rf.predict(x_new_shap))
print('prediction - mean_predicted_value:', rf.predict(x_new_shap) - np.mean(rf.predict(X)))