In [None]:
### Please use this when running code in Google Colab.
from google.colab import drive
drive.mount('/content/drive')

In [None]:
### Module ###
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.cm import ScalarMappable

In [None]:
###
### SHAP Behavior Visualization ###
### Note: Partially modified function as used in 4_SHAP_behavior_visualization.ipynb ###
###
def Visualize_shap_behavior(subtype_val_type, subtype, cm_label="tp", figsize=(12,6), fontsize=15):

    ### Load SHAP values ###
    df = pd.read_pickle("/content/drive/MyDrive/res_death_destiny/data/shap_test_values/shap_test_values_norm_"+cm_label+".pkl")
    df = df.loc[:, ["Patient_ID", "Test_item", "SHAP_value", "Time_point"]]

    ### Subtype labeling ###
    df_label = pd.read_pickle("/content/drive/MyDrive/res_death_destiny/data/shap_test_values/patient_id_subtype_label.pkl")
    df = pd.merge(df, df_label, on="Patient_ID").drop(["Patient_ID"], axis=1)

    ### Subtype filtering ###
    df = df.rename(columns={subtype_val_type+"_subtype": "Plot_subtype"})
    df = df.query("Plot_subtype == @subtype")

    ls_test_item = df["Test_item"].unique().tolist()

    ### Colors: Please change as appropriate ###
    dict_color = {"Item_1":"r", "Item_2":"orange", "Item_3":"gold"}

    fig, ax_behavior = plt.subplots(figsize=figsize)

    for test_item in ls_test_item:

        df_one_test_item = df.query("Test_item == @test_item")

        ls_time_mean = []

        ### Calculate Mean SHAP values for each time point ###
        for time_point in range(1, 90+1):
            data_one_time = df_one_test_item.query("Time_point == @time_point")
            mean = data_one_time["SHAP_value"].mean()
            ls_time_mean.append([time_point, mean])

        df_one_test_item_behavior = pd.DataFrame(ls_time_mean, columns=["Time_point", test_item]).drop("Time_point", axis="columns")

        ### Plot Behavior ###
        ax_behavior.plot(np.arange(1,90+1), df_one_test_item_behavior[test_item].T.values,
                         linewidth=1, marker="*", markersize=8, color=dict_color[test_item], label=test_item)

    ax_behavior.legend(title="Laboratory test items", loc="upper left", bbox_to_anchor=(0.0, 1.0), fontsize=fontsize*1.1, title_fontsize=fontsize*1.2, fancybox=None, framealpha=1)

    ax_behavior.set_xlabel("Time point (Day)", fontsize=fontsize*1.3)
    ax_behavior.set_ylabel("Scaled mean SHAP value", fontsize=fontsize*1.3)

    ax_behavior.set_xticks([1]+[i for i in range(10, 90+1 ,10)])

    ax_behavior.tick_params(axis='x', labelsize=fontsize*1.1,)
    ax_behavior.tick_params(axis='y', labelsize=fontsize*1.1)
    ax_behavior.grid(color='k', linestyle=':', linewidth=1)
    ax_behavior.invert_xaxis()

    plt.show()



###
### Distribution of SHAP/Test values ###
### Note: Partially modified function as used in 4_SHAP_behavior_visualization.ipynb ###
###
def Visualize_distribution(val_type, subtype_val_type, subtype, cm_label="tp", round=3, figsize=(12,4), fontsize=15):

    ### Load values ###
    df = pd.read_pickle("/content/drive/MyDrive/res_death_destiny/data/shap_test_values/shap_test_values_norm_"+cm_label+".pkl")
    df = df.loc[:, ["Patient_ID", "Test_item", val_type, "Time_point"]]

    ### Subtype labeling ###
    df_label = pd.read_pickle("/content/drive/MyDrive/res_death_destiny/data/shap_test_values/patient_id_subtype_label.pkl")
    df = pd.merge(df, df_label, on="Patient_ID").drop(["Patient_ID"], axis=1)

    ### Subtype filtering ###
    df = df.rename(columns={subtype_val_type+"_subtype": "Plot_subtype"})
    df = df.query("Plot_subtype == @subtype")

    ls_test_item = df["Test_item"].unique().tolist()

    ### Colors: Please change as appropriate ###
    dict_color = {"Item_1":"r", "Item_2":"orange", "Item_3":"gold"}

    for test_item in ls_test_item:

        df_one_test_item = df.query("Test_item == @test_item").copy()

        ymin, ymax = df_one_test_item[val_type].min(), df_one_test_item[val_type].max()

        ### Round time points ###
        ### Tips: Set 'round' argument to group and display several time points together ###
        df_one_test_item["Group"] = (df_one_test_item["Time_point"] - 1) // round + 1
        df_one_test_item = df_one_test_item.drop(["Time_point"], axis="columns").rename(columns={"Group":"Time"})

        ### Plot distributions for each test item ###
        fig, ax_dist = plt.subplots(figsize=figsize)
        sns.boxplot(x="Time", y=val_type, ax=ax_dist, data=df_one_test_item, showfliers=False, color=dict_color[test_item])

        ax_dist.set_xlabel("Time point (×"+str(round)+"Day)", fontsize=fontsize*1.3)
        ax_dist.set_ylabel("Scaled_"+val_type, fontsize=fontsize*1.3)
        ax_dist.set_title(test_item, fontsize=fontsize*1.2)

        ax_dist.set_ylim(ymin*1.05, ymax*1.05)

        ax_dist.tick_params(axis='x', labelsize=fontsize*1.1)
        ax_dist.tick_params(axis='y', labelsize=fontsize*1.1)
        ax_dist.xaxis.set_tick_params(rotation=90)
        ax_dist.set_xticks([0]+[i for i in range(1, 90//round)])
        ax_dist.grid(color='k', linestyle=':', linewidth=1)
        ax_dist.invert_xaxis()

        plt.show()



###
### UMAP Visualization ###
### Note: Partially modified function as used in "5_Visualize_SHAP_behavior_via_UMAP.ipynb" ###
###
def Visualize_SHAP_behavior_via_UMAP(val_type, subtype_val_type, subtype, cm_label="tp", fontsize=15, cbar_labelsize=10, figsize=7, alpha=0.3, size=10):

    ### Dataset ###
    filename = "shap_test_values_all_after_UMAP_standardized_subtyped.pkl"
    df = pd.read_pickle("/content/drive/MyDrive/res_death_destiny/data/shap_test_values/"+filename)

    fig, ax = plt.subplots(figsize=(figsize,figsize))

    # Tips: The "val_type" argument is either "SHAP" or "Test"
    c_1 = val_type+"_umap1"
    c_2 = val_type+"_umap2"

    (xmin, xmax, ymin, ymax) = (df[c_1].min(), df[c_1].max(), df[c_2].min(), df[c_2].max())

    x_bks = (xmax - xmin)*0.05
    y_bks = (ymax - ymin)*0.05

    df = df.rename(columns={subtype_val_type+"_subtype": "Plot_subtype"})
    df = df.query("Plot_subtype == @subtype")
    df_cm_label = df.query("Cm_label==@cm_label")

    label = np.array(df_cm_label.loc[:,"Time_point"])
    (cbar_min, cbar_max) = (np.amin(label), np.amax(label))
    cm_name = "viridis_r"
    cmap = plt.get_cmap(cm_name,90)

    ### Scatter Plot ###
    scatter = ax.scatter(df_cm_label.loc[:,c_1], df_cm_label.loc[:,c_2], alpha=alpha, s=size, c=label, cmap=cmap)

    ax.set_title("Subtype("+subtype_val_type+"): "+str(subtype)+" (CM_label: "+str(cm_label)+")", fontsize=fontsize*1.2)
    ax.set_xlabel("umap1 ("+val_type+")", fontsize=fontsize*1.3)
    ax.set_ylabel("umap2 ("+val_type+")", fontsize=fontsize*1.3)

    ax.tick_params(axis='x', labelsize=fontsize*1.5)
    ax.tick_params(axis='y', labelsize=fontsize*1.5)

    ax.set_xlim(xmin-x_bks, xmax+x_bks)
    ax.set_ylim(ymin-y_bks, ymax+y_bks)

    ax.grid(linestyle=":", linewidth=1)

    ### Colorbar ###
    axpos = ax.get_position()
    cbar_ax = fig.add_axes([1.0, axpos.y0, 0.05, axpos.height])
    norm = colors.Normalize(vmin=np.nanmin(label),vmax=np.nanmax(label))
    mappable = ScalarMappable(cmap=cmap,norm=norm)
    mappable.set_clim(cbar_min,cbar_max)
    mappable._A = []
    pp = fig.colorbar(mappable, cax=cbar_ax)

    pp.set_label(label="Time_point", size=fontsize*1.2)

    pp.ax.tick_params(labelsize=cbar_labelsize*2.0)

    plt.show()


In [None]:
### Subtype Visualization ###

### General: Each argument  controlls as follows;
### "val_type": Input data type. Please select "SHAP" or "Test" (For Visualize_distribution, "SHAP_value" or "Test_value").
### "subtype_val_type": Type of subtype label to be applied for filtering. Please select "SHAP" or "Test".
### "subtype": Subtype to be visualized.

### Visualize UMAP for each subtype identified by SHAP values ###
Visualize_SHAP_behavior_via_UMAP(val_type="SHAP", subtype_val_type="SHAP", subtype=1, cm_label="tp")
Visualize_SHAP_behavior_via_UMAP(val_type="SHAP", subtype_val_type="SHAP", subtype=2, cm_label="tp")

### Visualize UMAP for each subtype identified by Test values ###
Visualize_SHAP_behavior_via_UMAP(val_type="SHAP", subtype_val_type="Test", subtype=1, cm_label="tp")
Visualize_SHAP_behavior_via_UMAP(val_type="SHAP", subtype_val_type="Test", subtype=2, cm_label="tp")

### Visualize SHAP behaviors for each subtype identified by SHAP values ###
Visualize_shap_behavior(subtype_val_type="SHAP", subtype=1, cm_label="tp")
Visualize_shap_behavior(subtype_val_type="SHAP", subtype=2, cm_label="tp")

### Visualize SHAP value distribution for each subtype identified by SHAP values ###
Visualize_distribution(val_type="SHAP_value", subtype_val_type="SHAP", subtype=1, cm_label="tp")
Visualize_distribution(val_type="SHAP_value", subtype_val_type="SHAP", subtype=2, cm_label="tp")