# SurvSHAP(t): Time-Dependent Explanations Of Machine Learning Survival Models
### M. Krzyziński, M. Spytek, H. Baniecki, P. Biecek
## Experiment 1: Evaluating explanations on synthetic data

#### Imports

In [None]:
import pandas as pd
import numpy as np 
import pickle
import itertools
from copy import deepcopy
from tqdm import tqdm
np.random.seed(123)

#### Preparing data and models 

In [None]:
data = pd.read_csv("data/exp1_data.csv")

In [None]:
from sksurv.util import Surv
X = data.iloc[:, :5]
y = Surv.from_dataframe("event", "time", data)

In [None]:
from sksurv.linear_model import CoxPHSurvivalAnalysis
cph = CoxPHSurvivalAnalysis()
cph.fit(X, y)
cph.score(X, y)

In [None]:
from sksurv.ensemble import RandomSurvivalForest
rsf = RandomSurvivalForest(random_state=42, n_estimators=100, min_samples_split=8, min_samples_leaf=4, max_features=3, max_samples=0.8)
rsf.fit(X, y)
rsf.score(X, y)

#### Evaluating performance of models 

In [None]:
from sksurv.metrics import brier_score, integrated_brier_score
# mask created to enable for calculating Brier score
mask = (y["time"] < y[y["event"]==1]["time"].max()) & (y["time"] > y[y["event"]==1]["time"].min())
times = np.percentile(y[mask]["time"], np.linspace(0.1, 99.9, 101))

In [None]:
survs_rsf = rsf.predict_survival_function(X[mask])
survs_cph = cph.predict_survival_function(X[mask])
preds_rsf = [fn(times) for fn in survs_rsf]
preds_cph = [fn(times) for fn in survs_cph]
brier_rsf = brier_score(y, y[mask], preds_rsf, times)
brier_cph = brier_score(y, y[mask], preds_cph, times)

In [None]:
pd.concat([pd.DataFrame({"time": times, "brier_score":  brier_rsf[1], "label": "RSF"}),
            pd.DataFrame({"time": times, "brier_score":  brier_cph[1], "label": "CPH"})]).to_csv("results/exp1_model_brier_score.csv")

In [None]:
integrated_brier_score(y, y[mask], preds_rsf, times), integrated_brier_score(y, y[mask], preds_cph, times)

#### Creating explanations

In [None]:
from survshap import SurvivalModelExplainer, PredictSurvSHAP, ModelSurvSHAP
rsf_exp = SurvivalModelExplainer(rsf, X, y)
cph_exp = SurvivalModelExplainer(cph, X, y)

In [None]:
exp1_survshap_global_rsf = ModelSurvSHAP(random_state=42)
exp1_survshap_global_rsf.fit(rsf_exp)

In [None]:
import pickle
with open("pickles/exp1_survshap_global_rsf", "wb") as file:
    pickle.dump(exp1_survshap_global_rsf, file)

In [None]:
exp1_survshap_global_cph = ModelSurvSHAP(random_state=42)
exp1_survshap_global_cph.fit(cph_exp)

In [None]:
with open("pickles/exp1_survshap_global_cph", "wb") as file:
    pickle.dump(exp1_survshap_global_cph, file)

#### Calculating ground-truth SurvSHAP

In [None]:
import math
from sklearn.metrics import r2_score
def shap_kernel(
    explainer, new_observation, function_type, timestamps, baseline_f, simplified_inputs, kernel_weights, n
):
    data = generate_data(n)

    shap_values, r2 = calculate_shap_values(
        explainer.model,
        function_type,
        baseline_f,
        data,
        simplified_inputs,
        kernel_weights,
        new_observation,
        timestamps,
    )
    result_shap = pd.DataFrame(
        shap_values, columns=[" = ".join(["t", str(time)]) for time in timestamps]
    )

    return result_shap, r2


def generate_shap_kernel_weights(simplified_inputs, num_variables):
    weights = []
    for coalition_vector in simplified_inputs:
        num_available_variables = np.count_nonzero(coalition_vector)
        if num_available_variables == 0 or num_available_variables == num_variables:
            weights.append(1e9)
        else:
            weights.append(
                (num_variables - 1)
                / (
                    math.comb(num_variables, num_available_variables)
                    * num_available_variables
                    * (num_variables - num_available_variables)
                )
            )
    return weights


def make_prediction_for_simplified_input(
    model, function_type, data, simplified_inputs, new_observation, timestamps
):
    preds = np.zeros((len(simplified_inputs), len(timestamps)))
    for i, mask in enumerate(simplified_inputs):
        X_tmp = pd.DataFrame(
            np.where(mask, new_observation, data), columns=data.columns
        )
        preds[
            i,
        ] = calculate_mean_function(model, function_type, X_tmp, timestamps)
    return preds

def calculate_mean_function(model, function_type, data, timestamps):
    if function_type == "sf":
        all_functions = model.predict_survival_function(data)
    elif function_type == "chf":
        all_functions = model.predict_cumulative_hazard_function(data)
    all_function_vals = [f(timestamps) for f in all_functions]
    return np.mean(all_function_vals, axis=0)


def calculate_shap_values(
    model,
    function_type,
    avg_function,
    data,
    simplified_inputs,
    shap_kernel_weights,
    new_observation,
    timestamps,
):
    W = np.diag(shap_kernel_weights)
    X = np.array(simplified_inputs)
    R = np.linalg.inv(X.T @ W @ X) @ (X.T @ W)
    y = (
        make_prediction_for_simplified_input(
            model, function_type, data, simplified_inputs, new_observation, timestamps
        )
        - avg_function
    )
    shap_values = R @ y
    y_pred = X @ shap_values
    r2 = [None] * y.shape[1]
    for i in range(y.shape[1]):
        r2[i] = r2_score(y[:, i], y_pred[:, i], sample_weight=shap_kernel_weights)
    return shap_values, r2

def generate_data(n):
    x1 = np.random.binomial(1, 0.5, n)
    x2 = np.random.binomial(1, 0.5, n)
    x3 = np.random.normal(10, 2, n)
    x4 = np.random.normal(20, 4, n)
    x5 = np.random.normal(0, 1, n)
    return  pd.DataFrame({"x1": x1, "x2": x2, "x3": x3, "x4": x4, "x5": x5})


In [None]:
all_functions_rsf = rsf.predict_survival_function(X)
all_functions_vals = [f.y for f in all_functions_rsf]
timestamps = all_functions_rsf[0].x
baseline_f = np.mean(all_functions_vals, axis=0)
simplified_inputs = [list(z) for z in itertools.product(range(2), repeat=5)]
kernel_weights = generate_shap_kernel_weights(simplified_inputs, 5)

In [None]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42, stratify=y["event"])

In [None]:
shap_groundtruth = pd.DataFrame()
to_calculate = list(X_test.index)
for i in tqdm(to_calculate):
    shap_gt = shap_kernel(
        rsf_exp, X.iloc[[i]], "sf", timestamps, baseline_f, simplified_inputs, kernel_weights, 10000
    )
    shap_gt[0].insert(0, "index", i)
    shap_groundtruth = pd.concat([shap_groundtruth, shap_gt[0]])

In [None]:
shap_groundtruth.to_csv("results/exp1_shap_groundtruth_rsf.csv")

In [None]:
import itertools
all_functions_cph = cph.predict_survival_function(X)
all_functions_vals = [f.y for f in all_functions_cph]
timestamps = all_functions_cph[0].x
baseline_f = np.mean(all_functions_vals, axis=0)
simplified_inputs = [list(z) for z in itertools.product(range(2), repeat=5)]
kernel_weights = generate_shap_kernel_weights(simplified_inputs, 5)

In [None]:
shap_groundtruth_cph = pd.DataFrame()
to_calculate = list(X_test.index)
for i in tqdm(to_calculate):
    shap_gt = shap_kernel(
        cph_exp, X.iloc[[i]], "sf", timestamps, baseline_f, simplified_inputs, kernel_weights, 10000
    )
    shap_gt[0].insert(0, "index", i)
    shap_groundtruth_cph = pd.concat([shap_groundtruth_cph, shap_gt[0]])

In [None]:
shap_groundtruth_cph.to_csv("results/exp1_shap_groundtruth_cph.csv")

#### Results analysis

In [None]:
shap_grountruth_rsf = pd.read_csv("results/exp1_shap_groundtruth_rsf.csv")
shap_grountruth_cph = pd.read_csv("results/exp1_shap_groundtruth_cph.csv")

In [None]:
with open("pickles/exp1_survshap_global_rsf", "rb") as file:
    exp1_survshap_global_rsf = pickle.load(file)
with open("pickles/exp1_survshap_global_cph", "rb") as file:
    exp1_survshap_global_cph = pickle.load(file)

##### Plot examples

In [None]:
# SurvSHAP(t) values plot example

example_rsf = deepcopy(exp1_survshap_global_rsf.individual_explanations[690])
example_cph = deepcopy(exp1_survshap_global_cph.individual_explanations[690])

melted_example_rsf = pd.melt(example_rsf.result, id_vars="variable_name", value_vars=example_rsf.result.columns[6:])
melted_example_rsf["variable"] = melted_example_rsf["variable"].str[4:].astype(float)
melted_example_rsf.to_csv("results/exp1_example_rsf.csv", index=False)
melted_example_cph = pd.melt(example_cph.result, id_vars="variable_name", value_vars=example_cph.result.columns[6:])
melted_example_cph["variable"] = melted_example_cph["variable"].str[4:].astype(float)
melted_example_cph.to_csv("results/exp1_example_cph.csv", index=False)

In [None]:
# normalized SurvSHAP(t) values plot example

example_rsf.result.iloc[:, 5:] = np.nan_to_num(
            example_rsf.result[example_rsf.result["B"] == 0].iloc[:, 5:]
            / example_rsf.result[example_rsf.result["B"] == 0].iloc[:, 5:].abs().sum())

example_cph.result.iloc[:, 5:] = np.nan_to_num(
            example_cph.result[example_cph.result["B"] == 0].iloc[:, 5:]
            / example_cph.result[example_cph.result["B"] == 0].iloc[:, 5:].abs().sum())

melted_example_norm_rsf = pd.melt(example_rsf.result, id_vars="variable_name", value_vars=example_rsf.result.columns[6:])
melted_example_norm_rsf["variable"] = melted_example_norm_rsf["variable"].str[4:].astype(float)
melted_example_norm_rsf.to_csv("results/exp1_example_norm_rsf.csv", index=False)

melted_example_norm_cph = pd.melt(example_cph.result, id_vars="variable_name", value_vars=example_cph.result.columns[6:])
melted_example_norm_cph["variable"] = melted_example_norm_cph["variable"].str[4:].astype(float)
melted_example_norm_cph.to_csv("results/exp1_example_norm_cph.csv", index=False)

##### Changing Signs Proportion

In [None]:
shap_signs_rsf = np.sign(exp1_survshap_global_rsf.full_result.iloc[:, 6:].values)
timestamps_rsf = exp1_survshap_global_rsf.timestamps

shap_signs_cph = np.sign(exp1_survshap_global_cph.full_result.iloc[:, 6:].values)
timestamps_cph = exp1_survshap_global_cph.timestamps

In [None]:
start_index, end_index = np.where((timestamps_rsf >= np.percentile(timestamps_rsf, 10)) & (timestamps_rsf <= np.percentile(timestamps_rsf, 90)))[0][[0, -1]]

In [None]:
sign_ranges = []
for row in shap_signs_rsf:
    sign_ranges_row = []
    last_sign = row[start_index]
    start_time_sign_sequence = timestamps_rsf[start_index]
    for i in range(start_index, end_index+1):
        if row[i] != last_sign and row[i] != 0:
            sign_ranges_row.append(last_sign*(timestamps_rsf[i-1] - start_time_sign_sequence))
            start_time_sign_sequence = timestamps_rsf[i-1]
        if row[i] != 0:
            last_sign = row[i] 
    sign_ranges_row.append(last_sign*(timestamps_rsf[i] - start_time_sign_sequence))
    sign_ranges.append(sign_ranges_row)

In [None]:
negative_range = [sum(sign_seq_range for sign_seq_range in sign_ranges_row if sign_seq_range < 0) for sign_ranges_row in sign_ranges]
positive_range = [sum(sign_seq_range for sign_seq_range in sign_ranges_row if sign_seq_range > 0) for sign_ranges_row in sign_ranges]
timestamps_range = timestamps_rsf[end_index] - timestamps_rsf[start_index]

In [None]:
changed_signs_005 = (np.abs(np.array(negative_range)) >= 0.05 * timestamps_range) & (np.array(positive_range) >= 0.05 * timestamps_range)
changed_signs_01 = (np.abs(np.array(negative_range)) >= 0.1 * timestamps_range) & (np.array(positive_range) >= 0.1 * timestamps_range)
changed_signs_02 = (np.abs(np.array(negative_range)) >= 0.2 * timestamps_range) & (np.array(positive_range) >= 0.2 * timestamps_range)
csp_rsf = pd.DataFrame({"variable_name": exp1_survshap_global_rsf.full_result.variable_name, 
                                "variable_value": exp1_survshap_global_rsf.full_result.variable_value, 
                                "index": exp1_survshap_global_rsf.full_result.index, 
                                "changed_signs_0.05": changed_signs_005,
                                "changed_signs_0.1": changed_signs_01,
                                "changed_signs_0.2": changed_signs_02})

In [None]:
start_index, end_index = np.where((timestamps_cph >= np.percentile(timestamps_rsf, 10)) & (timestamps_cph <= np.percentile(timestamps_rsf, 90)))[0][[0, -1]]

In [None]:
sign_ranges = []
for row in shap_signs_cph:
    sign_ranges_row = []
    last_sign = row[start_index]
    start_time_sign_sequence = timestamps_cph[start_index]
    for i in range(start_index, end_index):
        if row[i] != last_sign and row[i] != 0:
            sign_ranges_row.append(last_sign*(timestamps_cph[i-1] - start_time_sign_sequence))
            start_time_sign_sequence = timestamps_cph[i-1]
        if row[i] != 0:
            last_sign = row[i] 
    sign_ranges_row.append(last_sign*(timestamps_cph[i] - start_time_sign_sequence))
    sign_ranges.append(sign_ranges_row)

In [None]:
negative_range = [sum(sign_seq_range for sign_seq_range in sign_ranges_row if sign_seq_range < 0) for sign_ranges_row in sign_ranges]
positive_range = [sum(sign_seq_range for sign_seq_range in sign_ranges_row if sign_seq_range > 0) for sign_ranges_row in sign_ranges]
timestamps_range = timestamps_cph[-1] - timestamps_cph[0]

In [None]:
changed_signs_005 = (np.abs(np.array(negative_range)) >= 0.05 * timestamps_range) & (np.array(positive_range) >= 0.05 * timestamps_range)
changed_signs_01 = (np.abs(np.array(negative_range)) >= 0.1 * timestamps_range) & (np.array(positive_range) >= 0.1 * timestamps_range)
changed_signs_02 = (np.abs(np.array(negative_range)) >= 0.2 * timestamps_range) & (np.array(positive_range) >= 0.2 * timestamps_range)
csp_cph = pd.DataFrame({"variable_name": exp1_survshap_global_cph.full_result.variable_name, 
                                "variable_value": exp1_survshap_global_cph.full_result.variable_value, 
                                "index": exp1_survshap_global_cph.full_result.index, 
                                "changed_signs_0.05": changed_signs_005,
                                "changed_signs_0.1": changed_signs_01,
                                "changed_signs_0.2": changed_signs_02})

In [None]:
csp_rsf.groupby("variable_name").mean()

In [None]:
csp_cph.groupby("variable_name").mean()

##### Local accuracy

In [None]:
def get_local_accuracy_from_shap_explanations(all_explanation, label, last_index=None):
    if last_index is None:
        last_index=len(all_explanation.timestamps)
    diffs = []
    preds = []
    for explanation in all_explanation.individual_explanations:
        preds.append(explanation.predicted_function[:last_index])
        diffs.append(explanation.predicted_function[:last_index] - explanation.baseline_function[:last_index] - np.array(explanation.result.iloc[:, 6:].sum(axis=0))[:last_index])
    diffs_squared = np.array(diffs)**2
    E_diffs_sqared = np.mean(diffs_squared, axis=0)
    preds_squared = np.array(preds)**2
    E_preds_squared = np.mean(preds_squared, axis=0)
    return  pd.DataFrame({"time": all_explanation.timestamps[:last_index], "sigma": np.sqrt(E_diffs_sqared) / np.sqrt(E_preds_squared), "label": label})

In [None]:
local_accuracy_rsf = get_local_accuracy_from_shap_explanations(exp1_survshap_global_rsf, "RSF")

In [None]:
local_accuracy_cph = get_local_accuracy_from_shap_explanations(exp1_survshap_global_cph, "CPH")

In [None]:
pd.concat([local_accuracy_rsf, local_accuracy_cph]).to_csv("results/exp1_local_accuracy.csv")

##### GT-Shapley

In [None]:
shap_grountruth_rsf = shap_grountruth_rsf.sort_values(by=["observation_index", "variable_index"]) 
shap_grountruth_cph = shap_grountruth_cph.sort_values(by=["observation_index", "variable_index"]) 

In [None]:
diff_rsf = shap_grountruth_rsf.values[:, 2:] - exp1_survshap_global_rsf.full_result[exp1_survshap_global_rsf.full_result["index"].isin(shap_grountruth_rsf["observation_index"])].values[:, 6:]
diff_cph = shap_grountruth_cph.values[:, 2:] - exp1_survshap_global_cph.full_result[exp1_survshap_global_cph.full_result["index"].isin(shap_grountruth_cph["observation_index"])].values[:, 6:]

In [None]:
gt_shap_profiles_rsf = shap_grountruth_rsf.values[:, 2:]
gt_shap_profiles_cph = shap_grountruth_cph.values[:, 2:]

In [None]:
rmse_rsf = np.sqrt(np.array(np.mean(diff_rsf**2, axis=0), dtype=np.float64))
rmse_gt_shap_profiles_rsf = np.sqrt(np.array(np.mean(gt_shap_profiles_rsf**2, axis=0), dtype=np.float64))
rmse_cph = np.sqrt(np.array(np.mean(diff_cph**2, axis=0), dtype=np.float64))
rmse_gt_shap_profiles_cph = np.sqrt(np.array(np.mean(gt_shap_profiles_cph**2, axis=0), dtype=np.float64))

In [None]:
gt_comp_by_vars_rsf = np.zeros((5, 669))
for i in range(5): 
    rmse = np.sqrt(np.array(np.mean(diff_rsf[i::5,] **2, axis=0), dtype=np.float64)) 
    normalization_factor = np.sqrt(np.array(np.mean(gt_shap_profiles_rsf[i::5,]**2, axis=0), dtype=np.float64))
    gt_comp_by_vars_rsf[i,:] = rmse/normalization_factor

In [None]:
gt_comp_by_vars_cph = np.zeros((5, 1000))
for i in range(5): 
    rmse = np.sqrt(np.array(np.mean(diff_cph[i::5,] **2, axis=0), dtype=np.float64)) 
    normalization_factor = np.sqrt(np.array(np.mean(gt_shap_profiles_cph[i::5,]**2, axis=0), dtype=np.float64))
    gt_comp_by_vars_cph[i,:] = rmse/normalization_factor

In [None]:
tmp = pd.DataFrame(gt_comp_by_vars_rsf, index=["x1", "x2", "x3", "x4", "x5"], columns=timestamps_rsf).reset_index().rename(columns={"index": "variable_name"})
pd.melt(tmp, id_vars="variable_name", value_vars=tmp.columns).to_csv("results/exp1_gt_shap_rsf.csv")

In [None]:
tmp = pd.DataFrame(gt_comp_by_vars_cph, index=["x1", "x2", "x3", "x4", "x5"], columns=timestamps_cph).reset_index().rename(columns={"index": "variable_name"})
pd.melt(tmp, id_vars="variable_name", value_vars=tmp.columns).to_csv("results/exp1_gt_shap_cph.csv")

In [None]:
def GT_Shapley(all_explanation, groundtruth, label):
    corrs = []
    for i in range(1, 100):
        gt = (groundtruth.values[(i*5):(i*5+5), 2:])
        obt = (np.array(all_explanation.full_result[all_explanation.full_result["index"].isin(groundtruth["observation_index"])].values[(i*5):(i*5+5), 6:], dtype=np.float64))
        corrs.append(pd.DataFrame(gt).corrwith(pd.DataFrame(obt), axis=0))
    return  pd.DataFrame({"time": all_explanation.timestamps, "correlation": np.array(corrs).mean(axis=0), "label": label})

In [None]:
pd.concat([GT_Shapley(exp1_survshap_global_rsf, shap_grountruth_rsf, "RSF"), GT_Shapley(exp1_survshap_global_cph, shap_grountruth_cph, "CPH")]).to_csv("results/exp1_corr.csv")