Import dependencies

In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
from collections import defaultdict, Counter
import more_itertools
import heapq
from Bio import Align
from Bio.Align import substitution_matrices
import uuid
import json
from sklearn.model_selection import train_test_split
import time

Load sequence data

In [2]:
with open('uniref_taxonomy_id_9606_AND_identity_2024_09_13.json') as f:
# with open('../../RSRC/ECCB/uniref_taxonomy_id_9606_AND_identity_2024_09_13.json') as f:
    human_proteins_json = json.load(f)['results']
    
human_proteins_df = []
for prot in human_proteins_json:
    human_proteins_df.append({'id': prot['id'], 'sequence': prot['representativeMember']['sequence']['value']})
human_proteins_df = pd.DataFrame(human_proteins_df)
human_proteins_df = human_proteins_df[~human_proteins_df['sequence'].str.contains('U')]

df_ds_train, df_ds_test = train_test_split(human_proteins_df, test_size=0.2, random_state=42)

corpus = df_ds_train['sequence']

Extract the initial symbols

In [3]:
alphabet = []
for seq in corpus:
    for letter in seq:
        if letter not in alphabet:
            alphabet.append(letter)
alphabet.sort()
print(alphabet)

['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'X', 'Y']


Import custom classes for efficient BPE

In [4]:
from helper_classes import Sym, SymList, SymPair, MaxHeapMap

Convert every protein sequence into a SymList (a doubly linked list containing tokens as nodes)

In [5]:
sequences = []
for seq in corpus:
    if len(seq) == 0: continue
    symlist = SymList()
    for sym_str in seq:
        symlist.append(Sym(sym_str))
    sequences.append(symlist)


In [6]:
print(sequences[0])

MEVLRRSSVFAAEIMDAFDRCGDAADGLMSSSVWSAQTLASAPTGWWLHSAASAAS


Construct the initial MaxHeapMap

This data structure maintains a max heap of SymbolPairs with respect to their occurence counts.
- A SymbolPair is basically a struct that keeps track of every occurence of a token pair in the corpus

In [7]:
# Generate the data structure for bookkeeping of all symbol pairs found in the data.

# Maps ("A", "B") -> SymPair("A", "B")
pair_database = {}

def add_entry(db, sym1, sym2):
    sym1_str = sym1.literal
    sym2_str = sym2.literal
    curr_pair = db.get((sym1_str, sym2_str), None)
    if curr_pair is None:
        curr_pair = SymPair(sym1_str, sym2_str)
        db[(sym1_str, sym2_str)] = curr_pair
    curr_pair.add_pos((sym1, sym2))


for seq in sequences:
    for sym1, sym2 in more_itertools.pairwise(seq):
        add_entry(pair_database, sym1, sym2)


merge_heap = MaxHeapMap()
for pair in pair_database.values():
    merge_heap.push(pair)

del pair_database

str(merge_heap.peek())

'Pair: (L, L), Count: 153309'

In [8]:
[str(elem) for elem in merge_heap.heap[:10]]

['Pair: (L, L), Count: 153309',
 'Pair: (S, S), Count: 135477',
 'Pair: (L, S), Count: 111026',
 'Pair: (L, A), Count: 93549',
 'Pair: (S, L), Count: 115163',
 'Pair: (A, L), Count: 99568',
 'Pair: (E, E), Count: 107112',
 'Pair: (S, G), Count: 82440',
 'Pair: (G, G), Count: 73809',
 'Pair: (A, A), Count: 98339']

Define the merging function.

This function is the core of the algorithm.

In [None]:
def merge_pair(sym_pair):
    # print(f"merging: {str(sym_pair)}")
    add_database = {}
    remove_database = {}
    for sym1, sym2 in sym_pair.positions:
        
        # Due to lazy removing, check if we are trying to merge an already removed pair
        if not sym1.next.literal == sym_pair.right: continue
        if not sym2.prev.literal == sym_pair.left: continue

        merged_sym = Sym(sym1.literal + sym2.literal)
        merged_sym.prev = sym1.prev
        merged_sym.next = sym2.next
        if sym1.prev is not None:
            pre = sym1.prev
            pre.next = merged_sym
            # print(f"Adding new entries to databases", pre, merged_sym)
            # print(f"Adding new entries to databases", pre, sym1)
            add_entry(add_database, pre, merged_sym)
            # Don't attempt to remove currently processed merge pair
            if not SymPair(pre.literal, sym1.literal) == sym_pair: add_entry(remove_database, pre, sym1)
        if sym2.next is not None:
            nex = sym2.next
            nex.prev = merged_sym
            # print(f"Adding new entries to databases", pre, merged_sym)
            # print(f"Adding new entries to databases", pre, sym1)
            add_entry(add_database, merged_sym, nex)
             # Don't attempt to remove currently processed merge pair
            if not SymPair(sym2.literal, nex.literal) == sym_pair: add_entry(remove_database, sym2, nex)
    # print(f"Sizes of the new databases add and remove: {len(add_database)}, {len(remove_database)}")
    for val in add_database.values():
        # print(f"adding: {str(val)}")
        merge_heap.push(val)
    for r_val in remove_database.values():
        # print(f"trying to remove: {str(r_val)}")
        inner_val = merge_heap.remove_by_value(r_val)
        inner_val.count -= r_val.count
        merge_heap.push(inner_val)



In [10]:
%%prun -s cumulative
init_vocab = alphabet.copy()

# Add 10000 words
for i in range(10000):
    best_pair = merge_heap.pop()
    merge_pair(best_pair)
    init_vocab.append(best_pair.left + best_pair.right)

print(init_vocab)

Sizes of the new databases add and remove: 42, 41
Sizes of the new databases add and remove: 44, 43
Sizes of the new databases add and remove: 46, 45
Sizes of the new databases add and remove: 48, 47
Sizes of the new databases add and remove: 47, 46
Sizes of the new databases add and remove: 52, 51
Sizes of the new databases add and remove: 52, 51
Sizes of the new databases add and remove: 55, 54
Sizes of the new databases add and remove: 57, 56
Sizes of the new databases add and remove: 58, 57
Sizes of the new databases add and remove: 61, 60
Sizes of the new databases add and remove: 62, 61
Sizes of the new databases add and remove: 65, 64
Sizes of the new databases add and remove: 67, 66
Sizes of the new databases add and remove: 70, 69
Sizes of the new databases add and remove: 68, 67
Sizes of the new databases add and remove: 73, 72
Sizes of the new databases add and remove: 75, 74
Sizes of the new databases add and remove: 77, 76
Sizes of the new databases add and remove: 79, 78


         322035974 function calls (318786781 primitive calls) in 194.244 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000  194.244  194.244 {built-in method builtins.exec}
        1    2.263    2.263  194.244  194.244 <string>:1(<module>)
    10000   22.832    0.002  191.591    0.019 2141618480.py:1(merge_pair)
 22849909   84.253    0.000   84.398    0.000 helper_classes.py:25(__eq__)
 20838599   37.013    0.000   37.085    0.000 helper_classes.py:15(__init__)
  2013876    3.147    0.000   21.818    0.000 helper_classes.py:106(remove_by_value)
 33495518    9.215    0.000   19.288    0.000 598563906.py:6(add_entry)
5273068/2023875    4.485    0.000   10.232    0.000 helper_classes.py:127(_heapify_down)
  9545482    5.483    0.000    8.607    0.000 helper_classes.py:75(_swap)
  6044193    3.123    0.000    8.485    0.000 helper_classes.py:118(_heapify_up)
  4030318    2.593    0.000    7.036    0.

In [11]:
# [print(str(k)) for k in merge_heap.map]