# SurvSHAP(t): Time-Dependent Explanations Of Machine Learning Survival Models
### M. Krzyziński, M. Spytek, H. Baniecki, P. Biecek
## Experiment 3: Real-world use case: predicting survival of patients with heart failure

#### Imports

In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
from scipy.integrate import trapezoid

#### Preparing data and models 

In [None]:
real_data = pd.read_csv("data/exp3_heart_failure_dataset.csv")
real_data.rename(columns={"DEATH_EVENT": "event"}, inplace=True)

In [None]:
from sksurv.util import Surv
X = real_data[["age", "creatinine_phosphokinase", "ejection_fraction", "platelets", 
                "serum_creatinine", "serum_sodium", "sex", "smoking"]]
y = Surv.from_dataframe("event", "time", real_data)

In [None]:
from sksurv.ensemble import RandomSurvivalForest
rsf = RandomSurvivalForest(random_state=42, n_estimators=120, max_depth=8, min_samples_leaf=4, max_features=3)
rsf.fit(X, y)
rsf.score(X, y)

In [None]:
from sksurv.linear_model import CoxPHSurvivalAnalysis
cph = CoxPHSurvivalAnalysis()
cph.fit(X, y)
cph.score(X, y)

#### Evaluating performance of models 

In [None]:
from sksurv.metrics import brier_score, integrated_brier_score
# mask created to enable for calculating Brier score
mask = (y["time"] < y[y["event"]==1]["time"].max()) & (y["time"] > y[y["event"]==1]["time"].min())
times = np.unique(np.percentile(y[mask]["time"], np.linspace(0.1, 99.9, 101)))

In [None]:
survs_rsf = rsf.predict_survival_function(X[mask])
survs_cph = cph.predict_survival_function(X[mask])
preds_rsf = [fn(times) for fn in survs_rsf]
preds_cph = [fn(times) for fn in survs_cph]
brier_rsf = brier_score(y, y[mask], preds_rsf, times)
brier_cph = brier_score(y, y[mask], preds_cph, times)

In [None]:
pd.concat([pd.DataFrame({"time": times, "brier_score":  brier_rsf[1], "label": "RSF"}),
            pd.DataFrame({"time": times, "brier_score":  brier_cph[1], "label": "CPH"})]).to_csv("results/exp3_model_brier_score.csv")

In [None]:
integrated_brier_score(y, y[mask], preds_rsf, times), integrated_brier_score(y, y[mask], preds_cph, times)

#### Creating explanations

In [None]:
from survshap import SurvivalModelExplainer, ModelSurvSHAP
from survlime import SurvLIME
rsf_exp = SurvivalModelExplainer(rsf, X, y)
cph_exp = SurvivalModelExplainer(cph, X, y)

In [None]:
exp3_survshap_global_rsf = ModelSurvSHAP(random_state=42)
exp3_survshap_global_rsf.fit(rsf_exp)

In [None]:
with open("pickles/exp3_survshap_global_rsf", "wb") as f:
    pickle.dump(exp3_survshap_global_rsf, f)

In [None]:
exp3_survshap_global_cph = ModelSurvSHAP(random_state=42)
exp3_survshap_global_cph.fit(cph_exp)

In [None]:
with open("pickles/exp3_survshap_global_cph", "wb") as f:
    pickle.dump(exp3_survshap_global_cph, f)

In [None]:
n_obs = len(X)
sls = [None for i in range(n_obs)]
funcs_dist = [None for i in range(n_obs)]
for i, obs in tqdm(enumerate(X.values)):
    xx = pd.DataFrame(np.atleast_2d(obs), columns=cph_exp.data.columns)
    survlime = SurvLIME(N=1000)
    survlime.fit(cph_exp, xx, k=1)
    sls[i] = survlime

In [None]:
with open("pickles/exp3_survlime_global_cph", "wb") as f:
    pickle.dump(sls, f)

In [None]:
n_obs = len(X)
sls = [None for i in range(n_obs)]
funcs_dist = [None for i in range(n_obs)]
for i in tqdm(range(n_obs)):
    survlime = SurvLIME(N=1000)
    survlime.fit(rsf_exp, X.iloc[[i]], k=2, timestamps=rsf.predict_survival_function(X.iloc[[0]])[0].x)
    sls[i] = survlime

In [None]:
with open("pickles/exp3_survlime_global_rsf", "wb") as f:
    pickle.dump(sls, f)

#### Results analysis

In [None]:
with open("pickles/exp3_survshap_global_rsf", "rb") as f:
    exp3_survshap_global_rsf = pickle.load(f)

In [None]:
with open("pickles/exp3_survshap_global_cph", "rb") as f:
    exp3_survshap_global_cph = pickle.load(f)

In [None]:
with open("pickles/exp3_survlime_global_cph", "rb") as f:
    exp3_survlime_global_cph = pickle.load(f)

In [None]:
with open("pickles/exp3_survlime_global_rsf", "rb") as f:
    exp3_survlime_global_rsf = pickle.load(f)

##### Plot examples

In [None]:
example_rsf = exp3_survshap_global_rsf.individual_explanations[12]

In [None]:
melted_example_rsf = pd.melt(example_rsf.result, id_vars="variable_name", value_vars=example_rsf.result.columns[6:])
melted_example_rsf["variable"] = melted_example_rsf["variable"].str[4:].astype(float)
melted_example_rsf.to_csv("results/exp3_example_rsf.csv", index=False)
example_rsf.simplified_result.to_csv("results/exp3_example_rsf_agg.csv", index=False)

In [None]:
example_rsf2 = exp3_survshap_global_rsf.individual_explanations[14]
melted_example_rsf_2 = pd.melt(example_rsf2.result, id_vars="variable_name", value_vars=example_rsf2.result.columns[6:])
melted_example_rsf_2["variable"] = melted_example_rsf_2["variable"].str[4:].astype(float)
melted_example_rsf_2.to_csv("results/exp3_example_rsf_2.csv", index=False)

In [None]:
example_cph2 = exp3_survshap_global_cph.individual_explanations[14]
melted_example_cph_2 = pd.melt(example_cph2.result, id_vars="variable_name", value_vars=example_cph2.result.columns[6:])
melted_example_cph_2["variable"] = melted_example_cph_2["variable"].str[4:].astype(float)
melted_example_cph_2.to_csv("results/exp3_example_cph_2.csv", index=False)

##### Importance rankings

In [None]:
def get_orderings_and_ranks_shap(explanations):
    importance_orderings = []
    importance_ranks = []
    for explanation in explanations:
        df = explanation.result.copy()
        df["aggregated_change"] = trapezoid(np.abs(df.iloc[:, 6:].values), explanation.timestamps)
        importance_orderings.append(df.sort_values(by="aggregated_change", key=lambda x: -abs(x)).index.to_list())
        importance_ranks.append(np.abs(df.aggregated_change).rank(ascending=False).to_list())
    return pd.DataFrame(importance_orderings), pd.DataFrame(importance_ranks)

from scipy.stats import weightedtau
def mean_weighted_tau(ranks1, ranks2):
    taus = [None] * 100
    for i in range(100):
        tau, _ = weightedtau(ranks1.iloc[i], ranks2.iloc[i])
        taus[i] = tau
    return np.mean(taus), np.std(taus)

def prepare_ranking_summary_long(ordering):
    res = pd.DataFrame(columns=[0, 1, 2, 3, 4, 5, 6, 7])
    for i in range(8):
        tmp = pd.DataFrame(ordering[i].value_counts().to_dict(), index=[i+1])
        res = pd.concat([res, tmp])
    res = res.reset_index().rename(columns=dict(zip([i for i in range(8)] + ["index"], X.columns.to_list() + ["importance_ranking"])))
    return res.melt(id_vars=["importance_ranking"], value_vars=X.columns)

In [None]:
def get_orderings_and_ranks_shap(explanations):
    importance_orderings = []
    importance_ranks = []
    for explanation in explanations:
        df = explanation.result.copy()
        df["aggregated_change"] = trapezoid(np.abs(df.iloc[:, 6:].values), explanation.timestamps)
        importance_orderings.append(df.sort_values(by="aggregated_change", key=lambda x: -abs(x)).index.to_list())
        importance_ranks.append(np.abs(df.aggregated_change).rank(ascending=False).to_list())
    return pd.DataFrame(importance_orderings), pd.DataFrame(importance_ranks)

def get_orderings_and_ranks_lime(explanations):
    importance_orderings = []
    importance_ranks = []
    for explanation in explanations:
        df = explanation.result
        df["impact"] = df["variable_value"] * df["coefficient"] 
        importance_orderings.append(df.sort_values(by="impact", key=lambda x: -abs(x)).index.to_list())
        importance_ranks.append(np.abs(df.impact).rank(ascending=False).to_list())
    return pd.DataFrame(importance_orderings), pd.DataFrame(importance_ranks)

In [None]:
cph_survshap_orderings, cph_survshap_ranks = get_orderings_and_ranks_shap(exp3_survshap_global_cph.individual_explanations)
prepare_ranking_summary_long(cph_survshap_orderings).to_csv("results/exp3_survshap_orderings_cph.csv")

In [None]:
rsf_survshap_orderings, rsf_survshap_ranks = get_orderings_and_ranks_shap(exp3_survshap_global_rsf.individual_explanations)
prepare_ranking_summary_long(rsf_survshap_orderings).to_csv("results/exp3_survshap_orderings_rsf.csv")

#### Calculating permutational variable importance

In [None]:
np.unique(np.percentile(y[mask]["time"], np.linspace(0.1, 99.9, 101)))
def loss_integrated_brier_score(model, data, y):
    sfs = model.predict_survival_function(data)
    sfs = [sf(times) for sf in sfs] 
    result = integrated_brier_score(y, y, sfs, times)
    return 1-result

In [None]:
from sklearn.inspection import permutation_importance
imp_mean_rsf = permutation_importance(rsf, X, y, 
                    scoring = loss_integrated_brier_score, n_repeats=100, random_state=42)["importances_mean"]

In [None]:
pd.DataFrame(zip(X.columns, imp_mean_rsf)).sort_values(by=1, ascending=False)

In [None]:
imp_mean_cph = permutation_importance(cph, X, y, 
                    scoring = loss_integrated_brier_score, n_repeats=100, random_state=42)["importances_mean"]

In [None]:
pd.DataFrame(zip(X.columns, imp_mean_cph)).sort_values(by=1, ascending=False)

In [None]:
cph_survlime_orderings, cph_survlime_ranks = get_orderings_and_ranks_lime(exp3_survlime_global_cph)
rsf_survlime_orderings, rsf_survlime_ranks = get_orderings_and_ranks_lime(exp3_survlime_global_rsf)

In [None]:
prepare_ranking_summary_long(cph_survlime_orderings).to_csv("results/exp3_survlime_orderings_cph.csv")
prepare_ranking_summary_long(rsf_survlime_orderings).to_csv("results/exp3_survlime_orderings_rsf.csv")