Import dependencies

In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
from collections import defaultdict, Counter
import more_itertools
import heapq
from Bio import Align
from Bio.Align import substitution_matrices
import uuid
import json
from sklearn.model_selection import train_test_split
import time

Load sequence data

In [2]:
with open('uniref_taxonomy_id_9606_AND_identity_2024_09_13.json') as f:
# with open('../../RSRC/ECCB/uniref_taxonomy_id_9606_AND_identity_2024_09_13.json') as f:
    human_proteins_json = json.load(f)['results']
    
human_proteins_df = []
for prot in human_proteins_json:
    human_proteins_df.append({'id': prot['id'], 'sequence': prot['representativeMember']['sequence']['value']})
human_proteins_df = pd.DataFrame(human_proteins_df)
human_proteins_df = human_proteins_df[~human_proteins_df['sequence'].str.contains('U')]

df_ds_train, df_ds_test = train_test_split(human_proteins_df, test_size=0.2, random_state=42)

corpus = df_ds_train['sequence']

Extract the initial symbols

In [3]:
alphabet = []
for seq in corpus:
    for letter in seq:
        if letter not in alphabet:
            alphabet.append(letter)
alphabet.sort()
print(alphabet)

['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'X', 'Y']


Import custom classes for efficient BPE

In [4]:
from helper_classes import Sym, SymList, SymPair, MaxHeapMap

Convert every protein sequence into a SymList (a doubly linked list containing tokens as nodes)

In [5]:
sequences = []
for seq in corpus:
    if len(seq) == 0: continue
    symlist = SymList()
    for sym_str in seq:
        symlist.append(Sym(sym_str))
    sequences.append(symlist)


In [6]:
print(sequences[0])

MEVLRRSSVFAAEIMDAFDRCGDAADGLMSSSVWSAQTLASAPTGWWLHSAASAAS


Construct the initial MaxHeapMap

This data structure maintains a max heap of SymbolPairs with respect to their occurence counts.
- A SymbolPair is basically a struct that keeps track of every occurence of a token pair in the corpus

In [7]:
# Generate the data structure for bookkeeping of all symbol pairs found in the data.

# Maps ("A", "B") -> SymPair("A", "B")
pair_database = {}

def add_entry(db, sym1, sym2):
    sym1_str = sym1.literal
    sym2_str = sym2.literal
    curr_pair = db.get((sym1_str, sym2_str), None)
    if curr_pair is None:
        curr_pair = SymPair(sym1_str, sym2_str)
        db[(sym1_str, sym2_str)] = curr_pair
    curr_pair.add_pos((sym1, sym2))


for seq in sequences:
    for sym1, sym2 in more_itertools.pairwise(seq):
        add_entry(pair_database, sym1, sym2)


merge_heap = MaxHeapMap()
for pair in pair_database.values():
    merge_heap.push(pair)

del pair_database

str(merge_heap.peek())

'Pair: (L, L), Count: 153309'

In [8]:
[str(elem) for elem in merge_heap.heap[:10]]

['Pair: (L, L), Count: 153309',
 'Pair: (S, S), Count: 135477',
 'Pair: (L, S), Count: 111026',
 'Pair: (L, A), Count: 93549',
 'Pair: (S, L), Count: 115163',
 'Pair: (A, L), Count: 99568',
 'Pair: (E, E), Count: 107112',
 'Pair: (S, G), Count: 82440',
 'Pair: (G, G), Count: 73809',
 'Pair: (A, A), Count: 98339']

Define the merging function.

This function is the core of the algorithm.

In [20]:
def merge_pair(sym_pair):
    # print(f"merging: {str(sym_pair)}")
    add_database = {}
    remove_database = {}
    for sym1, sym2 in sym_pair.positions:
        
        # Due to lazy removing, check if we are trying to merge an already removed pair
        if sym1.next.literal != sym_pair.right: continue
        if sym2.prev.literal != sym_pair.left: continue

        merged_sym = Sym(sym1.literal + sym2.literal)
        merged_sym.prev = sym1.prev
        merged_sym.next = sym2.next
        if sym1.prev is not None:
            pre = sym1.prev
            pre.next = merged_sym
            # print(f"Adding new entries to databases", pre, merged_sym)
            # print(f"Adding new entries to databases", pre, sym1)
            add_entry(add_database, pre, merged_sym)
            # Don't attempt to remove currently processed merge pair
            if (pre.literal, sym1.literal) != (sym_pair.left, sym_pair.right): add_entry(remove_database, pre, sym1)
        if sym2.next is not None:
            nex = sym2.next
            nex.prev = merged_sym
            # print(f"Adding new entries to databases", pre, merged_sym)
            # print(f"Adding new entries to databases", pre, sym1)
            add_entry(add_database, merged_sym, nex)
             # Don't attempt to remove currently processed merge pair
            if (sym2.literal, nex.literal) != (sym_pair.left, sym_pair.right): add_entry(remove_database, sym2, nex)
    # print(f"Sizes of the new databases add and remove: {len(add_database)}, {len(remove_database)}")
    for val in add_database.values():
        # print(f"adding: {str(val)}")
        merge_heap.push(val)
    for r_val in remove_database.values():
        # print(f"trying to remove: {str(r_val)}")
        inner_val = merge_heap.remove_by_value(r_val)
        inner_val.count -= r_val.count
        merge_heap.push(inner_val)



Add words by vocabulary size

In [10]:
# # %%prun -s cumulative
# init_vocab = alphabet.copy()

# # Add 10000 words
# for i in range(100):
#     best_pair = merge_heap.pop()
    
#     merge_pair(best_pair)
#     init_vocab.append(best_pair.merged())

# print(init_vocab)

Add words by cutoff frequency threshold

In [11]:
# # %%prun -s cumulative
# init_vocab = alphabet.copy()

# # Add 10000 words
# best_pair = merge_heap.pop()
# max_count = best_pair.count
# while best_pair.count > max_count / 1000:
#     merge_pair(best_pair)
#     init_vocab.append(best_pair.merged())
#     best_pair = merge_heap.pop()

# print(init_vocab)

In [12]:
# [print(str(k)) for k in merge_heap.map]
# print(len(init_vocab))

# cnts = Counter(init_vocab)

## Code blocks related to addition of mutated sequences

In [13]:
from itertools import product

# Generates a sorted list of mutations (above the cutoff similarity) for the given sequence.
# Each resulting element is a tuple (m_sequence, score), where:
# - m_sequence: the mutated sequence as string
# - score: similarity score to the original sequence, a number between 0 and 1
# Results are sorted from highest to lowest score.
# First element of the results is the original sequence
# Mutations are based only on substitutions with non-negative matrix values.
def generate_mutations(seq, matrix, cutoff):
    # Hard cutoff for sequences longer than 10
    if len(seq) > 10:
        return []
    alp = matrix.alphabet
    candidates = []

    max_score = 0
    for aa in seq:
        max_score += matrix[aa][aa]

    # Create mutation candidates for each symbol in the sequence
    for i, aa in enumerate(seq):
        candidates.append([])
        # Ignore X from calculation
        if aa == "X":
            candidates[i].append((aa, 0.0))
            continue
        # Consider substitutions with non-negative scores
        for c_aa in alp:
            score = matrix[aa][c_aa]
            if score >= -1e-4: # for floating point precision
                similarity_loss = (matrix[aa][aa] - score)/max_score
                # if the similarity loss from this particular aminoacid is large enough
                # to go under the cutoff, don't even consider it
                if similarity_loss < 1 - cutoff:
                    candidates[i].append((c_aa, similarity_loss))

    # Generate all possible mutated sequences and their scores
    combinations = product(*candidates)
    final_mutations = []
    for combination in combinations:
        m_sequence = ""
        cumulative_similarity_loss = 0
        for c_aa, similarity_loss in combination:
            m_sequence += c_aa
            cumulative_similarity_loss += similarity_loss
            # Stop if enough dissimilarity is accumulated to
            # go under cutoff threshold
            if 1 - cumulative_similarity_loss < cutoff:
                break
        # Normalize the total score by max_score
        if 1 - cumulative_similarity_loss > cutoff:
            final_mutations.append((m_sequence, round(1 - cumulative_similarity_loss, 3)))
    
    # Sort mutations by score in descending order
    final_mutations.sort(key=lambda x: x[1], reverse=True)

    return final_mutations


Example usage of generate_mutations

In [14]:
from Bio.Align import substitution_matrices

blosum62 = substitution_matrices.load("BLOSUM62")
pam250 = substitution_matrices.load("PAM250")

print(generate_mutations("SDSXXXXXXXXXXXXX", blosum62, 0.8))
print(generate_mutations("LLL", pam250, 0.8))



[]
[('LLL', 1.0), ('LLM', 0.889), ('LML', 0.889), ('MLL', 0.889)]


An example code block for also adding the mutated sequences during the training of BPE tokenizer

Currently, it adds **all** the mutations that are generated from positive values in the substitution matrix.

In [None]:
# %%prun -s cumulative
init_vocab = alphabet.copy()

# Add 100 words + their mutations
for i in range(1000):
    best_pair = merge_heap.pop()
    merge_pair(best_pair)
    merged_string = best_pair.merged()
    init_vocab.append(merged_string)

    # For the mutations:
    
    # Consider only the mutations with a similarity score larger than 0.8
    # [1:] ignores the original string
    mutations = generate_mutations(merged_string, blosum62, 0.8)[1:]
    for mutated_str, score in mutations:
        pairs_to_merge = merge_heap.merged_to_pair.get(mutated_str, [])
        if len(pairs_to_merge) > 0:
            init_vocab.append(mutated_str)
        for pair in pairs_to_merge:
            merge_heap.remove_by_value(pair)
            merge_pair(pair)

print(init_vocab)

['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'X', 'Y', 'LL', 'SS', 'EE', 'AA', 'SL', 'PP', 'AL', 'VL', 'IL', 'GL', 'SP', 'GG', 'EL', 'RL', 'TL', 'EK', 'SG', 'RR', 'DL', 'KK', 'QL', 'PL', 'PI', 'PM', 'SA', 'PG', 'SV', 'SI', 'ST', 'EA', 'SR', 'FL', 'KL', 'ED', 'PA', 'SQ', 'EV', 'EI', 'EG', 'NL', 'TV', 'TI', 'AV', 'AI', 'SD', 'SK', 'TG', 'ER', 'PV', 'AG', 'HL', 'HI', 'HM', 'PR', 'QQ', 'SF', 'EN', 'RG', 'TT', 'VV', 'IV', 'VI', 'DG', 'EQ', 'RK', 'YL', 'YI', 'YM', 'PQ', 'TA', 'SN', 'CL', 'CI', 'CM', 'KG', 'DV', 'DI', 'ET', 'KA', 'KV', 'KI', 'RA', 'RV', 'RI', 'QG', 'QA', 'PT', 'PD', 'SH', 'QV', 'QI', 'FG', 'CG', 'KD', 'NG', 'SE', 'ML', 'RD', 'FV', 'FI', 'SY', 'KT', 'SC', 'NI', 'NV', 'RT', 'PE', 'PK', 'QD', 'AT', 'QK', 'HG', 'QR', 'YG', 'PF', 'DD', 'SM', 'VG', 'IG', 'FT', 'WL', 'WI', 'WM', 'QT', 'KN', 'AD', 'PN', 'AR', 'KR', 'PY', 'EM', 'EF', 'HV', 'QN', 'TD', 'LLL', 'ILL', 'LIL', 'LLI', 'LLM', 'LML', 'MLL', 'YV', 'RN', 'SW', 'PH', 'CV', 'AF',

         64873938 function calls (63666976 primitive calls) in 212.790 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000  212.790  212.790 {built-in method builtins.exec}
        1    0.209    0.209  212.789  212.789 <string>:4(<module>)
     1000  196.705    0.197  196.968    0.197 4011113511.py:10(generate_mutations)
     4626    2.129    0.000   15.510    0.003 1314981216.py:1(merge_pair)
   655415    1.428    0.000    8.075    0.000 helper_classes.py:116(remove_by_value)
1863367/656415    1.576    0.000    3.547    0.000 helper_classes.py:142(_heapify_down)
  3486659    1.952    0.000    3.067    0.000 helper_classes.py:75(_swap)
  1304318    1.376    0.000    3.065    0.000 helper_classes.py:80(push)
  1959733    1.080    0.000    2.997    0.000 helper_classes.py:133(_heapify_up)
  2330998    0.880    0.000    2.125    0.000 598563906.py:6(add_entry)
 10244881    1.243    0.000    1.611    0

In [24]:
# For debugging purposes, to see if any duplicates are added
print(len(init_vocab))

cnts = Counter(init_vocab)
cnts

# import json

# with open("test.json", "w") as f:
#     json.dump(list(cnts.keys()), f, indent=4)

7282


Counter({'A': 1,
         'C': 1,
         'D': 1,
         'E': 1,
         'F': 1,
         'G': 1,
         'H': 1,
         'I': 1,
         'K': 1,
         'L': 1,
         'M': 1,
         'N': 1,
         'P': 1,
         'Q': 1,
         'R': 1,
         'S': 1,
         'T': 1,
         'V': 1,
         'W': 1,
         'X': 1,
         'Y': 1,
         'LL': 1,
         'SS': 1,
         'EE': 1,
         'AA': 1,
         'SL': 1,
         'PP': 1,
         'AL': 1,
         'VL': 1,
         'IL': 1,
         'GL': 1,
         'SP': 1,
         'GG': 1,
         'EL': 1,
         'RL': 1,
         'TL': 1,
         'EK': 1,
         'SG': 1,
         'RR': 1,
         'DL': 1,
         'KK': 1,
         'QL': 1,
         'PL': 1,
         'PI': 1,
         'PM': 1,
         'SA': 1,
         'PG': 1,
         'SV': 1,
         'SI': 1,
         'ST': 1,
         'EA': 1,
         'SR': 1,
         'FL': 1,
         'KL': 1,
         'ED': 1,
         'PA': 1,
         'SQ'