In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
from survlimepy import SurvLimeExplainer
from survlimepy.load_datasets import RandomSurvivalData
import pandas as pd
from pycox.models import DeepHitSingle, CoxPH
from pycox.evaluation import EvalSurv
import torchtuples as tt

In [None]:
# Generate data for the first cluster
n_points_1 = 1000
true_coef_1 = [10**(-6), 0.1, -0.15, 10**(-6), 10**(-6)]
r_1 = 8
center_1 = [0, 0, 0, 0, 0]
prob_event_1 = 0.9
lambda_weibull_1 = 10**(-5)
v_weibull_1 = 2
n_features_1 = len(true_coef_1)

rsd_1 = RandomSurvivalData(
    center=center_1,
    radius=r_1,
    coefficients=true_coef_1,
    prob_event=prob_event_1,
    lambda_weibull=lambda_weibull_1,
    v_weibull=v_weibull_1,
    time_cap=2000,
    random_seed=90,
)

X_1, time_to_event_1, delta_1 = rsd_1.random_survival_data(num_points=n_points_1)

In [None]:

# Train test split for the first cluster
n_train_1 = 900
np.random.seed(90)
all_idx_1 = np.arange(X_1.shape[0])
idx_train_1 = np.random.choice(a=all_idx_1, size=n_train_1, replace=False)
idx_test_1 = [i for i in all_idx_1 if i not in idx_train_1]
X_train_1 = X_1[idx_train_1, :]
X_test_1 = X_1[idx_test_1, :]
time_to_event_train_1 = [time_to_event_1[i] for i in idx_train_1]
time_to_event_test_1 = [time_to_event_1[i] for i in idx_test_1]
delta_train_1 = [delta_1[i] for i in idx_train_1]
delta_test_1 = [delta_1[i] for i in idx_test_1]
z_train_1 = [(d, t) for d, t in zip(delta_train_1, time_to_event_train_1)]
y_train_1 = np.array(z_train_1, dtype=[("delta", np.bool_), ("time_to_event", np.float32)])

In [None]:
# Transform the data in order to have the DeepHit format
X_transformed_train = X_train_1.astype('float32')
X_transformed_test = X_test_1.astype('float32')

In [None]:
get_target = lambda df: (df['duration'].values, df['event'].values)

In [None]:
y_df_train = pd.DataFrame(data={'duration': time_to_event_train_1, 'event': delta_train_1})
y_deepsurv_train = get_target(y_df_train)

In [None]:
y_df_test = pd.DataFrame(data={'duration': time_to_event_test_1, 'event': delta_test_1})
durations_test, events_test = get_target(y_df_test)

In [None]:
in_features = X_transformed_train.shape[1]
num_nodes = [32, 32]
batch_norm = True
dropout = 0.1
output_bias = False
batch_size = 256
epochs = 512

In [None]:
net_deep_surv = tt.practical.MLPVanilla(in_features, num_nodes, 1, batch_norm,dropout, output_bias=output_bias)
deep_surv = CoxPH(net_deep_surv, tt.optim.Adam())
deep_surv.optimizer.set_lr(0.001)
log = deep_surv.fit(
    input=X_transformed_train,
    target=y_deepsurv_train,
    batch_size=batch_size,
    epochs=epochs,
    verbose=False
)

In [None]:
deep_surv.compute_baseline_hazards()

In [None]:
predictions = deep_surv.predict_surv_df(X_transformed_test)

In [None]:
ev = EvalSurv(predictions, durations_test, events_test, censor_surv='km')

In [None]:
ev.concordance_td()

In [None]:
def create_chf(fun):
    def inner(X):
        Y = fun(X)
        return Y.T
    return inner

predict_chf = create_chf(deep_surv.predict_cumulative_hazards)

In [None]:
X_test_1[0]

In [None]:
explainer_deepsurv = SurvLimeExplainer(
    training_features=X_transformed_test,
    training_events=delta_test_1,
    training_times=time_to_event_test_1,
    model_output_times=np.sort(np.unique(time_to_event_train_1)),
    random_state=10,
)

b_deepsurv = explainer_deepsurv.explain_instance(
    data_row=X_test_1[0],
    predict_fn=predict_chf,
    num_samples=1000,
    verbose=False,
)

explainer_deepsurv.plot_weights()

In [None]:
true_coef_1 = [10**(-6), 0.1, -0.15, 10**(-6), 10**(-6)]

In [None]:
explainer_deepsurv = SurvLimeExplainer(
    training_features=X_transformed_test,
    training_events=delta_test_1,
    training_times=time_to_event_test_1,
    model_output_times=np.sort(np.unique(time_to_event_train_1)),
    random_state=10,
)

b_deepsurv = explainer_deepsurv.montecarlo_explanation(
    data=X_test_1[:10],
    predict_fn=predict_chf,
    num_samples=1000,
    verbose=False,
)

In [None]:
explainer_deepsurv.plot_montecarlo_weights()

In [None]:
b_deepsurv