In [1]:
import os
import pickle
import numpy as np
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn import preprocessing
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
#import umap
import umap.umap_ as umap
import torch
import pdb
from IPython.display import display

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
# matplotlib inline

# Parameter Setting

In [2]:
sns.set(style='white', context='notebook', rc={'figure.figsize':(14,10)})

In [3]:
theme = "embeddings"   #"embeddings_with_align"
root_embedding_dir = "/share/liyu/RNA/Data/PDB_RNA-FM_cluster-experiment/{}".format(theme)

pretrained_embed_file = os.path.join(root_embedding_dir, "rna-fm-pretrained_embedding", "representations-collection-20000-0-left13829.npy")
random_embed_file = os.path.join(root_embedding_dir, "rna-fm-random-exist_embedding", "representations-collection-20000-0-left13829.npy")
onehot_embed_file = os.path.join(root_embedding_dir, "rna-fm-onehot_embedding", "representations-collection-20000-0-left13829.npy")

# 1. load sequence distribution in a cluster
cluster_ann_pickle = os.path.join(root_embedding_dir, "..", "cluster.pickle")
with open(cluster_ann_pickle, 'rb') as f:
    cluster_dict = pickle.load(f)
#print(cluster_dict)

cluster_stats = []
for key in cluster_dict.keys():
    cluster_stats.append([key, len(cluster_dict[key])])
cluster_df=pd.DataFrame(cluster_stats, columns=["cluster_id", "number"])
cluster_df=cluster_df.sort_values(by="number", ascending=False)
max_clusters = 2
cluster_df = cluster_df.iloc[:max_clusters,:]
print("We only choose the first {} large cluster".format(max_clusters))
display(cluster_df)

instance_stats = []
for cluster_name in cluster_df["cluster_id"].values:
    for seq_name in cluster_dict[cluster_name]:
        instance_stats.append([seq_name, cluster_name])
instance_df = pd.DataFrame(instance_stats, columns=["seq_name", "cluster_name"])
display(instance_df)


file_list = {
    "RNA-FM-pretrained": pretrained_embed_file,
    "RNA-FM-random": random_embed_file,
    "RNA-FM-onehot": onehot_embed_file
}

save_dir = os.path.join(root_embedding_dir, "figure")
if os.path.exists(save_dir) != True:
    os.makedirs(save_dir)

#a = np.load(pretrained_embed_file, allow_pickle=True).item()
#a['3ja1_LB'].mean(axis=0).shape

We only choose the first 2 large cluster


Unnamed: 0,cluster_id,number
215,215,633
76,76,425


Unnamed: 0,seq_name,cluster_name
0,4v9q_DV,215
1,4v9q_BV,215
2,4v9q_DW,215
3,4v9q_BW,215
4,6cfj_1x,215
...,...,...
1053,4v5n_BB,76
1054,4v5m_BB,76
1055,5a9z_AB,76
1056,5aa0_AB,76


# Load Data & Create Dataframe including embeddings and labels

In [4]:
# 1. load embedding and convert them into Dataframe
df_dict = {}
for key in file_list.keys():
    embedding_file = file_list[key]
    embedding_dict = np.load(embedding_file, allow_pickle=True).item()
    
    embeddings = []
    clusters = []
    
    for index, row in instance_df.iterrows():
        #print(index)
        seq_name = row["seq_name"]
        cluster_name = row["cluster_name"]
        try:
            seq_embedding = embedding_dict[seq_name]
            embeddings.append(seq_embedding)
            clusters.append(cluster_name)
        except:
            continue
                
    df = pd.DataFrame(embeddings)
    df["cluster"] = clusters
    
#     # join with clan information
#     df = pd.merge(df, family_clan_df, left_on="family", right_on="Accession", how="inner")    
#     #display(df.groupby("Clan")["family"].value_counts())
    
#     # filter non-gene
#     df = df[df["Type"].str.contains("Gene")]
#     df["Family"] = df["family"] + "(" + df["ID"] + ")"
    
    df = df.dropna()
    df_dict[key] = df
    display(key)
    display(df)

'RNA-FM-pretrained'

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,631,632,633,634,635,636,637,638,639,cluster
0,-0.016244,-0.119352,0.166237,0.182518,-0.160660,-0.043056,-0.061392,0.008281,-0.288199,0.147157,...,0.049340,-0.284668,-0.400485,0.267467,-0.082507,0.121810,-0.037574,0.477329,-0.024944,215
1,-0.016244,-0.119352,0.166237,0.182518,-0.160660,-0.043056,-0.061392,0.008281,-0.288199,0.147157,...,0.049340,-0.284668,-0.400485,0.267467,-0.082507,0.121810,-0.037574,0.477329,-0.024944,215
2,-0.016244,-0.119352,0.166237,0.182518,-0.160660,-0.043056,-0.061392,0.008281,-0.288199,0.147157,...,0.049340,-0.284668,-0.400485,0.267467,-0.082507,0.121810,-0.037574,0.477329,-0.024944,215
3,-0.016244,-0.119352,0.166237,0.182518,-0.160660,-0.043056,-0.061392,0.008281,-0.288199,0.147157,...,0.049340,-0.284668,-0.400485,0.267467,-0.082507,0.121810,-0.037574,0.477329,-0.024944,215
4,-0.016244,-0.119352,0.166237,0.182518,-0.160660,-0.043056,-0.061392,0.008281,-0.288199,0.147157,...,0.049340,-0.284668,-0.400485,0.267467,-0.082507,0.121810,-0.037574,0.477329,-0.024944,215
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1052,-0.013484,-0.124481,0.293459,0.032662,-0.381013,0.031851,0.082504,-0.030405,-0.117953,-0.284297,...,-0.089157,-0.126622,-0.128834,-0.053730,0.055334,-0.048758,-0.146522,0.092958,-0.409054,76
1053,-0.009093,-0.120671,0.305129,0.030201,-0.381458,0.034400,0.103066,-0.020420,-0.116489,-0.276178,...,-0.074612,-0.132630,-0.112402,-0.060486,0.071479,-0.029586,-0.160617,0.103573,-0.418223,76
1054,-0.009093,-0.120671,0.305129,0.030201,-0.381458,0.034400,0.103066,-0.020420,-0.116489,-0.276178,...,-0.074612,-0.132630,-0.112402,-0.060486,0.071479,-0.029586,-0.160617,0.103573,-0.418223,76
1055,-0.013484,-0.124481,0.293459,0.032662,-0.381013,0.031851,0.082504,-0.030405,-0.117953,-0.284297,...,-0.089157,-0.126622,-0.128834,-0.053730,0.055334,-0.048758,-0.146522,0.092958,-0.409054,76


'RNA-FM-random'

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,631,632,633,634,635,636,637,638,639,cluster
0,0.572214,-1.214358,-0.188109,0.222100,-1.606857,0.443725,-0.518534,1.020193,-0.527351,-0.200326,...,-1.563706,0.034105,-0.513810,0.397318,-0.248958,-0.299501,0.166479,1.099971,0.846392,215
1,0.572214,-1.214358,-0.188109,0.222100,-1.606857,0.443725,-0.518534,1.020193,-0.527351,-0.200326,...,-1.563706,0.034105,-0.513810,0.397318,-0.248958,-0.299501,0.166479,1.099971,0.846392,215
2,0.572214,-1.214358,-0.188109,0.222100,-1.606857,0.443725,-0.518534,1.020193,-0.527351,-0.200326,...,-1.563706,0.034105,-0.513810,0.397318,-0.248958,-0.299501,0.166479,1.099971,0.846392,215
3,0.572214,-1.214358,-0.188109,0.222100,-1.606857,0.443725,-0.518534,1.020193,-0.527351,-0.200326,...,-1.563706,0.034105,-0.513810,0.397318,-0.248958,-0.299501,0.166479,1.099971,0.846392,215
4,0.572214,-1.214358,-0.188109,0.222100,-1.606857,0.443725,-0.518534,1.020193,-0.527351,-0.200326,...,-1.563706,0.034105,-0.513810,0.397318,-0.248958,-0.299501,0.166479,1.099971,0.846392,215
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1052,0.313894,-1.665140,-0.384488,0.187454,-1.616546,0.198687,-0.616450,0.774985,-0.539722,0.076037,...,-1.657815,0.132780,-0.438964,0.260487,-0.366507,-0.129856,-0.147658,1.015535,0.920925,76
1053,0.288179,-1.667371,-0.419578,0.199051,-1.622811,0.177805,-0.629422,0.759017,-0.555982,0.109812,...,-1.657301,0.150668,-0.414556,0.266523,-0.413512,-0.140830,-0.136770,1.017482,0.917709,76
1054,0.288179,-1.667371,-0.419578,0.199051,-1.622811,0.177805,-0.629422,0.759017,-0.555982,0.109812,...,-1.657301,0.150668,-0.414556,0.266523,-0.413512,-0.140830,-0.136770,1.017482,0.917709,76
1055,0.313894,-1.665140,-0.384488,0.187454,-1.616546,0.198687,-0.616450,0.774985,-0.539722,0.076037,...,-1.657815,0.132780,-0.438964,0.260487,-0.366507,-0.129856,-0.147658,1.015535,0.920925,76


'RNA-FM-onehot'

Unnamed: 0,0,1,2,3,cluster
0,0.181818,0.155844,0.337662,0.324675,215
1,0.181818,0.155844,0.337662,0.324675,215
2,0.181818,0.155844,0.337662,0.324675,215
3,0.181818,0.155844,0.337662,0.324675,215
4,0.181818,0.155844,0.337662,0.324675,215
...,...,...,...,...,...
1052,0.195122,0.138211,0.308943,0.357724,76
1053,0.196721,0.131148,0.311475,0.360656,76
1054,0.196721,0.131148,0.311475,0.360656,76
1055,0.195122,0.138211,0.308943,0.357724,76


# Statistics

In [5]:
# family_df = df["Family"]
# family_count_df = family_df.value_counts()
# display(family_count_df)

# Subsample

In [6]:
# selected_family_df = family_count_df[family_count_df.iloc[:]>500]
# print("selected_familes:")
# display(selected_family_df)
# selected_family_names = list(selected_family_df.index)
# #print(selected_family_names)

# print("selected_clans:")
# selected_clan_names = ["CL00002", "CL00011"]
# selected_clan_df = df[df["Clan"].isin(selected_clan_names)].groupby("Clan")["Family"].value_counts()
# display(selected_clan_df)


# for key in df_dict.keys():
#     df = df_dict[key]    
#     #sub_df = df[(df["Family"].isin(selected_family_names))]
#     #sub_df = df[(df["Clan"].isin(selected_clan_names))]
#     sub_df = df[(df["Family"].isin(selected_family_names)) | (df["Clan"].isin(selected_clan_names))]    
#     df_dict[key] = sub_df
    
#     #print(key)
#     #display(sub_df["Family"].value_counts())
#     #display(sub_df["Clan"].value_counts())
#     print(key, sub_df.shape)

# Plot UMAP

# Generate UMAP embedding for visualization

In [None]:
# UMAP parameter
n_neighbors = 100 #50 
min_dist = 0.5 #0.2
n_components=2
metric='euclidean'

random_state = 10 #2022

visual_df_dict = {}
for key in df_dict.keys():
    df = df_dict[key]  
    embeddings = df.iloc[:,0:df.shape[1]-1].values
    print(embeddings.shape)
    hue_labels = df.loc[:, "cluster"].values    
    style_labels = None #df.loc[:, "Clan"].values    
    
    # normalization
    min_max_scaler = preprocessing.MinMaxScaler(feature_range=(0,1))
    scaled_embeddings = min_max_scaler.fit_transform(embeddings)
    scaled_embeddings.shape
    
    # umap (2D)  dimension reduction
    fit = umap.UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        n_components=n_components,
        metric=metric,
        random_state=random_state
    )
    umap_embedding = fit.fit_transform(scaled_embeddings)
    
    # save into dataframe for visualization
    visual_df = pd.DataFrame(umap_embedding)
    visual_df['hue'] = hue_labels  
    #visual_df['style'] = style_labels  
    visual_df.columns = ['x', 'y', 'cluster']   
   
    visual_df_dict[key] = visual_df

(1057, 640)
(1057, 640)


# Visualize Embeddings

In [None]:
"""
# old version - deprecated
from PIL import Image

RGB格式颜色转换为16进制颜色格式
def RGB_to_Hex(rgb):
    rgb = list(rgb) 
    RGB = [i* 255 for i in rgb]
    color = '#'
    for i in RGB:
        num = int(i)
        # 将R、G、B分别转化为16进制拼接转换并大写  hex() 函数用于将10进制整数转换成16进制，以字符串形式表示
        color += str(hex(num))[-2:].replace('x', '0').upper()
    return color

def generate_color_dict(labels):
    color_dict = {}
    rgb_list = sns.color_palette("husl", len(labels))
    
    for i in range(len(rgb_list)):
        temp = RGB_to_Hex(rgb_list[i])
        #print(labels[i])    
        color_dict[labels[i]] = temp

    print(color_dict)
    return color_dict
    
color_dict = generate_color_dict(selected_family_names)

def visualize_embeddings(visual_df, color_dict, title, save_path):
    plt.figure(figsize=(8,8))    

    sub_handles = []
    key_list = []
    for key in color_dict.keys():           
        sub_df = visual_df[visual_df['label']== key]        
        sub_handle = plt.scatter(sub_df['x'], sub_df['y'], c = color_dict[key], s = 10)# 10 #0.1)
        sub_handles.append(sub_handle)
        key_list.append(key) 
   
    plt.title(title, fontsize=18)    
    plt.legend(sub_handles, key_list, bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0)
    plt.xticks(())
    plt.yticks(())
    
    if save_path is not None:
        plt.savefig(save_path, dpi=600, bbox_inches='tight') 
"""

def visualize_embeddings(visual_df, hue=None, style=None, title="", save_path=None):
    fig = plt.figure(figsize=(8,8))    
   
    ax = sns.scatterplot(visual_df, x="x", y="y", hue=hue, style=style, s=100)
    ax.set_title(title)
    
    sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1),)

    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')   
        
for key in visual_df_dict.keys():
     # draw and save
    visual_df = visual_df_dict[key]    
    save_path = os.path.join(save_dir, "cluster_{}_{}.pdf".format(theme, key))    
    visualize_embeddings(visual_df, "cluster", title='{}:n_{};d_{}'.format(key, n_neighbors, min_dist), save_path=save_path)   