In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest
from survlimepy import SurvLimeExplainer
from survlimepy.load_datasets import RandomSurvivalData
from xgbse import XGBSEKaplanNeighbors
import pandas as pd
from pycox.models import DeepHitSingle, CoxPH
import torchtuples as tt

In [None]:
# Generate data
n_points = 500
true_coef = [1, 2]
r = 1
center = [0, 0]
prob_event = 0.9
lambda_weibull = 10**(-6)
v_weibull = 2
n_features = len(true_coef)

rsd = RandomSurvivalData(
    center=center,
    radius=r,
    coefficients=true_coef,
    prob_event=prob_event,
    lambda_weibull=lambda_weibull,
    v_weibull=v_weibull,
    time_cap=None,
    random_seed=90,
)

# Train
X, time_to_event, delta = rsd.random_survival_data(num_points=n_points)
z = [(d, t) for d, t in zip(delta, time_to_event)]
y = np.array(z, dtype=[("delta", np.bool_), ("time_to_event", np.float32)])
total_row_train = X.shape[0]
print('total_row_train:', total_row_train)
unique_times = np.sort(np.unique(time_to_event))

In [None]:
# Fit a Cox model
cox = CoxPHSurvivalAnalysis()
cox.fit(X, y)
print(cox.coef_)

In [None]:
x_new = [0, 0]

In [None]:
# SurvLime for COX
explainer_cox = SurvLimeExplainer(
    training_features=X,
    training_events=[tp[0] for tp in y],
    training_times=[tp[1] for tp in y],
    model_output_times=cox.event_times_,
    sample_around_instance=True,
    random_state=10,
)

b_cox = explainer_cox.explain_instance(
    data_row=x_new,
    predict_fn=cox.predict_cumulative_hazard_function,
    num_samples=100,
    verbose=False,
)

print(b_cox)
#explainer_cox.plot_weights()

In [None]:
prediction = cox.predict_cumulative_hazard_function(np.array(x_new).reshape(1, -1))

In [None]:
len(prediction)