In [None]:
%load_ext autoreload
%autoreload 2
from pycox.models import CoxPH
from survlimepy import SurvLimeExplainer
from survlimepy.load_datasets import Loader
import torchtuples as tt
import numpy as np
import pandas as pd
from sksurv.metrics import concordance_index_censored


In [None]:
# Load UDCA dataset
loader = Loader(dataset_name='udca')
X, events, times = loader.load_data()
X_transformed = X.to_numpy().astype('float32')

In [None]:
# Split between train and test
n = X.shape[0]
n_train = round(0.8 * n)
training_features = X_transformed[:n_train]
training_events = events[:n_train]
training_times = times[:n_train]
training_target = (np.array(training_times), np.array(training_events))
unique_times = np.sort(np.unique(training_times))
test_features = X_transformed[n_train:]

In [None]:
in_features = X.shape[1]
num_nodes = [2, 2]
batch_norm = True
dropout = 0.1
output_bias = False
batch_size = 256
epochs = 512

In [None]:
net_deep_surv = tt.practical.MLPVanilla(in_features, num_nodes, 1, batch_norm, dropout, output_bias=output_bias)
deep_surv = CoxPH(net_deep_surv, tt.optim.Adam())
deep_surv.optimizer.set_lr(0.001)
log = deep_surv.fit(
    input=training_features,
    target=training_target,
    batch_size=batch_size,
    epochs=epochs,
    verbose=False
)

In [None]:
deep_surv.compute_baseline_cumulative_hazards()

In [None]:
def predict_chf(pred_fn):
    def inner(X):
        pred_values = pred_fn(X)
        pred_values = pred_values.to_numpy().T
        return pred_values
    return inner

In [None]:
example = np.array([[5.735963, -0.2648227, 1.055282, 0.05477221]]).reshape(1, -1).astype("float32")
deep_surv.predict_cumulative_hazards(example)

In [None]:
explainer = SurvLimeExplainer(
    H0=deep_surv.baseline_cumulative_hazards_.to_numpy(),
    training_features=training_features,
    training_events=training_events,
    training_times=training_times,
    model_output_times=unique_times,
    sample_around_instance=True,
    random_state=10,
)

b = explainer.montecarlo_explanation(
    data=test_features,
    predict_fn=predict_chf(deep_surv.predict_cumulative_hazards),
    num_samples=30,
    num_repetitions=1,
)
