In [None]:
### Please use this when running code in Google Colab.
from google.colab import drive
drive.mount('/content/drive')

In [None]:
### Module ###
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
###
### SHAP Behavior Visualization ###
###
def Visualize_shap_behavior(cm_label="tp", figsize=(12,6), fontsize=15):

    ### Load SHAP values ###
    df = pd.read_pickle("/content/drive/MyDrive/res_death_destiny/data/shap_test_values/shap_test_values_norm_"+cm_label+".pkl")
    df = df.loc[:, ["Test_item","SHAP_value","Time_point"]]

    ls_test_item = df["Test_item"].unique().tolist()

    ### Colors: Please change as appropriate ###
    dict_color = {"Item_1":"r", "Item_2":"orange", "Item_3":"gold"}

    fig, ax_behavior = plt.subplots(figsize=figsize)

    for test_item in ls_test_item:

        df_one_test_item = df.query("Test_item == @test_item")

        ls_time_mean = []

        ### Calculate Mean SHAP values for each time point ###
        for time_point in range(1, 90+1):
            data_one_time = df_one_test_item.query("Time_point == @time_point")
            mean = data_one_time["SHAP_value"].mean()
            ls_time_mean.append([time_point, mean])

        df_one_test_item_behavior = pd.DataFrame(ls_time_mean, columns=["Time_point", test_item]).drop("Time_point", axis="columns")

        ### Plot Behavior ###
        ax_behavior.plot(np.arange(1,90+1), df_one_test_item_behavior[test_item].T.values,
                         linewidth=1, marker="*", markersize=8, color=dict_color[test_item], label=test_item)

    ax_behavior.legend(title="Laboratory test items", loc="upper left", bbox_to_anchor=(0.0, 1.0), fontsize=fontsize*1.1, title_fontsize=fontsize*1.2, fancybox=None, framealpha=1)

    ax_behavior.set_xlabel("Time point (Day)", fontsize=fontsize*1.3)
    ax_behavior.set_ylabel("Scaled mean SHAP value", fontsize=fontsize*1.3)

    ax_behavior.set_xticks([1]+[i for i in range(10, 90+1 ,10)])

    ax_behavior.tick_params(axis='x', labelsize=fontsize*1.1,)
    ax_behavior.tick_params(axis='y', labelsize=fontsize*1.1)
    ax_behavior.grid(color='k', linestyle=':', linewidth=1)
    ax_behavior.invert_xaxis()

    plt.show()



###
### Distribution of SHAP/Test values ###
###
def Visualize_distribution(val_type, cm_label="tp", round=3, figsize=(12,4), fontsize=15):

    ### Load values ###
    df = pd.read_pickle("/content/drive/MyDrive/res_death_destiny/data/shap_test_values/shap_test_values_norm_"+cm_label+".pkl")
    df = df.loc[:, ["Test_item", val_type, "Time_point"]]

    ls_test_item = df["Test_item"].unique().tolist()

    ### Colors: Please change as appropriate ###
    dict_color = {"Item_1":"r", "Item_2":"orange", "Item_3":"gold"}

    for test_item in ls_test_item:

        df_one_test_item = df.query("Test_item == @test_item").copy()

        ymin, ymax = df_one_test_item[val_type].min(), df_one_test_item[val_type].max()

        ### Round time points ###
        ### Tips: Set 'round' argument to group and display several time points together ###
        df_one_test_item["Group"] = (df_one_test_item["Time_point"] - 1) // round + 1
        df_one_test_item = df_one_test_item.drop(["Time_point"], axis="columns").rename(columns={"Group":"Time"})

        ### Plot distributions for each test item ###
        fig, ax_dist = plt.subplots(figsize=figsize)
        sns.boxplot(x="Time", y=val_type, ax=ax_dist, data=df_one_test_item, showfliers=False, color=dict_color[test_item])

        ax_dist.set_xlabel("Time point (×"+str(round)+"Day)", fontsize=fontsize*1.3)
        if val_type == "SHAP_value":
            ax_dist.set_ylabel("Scaled SHAP value", fontsize=fontsize*1.3)
        elif val_type == "Test_value":
            ax_dist.set_ylabel("Test value", fontsize=fontsize*1.3)
        ax_dist.set_title(test_item, fontsize=fontsize*1.2)

        ax_dist.set_ylim(ymin*1.05, ymax*1.05)

        ax_dist.tick_params(axis='x', labelsize=fontsize*1.1)
        ax_dist.tick_params(axis='y', labelsize=fontsize*1.1)
        ax_dist.xaxis.set_tick_params(rotation=90)
        ax_dist.set_xticks([0]+[i for i in range(1, 90//round)])
        ax_dist.grid(color='k', linestyle=':', linewidth=1)
        ax_dist.invert_xaxis()

        plt.show()


In [None]:
###
### SHAP Behavior Visualization ###
###
Visualize_shap_behavior(cm_label="tp")

In [None]:
###
### Distribution of SHAP/Test values ###
###
Visualize_distribution(val_type="SHAP_value", cm_label="tp", round=3)
Visualize_distribution(val_type="Test_value", cm_label="tp", round=3)