In [None]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest
from survLime.survlime_explainer import SurvLimeExplainer
from survLime.datasets.load_datasets import RandomSurvivalData
from shap import KernelExplainer
from functools import partial
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt

In [None]:
# Generate data
n_points_1 = 500
n_points_2 = 500
n_points_test_per_cluster = 50
true_coef_1 = [1, 1]
true_coef_2 = [1, 1]
r_1 = 1
r_2 = 1
center_1 = [0, 0]
center_2 = [0.2, 0.2]
prob_event_1 = 0.9
prob_event_2 = 0.9
lambda_weibull_1 = 10**(-6)
lambda_weibull_2 = 10**(-6)
v_weibull_1 = 2
v_weibull_2 = 2.2
n_features = len(true_coef_1)

rsd_1 = RandomSurvivalData(
    center=center_1,
    radius=r_1,
    coefficients=true_coef_1,
    prob_event=prob_event_1,
    lambda_weibull=lambda_weibull_1,
    v_weibull=v_weibull_1,
    time_cap=None,
    random_seed=90,
)

rsd_2 = RandomSurvivalData(
    center=center_2,
    radius=r_2,
    coefficients=true_coef_2,
    prob_event=prob_event_2,
    lambda_weibull=lambda_weibull_2,
    v_weibull=v_weibull_2,
    time_cap=None,
    random_seed=85,
)

rsdtest_1 = RandomSurvivalData(
    center=center_1,
    radius=r_1,
    coefficients=true_coef_1,
    prob_event=prob_event_1,
    lambda_weibull=lambda_weibull_1,
    v_weibull=v_weibull_1,
    time_cap=None,
    random_seed=11,
)

rsdtest_2 = RandomSurvivalData(
    center=center_2,
    radius=r_2,
    coefficients=true_coef_2,
    prob_event=prob_event_2,
    lambda_weibull=lambda_weibull_2,
    v_weibull=v_weibull_2,
    time_cap=None,
    random_seed=13,
)

# Train
X_1, time_to_event_1, delta_1 = rsd_1.random_survival_data(num_points=n_points_1)
X_2, time_to_event_2, delta_2 = rsd_2.random_survival_data(num_points=n_points_2)
z_1 = [(d, t) for d, t in zip(delta_1, time_to_event_1)]
z_2 = [(d, t) for d, t in zip(delta_2, time_to_event_2)]
y_1 = np.array(z_1, dtype=[("delta", np.bool_), ("time_to_event", np.float32)])
y_2 = np.array(z_2, dtype=[("delta", np.bool_), ("time_to_event", np.float32)])
X_train = np.concatenate((X_1, X_2), axis=0)
y_train = np.concatenate((y_1, y_2))
total_row_train = X_train.shape[0]
print('total_row_train:', total_row_train)
idx_train = np.arange(total_row_train)
np.random.shuffle(idx_train)
X_train = X_train[idx_train, :]
y_train = y_train[idx_train]

# Test
X_test_1, time_to_event_test_1, delta_test_1 = rsdtest_1.random_survival_data(num_points=n_points_test_per_cluster)
z_test_1 = [(d, t) for d, t in zip(delta_test_1, time_to_event_test_1)]
y_test_1 = np.array(z_test_1, dtype=[("delta", np.bool_), ("time_to_event", np.float32)])
X_test_2, time_to_event_test_2, delta_test_2 = rsdtest_2.random_survival_data(num_points=n_points_test_per_cluster)
z_test_2 = [(d, t) for d, t in zip(delta_test_2, time_to_event_test_2)]
y_test_2 = np.array(z_test_2, dtype=[("delta", np.bool_), ("time_to_event", np.float32)])
X_test = np.concatenate((X_test_1, X_test_2), axis=0)
y_test = np.concatenate((y_test_1, y_test_2))
total_row_test = X_test.shape[0]
idx_test = np.arange(total_row_test)
np.random.shuffle(idx_test)
X_test = X_test[idx_test, :]
y_test = y_test[idx_test]


print("X_train.shape:", X_train.shape)
print("X_test.shape:", X_test.shape)
print("y_train.shape:", y_train.shape)
print("y_test.shape:", y_test.shape)

In [None]:
# Point to explain
rsd_new = RandomSurvivalData(
    center=center_1,
    radius=r_1,
    coefficients=true_coef_1,
    prob_event=prob_event_1,
    lambda_weibull=lambda_weibull_1,
    v_weibull=v_weibull_1,
    time_cap=None,
    random_seed=55,
)

new_point, _, _ = rsd_new.random_survival_data(num_points=1)
x_new_list = new_point[0].tolist()
x_new = np.array(x_new_list)
x_new_shap = np.array([x_new])
print(x_new)

In [None]:
plt.scatter(X_1[:, 0], X_1[:, 1])
plt.scatter(X_2[:, 0], X_2[:, 1])

In [None]:
plt.scatter(X_train[:, 0], X_train[:, 1])

In [None]:
plt.scatter(X_test[:, 0], X_test[:, 1])

In [None]:
plt.hist(time_to_event_1)
plt.hist(time_to_event_2)

In [None]:
# Fit a Cox model
model = CoxPHSurvivalAnalysis()
model.fit(X_train, y_train)
print(model.coef_)

In [None]:
x_prod_b = [x * g for x, g in zip(x_new, true_coef_1)]
print(x_prod_b)

In [None]:
# SurvLime for COX
explainer = SurvLimeExplainer(
    training_data=X_test,
    target_data=y_test,
    model_output_times=model.event_times_,
    sample_around_instance=True,
)

b, result = explainer.explain_instance(
    data_row=x_new,
    predict_fn=partial(model.predict_cumulative_hazard_function, return_array=True),
    num_samples=1000,
)

print('b:', b)
importance = [x * c for x, c in zip(x_new, b)]
print('importance:', importance)
error_survlime = 1/n_features * np.sum(np.square([a1 - a2 for a1, a2 in zip(importance, x_prod_b)]))
print('error_survlime:', error_survlime)

In [None]:
# SHAP for Cox
explainer = KernelExplainer(model.predict, X_test)
shap_values = explainer.explain(x_new_shap)
print('shap_values:', shap_values)
error_shap = 1/n_features * np.sum(np.square([a1 - a2 for a1, a2 in zip(shap_values, x_prod_b)]))
print('error_shap:', error_shap)