In [None]:
import pandas as pd
import sys

sys.path.insert(0, '../ProtGNN')

from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pickle
import torch
from collections import Counter
import numpy as np
from sklearn.metrics import silhouette_score, silhouette_samples
from matplotlib.patches import Patch

## Visualizing predicted embedding space

In [None]:
pred_emb_path = '../Data/embeddings/gearnet_protgnn_embeds_noesm.pkl'
with open(pred_emb_path, 'rb') as f:
    pred_prot_embs = pickle.load(f)

emb_path = '../Data/embeddings/protgnn_finetuned_noesm.pkl'
with open(emb_path, 'rb') as f:
    embeddings = pickle.load(f)
prot_embs = embeddings['gene/protein']

In [None]:
gearnet_path = '../Data/embeddings/gearnet_embeds.pkl'
with open(gearnet_path, 'rb') as f:
    gearnet_embs = pickle.load(f)

In [None]:
tsne = TSNE(n_components=2, verbose=0, perplexity=40, n_iter=300)
just_prot_tsne_results = tsne.fit_transform(prot_embs)

In [None]:
tsne = TSNE(n_components=2, verbose=0, perplexity=40, n_iter=300)
gearnet_tsne_results = tsne.fit_transform(gearnet_embs)

In [None]:
combined_embs = np.vstack((prot_embs, pred_prot_embs))

tsne = TSNE(n_components=2, verbose=0, perplexity=40, n_iter=300)
combined_tsne_results = tsne.fit_transform(combined_embs)

prot_tsne_results = combined_tsne_results[:len(prot_embs)]
pred_tsne_results = combined_tsne_results[len(prot_embs):]

In [None]:
pdb2idx_file = 'helper_files/pdb2txgnnIdx.csv'
pdb2idx_df = pd.read_csv(pdb2idx_file)
pdb2idx_dict = dict(zip(pdb2idx_df['pdb_idx'],pdb2idx_df['txgnn_idx']))

def get_target_embeds(data, map_dict, dtype = 'tensor', unique = False):
    new_data = []
    if unique:
        for i in list(dict.fromkeys(map_dict.values())):
            new_data.append(data[i])
        if dtype == 'tensor':
            new_data = torch.stack(new_data, axis=0)
        elif dtype == 'numpy':
            new_data = np.stack(new_data, axis=0)
    else:
        for i in range(max(map_dict.keys())+1):
            new_data.append(data[map_dict[i]])
        if dtype == 'tensor':
            new_data = torch.stack(new_data, axis=0)
        elif dtype == 'numpy':
            new_data = np.stack(new_data, axis=0)
    
    return new_data

In [None]:
filtered_prot_embs = get_target_embeds(prot_embs, pdb2idx_dict, unique=False)

In [None]:
filtered_prot_tsne_results = get_target_embeds(prot_tsne_results, pdb2idx_dict, dtype='numpy')

In [None]:
pdb2cath_df = pd.read_csv('..Data/cath_superfamily_data.csv')
pdb2cath_dict = dict(zip(pdb2cath_df['PDB IDX'], pdb2cath_df['Superfamily Name']))
pdb2cathid_dict = dict(zip(pdb2cath_df['PDB IDX'], pdb2cath_df['Superfamily ID']))

cluster_list = [pdb2cathid_dict[i] if i in pdb2cathid_dict else 'Other' for i in range(len(gearnet_embs))]

In [None]:
consolidated_clusters = ['.'.join(cluster.split('.')[:3]) if type(cluster) != float else 'Other' for cluster in cluster_list]

In [None]:
cluster2name = {
    '2.60.40': 'Mainly Beta, Sandwich, Immunoglobin-like',
    '3.40.50': 'Alpha-Beta, 3-layer(aba) Sandwich, Rossman Fold',
    '2.30.30': 'Mainly Beta, Roll, SH3 Type Barrels',
    '3.30.70': 'Alpha-Beta, 2-layer Sandwich, Alpha-Beta Plaits',
    '3.10.20': 'Alpha-Beta, Roll, Ubiquitin-like',
    '1.10.10': 'Mainly Alpha, Orthogonal Bundle, Arc Repressor Mutant',
    '3.40.30': 'Alpha-Beta, 3-layer(aba) Sandwich, Glutaredoxin',
    '3.30.200': 'Alpha-Beta, 2-layer Sandwich, Phosphorylase Kinase',
    '2.60.120': 'Mainly Beta, Sandwich, Jelly Rolls', 
    'Other': 'Other'
}

In [None]:
cluster_counts = Counter(consolidated_clusters)
top_19_clusters = [cluster for cluster, count in cluster_counts.most_common(10) if cluster not in ['Other', np.nan]]
recluster_list = [i if i in top_19_clusters else 'Other' for i in consolidated_clusters]
recluster_list_name = [cluster2name[i] for i in recluster_list]

In [None]:
def plot(cluster_list, tsne_results, title= 'Protein Embeddings', colors = None, legend="full", other=True, num_classes=12, transform_type='t-SNE', return_colors=False):
    df = pd.DataFrame()
    df[f'{transform_type} 1'] = tsne_results[:,0]
    df[f'{transform_type} 2'] = tsne_results[:,1]
    df['cluster'] = cluster_list

    
    if not colors:
        palette = sns.color_palette("tab10", num_classes)
        colors = {cluster: color for cluster, color in zip(df['cluster'].unique(), palette)}
        colors['Other'] = (211/255.0, 211/255.0, 211/255.0, 1.0)

    plt.figure(figsize=(10,10))

    s = 15 if transform_type == 'UMAP' else 20

    sns.scatterplot(
        x=f'{transform_type} 1', y=f'{transform_type} 2',
        hue='cluster',
        palette = colors,
        data=df[df['cluster'] == 'Other'],
        legend=True,  
        alpha=0.6,
        s=10
    )

    sns.scatterplot(
        x=f'{transform_type} 1', y=f'{transform_type} 2',
        hue='cluster',
        palette = colors,
        data=df[df['cluster'] != 'Other'],
        legend="full",
        alpha=0.6,
        s=s 
    )
    if transform_type == 'UMAP':
        plt.xlim((0,20))
        plt.ylim((-2,10))
        pass
    plt.title(title)

    plt.legend(title='Structure Types', loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=2)
    plt.subplots_adjust(bottom=0.3)  
    #plt.subplots_adjust(right=0.7)
    plt.show()

    if return_colors:
        return colors

In [None]:
color_palette = plot(recluster_list_name, pred_tsne_results, num_classes = 10, title = 'Predicted Protein Embeddings from GearNet-ProtGNN colored by CAT Structural Family', return_colors=True)

In [None]:
plot(recluster_list_name, pred_tsne_results, num_classes = 10, title = 'Predicted Protein Embeddings from GearNet-ProtGNN colored by Protein Structural Family')
plot(recluster_list_name, filtered_prot_tsne_results, num_classes = 10, title = 'Target Protein Embeddings from ProtGNN colored by Protein Structural Family')


## Silhouette Score

In [None]:
def plot_silhouette(pred_tsne_results_filtered, recluster_array, color_palette, title="Silhouette analysis for Function Clusters for Protein Embeddings", range_x = [-1,1], x_ticks=[-1, -0.8, -0.6, -0.4, -0.2, 0, 0.2, 0.4, 0.6, 0.8, 1]):
    
    #filter by other
    pred_tsne_results_filtered = pred_tsne_results_filtered[recluster_array != 'Other']
    recluster_array = recluster_array[recluster_array != 'Other']
    
    # Create a subplot with 1 row and 2 columns
    fig, (ax1) = plt.subplots(1, 1)
    fig.set_size_inches(12, 8)
    ax1.set_xlim(range_x)
    ax1.set_ylim([0, len(pred_tsne_results_filtered) + (10 + 1) * 10])

    silhouette_avg = silhouette_score(pred_tsne_results_filtered, recluster_array)
    print(
        "For n_clusters =",
        10,
        "The average silhouette_score is :",
        silhouette_avg,
    )

    sample_silhouette_values = silhouette_samples(pred_tsne_results_filtered, recluster_array)

    y_lower = 10
    legend_handles = []
    silhouette_clusters = {}
    for idx, i in enumerate(list(color_palette.keys())):
        ith_cluster_silhouette_values = sample_silhouette_values[recluster_array == i]

        ith_cluster_silhouette_values.sort()

        size_cluster_i = ith_cluster_silhouette_values.shape[0]
        y_upper = y_lower + size_cluster_i
        
        silhouette_clusters[i] = np.mean(ith_cluster_silhouette_values)

        color = color_palette[i]
        ax1.fill_betweenx(
            np.arange(y_lower, y_upper),
            0,
            ith_cluster_silhouette_values,
            facecolor=color,
            edgecolor=color,
            alpha=0.7,
        )

        # Label the silhouette plots with their cluster numbers at the middle
        #ax1.text(-0.15, y_lower + 0.5 * size_cluster_i, str(i), va='center', ha='right', fontsize=10)
        #ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))

        legend_handles.append(Patch(color=color, label=str(i)))

        y_lower = y_upper + 10  

    ax1.set_xlabel("Silhouette coefficient values")
    ax1.set_ylabel("Cluster label")
    ax1.set_title(title)

    ax1.axvline(x=silhouette_avg, color="red", linestyle="--")
    #ax1.legend(handles=legend_handles[::-1], loc='lower right', bbox_to_anchor=(1, 0))

    ax1.set_yticks([])  
    ax1.set_xticks(x_ticks)

    plt.show()
    return silhouette_clusters

In [None]:
print('Latent Space: Predicted, Cluster: Structure',silhouette_score(pred_prot_embs, recluster_list))
print('Latent Space: Target, Cluster: Structure',silhouette_score(filtered_prot_embs, recluster_list))
print('Latent Space: GearNet, Cluster: Structure',silhouette_score(gearnet_embs, recluster_list))

In [None]:
silhouette_clusters_pred = plot_silhouette(pred_prot_embs, np.array(recluster_list_name), color_palette, range_x = [-0.4, 0.4], x_ticks = [-0.4, -0.2, 0, 0.2, 0.4], title='Silhouette Analysis of Predicted Protein Embeddings by Protein Structural Family')

In [None]:
silhouette_clusters_pred

In [None]:
silhouette_clusters_gearnet = plot_silhouette(filtered_prot_embs, np.array(recluster_list_name), color_palette, range_x = [-0.4, 0.5], x_ticks = [-0.4, -0.2, 0, 0.2, 0.4, 0.5], title='Silhouette Analysis of GearNet-Edge Protein Embeddings by Protein Structural Family')