In [None]:
%load_ext autoreload
%autoreload 
import os
os.chdir(os.path.dirname((os.path.dirname(os.getcwd()))))
from functools import partial
import cvxpy as cp
import numpy as np
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sklearn.model_selection import train_test_split
from survLime.datasets.load_datasets import RandomSurvivalData
from survLime import survlime_tabular
from survLime.scripts.experiment_1 import create_clusters
from survLime.utils.generic_utils import compare_survival_times
np.random.seed(42)

In [None]:
# Create clusters can be found in scripts/experiment_1
cluster_0, cluster_1 = create_clusters()
x_train_1, x_test_1, y_train_1, y_test_1 = train_test_split(cluster_0[0],
                                                            cluster_0[1], test_size=0.1)
times = [x[1] for x in y_train_1]
times_to_fill = list(set(times))
times_to_fill.sort() 
m = len(times_to_fill)

In [None]:
# Train bb model
columns =[f'feat_{i}' for i in range(x_train_1.shape[1])]
model = CoxPHSurvivalAnalysis(alpha=0.0001)
model.fit(x_train_1, y_train_1)
model.feature_names_in_ = columns # This is needed in order to compare survival times later

# Obtain a test point to use for prediction
test_point = x_test_1[0]

In [None]:
# Wrapper for predict function
predict_chf = partial(model.predict_cumulative_hazard_function, return_array=True)

In [None]:
# Baseline cumulative hazard 
H0 = model.cum_baseline_hazard_.y.reshape(m, 1)

In [None]:
explainer = survlime_tabular.LimeTabularExplainer(x_train_1,
                                                  y_train_1,
                                                  feature_names=columns,
                                                  H0=H0,
                                                  verbose=True,
                                                  discretize_continuous=False)
num_neighbours = 1000
# From here we are only using log_correction, Ho_t_ and inverse 
H, weights, log_correction, scaled_data, b, opt_value = explainer.explain_instance(test_point,predict_chf,
                                                                                   verbose=True,
                                                                                   num_samples = num_neighbours)

In [None]:
print(opt_value)
print('--------------------')
print(b.T)
print('--------------------')
print(model.coef_)

In [None]:
# Coefficient coppied from survLime/scripts/experiment_1.create_clusters
coefficients = [10**(-6), 0.1,  -0.15, 10**(-6), 10**(-6)]
values = [x[0] for x in b]
print(coefficients)
print(values)
print(model.coef_)
compare_survival_times(model, values, x_train_1, y_train_1, x_test_1, true_coef=coefficients)