In [None]:
%load_ext autoreload
%autoreload 2
from sksurv.linear_model import CoxPHSurvivalAnalysis
from survlimepy import SurvLimeExplainer
from survlimepy.load_datasets import RandomSurvivalData
import numpy as np

In [None]:
# Generate data
n_points = 500
true_coef = [0.1, 0.5, -1.7, -0.2]
r = 1
center = [0, 0, 0, 0]
prob_event = 0.9
lambda_weibull = 10**(-6)
v_weibull = 2
n_features = len(true_coef)

rsd = RandomSurvivalData(
    center=center,
    radius=r,
    coefficients=true_coef,
    prob_event=prob_event,
    lambda_weibull=lambda_weibull,
    v_weibull=v_weibull,
    time_cap=None,
    random_seed=90,
)

# Train
X, time_to_event, delta = rsd.random_survival_data(num_points=n_points)
z = [(d, t) for d, t in zip(delta, time_to_event)]
y = np.array(z, dtype=[("delta", np.bool_), ("time_to_event", np.float32)])
total_row_train = X.shape[0]
print('total_row_train:', total_row_train)
unique_times = np.sort(np.unique(time_to_event))

In [None]:
# Point to explain
x_new = np.array(center)

In [None]:
# Fit a Cox model
cox = CoxPHSurvivalAnalysis()
cox.fit(X, y)
print(cox.coef_)

In [None]:
# SurvLime for COX
explainer = SurvLimeExplainer(
    training_features=X,
    training_events=[tp[0] for tp in y],
    training_times=[tp[1] for tp in y],
    model_output_times=cox.event_times_,
    sample_around_instance=True,
    random_state=10,
)

b = explainer.explain_instance(
    data_row=x_new,
    predict_fn=cox.predict_cumulative_hazard_function,
    num_samples=1000,
    verbose=False,
)

print('b:', b)

In [None]:
explainer.plot_weights()

In [None]:
mc_b = explainer.montecarlo_explanation(
        data=X[:2, :],
        predict_fn=cox.predict_cumulative_hazard_function,
        num_samples=100,
        num_repetitions=10,
)

In [None]:

explainer.plot_montecarlo_weights()

In [None]:
print(np.argsort(true_coef)[::-1])

In [None]:
print(mc_b)