# SurvSHAP(t): Time-Dependent Explanations Of Machine Learning Survival Models
### M. Krzyziński, M. Spytek, H. Baniecki, P. Biecek
## Experiment 2: Comparison to SurvLIME
### Generating SurvSHAP(t) explanations for dataset1

#### Imports

In [None]:
import pandas as pd
import numpy as np 
import pickle
from tqdm import tqdm
np.random.seed(123)

#### Reading data

In [None]:
dataset1_train = pd.read_csv("data/exp2_dataset1_train.csv")
dataset1_test = pd.read_csv("data/exp2_dataset1_test.csv")

In [None]:
from sksurv.util import Surv
X_train = dataset1_train.iloc[:, :5]
X_test = dataset1_test.iloc[:, :5]
y_train = Surv.from_dataframe("event", "time", dataset1_train)
y_test = Surv.from_dataframe("event", "time", dataset1_test)

#### Creating models

In [None]:
from sksurv.linear_model import CoxPHSurvivalAnalysis
cph = CoxPHSurvivalAnalysis()
cph.fit(X_train, y_train)

In [None]:
from sksurv.ensemble import RandomSurvivalForest
rsf = RandomSurvivalForest(n_estimators=150, max_depth=12, max_features=3, min_samples_leaf=6, min_samples_split=10, random_state=123, )
rsf.fit(X_train, y_train)

##### Creating explanations

In [None]:
from survshap import SurvivalModelExplainer, PredictSurvSHAP, ModelSurvSHAP
cph_exp = SurvivalModelExplainer(cph, X_train, y_train)
rsf_exp = SurvivalModelExplainer(rsf, X_train, y_train)

In [None]:
n_obs = len(X_test)
survshaps = [None for i in range(n_obs)]
for i, obs in tqdm(enumerate(X_test.values)):
    xx = pd.DataFrame(np.atleast_2d(obs), columns=cph_exp.data.columns)
    survshap = PredictSurvSHAP()
    survshap.fit(cph_exp, xx)
    survshaps[i] = survshap

In [None]:
with open("pickles/exp2_survshap_dataset1_cph", "wb") as file:
    pickle.dump(survshaps, file)

In [None]:
n_obs = len(X_test)
survshaps = [None for i in range(n_obs)]
for i, obs in tqdm(enumerate(X_test.values)):
    xx = pd.DataFrame(np.atleast_2d(obs), columns=rsf_exp.data.columns)
    survshap = PredictSurvSHAP()
    survshap.fit(rsf_exp, xx)
    survshaps[i] = survshap

In [None]:
with open("pickles/exp2_survshap_dataset1_rsf", "wb") as file:
    pickle.dump(survshaps, file)