In [None]:
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pickle
import pandas as pd
import torch
from txgnn import TxData
from umap import UMAP
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score, silhouette_samples
from matplotlib.patches import Patch
from visualize_utils import visualize_pipeline, plot

In [None]:
pred_emb_path = '../Data/embeddings/gearnet_protgnn_embeds_noesm.pkl'
with open(pred_emb_path, 'rb') as f:
    pred_prot_embs = pickle.load(f)

emb_path = '../Data/embeddings/protgnn_finetuned_noesm.pkl'
with open(emb_path, 'rb') as f:
    embeddings = pickle.load(f)
prot_embs = embeddings['gene/protein']

In [None]:
tsne = TSNE(n_components=2, verbose=0, perplexity=40, n_iter=300)
prot_tsne_results = tsne.fit_transform(prot_embs)

tsne = TSNE(n_components=2, verbose=0, perplexity=40, n_iter=300)
pred_tsne_results = tsne.fit_transform(pred_prot_embs)

In [None]:
combined_embs = np.vstack((prot_embs, pred_prot_embs))

tsne = TSNE(n_components=2, verbose=0, perplexity=40, n_iter=300)
combined_tsne_results = tsne.fit_transform(combined_embs)

prot_tsne_results = combined_tsne_results[:len(prot_embs)]
pred_tsne_results = combined_tsne_results[len(prot_embs):]

In [None]:
umap = UMAP(n_components=2)
prot_umap_results = umap.fit_transform(prot_embs)
pred_umap_results = umap.transform(pred_prot_embs)


In [None]:
pdb2idx_file = 'helper_files/pdb2txgnnIdx.csv'
pdb2idx_df = pd.read_csv(pdb2idx_file)
pdb2idx_dict = dict(zip(pdb2idx_df['pdb_idx'],pdb2idx_df['txgnn_idx']))

In [None]:
def get_target_embeds(data, map_dict, dtype = 'tensor', unique = False):
    new_data = []
    if unique:
        for i in list(dict.fromkeys(map_dict.values())):
            new_data.append(data[i])
        if dtype == 'tensor':
            new_data = torch.stack(new_data, axis=0)
        elif dtype == 'numpy':
            new_data = np.stack(new_data, axis=0)
    else:
        for i in range(max(map_dict.keys())+1):
            new_data.append(data[map_dict[i]])
        if dtype == 'tensor':
            new_data = torch.stack(new_data, axis=0)
        elif dtype == 'numpy':
            new_data = np.stack(new_data, axis=0)
    
    return new_data

In [None]:
filtered_prot_embs = get_target_embeds(prot_embs, pdb2idx_dict, unique=True)

In [None]:
tsne = TSNE(n_components=2, verbose=0, perplexity=40, n_iter=300)
prefiltered_prot_tsne_results = tsne.fit_transform(filtered_prot_embs)

In [None]:
unique_prot_tsne_results = get_target_embeds(prot_tsne_results, pdb2idx_dict, dtype='numpy', unique = True)
unique_prot_umap_results = get_target_embeds(prot_umap_results, pdb2idx_dict, dtype='numpy', unique = True)
unique_prot_pca_results = get_target_embeds(prot_pca_results, pdb2idx_dict, dtype='numpy', unique = True)


In [None]:
mf_embed_path = '../Data/embeddings/finetune_MF_emb.pkl'
bp_embed_path = '../Data/embeddings/finetune_BP_emb.pkl'
mf_bp_embed_path = '../Data/embeddings/finetune_MF_BP_emb.pkl'

TxData_inst = TxData(data_folder_path = '../Data/PrimeKG/')
TxData_inst.prepare_split(split = 'random', seed = 42, no_kg = False)

In [None]:
node_type = 'biological_process'
prot_assignments  = visualize_pipeline(embed_path=mf_bp_embed_path, node_type = 'biological_process', TxData_inst=TxData_inst, kmeans=True, filter=None, return_clusters=True)
color_palette = plot(prot_assignments, prot_tsne_results, num_classes=len(set(prot_assignments)), title='Protein Embeddings: Clustered by Secondary Level Biological Processes', return_colors = True)

In [None]:
new_prot_assignments = get_target_embeds(prot_assignments, pdb2idx_dict, dtype='list')
unique_prot_assignments = get_target_embeds(prot_assignments, pdb2idx_dict, dtype='list', unique = True)

plot(unique_prot_assignments, unique_prot_tsne_results, num_classes=len(set(new_prot_assignments)), title='Filtered Target Protein Embeddings from ProtGNN: Biological Process', colors = color_palette)
plot(new_prot_assignments, pred_tsne_results, num_classes=len(set(new_prot_assignments)), title = 'Predicted Protein Embeddings from GearNet-ProtGNN: Biological Process', colors = color_palette)

In [None]:
color_palette = plot(prot_assignments, prot_tsne_results, num_classes=len(set(prot_assignments)), title='Protein Embeddings: Clustered by Secondary Level Biological Processes', return_colors = True)

In [None]:
plot(unique_prot_assignments, unique_prot_tsne_results, num_classes=len(set(new_prot_assignments)), title='Filtered Target Protein Embeddings from ProtGNN: Biological Process', colors = color_palette)
plot(new_prot_assignments, pred_tsne_results, num_classes=len(set(new_prot_assignments)), title = 'Predicted Protein Embeddings from GearNet-ProtGNN: Biological Process', colors = color_palette)

In [None]:
node_type = 'molecular_function'
prot_assignments_mf  = visualize_pipeline(embed_path=mf_bp_embed_path, node_type = node_type, TxData_inst=TxData_inst, kmeans=True, filter=None, return_clusters=True)

In [None]:
new_prot_assignments_mf = get_target_embeds(prot_assignments_mf, pdb2idx_dict, dtype='list')
unique_prot_assignments_mf = get_target_embeds(prot_assignments_mf, pdb2idx_dict, dtype='list', unique=True)

In [None]:
color_palette_mf = plot(prot_assignments_mf, prot_tsne_results, num_classes=len(set(prot_assignments_mf)), title='Protein Embeddings: Clustered by Secondary Level Molecular Functions', return_colors = True)

In [None]:
plot(unique_prot_assignments_mf, unique_prot_tsne_results, num_classes=len(set(new_prot_assignments_mf)), colors = color_palette, title='Filtered Target Protein Embeddings from ProtGNN: Molecular Function')
plot(new_prot_assignments_mf, pred_tsne_results, num_classes=len(set(new_prot_assignments_mf)), colors = color_palette, title = 'Predicted Protein Embeddings from GearNet-ProtGNN: Molecular Function')

### Silouette analysis


In [None]:
def plot_silhouette(pred_tsne_results_filtered, recluster_array, color_palette, title="Silhouette analysis for Function Clusters for Protein Embeddings", range_x = [-1,1], x_ticks=[-1, -0.8, -0.6, -0.4, -0.2, 0, 0.2, 0.4, 0.6, 0.8, 1]):
    fig, (ax1) = plt.subplots(1, 1)
    fig.set_size_inches(12, 8)
    ax1.set_xlim(range_x)
    ax1.set_ylim([0, len(pred_tsne_results_filtered) + (10 + 1) * 10])
    silhouette_avg = silhouette_score(pred_tsne_results_filtered, recluster_array)
    print(
        "For n_clusters =",
        10,
        "The average silhouette_score is :",
        silhouette_avg,
    )

    # Compute the silhouette scores for each sample
    sample_silhouette_values = silhouette_samples(pred_tsne_results_filtered, recluster_array)

    y_lower = 10
    legend_handles = []
    silhouette_clusters = {}
    for idx, i in enumerate(list(color_palette.keys())):
        # Aggregate the silhouette scores for samples belonging to
        # cluster i, and sort them
        ith_cluster_silhouette_values = sample_silhouette_values[recluster_array == i]

        ith_cluster_silhouette_values.sort()

        size_cluster_i = ith_cluster_silhouette_values.shape[0]
        y_upper = y_lower + size_cluster_i
        
        silhouette_clusters[i] = np.mean(ith_cluster_silhouette_values)

        color = color_palette[i]
        ax1.fill_betweenx(
            np.arange(y_lower, y_upper),
            0,
            ith_cluster_silhouette_values,
            facecolor=color,
            edgecolor=color,
            alpha=0.7,
        )


        legend_handles.append(Patch(color=color, label=str(i)))

        # Compute the new y_lower for next plot
        y_lower = y_upper + 10  # 10 for the 0 samples

    ax1.set_xlabel("Silhouette coefficient values")
    ax1.set_ylabel("Cluster label")
    ax1.set_title(title)

    # The vertical line for average silhouette score of all the values
    ax1.axvline(x=silhouette_avg, color="red", linestyle="--")
    ax1.legend(handles=legend_handles[::-1], loc='lower right', bbox_to_anchor=(1, 0))

    ax1.set_yticks([])  # Clear the yaxis labels / ticks
    ax1.set_xticks(x_ticks)

    plt.show()
    return silhouette_clusters

In [None]:
silhouette_clusters_bp = plot_silhouette(prot_embs, np.array(prot_assignments), color_palette, range_x = [-0.4, 0.5], x_ticks = [-0.4, -0.2, 0, 0.2, 0.4, 0.5], title='Silhouette Analysis of ProtGNN Protein Embeddings for Biological Processes Clusters')

In [None]:
silhouette_clusters_mf = plot_silhouette(prot_embs, np.array(prot_assignments_mf), color_palette_mf, range_x = [-0.3, 0.3], x_ticks = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3], title='Silhouette Analysis of ProtGNN Protein Embeddings for Molecular Function Clusters')

In [None]:
silhouette_clusters_bp

In [None]:
silhouette_clusters_mf

In [None]:
print('Latent Space: Predicted, Cluster: BP',silhouette_score(pred_prot_embs, new_prot_assignments))
print('Latent Space: Target, Cluster: BP',silhouette_score(filtered_prot_embs, unique_prot_assignments))
print('Latent Space: GearNet, Cluster: BP',silhouette_score(gearnet_embs, new_prot_assignments))
print('Latent Space: Predicted, Cluster: MF',silhouette_score(pred_prot_embs, new_prot_assignments_mf))
print('Latent Space: Target, Cluster: MF',silhouette_score(filtered_prot_embs, unique_prot_assignments_mf))
print('Latent Space: GearNet, Cluster: MF',silhouette_score(gearnet_embs, new_prot_assignments_mf))

In [None]:
silhouette_clusters_mf_filtered = plot_silhouette(filtered_prot_embs, np.array(unique_prot_assignments_mf), color_palette_mf, range_x = [-0.2, 0.2], x_ticks = [-0.2, 0, 0.2], title='Silhouette Analysis of ProtGNN Protein Embeddings for Molecular Function Clusters')

In [None]:
silhouette_clusters_mf_filtered