# SurvSHAP(t): Time-Dependent Explanations Of Machine Learning Survival Models
### M. Krzyziński, M. Spytek, H. Baniecki, P. Biecek
## Experiment 2: Comparison to SurvLIME
### Generating SurvLIME explanations for dataset1

#### Imports

In [None]:
import pandas as pd
import numpy as np 
import pickle
from tqdm import tqdm
np.random.seed(123)

#### Preparing data - split saved for further use

In [None]:
dataset1 = pd.read_csv("data/exp2_dataset1.csv")

In [None]:
from sklearn.model_selection import train_test_split
dataset1_train, dataset1_test = train_test_split(dataset1, test_size=0.1, random_state=123, )
dataset1_train.to_csv("data/exp2_dataset1_train.csv", index=False)
dataset1_test.to_csv("data/exp2_dataset1_test.csv", index=False)

In [None]:
dataset1_train = pd.read_csv("data/exp2_dataset1_train.csv")
dataset1_test = pd.read_csv("data/exp2_dataset1_test.csv")

#### Creating models

In [None]:
from sksurv.util import Surv
X_train = dataset1_train.iloc[:, :5]
X_test = dataset1_test.iloc[:, :5]
y_train = Surv.from_dataframe("event", "time", dataset1_train)
y_test= Surv.from_dataframe("event", "time", dataset1_test)

In [None]:
from sksurv.linear_model import CoxPHSurvivalAnalysis
cph = CoxPHSurvivalAnalysis()
cph.fit(X_train, y_train)

In [None]:
from sksurv.ensemble import RandomSurvivalForest
rsf = RandomSurvivalForest(n_estimators=150, max_depth=12, max_features=3, min_samples_leaf=6, min_samples_split=10, random_state=123)
rsf.fit(X_train, y_train)

#### Evaluating performance of models 

In [None]:
from sksurv.metrics import brier_score, integrated_brier_score
# mask created to enable for calculating Brier score
mask = (y_test["time"] < y_train[y_train["event"]==1]["time"].max()) & (y_test["time"] > y_train[y_train["event"]==1]["time"].min())
times = np.percentile(y_test[mask]["time"], np.linspace(0.1, 99.9, 101))

In [None]:
survs_rsf = rsf.predict_survival_function(X_test[mask])
survs_cph = cph.predict_survival_function(X_test[mask])
preds_rsf = [fn(times) for fn in survs_rsf]
preds_cph = [fn(times) for fn in survs_cph]
brier_rsf = brier_score(y_train, y_test[mask], preds_rsf, times)
brier_cph = brier_score(y_train, y_test[mask], preds_cph, times)

In [None]:
pd.concat([pd.DataFrame({"time": times, "brier_score":  brier_rsf[1], "label": "RSF"}),
            pd.DataFrame({"time": times, "brier_score":  brier_cph[1], "label": "CPH"})]).to_csv("results/exp2_dataset1_model_brier_score.csv")

In [None]:
integrated_brier_score(y_train, y_test[mask], preds_rsf, times), integrated_brier_score(y_train, y_test[mask], preds_cph, times)

#### Calculating permutational variable importance

In [None]:
times = np.percentile(y_test["time"], np.linspace(10, 90, 90))
def loss_integrated_brier_score(model, data, y):
    sfs = model.predict_survival_function(data)
    sfs = [sf(times) for sf in sfs] 
    result = integrated_brier_score(y_train, y, sfs, times)
    return 1-result

In [None]:
from sklearn.inspection import permutation_importance
perm_imp = permutation_importance(rsf, X_test[y_test["time"]<1000], y_test[y_test["time"]<1000], 
                    scoring = loss_integrated_brier_score, n_repeats=100, random_state=42)
perm_imp["importances_mean"]

##### Creating explanations

In [None]:
from survshap import SurvivalModelExplainer
cph_exp = SurvivalModelExplainer(cph, X_train, y_train)
rsf_exp = SurvivalModelExplainer(rsf, X_train, y_train)

In [None]:
### kernel and neighbourhood based on SurvLIME paper
def kernel(distance):
    return 1 - np.sqrt(distance / 0.5)
def generate_neighbours(ind):
    count_neighbours = 0
    neighbours = np.zeros((1000, 5))
    while count_neighbours < 1000:
        neigh = np.random.random(5) - 0.5
        if np.sqrt(np.sum(neigh**2)) <= 0.5:
            neighbours[count_neighbours, ] = neigh
            count_neighbours += 1
    neighbours += X_test.iloc[[ind]].values
    neighbours[0, ] = X_test.iloc[[ind]].values
    return neighbours

In [None]:
from survlime import SurvLIME
n_obs = len(X_test)
sls = [None for i in range(n_obs)]
funcs_dist = [None for i in range(n_obs)]
for i, obs in tqdm(enumerate(X_test.values)):
    xx = pd.DataFrame(np.atleast_2d(obs), columns=cph_exp.data.columns)
    survlime = SurvLIME(N=1000)
    survlime.fit(cph_exp, xx, kernel=kernel, neighbourhood=generate_neighbours(i), k=1)
    sls[i] = survlime

In [None]:
with open("pickles/exp2_survlime_dataset1_cph", "wb") as f:
    pickle.dump(sls, f)

In [None]:
n_obs = len(X_test)
sls = [None for i in range(n_obs)]
funcs_dist = [None for i in range(n_obs)]
for i, obs in tqdm(enumerate(X_test.values)):
    xx = pd.DataFrame(np.atleast_2d(obs), columns=rsf_exp.data.columns)
    survlime = SurvLIME(N=1000)
    survlime.fit(rsf_exp, xx, kernel=kernel, neighbourhood=generate_neighbours(i), k=1+1e-4)
    sls[i] = survlime

In [None]:
with open("pickles/exp2_survlime_dataset1_rsf", "wb") as f:
    pickle.dump(sls, f)