Import dependencies

In [16]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
from collections import defaultdict, Counter
import more_itertools
import heapq
from Bio import Align
from Bio.Align import substitution_matrices
import uuid
import json
from sklearn.model_selection import train_test_split
import time

Load sequence data

In [17]:
with open('uniref_taxonomy_id_9606_AND_identity_2024_09_13.json') as f:
# with open('../../RSRC/ECCB/uniref_taxonomy_id_9606_AND_identity_2024_09_13.json') as f:
    human_proteins_json = json.load(f)['results']
    
human_proteins_df = []
for prot in human_proteins_json:
    human_proteins_df.append({'id': prot['id'], 'sequence': prot['representativeMember']['sequence']['value']})
human_proteins_df = pd.DataFrame(human_proteins_df)
human_proteins_df = human_proteins_df[~human_proteins_df['sequence'].str.contains('U')]

df_ds_train, df_ds_test = train_test_split(human_proteins_df, test_size=0.2, random_state=42)

corpus = df_ds_train['sequence']

Extract the initial symbols

In [18]:
alphabet = []
for seq in corpus:
    for letter in seq:
        if letter not in alphabet:
            alphabet.append(letter)
alphabet.sort()
print(alphabet)

['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'X', 'Y']


Import custom classes for efficient BPE

In [19]:
from helper_classes import Sym, SymList, SymPair, MaxHeapMap

Convert every protein sequence into a SymList (a doubly linked list containing tokens as nodes)

In [20]:
sequences = []
for seq in corpus:
    if len(seq) == 0: continue
    symlist = SymList()
    for sym_str in seq:
        symlist.append(Sym(sym_str))
    sequences.append(symlist)


In [21]:
print(sequences[0])

MEVLRRSSVFAAEIMDAFDRCGDAADGLMSSSVWSAQTLASAPTGWWLHSAASAAS


Construct the initial MaxHeapMap

This data structure maintains a max heap of SymbolPairs with respect to their occurence counts.
- A SymbolPair is basically a struct that keeps track of every occurence of a token pair in the corpus

In [22]:
# Generate the data structure for bookkeeping of all symbol pairs found in the data.

# Maps ("A", "B") -> SymPair("A", "B")
pair_database = {}

def add_entry(db, sym1, sym2):
    sym1_str = sym1.literal
    sym2_str = sym2.literal
    curr_pair = db.get((sym1_str, sym2_str), None)
    if curr_pair is None:
        curr_pair = SymPair(sym1_str, sym2_str)
        db[(sym1_str, sym2_str)] = curr_pair
    curr_pair.add_pos((sym1, sym2))


for seq in sequences:
    for sym1, sym2 in more_itertools.pairwise(seq):
        add_entry(pair_database, sym1, sym2)


merge_heap = MaxHeapMap()
for pair in pair_database.values():
    merge_heap.push(pair)

del pair_database

str(merge_heap.peek())

'Pair: (L, L), Count: 153309'

In [23]:
[str(elem) for elem in merge_heap.heap[:10]]

['Pair: (L, L), Count: 153309',
 'Pair: (S, S), Count: 135477',
 'Pair: (L, S), Count: 111026',
 'Pair: (L, A), Count: 93549',
 'Pair: (S, L), Count: 115163',
 'Pair: (A, L), Count: 99568',
 'Pair: (E, E), Count: 107112',
 'Pair: (S, G), Count: 82440',
 'Pair: (G, G), Count: 73809',
 'Pair: (A, A), Count: 98339']

Define the merging function.

This function is the core of the algorithm.

In [24]:
def merge_pair(sym_pair):
    # print(f"merging: {str(sym_pair)}")
    add_database = {}
    remove_database = {}
    for sym1, sym2 in sym_pair.positions:
        
        # Due to lazy removing, check if we are trying to merge an already removed pair
        if sym1.next.literal != sym_pair.right: continue
        if sym2.prev.literal != sym_pair.left: continue

        merged_sym = Sym(sym1.literal + sym2.literal)
        merged_sym.prev = sym1.prev
        merged_sym.next = sym2.next
        if sym1.prev is not None:
            pre = sym1.prev
            pre.next = merged_sym
            # print(f"Adding new entries to databases", pre, merged_sym)
            # print(f"Adding new entries to databases", pre, sym1)
            add_entry(add_database, pre, merged_sym)
            # Don't attempt to remove currently processed merge pair
            if SymPair(pre.literal, sym1.literal) != sym_pair: add_entry(remove_database, pre, sym1)
        if sym2.next is not None:
            nex = sym2.next
            nex.prev = merged_sym
            # print(f"Adding new entries to databases", pre, merged_sym)
            # print(f"Adding new entries to databases", pre, sym1)
            add_entry(add_database, merged_sym, nex)
             # Don't attempt to remove currently processed merge pair
            if SymPair(sym2.literal, nex.literal) != sym_pair: add_entry(remove_database, sym2, nex)
    # print(f"Sizes of the new databases add and remove: {len(add_database)}, {len(remove_database)}")
    for val in add_database.values():
        # print(f"adding: {str(val)}")
        merge_heap.push(val)
    for r_val in remove_database.values():
        # print(f"trying to remove: {str(r_val)}")
        inner_val = merge_heap.remove_by_value(r_val)
        inner_val.count -= r_val.count
        merge_heap.push(inner_val)



Add words by vocabulary size

In [25]:
# # %%prun -s cumulative
# init_vocab = alphabet.copy()

# # Add 10000 words
# for i in range(100):
#     best_pair = merge_heap.pop()
    
#     merge_pair(best_pair)
#     init_vocab.append(best_pair.left + best_pair.right)

# print(init_vocab)

Add words by cutoff frequency threshold

In [26]:
# # %%prun -s cumulative
# init_vocab = alphabet.copy()

# # Add 10000 words
# best_pair = merge_heap.pop()
# max_count = best_pair.count
# while best_pair.count > max_count / 1000:
#     merge_pair(best_pair)
#     init_vocab.append(best_pair.left + best_pair.right)
#     best_pair = merge_heap.pop()

# print(init_vocab)

In [27]:
# [print(str(k)) for k in merge_heap.map]
print(len(init_vocab))

cnts = Counter(init_vocab)

3815


## Code blocks related to addition of mutated sequences

In [28]:
from itertools import product

# Generates a sorted list of mutations for the given sequence.
# Each resulting element is a tuple (m_sequence, score), where:
# - m_sequence: the mutated sequence
# - score: the alignment score with the original sequence
# Results are sorted from highest to lowest score.
# Mutations are based only on substitutions with positive matrix values.
def generate_mutations(seq, matrix):
    alp = matrix.alphabet
    candidates = []

    # Create mutation candidates for each symbol in the sequence
    for i, aa in enumerate(seq):
        candidates.append([])
        # Consider substitutions with positive scores
        for c_aa in alp:
            score = matrix[aa][c_aa] - matrix[aa][aa]
            if score > -matrix[aa][aa]:
                candidates[i].append((c_aa, score))

    # Generate all possible mutated sequences and their scores
    combinations = product(*candidates)
    final_mutations = []
    for combination in combinations:
        m_sequence = ""
        total_score = 0
        for c_aa, score in combination:
            m_sequence += c_aa
            total_score += score
        final_mutations.append((m_sequence, total_score))
    
    # Sort mutations by score in descending order
    final_mutations.sort(key=lambda x: x[1], reverse=True)

    return final_mutations


Example usage of generate_mutations

In [29]:
from Bio.Align import substitution_matrices

blosum62 = substitution_matrices.load("BLOSUM62")
pam250 = substitution_matrices.load("PAM250")

print(generate_mutations("LL", blosum62))
print(generate_mutations("LL", pam250))


[('LL', 0.0), ('IL', -2.0), ('LI', -2.0), ('LM', -2.0), ('ML', -2.0), ('LV', -3.0), ('VL', -3.0), ('II', -4.0), ('IM', -4.0), ('MI', -4.0), ('MM', -4.0), ('IV', -5.0), ('MV', -5.0), ('VI', -5.0), ('VM', -5.0), ('VV', -6.0)]
[('LL', 0.0), ('LM', -2.0), ('ML', -2.0), ('IL', -4.0), ('LI', -4.0), ('LF', -4.0), ('LV', -4.0), ('MM', -4.0), ('FL', -4.0), ('VL', -4.0), ('IM', -6.0), ('MI', -6.0), ('MF', -6.0), ('MV', -6.0), ('FM', -6.0), ('VM', -6.0), ('II', -8.0), ('IF', -8.0), ('IV', -8.0), ('FI', -8.0), ('FF', -8.0), ('FV', -8.0), ('VI', -8.0), ('VF', -8.0), ('VV', -8.0)]


In [30]:
# Given a list of SymPairs and a string sequence seq_str
# Returns the SymPairs whose merge would result in the same string as seq_str
def search_existing_merge(sympair_list , seq_str):
    result = []
    for sympair in sympair_list:
        merged_pair = sympair.left + sympair.right
        if merged_pair == seq_str:
            result.append(sympair)
    return result


An example code block for also adding the mutated sequences during the training of BPE tokenizer

Currently, it adds all the mutations that are generated from positive values in the substitution matrix.

In [31]:
# %%prun -s cumulative
init_vocab = alphabet.copy()

# Add 100 words + mutations
for i in range(100):
    best_pair = merge_heap.pop()
    merge_pair(best_pair)
    init_vocab.append(best_pair.left + best_pair.right)

    # For the mutations:
    merged_string = best_pair.left + best_pair.right
    # [1:] ignores the original string
    mutations = generate_mutations(merged_string, blosum62)[1:]
    for mutated_str, score in mutations:
        pairs_to_merge = search_existing_merge(list(merge_heap.map.keys()), mutated_str)
        for pair in pairs_to_merge:
            merge_heap.remove_by_value(pair)
            merge_pair(pair)
        init_vocab.append(mutated_str)

print(init_vocab)

['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'X', 'Y', 'LL', 'IL', 'LI', 'LM', 'ML', 'LV', 'VL', 'II', 'IM', 'MI', 'MM', 'IV', 'MV', 'VI', 'VM', 'VV', 'SS', 'AS', 'NS', 'SA', 'SN', 'ST', 'TS', 'AA', 'AN', 'AT', 'NA', 'NN', 'NT', 'TA', 'TN', 'TT', 'EE', 'EZ', 'ZE', 'ZZ', 'DE', 'QE', 'ED', 'EQ', 'DZ', 'QZ', 'EK', 'EB', 'KE', 'BE', 'ZD', 'ZQ', 'KZ', 'BZ', 'ZK', 'ZB', 'DD', 'DQ', 'QD', 'QQ', 'DK', 'DB', 'QK', 'QB', 'KD', 'KQ', 'BD', 'BQ', 'KK', 'KB', 'BK', 'BB', 'PP', 'GG', 'SL', 'SI', 'SM', 'AL', 'NL', 'SV', 'TL', 'AI', 'AM', 'NI', 'NM', 'TI', 'TM', 'AV', 'NV', 'TV', 'PG', 'RR', 'RK', 'KR', 'RQ', 'QR', 'KK', 'QK', 'KQ', 'QQ', 'GL', 'GI', 'GM', 'GV', 'PL', 'PI', 'PM', 'PV', 'RL', 'RI', 'RM', 'RV', 'KL', 'QL', 'KI', 'KM', 'QI', 'QM', 'KV', 'QV', 'EL', 'ZL', 'EI', 'EM', 'DL', 'QL', 'EV', 'ZI', 'ZM', 'KL', 'BL', 'ZV', 'DI', 'DM', 'QI', 'QM', 'DV', 'QV', 'KI', 'KM', 'BI', 'BM', 'KV', 'BV', 'SG', 'AG', 'NG', 'TG', 'SP', 'AP', 'NP', 'TP', 'FL', 

In [33]:
# [print(str(k)) for k in merge_heap.map]
print(len(init_vocab))

cnts = Counter(init_vocab)
cnts

2803


Counter({'NY': 5,
         'NN': 4,
         'QQ': 4,
         'QK': 4,
         'KQ': 4,
         'KK': 4,
         'NK': 4,
         'NQ': 4,
         'YY': 4,
         'YK': 4,
         'YQ': 4,
         'YN': 4,
         'QS': 4,
         'KS': 4,
         'NP': 3,
         'NF': 3,
         'NW': 3,
         'KY': 3,
         'QY': 3,
         'NS': 2,
         'SN': 2,
         'AN': 2,
         'NA': 2,
         'NT': 2,
         'TN': 2,
         'QE': 2,
         'EQ': 2,
         'QZ': 2,
         'EK': 2,
         'KE': 2,
         'ZQ': 2,
         'KZ': 2,
         'ZK': 2,
         'DQ': 2,
         'QD': 2,
         'DK': 2,
         'QB': 2,
         'KD': 2,
         'BQ': 2,
         'KB': 2,
         'BK': 2,
         'NL': 2,
         'NI': 2,
         'NM': 2,
         'NV': 2,
         'RK': 2,
         'KR': 2,
         'RQ': 2,
         'QR': 2,
         'KL': 2,
         'QL': 2,
         'KI': 2,
         'KM': 2,
         'QI': 2,
         'QM': 2,
         '