# ESM Embeddigs

Given a sequence such as ABCDE, using ESM, get the embeddings for each aminoacid. Assume that mutbpe tokenizer segments this sequence into AB CDE and we know that our mutator is able to mutate AB into AN. Get the embeddings for sequence ANCDE. 
- Compare the embeddings of N and B. We expect the change in embeddings should be smaller compared to say any "non-mutbpe" mutation for B. It is possible to also consider the score or the observed frequency of this mutation during the training of mutBPE. Do the mutation scores or frequencies correlate positively with the change in embedding vectors?
- Observe the changes in embeddings of other aminoacids after the mutation. Also possible to consider "up to a range". How local or a global impact does this particular mutation have?

In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["HF_HOME"] = "/cta/share/users/esm"

In [2]:
from time import time
import sqlite3
import pandas as pd
from tqdm import tqdm
import numpy as np
from tokenizers import Tokenizer
import json
from collections import Counter
from transformers import AutoTokenizer, AutoModel
import torch
from tqdm import tqdm
from protein_embedding_database import ProteinEmbeddingDatabase
from EfficientBPE.vocabulary_functions import get_mutated, get_parents, set_difference, set_intersection, load_tokenizers, calc_agreement, calc_dice_idx_only

In [3]:
torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.current_device(), torch.cuda.get_device_name(0)

(True, 2, 0, 'NVIDIA RTX A6000')

## Get Embeddings

In [296]:
def get_embeddings(text, model_name="facebook/esm2_t6_8M_UR50D"):
    """
    Compute embeddings for each token in the text using a specified model.
    
    Parameters:
    - text (str): The input text for which embeddings need to be computed.
    - model_name (str): The path to the pretrained model.
    
    Returns:
    - numpy.ndarray: A matrix where each row is the embedding of a token in the text.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)

    # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=1024)
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    # Return embeddings after removing <cls> and <eos> tokens and converting to numpy.
    return outputs.last_hidden_state[:, 1:-1, :].squeeze(0).numpy()

In [310]:
aa = get_embeddings("MKWVTFISLLLLFSSAYS", model_name="facebook/esm2_t30_150M_UR50D")

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [311]:
aa.shape

(18, 640)

## Load DB

In [270]:
# Load the database later
model_name = "facebook/esm2_t6_8M_UR50D"
root_path = "/cta/share/users/uniprot/human/faiss"
faiss_path = f"{root_path}/{model_name.replace('/', '_')}_protein_embeddings.faiss"
id_map_path = f"{root_path}/{model_name.replace('/', '_')}_id_mapping.csv"
loaded_db = ProteinEmbeddingDatabase.load_database(faiss_path, id_map_path, model_name)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [279]:
# Retrieve embedding for a specific amino acid
amino_acid_id = "P12345_0"  # First amino acid of P12345
embedding = loaded_db.get_amino_acid_embedding(amino_acid_id)
embedding

array([-4.33536470e-01,  1.08182698e-01,  7.22915351e-01,  8.09530437e-01,
       -1.68629721e-01, -3.37319136e-01,  2.20036283e-01, -4.06025648e-01,
       -3.45300995e-02,  6.42152607e-01,  5.28250158e-01,  3.64216790e-02,
       -5.65152586e-01, -7.89133534e-02, -8.40457201e-01, -2.66285449e-01,
        2.25273985e-02,  2.46186495e-01, -1.99174836e-01,  5.00954762e-02,
       -6.57048151e-02, -5.58021069e-01, -3.60115975e-01, -7.39178155e-03,
        3.92815083e-01, -3.10629513e-03,  1.40860766e-01,  4.58809912e-01,
       -7.36962184e-02, -3.18919033e-01,  6.67334378e-01, -2.52795011e-01,
       -1.68448556e-02,  6.68131053e-01,  7.13799834e-01, -4.80622858e-01,
       -4.00493264e-01,  4.53898162e-01, -8.88479128e-02,  2.53326744e-01,
        3.68214339e-01, -1.33581683e-01,  6.44511700e-01,  5.48404872e-01,
       -7.02140212e-01,  4.46021050e-01, -3.54176372e-01,  2.21346363e-01,
       -6.59330860e-02, -3.03493470e-01, -5.11039972e-01, -1.79900214e-01,
        2.08203066e-02,  

## Load Tokenizers

In [36]:
# 'dataset': {'uniref50', 'uniref90'}
# 'is_pretokenizer': {True, False}
# 'subs_matrix': {'blosum45', 'blosum62', 'pam70', 'pam250'}
# 'mutation_cutoff': {0.7, 0.8, 0.9}
# 'min_mutation_freq': {0, 0.05,. 0.005}
# 'min_mutation_len': {3}
# 'max_mutation_len': {12}
# 'vocab_size': list=[800, 1600, 3200, 6400, 12800, 25600, 51200]

vocab_sizes = [800, 3200, 12800]
uniref_id = "50"

tokenizer_opts_list = [
    {
        'is_mut': False,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': False,
        'vocab_size': vocab_sizes
    },
    # {
    #     'is_mut': False,
    #     'dataset': f'uniref{uniref_id}',
    #     'is_pretokenizer': True,
    #     'vocab_size': vocab_sizes
    # },
    # {
    #     'is_mut': True,
    #     'dataset': f'uniref{uniref_id}',
    #     'is_pretokenizer': False,
    #     'subs_matrix': 'blosum62',
    #     'mutation_cutoff': 0.7,
    #     'min_mutation_freq': 0,
    #     'min_mutation_len': 3,
    #     'max_mutation_len': 12,
    #     'vocab_size': vocab_sizes
    # },
    # {
    #     'is_mut': True,
    #     'dataset': f'uniref{uniref_id}',
    #     'is_pretokenizer': False,
    #     'subs_matrix': 'blosum62',
    #     'mutation_cutoff': 0.9,
    #     'min_mutation_freq': 0,
    #     'min_mutation_len': 3,
    #     'max_mutation_len': 12,
    #     'vocab_size': vocab_sizes
    # },
    {
        'is_mut': True,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': False,
        'subs_matrix': 'blosum62',
        'mutation_cutoff': 0.7,
        'min_mutation_freq': 0.05,
        'min_mutation_len': 3,
        'max_mutation_len': 12,
        'vocab_size': vocab_sizes
    },
    {
        'is_mut': True,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': False,
        'subs_matrix': 'pam70',
        'mutation_cutoff': 0.7,
        'min_mutation_freq': 0.05,
        'min_mutation_len': 3,
        'max_mutation_len': 12,
        'vocab_size': vocab_sizes
    },
    {
        'is_mut': True,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': True,
        'subs_matrix': 'blosum62',
        'mutation_cutoff': 0.7,
        'min_mutation_freq': 0.05,
        'min_mutation_len': 3,
        'max_mutation_len': 12,
        'vocab_size': vocab_sizes
    },
]

In [37]:
tokenizer_list = load_tokenizers(tokenizer_opts_list, 'hf')
inner_vocab_list = load_tokenizers(tokenizer_opts_list, 'vocab')

vocab_list = {}
for name, tokenizer in tokenizer_list.items():
    vocab_list[name] = list(set([token for token, idx in tokenizer.get_vocab().items()]))

## Read Datasets

In [38]:
# Connect to DB
db_file = "/cta/share/users/uniprot/human/human.db"
conn = sqlite3.connect(db_file)

df_protein = pd.read_sql(f"""SELECT Entry as uniprot_id, Sequence as sequence
                          FROM proteins
                          WHERE Entry IN (SELECT uniprot_accession FROM uniref{uniref_id}_distilled)""", conn)
df_protein = df_protein[df_protein['sequence'].str.len() < 3000].reset_index(drop=True)

# df_protein_pre = pd.read_sql(f"SELECT * FROM uniref{uniref_id}_domain_sliced_plddt70", conn)
# df_protein_pre = df_protein_pre[~df_protein_pre['uniprot_id'].isin(df_protein[df_protein['sequence'].str.len() > 3000]['uniprot_id'].unique())]

df_interpro_domain = pd.read_sql(f"SELECT uniprot_id, interpro_id as source, start_index, end_index FROM interpro_entries_v2 WHERE type='domain'", conn)
df_ted = pd.read_sql(f"SELECT uniprot_id, ted_id as source, start_index, end_index FROM ted_entries_summary WHERE plddt >= 70", conn)

conn.close()

In [None]:
df_domains = pd.concat([df_interpro_domain, df_ted])
# Find uniprot_ids that have "interpro" as a source
interpro_ids = df_domains.loc[df_domains["source"].str.startswith("IPR"), "uniprot_id"].unique()
# Filter the DataFrame to exclude rows with source "ted" for those uniprot_ids
df_domains = df_domains[~((df_domains["uniprot_id"].isin(interpro_ids)) & (df_domains["source"].str.startswith("AF")))]
df_domains = df_protein.set_index('uniprot_id').join(df_domains.set_index('uniprot_id'), how='inner').reset_index()
df_domains['domain_sequence'] = df_domains.apply(lambda row: row['sequence'][row['start_index']-1: row['end_index']], axis=1)
df_domains = df_domains[df_domains['domain_sequence'].str.len()>0]

df_domains = df_domains[df_domains['source'].str.startswith('IPR')] # just keep interpro entries

df_domains = df_domains[['uniprot_id', 'source', 'domain_sequence']].reset_index(drop=True)
df_domains

Unnamed: 0,uniprot_id,source,domain_sequence
0,A0A087X296,IPR000742,PVNPCCYYPCQHQGICVRFGLDRYQCDCTRTGYSGPNCT
1,A0A0K2S4Q6,IPR003599,PSTVMGAVGESLSVQCRYEEKYKTFNKYWCRQPCLPIWHEMVETGG...
2,A0A0K2S4Q6,IPR007110,PGCLTVSGPSTVMGAVGESLSVQCRYEEKYKTFNKYWCRQPCLPIW...
3,A0A0K2S4Q6,IPR013106,GPSTVMGAVGESLSVQCRYEEKYKTFNKYWCRQPCLPIWHEMVETG...
4,A0A3B3ISZ0,IPR001206,AQVKKATVFLNPAACKGKARTLFEKNAAPILHLSGMDVTIVKTDYE...
...,...,...,...
99798,X6RHN7,IPR028889,KGLSNEPGQNSCFLNSALQVLWHLDIFRRSFRQLTTHKCMGDSCIF...
99799,X6RIL1,IPR005302,RPRRPHQIADLFRPKDQIAYSDTSPFLILSEASLADLNSRLEKKVK...
99800,X6RK39,IPR025946,LPRVLRVCSGVYFEGSIYEISGNECCLSTGDLIKVTQVRLQKVVCE...
99801,X6RK39,IPR025946,ILEVPEGRPIFLSPWVGSLQKGQRLCVYGLASPPWRVLASSKGRKV...


In [40]:
for name, tokenizer in tqdm(list(tokenizer_list.items())):
    df_domains[name] = [enc.tokens for enc in tokenizer.encode_batch(df_domains['domain_sequence'])]

100%|██████████| 12/12 [00:14<00:00,  1.21s/it]


In [41]:
df_domains.head()

Unnamed: 0,uniprot_id,source,domain_sequence,stdBPE 800,stdBPE 3200,stdBPE 12800,mutBPE blosum62 0.7 0.05 800,mutBPE blosum62 0.7 0.05 3200,mutBPE blosum62 0.7 0.05 12800,mutBPE pam70 0.7 0.05 800,mutBPE pam70 0.7 0.05 3200,mutBPE pam70 0.7 0.05 12800,mutBPE pre blosum62 0.7 0.05 800,mutBPE pre blosum62 0.7 0.05 3200,mutBPE pre blosum62 0.7 0.05 12800
0,A0A087X296,IPR000742,PVNPCCYYPCQHQGICVRFGLDRYQCDCTRTGYSGPNCT,"[PV, N, PC, C, YY, PC, QH, QG, I, CV, RF, GLD,...","[PVN, PCC, YY, PC, QH, QGI, CV, RF, GLD, RYQ, ...","[PVN, PCC, YY, PC, QH, QGI, CV, RF, GLD, RYQ, ...","[PV, N, PC, C, YY, PC, QH, QG, I, CV, RF, GL, ...","[PV, N, PC, C, YY, PC, QH, QGI, CV, RF, GLD, R...","[PV, NPC, C, YY, PC, QH, QGI, CV, RF, GLD, RY,...","[PV, N, PC, C, YY, PC, QH, QG, I, CV, RF, GL, ...","[PV, N, PC, C, YY, PC, QH, QGI, CV, RF, GLD, R...","[PV, NPC, CYY, PC, QH, QGI, CV, RF, GLD, RY, Q...","[PV, N, PC, C, YY, PC, QH, QG, I, CV, RF, GL, ...","[PV, N, PC, C, YY, PC, QH, QGI, CV, RF, GLD, R...","[PV, NPC, C, YY, PC, QH, QGI, CV, RF, GLD, RY,..."
1,A0A0K2S4Q6,IPR003599,PSTVMGAVGESLSVQCRYEEKYKTFNKYWCRQPCLPIWHEMVETGG...,"[PST, V, MG, AVG, ESL, SV, QC, RY, EE, KY, KT,...","[PST, V, MG, AVG, ESL, SV, QC, RY, EE, KY, KT,...","[PSTV, MG, AVG, ESL, SV, QC, RY, EE, KY, KT, F...","[PST, V, MG, AVG, ESL, SV, QC, RY, EE, KY, KT,...","[PST, VMG, AVG, ESL, SV, QC, RY, EE, KY, KT, F...","[PST, VMG, AVG, ESL, SV, QC, RY, EE, KY, KT, F...","[PST, V, MG, AV, G, ESL, SV, QC, RY, EE, KY, K...","[PST, VMG, AVG, ESL, SV, QC, RY, EE, KY, KT, F...","[PST, VMG, AVG, ESL, SV, QC, RY, EE, KY, KT, F...","[PST, V, MG, AVG, ESL, SV, QC, RY, EE, KY, KT,...","[PST, VMG, AVG, ESL, SV, QC, RY, EE, KY, KT, F...","[PST, VMG, AVG, ESL, SV, QC, RY, EE, KY, KT, F..."
2,A0A0K2S4Q6,IPR007110,PGCLTVSGPSTVMGAVGESLSVQCRYEEKYKTFNKYWCRQPCLPIW...,"[PG, CL, TV, SG, PST, V, MG, AVG, ESL, SV, QC,...","[PG, CL, TV, SG, PST, V, MG, AVG, ESL, SV, QC,...","[PG, CL, TVSG, PSTV, MG, AVG, ESL, SV, QC, RY,...","[PG, CL, TV, SG, PST, V, MG, AVG, ESL, SV, QC,...","[PG, CL, TV, SG, PST, VMG, AVG, ESL, SV, QC, R...","[PG, CL, TVSG, PST, VMG, AVG, ESL, SV, QC, RY,...","[PG, CL, TV, SG, PST, V, MG, AV, G, ESL, SV, Q...","[PG, CL, TV, SG, PST, VMG, AVG, ESL, SV, QC, R...","[PG, CL, TVSG, PST, VMG, AVG, ESL, SV, QC, RY,...","[PG, CL, TV, SG, PST, V, MG, AVG, ESL, SV, QC,...","[PG, CL, TV, SG, PST, VMG, AVG, ESL, SV, QC, R...","[PG, CL, TVSG, PST, VMG, AVG, ESL, SV, QC, RY,..."
3,A0A0K2S4Q6,IPR013106,GPSTVMGAVGESLSVQCRYEEKYKTFNKYWCRQPCLPIWHEMVETG...,"[G, PST, V, MG, AVG, ESL, SV, QC, RY, EE, KY, ...","[G, PST, V, MG, AVG, ESL, SV, QC, RY, EE, KY, ...","[G, PSTV, MG, AVG, ESL, SV, QC, RY, EE, KY, KT...","[GP, ST, V, MG, AVG, ESL, SV, QC, RY, EE, KY, ...","[GP, STV, MG, AVG, ESL, SV, QC, RY, EE, KY, KT...","[GP, STV, MG, AVG, ESL, SV, QC, RY, EE, KY, KT...","[G, PST, V, MG, AV, G, ESL, SV, QC, RY, EE, KY...","[G, PST, VMG, AVG, ESL, SV, QC, RY, EE, KY, KT...","[GPST, VMG, AVG, ESL, SV, QC, RY, EE, KY, KT, ...","[GP, ST, V, MG, AVG, ESL, SV, QC, RY, EE, KY, ...","[GP, STV, MG, AVG, ESL, SV, QC, RY, EE, KY, KT...","[GP, STV, MG, AVG, ESL, SV, QC, RY, EE, KY, KT..."
4,A0A3B3ISZ0,IPR001206,AQVKKATVFLNPAACKGKARTLFEKNAAPILHLSGMDVTIVKTDYE...,"[A, QV, KKA, TV, FL, NP, AA, C, KG, KA, RTL, F...","[A, QV, KKA, TV, FL, NP, AAC, KG, KA, RTL, F, ...","[A, QV, KKA, TVFL, NP, AAC, KGKA, RTL, F, EKN,...","[A, QV, KK, A, TV, FL, NP, AA, C, KG, KA, R, T...","[AQV, KK, ATV, FL, NP, AAC, KG, KA, RTL, F, EK...","[AQV, KK, ATV, FL, NP, AAC, KG, KA, RTLF, EK, ...","[A, QV, KK, ATV, FL, NP, AA, C, KG, KA, RTL, F...","[AQV, KK, ATV, FL, NP, AA, CKG, KA, RTL, FEK, ...","[AQV, KK, ATV, FL, NPAA, CKG, KA, RTL, FEK, NA...","[A, QV, KK, A, TV, FL, NP, AA, C, KG, KA, R, T...","[AQV, KK, ATV, FL, NP, AA, C, KG, KA, RTL, FEK...","[AQV, KK, ATV, FL, NP, AAC, KG, KA, RTL, FEK, ..."
