In [None]:
### Please use this when running code in Google Colab.
from google.colab import drive
drive.mount('/content/drive')

In [None]:
### Module ###
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster, set_link_color_palette
import matplotlib.cm as cm

In [None]:
###
### Run all ###
###
### Threshold; Please set appropriate thresholds from following dendrograms ###
def Subtype_classification_main(threshold_shap=10, threshold_testval=10):

    ### Clustering using time point 1 values of TP data ###
    X_header, X_c, X_shap, X_testval, Z_shap, Z_testval = Clustering_shap_testval()

    ### Dendrogram ###
    Dendrogram(Z_shap, val_type="SHAP value", threshold=threshold_shap, fontsize=20)
    Dendrogram(Z_testval, val_type="Test value", threshold=threshold_testval, fontsize=20)

    ### Subtype labeling ###
    id_subtype_label, df_labeled = Subtype_labeling(Z_shap, Z_testval, threshold_shap, threshold_testval,
                                        X_header, X_c, X_shap, X_testval)

    ### Check outputs ###
    print(id_subtype_label)
    print(df_labeled)



###
### Clustering using time point 1 values of TP data ###
###
def Clustering_shap_testval():

    ### Dataset: SHAP values & standardised test values ###
    filename = "shap_test_values_all_after_UMAP_standardized.pkl"
    df = pd.read_pickle("/content/drive/MyDrive/res_death_destiny/data/shap_test_values/"+filename)

    shap_cols = [element for element in df.columns if "shap" in element]
    testval_cols = [element for element in df.columns if "test" in element]

    X_header = df.query("Time_point == 1 and Cm_label == 'tp'").loc[:,["Time_point", "Cm_label", "Patient_ID"]]
    X_c = df.query("Time_point == 1 and Cm_label == 'tp'").loc[:,["SHAP_umap1", "SHAP_umap2", "Test_umap1", "Test_umap2"]]
    X_shap = df.query("Time_point == 1 and Cm_label == 'tp'").loc[:,shap_cols]
    X_testval = df.query("Time_point == 1 and Cm_label == 'tp'").loc[:,testval_cols]
    X_shap_array = np.array(X_shap)
    X_testval_array = np.array(X_testval)

    ### Hierarchical clustering ###
    Z_shap = linkage(X_shap_array, metric="euclidean", method="ward")
    Z_testval = linkage(X_testval_array, metric="euclidean", method="ward")

    return (X_header, X_c, X_shap, X_testval, Z_shap, Z_testval)



###
### Dendrogram ###
###
def Dendrogram(Z, val_type, threshold, fontsize=20):

    fig, ax = plt.subplots(figsize=(14,8))
    set_link_color_palette(["red", "blue"])
    dend = dendrogram(Z, color_threshold=threshold, ax=ax, above_threshold_color="grey")

    ax.axhline(y=threshold, color='red', linestyle="--", linewidth=1)
    ax.set_ylabel("Distance", fontsize=fontsize)
    ax.tick_params(axis="x", labelsize=fontsize*0, labelcolor="white")
    ax.tick_params(axis="y", labelsize=fontsize)
    ax.set_title("Dendrogram ("+val_type+")", fontsize=fontsize)

    plt.show()



###
### Subtype labeling ###
###
def Subtype_labeling(Z_shap, Z_testval, threshold_shap, threshold_testval,
                     X_header, X_c, X_shap, X_testval):

    ### Generate labels ###
    subtype_shap = fcluster(Z_shap, t=threshold_shap, criterion="distance")
    subtype_testval = fcluster(Z_testval, t=threshold_testval, criterion="distance")

    id_subtype_label = pd.DataFrame(X_header.loc[:,"Patient_ID"], columns=["Patient_ID"])
    id_subtype_label["SHAP_subtype"] = subtype_shap
    id_subtype_label["Test_subtype"] = subtype_testval

    ### Labeling ###
    filename = "shap_test_values_all_after_UMAP_standardized.pkl"
    df = pd.read_pickle("/content/drive/MyDrive/res_death_destiny/data/shap_test_values/"+filename)
    df_labeled = pd.merge(df, id_subtype_label, on="Patient_ID")

    ### Save DataFrame ###
    filename_id_label = "patient_id_subtype_label.pkl"
    id_subtype_label.to_pickle("/content/drive/MyDrive/res_death_destiny/data/shap_test_values/"+filename_id_label)

    filename_labeled_df = "shap_test_values_all_after_UMAP_standardized_subtyped.pkl"
    df_labeled.to_pickle("/content/drive/MyDrive/res_death_destiny/data/shap_test_values/"+filename_labeled_df)

    return (id_subtype_label, df_labeled)


In [None]:
###
### Run all ###
###
Subtype_classification_main(threshold_shap=40, threshold_testval=15)