# SurvSHAP(t): Time-Dependent Explanations Of Machine Learning Survival Models
### M. Krzyziński, M. Spytek, H. Baniecki, P. Biecek
## Experiment 1: Evaluating explanations on synthetic data
### DATASET: `EXP1_exponential`

#### Imports

In [1]:
import pandas as pd
import numpy as np 
import pickle
import itertools
from copy import deepcopy
from tqdm import tqdm
np.random.seed(123)

#### Preparing data and models 

In [3]:
data = pd.read_csv("/home/jkrajewski/survshap/data/exp1_data_exponential.csv")

In [4]:
data.shape

(1000, 8)

In [None]:
from sksurv.util import Surv
X = data.iloc[:, :5]
y = Surv.from_dataframe("event", "time", data)

In [None]:
from sksurv.linear_model import CoxPHSurvivalAnalysis
cph = CoxPHSurvivalAnalysis()
cph.fit(X, y)
cph.score(X, y)

In [None]:
from sksurv.ensemble import RandomSurvivalForest
rsf = RandomSurvivalForest(random_state=42, n_estimators=100, min_samples_split=8, min_samples_leaf=4, max_features=3, max_samples=0.8)
rsf.fit(X, y)
rsf.score(X, y)

#### Evaluating performance of models 

In [None]:
from sksurv.metrics import brier_score, integrated_brier_score
# mask created to enable for calculating Brier score
mask = (y["time"] < y[y["event"]==1]["time"].max()) & (y["time"] > y[y["event"]==1]["time"].min())
times = np.percentile(y[mask]["time"], np.linspace(0.1, 99.9, 101))

In [None]:
survs_rsf = rsf.predict_survival_function(X[mask])
survs_cph = cph.predict_survival_function(X[mask])
preds_rsf = [fn(times) for fn in survs_rsf]
preds_cph = [fn(times) for fn in survs_cph]
brier_rsf = brier_score(y, y[mask], preds_rsf, times)
brier_cph = brier_score(y, y[mask], preds_cph, times)

In [None]:
pd.concat([pd.DataFrame({"time": times, "brier_score":  brier_rsf[1], "label": "RSF"}),
            pd.DataFrame({"time": times, "brier_score":  brier_cph[1], "label": "CPH"})]).to_csv("results/exp1_exponential_model_brier_score.csv")

In [None]:
integrated_brier_score(y, y[mask], preds_rsf, times), integrated_brier_score(y, y[mask], preds_cph, times)

#### Creating explanations

In [None]:
from survshap import SurvivalModelExplainer, PredictSurvSHAP, ModelSurvSHAP
rsf_exp = SurvivalModelExplainer(rsf, X, y)
cph_exp = SurvivalModelExplainer(cph, X, y)

In [None]:
exp1_survshap_global_rsf = ModelSurvSHAP(random_state=42)
exp1_survshap_global_rsf.fit(rsf_exp)

In [None]:
import pickle
with open("pickles/exp1_exponential_survshap_global_rsf", "wb") as file:
    pickle.dump(exp1_survshap_global_rsf, file)

In [None]:
exp1_survshap_global_cph = ModelSurvSHAP(random_state=42)
exp1_survshap_global_cph.fit(cph_exp)

In [None]:
with open("pickles/exp1_exponential_survshap_global_cph", "wb") as file:
    pickle.dump(exp1_survshap_global_cph, file)

In [None]:
with open("pickles/exp1_exponential_survshap_global_rsf", "rb") as file:
    exp1_survshap_global_rsf = pickle.load(file)
with open("pickles/exp1_exponential_survshap_global_cph", "rb") as file:
    exp1_survshap_global_cph = pickle.load(file)

##### Changing Signs Proportion

In [None]:
shap_signs_rsf = np.sign(exp1_survshap_global_rsf.full_result.iloc[:, 6:].values)
timestamps_rsf = exp1_survshap_global_rsf.timestamps

shap_signs_cph = np.sign(exp1_survshap_global_cph.full_result.iloc[:, 6:].values)
timestamps_cph = exp1_survshap_global_cph.timestamps

In [None]:
start_index, end_index = np.where((timestamps_rsf >= np.percentile(timestamps_rsf, 10)) & (timestamps_rsf <= np.percentile(timestamps_rsf, 90)))[0][[0, -1]]

In [None]:
sign_ranges = []
for row in shap_signs_rsf:
    sign_ranges_row = []
    last_sign = row[start_index]
    start_time_sign_sequence = timestamps_rsf[start_index]
    for i in range(start_index, end_index+1):
        if row[i] != last_sign and row[i] != 0:
            sign_ranges_row.append(last_sign*(timestamps_rsf[i-1] - start_time_sign_sequence))
            start_time_sign_sequence = timestamps_rsf[i-1]
        if row[i] != 0:
            last_sign = row[i] 
    sign_ranges_row.append(last_sign*(timestamps_rsf[i] - start_time_sign_sequence))
    sign_ranges.append(sign_ranges_row)

In [None]:
negative_range = [sum(sign_seq_range for sign_seq_range in sign_ranges_row if sign_seq_range < 0) for sign_ranges_row in sign_ranges]
positive_range = [sum(sign_seq_range for sign_seq_range in sign_ranges_row if sign_seq_range > 0) for sign_ranges_row in sign_ranges]
timestamps_range = timestamps_rsf[end_index] - timestamps_rsf[start_index]

In [None]:
changed_signs_005 = (np.abs(np.array(negative_range)) >= 0.05 * timestamps_range) & (np.array(positive_range) >= 0.05 * timestamps_range)
changed_signs_01 = (np.abs(np.array(negative_range)) >= 0.1 * timestamps_range) & (np.array(positive_range) >= 0.1 * timestamps_range)
changed_signs_02 = (np.abs(np.array(negative_range)) >= 0.2 * timestamps_range) & (np.array(positive_range) >= 0.2 * timestamps_range)
csp_rsf = pd.DataFrame({"variable_name": exp1_survshap_global_rsf.full_result.variable_name, 
                                "variable_value": exp1_survshap_global_rsf.full_result.variable_value, 
                                "index": exp1_survshap_global_rsf.full_result.index, 
                                "changed_signs_0.05": changed_signs_005,
                                "changed_signs_0.1": changed_signs_01,
                                "changed_signs_0.2": changed_signs_02})

In [None]:
start_index, end_index = np.where((timestamps_cph >= np.percentile(timestamps_rsf, 10)) & (timestamps_cph <= np.percentile(timestamps_rsf, 90)))[0][[0, -1]]

In [None]:
sign_ranges = []
for row in shap_signs_cph:
    sign_ranges_row = []
    last_sign = row[start_index]
    start_time_sign_sequence = timestamps_cph[start_index]
    for i in range(start_index, end_index):
        if row[i] != last_sign and row[i] != 0:
            sign_ranges_row.append(last_sign*(timestamps_cph[i-1] - start_time_sign_sequence))
            start_time_sign_sequence = timestamps_cph[i-1]
        if row[i] != 0:
            last_sign = row[i] 
    sign_ranges_row.append(last_sign*(timestamps_cph[i] - start_time_sign_sequence))
    sign_ranges.append(sign_ranges_row)

In [None]:
negative_range = [sum(sign_seq_range for sign_seq_range in sign_ranges_row if sign_seq_range < 0) for sign_ranges_row in sign_ranges]
positive_range = [sum(sign_seq_range for sign_seq_range in sign_ranges_row if sign_seq_range > 0) for sign_ranges_row in sign_ranges]
timestamps_range = timestamps_cph[-1] - timestamps_cph[0]

In [None]:
changed_signs_005 = (np.abs(np.array(negative_range)) >= 0.05 * timestamps_range) & (np.array(positive_range) >= 0.05 * timestamps_range)
changed_signs_01 = (np.abs(np.array(negative_range)) >= 0.1 * timestamps_range) & (np.array(positive_range) >= 0.1 * timestamps_range)
changed_signs_02 = (np.abs(np.array(negative_range)) >= 0.2 * timestamps_range) & (np.array(positive_range) >= 0.2 * timestamps_range)
csp_cph = pd.DataFrame({"variable_name": exp1_survshap_global_cph.full_result.variable_name, 
                                "variable_value": exp1_survshap_global_cph.full_result.variable_value, 
                                "index": exp1_survshap_global_cph.full_result.index, 
                                "changed_signs_0.05": changed_signs_005,
                                "changed_signs_0.1": changed_signs_01,
                                "changed_signs_0.2": changed_signs_02})

In [None]:
csp_rsf.groupby("variable_name").mean()

In [None]:
csp_cph.groupby("variable_name").mean()

##### Local accuracy

In [None]:
def get_local_accuracy_from_shap_explanations(all_explanation, label, last_index=None):
    if last_index is None:
        last_index=len(all_explanation.timestamps)
    diffs = []
    preds = []
    for explanation in all_explanation.individual_explanations:
        preds.append(explanation.predicted_function[:last_index])
        diffs.append(explanation.predicted_function[:last_index] - explanation.baseline_function[:last_index] - np.array(explanation.result.iloc[:, 6:].sum(axis=0))[:last_index])
    diffs_squared = np.array(diffs)**2
    E_diffs_sqared = np.mean(diffs_squared, axis=0)
    preds_squared = np.array(preds)**2
    E_preds_squared = np.mean(preds_squared, axis=0)
    return  pd.DataFrame({"time": all_explanation.timestamps[:last_index], "sigma": np.sqrt(E_diffs_sqared) / np.sqrt(E_preds_squared), "label": label})

In [None]:
local_accuracy_rsf = get_local_accuracy_from_shap_explanations(exp1_survshap_global_rsf, "RSF")

In [None]:
local_accuracy_cph = get_local_accuracy_from_shap_explanations(exp1_survshap_global_cph, "CPH")

In [None]:
pd.concat([local_accuracy_rsf, local_accuracy_cph]).to_csv("results/exp1_exponential_local_accuracy_exp.csv")