In [None]:
### Please use this when running code in Google Colab.
from google.colab import drive
drive.mount('/content/drive')

In [None]:
! pip install shap==0.46.0

In [None]:
### Module ###
import pandas as pd
import matplotlib.pyplot as plt
import pickle
from tqdm.notebook import tqdm
tqdm.pandas(desc="progress: ")

from sklearn.model_selection import KFold
from sklearn.model_selection import GroupKFold
from sklearn import metrics
import shap

import warnings

# Ignore UserWarning
warnings.filterwarnings('ignore', category=UserWarning)

In [None]:
### Flatten Doubled List ###
def flatten(l: list):

    f_l = []
    for item in l:
        f_l.append(item[0])

    return f_l

### F1score ###
def f1_score(i, precision, recall):
    f1_score = 2*precision[i]*recall[i]/(precision[i]+recall[i])
    return f1_score


In [None]:
###
##### Main #####
###
def shap_value_calsulation_main(time_point: int, shap_summary_plot=False):

    dataset = load_prediction_dataset(time_point)

    ls_pt_id = dataset["Patient_ID"].unique()
    df_pt_id = pd.DataFrame(ls_pt_id, columns=["Patient_ID"])

    ### shap結合用のdf作成 ###
    columns = list(dataset.columns)
    columns.remove("Answer")

    rem_columns = columns.copy()
    rem_columns.remove("Patient_ID")

    df_shap_all_tp = pd.DataFrame(columns=columns)
    df_shap_all_fn = pd.DataFrame(columns=columns)
    df_shap_all_fp = pd.DataFrame(columns=columns)
    df_shap_all_tn = pd.DataFrame(columns=columns)

    ### 検査値結合用のdf作成 ###
    df_testval_all_tp = pd.DataFrame(columns=columns)
    df_testval_all_fn = pd.DataFrame(columns=columns)
    df_testval_all_fp = pd.DataFrame(columns=columns)
    df_testval_all_tn = pd.DataFrame(columns=columns)

    ### 5-fold cross validation ###
    kf = KFold(n_splits=5, shuffle=True, random_state=0)
    cv_index = 1

    # CV Progress bar #
    bar = tqdm(total = 5)
    bar.set_description("5-fold cross validation >>> ")

    for train_index, test_index in kf.split(df_pt_id):
        s_train = flatten(df_pt_id.iloc[train_index].values)
        s_test = flatten(df_pt_id.iloc[test_index].values)

        ### SHAP算出実行
        (df_shap_all_tp, df_shap_all_fn, df_shap_all_fp, df_shap_all_tn,
         df_testval_all_tp, df_testval_all_fn, df_testval_all_fp, df_testval_all_tn) = KFold_one(time_point, s_train, s_test, dataset,
                                                        df_shap_all_tp, df_shap_all_fn, df_shap_all_fp, df_shap_all_tn,
                                                        df_testval_all_tp, df_testval_all_fn, df_testval_all_fp, df_testval_all_tn,
                                                        columns, rem_columns, cv_index, shap_summary_plot)

        cv_index += 1
        bar.update(1)

    ### Save DataFrame ###
    test_items = ["Item_1", "Item_2", "Item_3"]

    save_df(time_point, df_shap_all_tp, "shap", "tp", test_items)
    save_df(time_point, df_shap_all_fn, "shap", "fn", test_items)
    save_df(time_point, df_shap_all_fp, "shap", "fp", test_items)
    save_df(time_point, df_shap_all_tn, "shap", "tn", test_items)

    save_df(time_point, df_testval_all_tp, "test", "tp", test_items)
    save_df(time_point, df_testval_all_fn, "test", "fn", test_items)
    save_df(time_point, df_testval_all_fp, "test", "fp", test_items)
    save_df(time_point, df_testval_all_tn, "test", "tn", test_items)

###
##### Load Datasets #####
###
def load_prediction_dataset(time_point: int):

    #df_all = pd.read_pickle("data/dummy_time_series_EHRdata.pkl")
    df_all = pd.read_csv("/content/drive/MyDrive/res_death_destiny/data/dummy_time_series_EHRdata.csv")

    ### Labeling ###
    df_death_positive = df_all.query("Time_point == @time_point").copy()
    df_death_negative = df_all.query("Time_point == 168").copy()
    df_death_positive.loc[:,"Answer"] = 1
    df_death_negative.loc[:,"Answer"] = 0

    prediction_dataset = pd.concat([df_death_positive, df_death_negative],axis=0).drop("Time_point", axis=1)

    return prediction_dataset



###
##### SHAP value calculation #####
###
def KFold_one(time_point: int, s_train: list, s_test: list, dataset: pd.DataFrame,
              df_shap_all_tp: pd.DataFrame, df_shap_all_fn: pd.DataFrame,
              df_shap_all_fp: pd.DataFrame, df_shap_all_tn: pd.DataFrame,
              df_testval_all_tp: pd.DataFrame, df_testval_all_fn: pd.DataFrame,
              df_testval_all_fp: pd.DataFrame, df_testval_all_tn: pd.DataFrame,
              columns: list, rem_columns: list, cv_index: int, shap_summary_plot: bool):

    ### Load Model ###
    model_name = "mpmodel_tp_"+str(time_point)+"_cv_"+str(cv_index)+".pkl"
    file = "/content/drive/MyDrive/res_death_destiny/data/models/"+model_name
    model = pickle.load(open(file, 'rb'))

    ### Build SHAP explainer ###
    explainer = shap.TreeExplainer(model)

    ### Dataset for calculate SHAP values: Test Dataset ###
    dataset_c = dataset.copy()
    dataset_c_test = dataset_c.query('Patient_ID in @s_test').copy()

    X_test = dataset_c_test.drop(["Patient_ID", "Answer"], axis=1).astype(float)
    y_test = dataset_c_test["Answer"].astype(float)

    ### Prediction ###
    y_pred = model.predict(X_test)

    precision, recall, thresholds = metrics.precision_recall_curve(y_test, y_pred)
    k=len(thresholds)

    df_f1_score = pd.DataFrame([f1_score(i, precision, recall) for i in range(len(thresholds))], columns=["f1_score"])
    df_f1_score.loc[:,"precision"] = precision[0:k]
    df_f1_score.loc[:,"recall"] = recall[0:k]
    df_f1_score.loc[:,"thresholds"] = thresholds

    ### Threshold which maximize F1-score ###
    thresholds_f1max = df_f1_score.sort_values(["f1_score"]).iloc[-1,-1]

    ### Confusion-matrix labels: TP and TN (FP, FN) ###
    dataset_c_test.iloc[:,:-3] = dataset_c_test.iloc[:,:-3].astype(float)
    dataset_c_test.loc[:,"Prediction_val"] = y_pred
    dataset_c_test.loc[:,"Prediction"] = 0
    dataset_c_test.loc[:,"Prediction"].mask(dataset_c_test.loc[:,"Prediction_val"] >= thresholds_f1max, 1, inplace=True)
    dataset_c_test = dataset_c_test.drop(["Prediction_val"], axis=1)

    df_tp_only = dataset_c_test.query("Answer == 1").query("Prediction == 1").drop(["Answer","Prediction"], axis=1)
    df_fn_only = dataset_c_test.query("Answer == 1").query("Prediction == 0").drop(["Answer","Prediction"], axis=1)
    df_fp_only = dataset_c_test.query("Answer == 0").query("Prediction == 1").drop(["Answer","Prediction"], axis=1)
    df_tn_only = dataset_c_test.query("Answer == 0").query("Prediction == 0").drop(["Answer","Prediction"], axis=1)

    ### SHAP value calculation ###
    if df_tp_only.shape[0] != 0:
        df_testval_all_tp, df_shap_all_tp = calc_shap_value(df_testval_all_tp, df_shap_all_tp, df_tp_only, explainer, rem_columns)

    if df_fn_only.shape[0] != 0:
        df_testval_all_fn, df_shap_all_fn = calc_shap_value(df_testval_all_fn, df_shap_all_fn, df_fn_only, explainer, rem_columns)

    if df_fp_only.shape[0] != 0:
        df_testval_all_fp, df_shap_all_fp = calc_shap_value(df_testval_all_fp, df_shap_all_fp, df_fp_only, explainer, rem_columns)

    if df_tn_only.shape[0] != 0:
        df_testval_all_tn, df_shap_all_tn = calc_shap_value(df_testval_all_tn, df_shap_all_tn, df_tn_only, explainer, rem_columns)


    ### SHAP Summary Plot (Optional) ###
    if shap_summary_plot == True and df_tp_only.shape[0] != 0:

        shap_values = explainer.shap_values(df_tp_only.iloc[:,:-1])
        fig = plt.figure(figsize=(10,10),dpi=100,tight_layout=True,facecolor='w')
        ax= fig.add_subplot(111)
        shap.summary_plot(shap_values, df_tp_only.iloc[:,:-1], plot_size=(5,5),show=False)
        plt.xlabel('SHAP value',fontsize=14)
        plt.show()

    elif shap_summary_plot == True and df_tp_only.shape[0] == 0:

        print("Display SHAP summary plot: Failed. No TP data found.")

    return (df_shap_all_tp, df_shap_all_fn, df_shap_all_fp, df_shap_all_tn,
            df_testval_all_tp, df_testval_all_fn, df_testval_all_fp, df_testval_all_tn)



###
### SHAP Value Calculation ###
def calc_shap_value(df_testval_all: pd.DataFrame, df_shap_all: pd.DataFrame,
                    df_cmlabel_only: pd.DataFrame, explainer: shap.TreeExplainer, rem_columns: list):

          df_testval_all = pd.concat([df_testval_all, df_cmlabel_only])

          ### Calcuration except for last column: Patient ID ###
          shap_values = explainer.shap_values(df_cmlabel_only.iloc[:,:-1])
          shap_values = pd.DataFrame(shap_values, columns=rem_columns)
          shap_values["Patient_ID"] = df_cmlabel_only.loc[:, "Patient_ID"].reset_index(drop=True)

          df_shap_all = pd.concat([df_shap_all, shap_values])

          return (df_testval_all, df_shap_all)



###
### Save DataFrame ###
###
def save_df(time_point: int, df: pd.DataFrame, val_type: str, cm_label: str, test_items: list):

    subdir_name = val_type+"_values_"+cm_label
    filename = val_type+"_values_"+cm_label+"_tp_"+str(time_point)+".pkl"

    df = df.reset_index(drop="True")
    df.to_pickle("/content/drive/MyDrive/res_death_destiny/data/"+val_type+"_values/"+subdir_name+"/"+filename)


In [None]:
###
### SHAP value calculation: From 1 to 90 time point ###
###
bar = tqdm(total = 90)
bar.set_description("Calculating SHAP values >>> ")

for time_point in range(1,1+1):
  shap_value_calsulation_main(time_point=time_point, shap_summary_plot=True)
  bar.update(1)
