# ESM Embeddigs

Given a sequence such as ABCDE, using ESM, get the embeddings for each aminoacid. Assume that mutbpe tokenizer segments this sequence into AB CDE and we know that our mutator is able to mutate AB into AN. Get the embeddings for sequence ANCDE. 
- Compare the embeddings of N and B. We expect the change in embeddings should be smaller compared to say any "non-mutbpe" mutation for B. It is possible to also consider the score or the observed frequency of this mutation during the training of mutBPE. Do the mutation scores or frequencies correlate positively with the change in embedding vectors?
- Observe the changes in embeddings of other aminoacids after the mutation. Also possible to consider "up to a range". How local or a global impact does this particular mutation have?

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["HF_HOME"] = "/cta/share/users/esm"

In [2]:
from time import time
import sqlite3
import pandas as pd
from tqdm import tqdm
import numpy as np
from tokenizers import Tokenizer
import json
from collections import Counter
from transformers import AutoTokenizer, AutoModel
import torch
from tqdm import tqdm
from protein_embedding_database import ProteinEmbeddingDatabase
from EfficientBPE.vocabulary_functions import get_mutated, get_parents, set_difference, set_intersection, load_tokenizers, calc_agreement, calc_dice_idx_only

In [3]:
torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.current_device(), torch.cuda.get_device_name(0)

(True, 1, 0, 'NVIDIA RTX A6000')

## Get Embeddings

In [5]:
# facebook/model_name
# Checkpoint name	    Num layers  Num parameters  Dim count   Have DB
# esm2_t48_15B_UR50D	48	        15B             -           No
# esm2_t36_3B_UR50D	    36          3B              -           No
# esm2_t33_650M_UR50D	33          650M            1280        Yes
# esm2_t30_150M_UR50D	30          150M            640         Yes
# esm2_t12_35M_UR50D	12          35M             480         No
# esm2_t6_8M_UR50D	    6           8M              320         Yes
model_name = "facebook/esm2_t30_150M_UR50D"

In [150]:
# def get_embeddings(text, model_name="facebook/esm2_t6_8M_UR50D"):
#     """
#     Compute embeddings for each token in the text using a specified model.
    
#     Parameters:
#     - text (str): The input text for which embeddings need to be computed.
#     - model_name (str): The path to the pretrained model.
    
#     Returns:
#     - numpy.ndarray: A matrix where each row is the embedding of a token in the text.
#     """
#     tokenizer = AutoTokenizer.from_pretrained(model_name)
#     model = AutoModel.from_pretrained(model_name)

#     # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=1024)
#     inputs = tokenizer(text, return_tensors="pt")
#     with torch.no_grad():
#         outputs = model(**inputs)

#     # Return embeddings after removing <cls> and <eos> tokens and converting to numpy.
#     return outputs.last_hidden_state[:, 1:-1, :].squeeze(0).numpy()

def get_embeddings_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)

    return tokenizer, model

def get_embeddings(text, tokenizer, model):
    # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=1024)
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    # Return embeddings after removing <cls> and <eos> tokens and converting to numpy.
    return outputs.last_hidden_state[:, 1:-1, :].squeeze(0).numpy()

In [157]:
tokenizer, model = get_embeddings_model(model_name)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [158]:
get_embeddings("PVNPCCYYPC", tokenizer, model).shape

(10, 640)

## Load ESM Embeddings Faiss DB

In [None]:
def load_protein_embedding_db(model_name, root_path="/cta/share/users/uniprot/human/faiss"):
    faiss_path = f"{root_path}/{model_name.replace('/', '_')}_protein_embeddings.faiss"
    id_map_path = f"{root_path}/{model_name.replace('/', '_')}_id_mapping.csv"
    loaded_db = ProteinEmbeddingDatabase.load_database(faiss_path, id_map_path, model_name)

    return loaded_db

loaded_db = load_protein_embedding_db(model_name)

In [14]:
def get_aminoacid_embedding(uniprot_id, aa_index):
    return loaded_db.get_amino_acid_embedding(f"{uniprot_id}_{aa_index}")

def get_protein_embedding(uniprot_id, sequence_len):
    return np.array([loaded_db.get_amino_acid_embedding(f"{uniprot_id}_{i}")for i in range(sequence_len)])

def get_protein_embedding_slice(uniprot_id, start_index, end_index):
    return np.array([loaded_db.get_amino_acid_embedding(f"{uniprot_id}_{i}")for i in range(start_index, end_index)])

## Load Datasets

In [57]:
# Connect to DB
db_file = "/cta/share/users/uniprot/human/human.db"
conn = sqlite3.connect(db_file)

uniref_id = '50'
df_protein = pd.read_sql(f"""SELECT Entry as uniprot_id, Sequence as sequence
                          FROM proteins
                          WHERE Entry IN (SELECT uniprot_accession FROM uniref{uniref_id}_distilled)""", conn)
df_protein = df_protein[df_protein['sequence'].str.len() < 3000].reset_index(drop=True)

df_interpro_domain = pd.read_sql(f"SELECT uniprot_id, interpro_id as source, start_index, end_index FROM interpro_entries_v2 WHERE type='domain'", conn)
df_ted = pd.read_sql(f"SELECT uniprot_id, ted_id as source, start_index, end_index FROM ted_entries_summary WHERE plddt >= 70", conn)

conn.close()

In [58]:
df_domains = pd.concat([df_interpro_domain, df_ted])
interpro_ids = df_domains.loc[df_domains["source"].str.startswith("IPR"), "uniprot_id"].unique() # Find uniprot_ids that have "interpro" as a source
df_domains = df_domains[~((df_domains["uniprot_id"].isin(interpro_ids)) & (df_domains["source"].str.startswith("AF")))] # Filter the DataFrame to exclude rows with source "ted" for those uniprot_ids
df_domains = df_protein.set_index('uniprot_id').join(df_domains.set_index('uniprot_id'), how='inner').reset_index()
df_domains['domain_sequence'] = df_domains.apply(lambda row: row['sequence'][row['start_index']-1: row['end_index']], axis=1)
df_domains = df_domains[df_domains['domain_sequence'].str.len()>0].reset_index(drop=True)

# df_domains = df_domains[df_domains['source'].str.startswith('IPR')] # just keep interpro entries
# df_domains = df_domains[['uniprot_id', 'source', 'domain_sequence']].reset_index(drop=True)
df_domains

Unnamed: 0,uniprot_id,sequence,source,start_index,end_index,domain_sequence
0,A0A087X1C5,MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...,AF-A0A087X1C5-F1-model_v4_TED01,32,333,RYPPGPLPLPGLGNLLHVDFQNTPYCFDQLRRRFGDVFSLQLAWTP...
1,A0A087X1C5,MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...,AF-A0A087X1C5-F1-model_v4_TED01,344,362,VCPVRVQQEIDDVIGQVRR
2,A0A087X1C5,MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...,AF-A0A087X1C5-F1-model_v4_TED01,376,515,AVIHEVQHFGDIVPLGVTHMTSRDIEVQGFRIPKGTTLITNLSSVL...
3,A0A087X296,MSRSLLLWFLLFLLLLPPLPVLLADPGAPTPVNPCCYYPCQHQGIC...,IPR000742,31,69,PVNPCCYYPCQHQGICVRFGLDRYQCDCTRTGYSGPNCT
4,A0A0B4J2F0,MFRRLTFAQLLFATVLGIAGGVYIFQPVFEQYAKDQKELKEKMQLV...,AF-A0A0B4J2F0-F1-model_v4_TED01,30,54,EQYAKDQKELKEKMQLVQESEEKKS
...,...,...,...,...,...,...
120430,X6RL26,MQPMSFGWDHSLHKRKRLPPVKRSLVYYLKNREVRLQNETSYSRVL...,IPR056151,54,189,LPSLLKEREFHLGTLNKVFASQWLNHRQVVCGTKCNTLFVVDVQTS...
120431,X6RL45,MVRCYVEIVEKLPERRPDPATIEGCAQLKPNNYLLAWHTPFNEKGS...,AF-X6RL45-F1-model_v4_TED01,1,161,MVRCYVEIVEKLPERRPDPATIEGCAQLKPNNYLLAWHTPFNEKGS...
120432,X6RL83,MLQEWLAAVGDDYAAVVWRPEGEPRFYPDEEGPKHWTKERHQFLME...,AF-X6RL83-F1-model_v4_TED01,2,218,LQEWLAAVGDDYAAVVWRPEGEPRFYPDEEGPKHWTKERHQFLMEL...
120433,X6RLN4,EVKGLFKSENCPKVISCEFAHNSNWYITFQSDTDAQQAFKYLREEV...,AF-X6RLN4-F1-model_v4_TED01,3,54,KGLFKSENCPKVISCEFAHNSNWYITFQSDTDAQQAFKYLREEVKT...


## Load Tokenizers

In [59]:
# 'dataset': {'uniref50', 'uniref90'}
# 'is_pretokenizer': {True, False}
# 'subs_matrix': {'blosum45', 'blosum62', 'pam70', 'pam250'}
# 'mutation_cutoff': {0.7, 0.8, 0.9}
# 'min_mutation_freq': {0, 0.05,. 0.005}
# 'min_mutation_len': {3}
# 'max_mutation_len': {12}
# 'vocab_size': list=[800, 1600, 3200, 6400, 12800, 25600, 51200]

vocab_sizes = [800, 3200, 12800]
uniref_id = "50"

tokenizer_opts_list = [
    # {
    #     'is_mut': False,
    #     'dataset': f'uniref{uniref_id}',
    #     'is_pretokenizer': False,
    #     'vocab_size': vocab_sizes
    # },
    {
        'is_mut': True,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': False,
        'subs_matrix': 'blosum62',
        'mutation_cutoff': 0.7,
        'min_mutation_freq': 0.05,
        'min_mutation_len': 3,
        'max_mutation_len': 12,
        'vocab_size': vocab_sizes
    },
    {
        'is_mut': True,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': False,
        'subs_matrix': 'pam70',
        'mutation_cutoff': 0.7,
        'min_mutation_freq': 0.05,
        'min_mutation_len': 3,
        'max_mutation_len': 12,
        'vocab_size': vocab_sizes
    },
    {
        'is_mut': True,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': True,
        'subs_matrix': 'blosum62',
        'mutation_cutoff': 0.7,
        'min_mutation_freq': 0.05,
        'min_mutation_len': 3,
        'max_mutation_len': 12,
        'vocab_size': vocab_sizes
    },
]

In [224]:
tokenizer_list = load_tokenizers(tokenizer_opts_list, 'hf')
inner_vocab_list = load_tokenizers(tokenizer_opts_list, 'vocab')

vocab_list = {}
for name, tokenizer in tokenizer_list.items():
    vocab_list[name] = list(set([token for token, idx in tokenizer.get_vocab().items()]))


inner_vocab_parents_list = {}
inner_vocab_mutated_list = {}
inner_vocab_family_list = {}
for k, v in inner_vocab_list.items():
    inner_vocab_parents_list[k] = get_parents(v)
    inner_vocab_mutated_list[k] = get_mutated(v)
    inner_vocab_family_list[k] = {p:0 for p in inner_vocab_parents_list[k].keys()}

In [225]:
for tokenizer_name in tokenizer_list.keys():
    for mutated_token, mutated_token_attr in inner_vocab_mutated_list[tokenizer_name].items():
        parent_token = mutated_token_attr['parent']
        inner_vocab_parents_list[tokenizer_name][parent_token]['mutations'] = inner_vocab_parents_list[tokenizer_name][parent_token].get('mutations', []) + [mutated_token]

In [61]:
for name, tokenizer in tqdm(list(tokenizer_list.items())):
    df_protein[name] = [enc.tokens for enc in tokenizer.encode_batch(df_protein['sequence'])]

  0%|          | 0/9 [00:00<?, ?it/s]

100%|██████████| 9/9 [00:26<00:00,  2.97s/it]


In [None]:
# df_protein_domain_sequences = df_domains[['uniprot_id', 'sequence']].drop_duplicates()
# for name, tokenizer in tqdm(list(tokenizer_list.items())):
#     df_protein_domain_sequences[name] = [enc.tokens for enc in tokenizer.encode_batch(df_protein_domain_sequences['sequence'])]
# df_domains = df_domains.set_index(['uniprot_id','sequence']).join(df_protein_domain_sequences.set_index(['uniprot_id','sequence']), how='inner').reset_index()
# df_domains.head()

100%|██████████| 12/12 [00:23<00:00,  2.00s/it]


In [68]:
np.random.seed(1)
df_protein_main = df_protein.sample(100)

In [124]:
import random

def generate_alternative_token(token: str, mutated_token: str, tabu_list: list, alphabet: str) -> str:
    """
    Generate an alternative token based on two input tokens, avoiding tokens in tabu list.
    
    Args:
        token (str): The original token
        mutated_token (str): The mutated version of the token
        tabu_list (list): List of tokens to avoid
        alphabet (str): String containing all possible characters
    
    Returns:
        str: A valid alternative token
    
    Raises:
        ValueError: If token and mutated_token have different lengths
        ValueError: If no valid alternative token can be generated after 100 attempts
    """
    if len(token) != len(mutated_token):
        raise ValueError("Token and mutated_token must have the same length")
    
    max_attempts = 100
    attempt = 0
    
    while attempt < max_attempts:
        # Initialize list to store characters of alternative token
        alternative_chars = []
        
        # Generate alternative token character by character
        for t, m in zip(token, mutated_token):
            if t == m:
                # If characters are same in both tokens, use that character
                alternative_chars.append(t)
            else:
                # If characters differ, randomly choose from alphabet
                alternative_chars.append(random.choice(alphabet))
        
        # Convert character list to string
        alternative = ''.join(alternative_chars)
        
        # Check if generated token is not in tabu list
        if alternative not in tabu_list:
            return alternative
            
        attempt += 1
    
    return mutated_token
    raise ValueError("Could not generate a valid alternative token after 100 attempts")

# Example usage
token = "hello"
mutated_token = "heppo"
tabu_list = ["hello", "heppo", "helpo"]
alphabet = "abcdefghijklmnopqrstuvwxyz"

alternative = generate_alternative_token(token, mutated_token, tabu_list, alphabet)
print(alternative)  # Might print something like "helao" or "helko"

heguo


In [None]:
tokenizer_name = 'mutBPE blosum62 0.7 0.05 3200'
token_set = df_protein_main[tokenizer_name].iloc[0]

In [255]:
def generate_mutated_alternative_token_set(token_set, tokenizer_name, random_seed=42):
    random.seed(random_seed)
    alphabet = "ARNDCEQGHILKMFPSTWYVUOXBZJ"
    mutated_token_set = []
    alternative_token_set = []
    for token in token_set:
        if token in inner_vocab_parents_list[tokenizer_name]:
            if 'mutations' in inner_vocab_parents_list[tokenizer_name][token]:
                mutated_token = inner_vocab_parents_list[tokenizer_name][token]['mutations'][0]
                tabu_list = [token] + inner_vocab_parents_list[tokenizer_name][token]['mutations']
                alternative_token = generate_alternative_token(token, mutated_token, tabu_list, alphabet)
            else:
                mutated_token = token
                alternative_token = token
        elif token in inner_vocab_mutated_list[tokenizer_name]:
            parent_token = inner_vocab_mutated_list[tokenizer_name][token]['parent']
            if len(inner_vocab_parents_list[tokenizer_name][parent_token]['mutations']) > 1:
                candidate_1 = inner_vocab_parents_list[tokenizer_name][parent_token]['mutations'][0]
                candidate_2 = inner_vocab_parents_list[tokenizer_name][parent_token]['mutations'][1]
                mutated_token = candidate_1 if candidate_1 != token else candidate_2
                tabu_list = [parent_token] + inner_vocab_parents_list[tokenizer_name][parent_token]['mutations']
                alternative_token = generate_alternative_token(token, mutated_token, tabu_list, alphabet)
            else:
                mutated_token = token
                alternative_token = token
        else:
            mutated_token = token
            alternative_token = token
        mutated_token_set.append(mutated_token)
        alternative_token_set.append(alternative_token)
    return mutated_token_set, alternative_token_set

In [256]:
aa = df_protein_main.apply(lambda row: {tokenizer_name:generate_mutated_alternative_token_set(row[tokenizer_name], tokenizer_name) for tokenizer_name in tokenizer_list.keys()}, axis=1)

In [267]:
aa = pd.DataFrame.from_dict(list(aa))

In [268]:
result = pd.concat(
    [aa[col].apply(pd.Series).add_prefix(f"{col}_") for col in aa.columns],
    axis=1
)

In [272]:
result

Unnamed: 0,mutBPE blosum62 0.7 0.05 800_0,mutBPE blosum62 0.7 0.05 800_1,mutBPE blosum62 0.7 0.05 3200_0,mutBPE blosum62 0.7 0.05 3200_1,mutBPE blosum62 0.7 0.05 12800_0,mutBPE blosum62 0.7 0.05 12800_1,mutBPE pam70 0.7 0.05 800_0,mutBPE pam70 0.7 0.05 800_1,mutBPE pam70 0.7 0.05 3200_0,mutBPE pam70 0.7 0.05 3200_1,mutBPE pam70 0.7 0.05 12800_0,mutBPE pam70 0.7 0.05 12800_1,mutBPE pre blosum62 0.7 0.05 800_0,mutBPE pre blosum62 0.7 0.05 800_1,mutBPE pre blosum62 0.7 0.05 3200_0,mutBPE pre blosum62 0.7 0.05 3200_1,mutBPE pre blosum62 0.7 0.05 12800_0,mutBPE pre blosum62 0.7 0.05 12800_1
0,"[M, EAI, SFI, KL, EV, NG, P, MV, TV, AL, SV, S...","[M, EAU, DFA, KL, EV, NG, P, MV, TV, AL, SV, B...","[M, EAI, SFI, KL, EV, NG, PVI, TV, AL, SV, SLI...","[M, EAU, DFA, KL, EV, NG, PBH, TV, AL, SV, GLG...","[M, EAI, SFI, KL, EV, NG, PVI, TIAL, SV, SLI, ...","[M, EAU, DFA, KL, EV, NG, PBH, TGAL, SV, GLC, ...","[M, DAL, TFL, KL, EV, NG, P, MV, TV, AL, SV, T...","[M, UAL, DFL, KL, EV, NG, P, MV, TV, AL, SV, B...","[M, DAL, TFL, KL, EV, NTP, MV, TV, AL, SV, TLL...","[M, UAL, DFL, KL, EV, NBP, MV, TV, AL, SV, HLL...","[KDAL, TFL, KL, EV, NTP, MV, TV, TLSI, TLL, TL...","[UDAL, BFL, KL, EV, NHP, MV, TV, GLSG, CLL, BL...","[M, EAI, SFI, KL, EV, NG, P, MV, TV, AL, SV, S...","[M, EAU, DFA, KL, EV, NG, P, MV, TV, AL, SV, B...","[M, EAI, SFI, KL, EV, NG, PVI, TV, AL, SV, SLI...","[M, EAU, DFA, KL, EV, NG, PBH, TV, AL, SV, GLG...","[M, EAI, SFI, KL, EV, NG, PVI, TIAL, SV, SLI, ...","[M, EAU, DFA, KL, EV, NG, PBH, TGAL, SV, GLC, ..."
1,"[M, AA, AA, PAC, SAS, SS, EA, PAC, SA, TA, EP,...","[M, AA, AA, PAU, DAA, SS, EA, PAB, SA, TA, EP,...","[M, AASA, PAC, SAS, SS, EA, PAC, SA, TA, EP, E...","[M, AAUA, PAD, AAB, SS, EA, PAH, SA, TA, EP, E...","[M, AASA, PAC, SAS, SSDA, PAC, SANA, SPQA, GD,...","[M, AAUA, PAD, AAB, SSGA, PAB, SAOA, BPWA, GD,...","[M, AA, AA, PAT, TAA, SS, EA, PAT, SA, TA, EP,...","[M, AA, AA, PAU, DAA, SS, EA, PAB, SA, TA, EP,...","[M, AAAT, PAT, TAA, SS, EA, PAT, SA, TA, EP, E...","[M, AAAU, PAD, BAA, SS, EA, PAH, SA, TA, EP, E...","[MAAAT, AAAAAT, STEG, PAT, TASA, EADA, GD, QD,...","[MAAAU, DAAAAA, SBEH, PAC, BADA, EOBA, GD, QD,...","[M, AA, AA, PAC, SAS, SS, EA, PAC, SA, TA, EP,...","[M, AA, AA, PAU, DAA, SS, EA, PAB, SA, TA, EP,...","[M, AASA, PAC, SAS, SS, EA, PAC, SA, TA, EP, E...","[M, AAUA, PAD, AAB, SS, EA, PAH, SA, TA, EP, E...","[M, AASA, PAC, SAS, SS, EA, PAC, SANA, EP, EA,...","[M, AAUA, PAD, AAB, SS, EA, PAH, SACA, EP, EA,..."
2,"[M, EEI, SD, EEM, D, HG, SED, D, SD, K, ED, Q,...","[M, EUD, SD, EEA, D, HG, BEH, D, SD, K, ED, Q,...","[M, EEI, SD, EEM, AHG, SED, DDD, KDD, KDI, D, ...","[M, EUD, SD, EEA, BHG, HEG, DGD, KCD, BDD, D, ...","[M, EEI, SD, EEM, AHG, SED, DDD, KDD, KDI, DQM...","[M, EUD, SD, EEA, BHG, HEG, DGD, KCD, BDD, DOM...","[M, EDL, SD, EE, ID, HG, SED, D, TEK, ED, Q, D...","[M, EUL, SD, EE, ID, HG, DEA, D, BHK, ED, Q, D...","[M, EDL, SD, EE, ID, HG, SED, D, TEK, ED, QDM,...","[M, EUL, SD, EE, ID, HG, DEA, D, BHK, ED, QDG,...","[M, EDL, SQED, ID, HG, SED, ESDK, ED, QDM, NRM...","[M, EUL, SDEA, ID, HG, BEH, GSDK, ED, QDG, CBM...","[M, EEI, SD, EEM, D, HG, SED, D, SD, K, ED, Q,...","[M, EUD, SD, EEA, D, HG, BEH, D, SD, K, ED, Q,...","[M, EEI, SD, EEM, AHG, SED, DDD, KDD, KDI, D, ...","[M, EUD, SD, EEA, BHG, HEG, DGD, KCD, BDD, D, ...","[M, EEI, SD, EEM, AHG, SED, DDD, KDD, KDI, DQM...","[M, EUD, SD, EEA, BHG, HEG, DGD, KCD, BDD, DOM..."
3,"[MG, A, MA, Y, P, LLI, C, LLI, A, QL, SIG, AV,...","[MG, A, MA, Y, P, LLU, C, LLD, A, QL, ABG, AV,...","[MG, AIA, Y, PLLI, C, LLI, A, QL, SIG, AV, GA,...","[MG, AUA, Y, PLLD, C, LLA, A, QL, BHG, AV, GA,...","[MG, AIA, Y, PLLI, C, LLI, AQM, SIG, AV, GA, S...","[MG, AUA, Y, PLLD, C, LLA, AQB, HGG, AV, GA, S...","[MG, A, MA, Y, P, LLM, C, LLM, A, QL, TLG, AV,...","[MG, A, MA, Y, P, LLU, C, LLD, A, QL, BLG, AV,...","[MAT, MA, Y, TLLL, C, LLM, TQM, TLG, AV, GA, S...","[MUD, MA, Y, BLLL, C, LLH, GQG, CLG, AV, GA, S...","[MAT, MTY, TLLL, CILL, TQM, TLG, AIAA, ARD, PQ...","[MUD, MBY, HLLL, CGLL, GQC, BLG, ADOA, BRD, PQ...","[MG, A, MA, Y, P, LLI, C, LLI, A, QL, SIG, AV,...","[MG, A, MA, Y, P, LLU, C, LLD, A, QL, ABG, AV,...","[MG, AIA, YP, LLI, C, LLI, AQI, SIG, AV, GA, S...","[MG, AUA, YP, LLD, C, LLA, AQB, HGG, AV, GA, S...","[MG, AIA, YP, LLI, C, LLI, AQI, SIG, AV, GA, S...","[MG, AUA, YP, LLD, C, LLA, AQB, HGG, AV, GA, S..."
4,"[M, SPI, YQ, AG, SLI, M, TV, N, TL, QG, KK, MI...","[M, UPD, YQ, AG, SAB, M, TV, N, TL, QG, KK, MI...","[M, SPI, YQ, AG, SLI, M, TV, NIL, QG, KK, MI, ...","[M, UPD, YQ, AG, SAB, M, TV, NHL, QG, KK, MI, ...","[M, SPI, YQ, AG, SLI, MSI, NIL, QG, KKLI, ESGM...","[M, UPD, YQ, AG, SAB, MHG, NGL, QG, KKCI, ESGB...","[M, DAL, YQ, AG, TLL, M, TV, N, TL, QG, KK, MI...","[M, UDL, YQ, AG, ABL, M, TV, N, TL, QG, KK, MI...","[M, DAL, YQ, AG, TLL, M, TV, NTM, QG, KK, MI, ...","[M, UDL, YQ, AG, ABL, M, TV, NTH, QG, KK, MI, ...","[M, DAL, YQ, AG, TLL, MTI, NTM, QG, KK, MI, ES...","[M, UDL, YQ, AG, ABL, MTH, NTG, QG, KK, MI, ES...","[M, SPI, YQ, AG, SLI, M, TV, N, TL, QG, KK, MI...","[M, UPD, YQ, AG, SAB, M, TV, N, TL, QG, KK, MI...","[M, SPI, YQ, AG, SLI, M, TV, NIL, QG, KK, MI, ...","[M, UPD, YQ, AG, SAB, M, TV, NHL, QG, KK, MI, ...","[M, SPI, YQ, AG, SLI, MSI, NIL, QG, KKLI, ESGI...","[M, UPD, YQ, AG, SAB, MHG, NGL, QG, KKCI, ESGB..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,"[C, IQ, PG, EA, QP, NV, D, KL, V, ED, HL, AV, ...","[C, IQ, PG, EA, QP, NV, D, KL, V, ED, HL, AV, ...","[C, IQ, PG, EA, QP, NV, DSI, V, ED, HL, AV, Q,...","[C, IQ, PG, EA, QP, NV, DUD, V, ED, HL, AV, Q,...","[C, IQ, PG, EA, QP, NV, DKLI, EDQI, AV, ESMV, ...","[C, IQ, PG, EA, QP, NV, DKLU, EDDA, AV, BSHG, ...","[C, IQ, PG, EA, Q, PTV, EKM, V, ED, HL, AV, Q,...","[C, IQ, PG, EA, Q, PUV, DKA, V, ED, HL, AV, Q,...","[C, IQ, PG, ETQ, PTV, EKM, V, ED, HL, TVQ, TLL...","[C, IQ, PG, EUQ, PDV, AKB, V, ED, HL, HVQ, GLG...","[CMQ, PG, ETQ, PTV, EKM, V, ED, HL, TVQ, TLL, ...","[CUQ, PG, EDQ, PBV, HKG, V, ED, HL, CVQ, BLD, ...","[CI, APG, EA, QP, NV, D, KL, V, ED, HL, AV, Q,...","[CI, UPG, EA, QP, NV, D, KL, V, ED, HL, AV, Q,...","[CI, APG, EA, QP, NV, DSI, V, ED, HL, AV, Q, S...","[CI, UPG, EA, QP, NV, DDA, V, ED, HL, AV, Q, S...","[CI, APG, EA, QP, NV, DKLI, ED, HIAA, ESMV, RA...","[CI, UPG, EA, QP, NV, DKLD, ED, HAAB, HSGG, RA..."
96,"[M, AL, T, QV, RL, TF, R, DV, AI, EF, SQ, EE, ...","[M, AL, T, QV, RL, TF, R, DV, AI, EF, SQ, EE, ...","[MAI, SQI, RL, TF, RDI, AI, EF, SQ, EDY, SCI, ...","[MAU, DQA, RL, TF, RDB, AI, EF, SQ, EHG, GCC, ...","[MAI, SQI, RL, TF, RDI, AI, EF, SQ, EDY, SCI, ...","[MAU, DQA, RL, TF, RDB, AI, EF, SQ, EHG, GCC, ...","[M, AL, T, QV, RL, TF, R, DV, AI, EF, SQ, EE, ...","[M, AL, T, QV, RL, TF, R, DV, AI, EF, SQ, EE, ...","[MAM, AQV, RL, TF, RNL, AI, EF, SQ, EE, WK, CL...","[MAU, DQV, RL, TF, RAB, AI, EF, SQ, EE, WK, CL...","[MAM, AQV, RL, TF, RNL, AI, EF, SQED, WK, CL, ...","[MAU, DQV, RL, TF, RAB, AI, EF, SQEH, WK, CL, ...","[M, AL, T, QV, RL, T, FR, DV, AI, EF, SQ, EE, ...","[M, AL, T, QV, RL, T, FR, DV, AI, EF, SQ, EE, ...","[MAI, SQI, RL, T, FR, DV, AI, EF, SQ, EE, W, S...","[MAU, DQA, RL, T, FR, DV, AI, EF, SQ, EE, W, B...","[MAI, SQI, RL, TFE, DVAV, EF, SQ, EDW, SCI, AP...","[MAU, DQA, RL, TFB, DVAH, EF, SQ, EGW, GCC, BP..."
97,"[M, C, FL, RR, PAC, PA, SW, I, W, W, RLI, RQ, ...","[M, C, FL, RR, PUD, PA, SW, I, W, W, RAB, RQ, ...","[M, CLI, RR, PAC, PA, SWM, WW, RLI, RQ, VL, RR...","[M, CUD, RR, PAB, PA, SWH, WW, RGG, RQ, VL, RR...","[M, CLI, RR, PAC, PA, SWM, WW, RLI, REVL, RRSI...","[M, CUD, RR, PAB, PA, SWH, WW, RGG, RCVL, RRBD...","[M, C, FL, RR, PAT, PA, SW, I, W, W, RLM, RQ, ...","[M, C, FL, RR, PUD, PA, SW, I, W, W, RAB, RQ, ...","[MC, FL, RR, PAT, PA, TWL, WW, RLM, RQ, VL, RR...","[MC, FL, RR, PUD, PA, AWB, WW, RHG, RQ, VL, RR...","[MC, FL, RR, PAT, PA, TWL, WW, RLM, RQIL, RR, ...","[MC, FL, RR, PUD, PA, AWB, WW, RHG, RQGL, RR, ...","[M, C, FL, RR, PAC, PA, SW, I, W, W, RLI, RQ, ...","[M, C, FL, RR, PUD, PA, SW, I, W, W, RAB, RQ, ...","[M, CLI, RR, PAC, PA, SWM, WW, RLI, RQ, VL, RR...","[M, CUD, RR, PAB, PA, SWH, WW, RGG, RQ, VL, RR...","[M, CLI, RR, PAC, PA, SWM, WW, RLI, RQIL, RRSI...","[M, CUD, RR, PAB, PA, SWH, WW, RGG, RQCL, RRBD..."
98,"[M, RLI, PR, LL, LL, LLI, VF, PAC, VL, FR, GG,...","[M, RLU, PR, LL, LL, LLD, VF, PAB, VL, FR, GG,...","[M, RLI, PR, LLLI, LLI, VF, PAC, VL, FR, GG, P...","[M, RLU, PR, LLLD, LLA, VF, PAB, VL, FR, GG, P...","[MKIL, PR, LLLI, LLI, VF, PAC, VL, FR, GG, PDG...","[MUDL, PR, LLLA, LLB, VF, PAH, VL, FR, GG, PGG...","[M, RLM, PR, LL, LL, LLM, VF, PA, SVM, FR, GG,...","[M, RLU, PR, LL, LL, LLD, VF, PA, AVB, FR, GG,...","[M, RLM, PR, LLLM, LLM, VF, PA, SVM, FR, GG, P...","[M, RLU, PR, LLLD, LLA, VF, PA, BVH, FR, GG, P...","[IRLL, PR, LLLM, LLM, VF, PA, SVM, FR, GG, PRA...","[URLL, PR, LLLD, LLA, VF, PA, BVH, FR, GG, PRC...","[M, RLI, PR, LL, LL, LLI, VF, PAC, VL, FR, GG,...","[M, RLU, PR, LL, LL, LLD, VF, PAB, VL, FR, GG,...","[M, RLI, PR, LLLI, LLI, VF, PAC, VL, FR, GG, P...","[M, RLU, PR, LLLD, LLA, VF, PAB, VL, FR, GG, P...","[MKIL, PR, LLLI, LLI, VF, PAC, VL, FR, GG, PDG...","[MUDL, PR, LLLA, LLB, VF, PAH, VL, FR, GG, PGG..."


In [261]:
import pandas as pd

# Example DataFrame with lists of lists
data = {
    'A': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
    'B': [[[9, 10], [11, 12]], [[13, 14], [15, 16]]]
}

df = pd.DataFrame(data)

# Exploding each column
exploded_df = pd.concat([df[col].apply(pd.Series).stack().reset_index(drop=True) for col in df], axis=1)
exploded_df.columns = df.columns

exploded_df


Unnamed: 0,A,B
0,"[1, 2]","[9, 10]"
1,"[3, 4]","[11, 12]"
2,"[5, 6]","[13, 14]"
3,"[7, 8]","[15, 16]"


In [264]:
import pandas as pd

# Example DataFrame with lists of lists
data = {
    'A': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
    'B': [[[9, 10], [11, 12]], [[13, 14], [15, 16]]]
}

df = pd.DataFrame(data)

# Exploding each column into new columns
result = pd.concat(
    [df[col].apply(pd.Series).add_prefix(f"{col}_") for col in df.columns],
    axis=1
)

result


Unnamed: 0,A_0,A_1,B_0,B_1
0,"[1, 2]","[3, 4]","[9, 10]","[11, 12]"
1,"[5, 6]","[7, 8]","[13, 14]","[15, 16]"


In [259]:
pd.DataFrame.from_dict(list(aa))['mutBPE blosum62 0.7 0.05 800'].explode()

0     [M, EAI, SFI, KL, EV, NG, P, MV, TV, AL, SV, S...
0     [M, EAU, DFA, KL, EV, NG, P, MV, TV, AL, SV, B...
1     [M, AA, AA, PAC, SAS, SS, EA, PAC, SA, TA, EP,...
1     [M, AA, AA, PAU, DAA, SS, EA, PAB, SA, TA, EP,...
2     [M, EEI, SD, EEM, D, HG, SED, D, SD, K, ED, Q,...
                            ...                        
97    [M, C, FL, RR, PUD, PA, SW, I, W, W, RAB, RQ, ...
98    [M, RLI, PR, LL, LL, LLI, VF, PAC, VL, FR, GG,...
98    [M, RLU, PR, LL, LL, LLD, VF, PAB, VL, FR, GG,...
99    [M, SRM, NK, NV, VL, SLI, TL, T, SAS, F, LL, F...
99    [M, SRU, NK, NV, VL, DLA, TL, T, SBH, F, LL, F...
Name: mutBPE blosum62 0.7 0.05 800, Length: 200, dtype: object

In [185]:
tokenizer_name = 'mutBPE blosum62 0.7 0.05 3200'
token_set = df_protein_main[tokenizer_name].iloc[0]
for t1, t2, t3 in zip(token_set, *generate_mutated_alternative_token_set(token_set, tokenizer_name)):
    if t1 == t2 == t3:
        pass
    else:
        print(t1, t2, t3, 'p' if t1 in inner_vocab_parents_list[tokenizer_name] else 'm')

EAL EAI EAS p
GFL SFI NFT m
PMV PVI PZY m
ALL SLI ILF m
ALL SLI FLD m
YST YAT YWT p
SAF AAY PAP m
SRL SRI SRA p
EKL EEI ENX m
PSP PAP PVP p
FIG FMG FRG m
NLT NLI NLN m
RQG RKG REG m
SQM SQI SQR m
RKL KKI LKB m
CGL CSI CIT m


In [186]:
inner_vocab_mutated_list[tokenizer_name]['RKL']

{'frequency': 4042,
 'order': 380,
 'pair': ['R', 'KL'],
 'parent': 'KKL',
 'similarity': 0.7857142857142857}

In [187]:
inner_vocab_parents_list[tokenizer_name]['KKL']

{'frequency': 7501,
 'order': 375,
 'pair': ['KK', 'L'],
 'is_parent': True,
 'mutations': ['KKI', 'KKM', 'KKV', 'KRL', 'RKL', 'KKF', 'KQL', 'QKL']}

In [81]:
aa = 0
bb = 0
for tok in seq:
    if tok in inner_vocab_parents_list[method]:
        aa += 1
    if tok in inner_vocab_mutated_list[method]:
        bb += 1
aa, bb

(3, 5)