Import dependencies

In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
from collections import defaultdict, Counter
import more_itertools
import heapq
from Bio import Align
from Bio.Align import substitution_matrices
import uuid
import json
from sklearn.model_selection import train_test_split
import time

Load sequence data

In [2]:
with open('uniref_taxonomy_id_9606_AND_identity_2024_09_13.json') as f:
# with open('../../RSRC/ECCB/uniref_taxonomy_id_9606_AND_identity_2024_09_13.json') as f:
    human_proteins_json = json.load(f)['results']
    
human_proteins_df = []
for prot in human_proteins_json:
    human_proteins_df.append({'id': prot['id'], 'sequence': prot['representativeMember']['sequence']['value']})
human_proteins_df = pd.DataFrame(human_proteins_df)
human_proteins_df = human_proteins_df[~human_proteins_df['sequence'].str.contains('U')]

df_ds_train, df_ds_test = train_test_split(human_proteins_df, test_size=0.2, random_state=42)

corpus = df_ds_train['sequence']

Extract the initial symbols

In [3]:
alphabet = []
for seq in corpus:
    for letter in seq:
        if letter not in alphabet:
            alphabet.append(letter)
alphabet.sort()
print(alphabet)

['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'X', 'Y']


Import custom classes for efficient BPE

In [4]:
from helper_classes import Sym, SymList, SymPair, MaxHeapMap

Convert every protein sequence into a SymList (a doubly linked list containing tokens as nodes)

In [5]:
sequences = []
for seq in corpus:
    if len(seq) == 0: continue
    symlist = SymList()
    for sym_str in seq:
        symlist.append(Sym(sym_str))
    sequences.append(symlist)


In [6]:
print(sequences[0])

MEVLRRSSVFAAEIMDAFDRCGDAADGLMSSSVWSAQTLASAPTGWWLHSAASAAS


Construct the initial MaxHeapMap

This data structure maintains a max heap of SymbolPairs with respect to their occurence counts.
- A SymbolPair is basically a struct that keeps track of every occurence of a token pair in the corpus

In [7]:
# Generate the data structure for bookkeeping of all symbol pairs found in the data.

# Maps ("A", "B") -> SymPair("A", "B")
pair_database = {}

def add_entry(db, sym1, sym2):
    sym1_str = sym1.literal
    sym2_str = sym2.literal
    curr_pair = db.get((sym1_str, sym2_str), None)
    if curr_pair is None:
        curr_pair = SymPair(sym1_str, sym2_str)
        db[(sym1_str, sym2_str)] = curr_pair
    curr_pair.add_pos((sym1, sym2))


for seq in sequences:
    for sym1, sym2 in more_itertools.pairwise(seq):
        add_entry(pair_database, sym1, sym2)


merge_heap = MaxHeapMap()
for pair in pair_database.values():
    merge_heap.push(pair)

del pair_database

str(merge_heap.peek())

'Pair: (L, L), Count: 153309'

In [8]:
[str(elem) for elem in merge_heap.heap[:10]]

['Pair: (L, L), Count: 153309',
 'Pair: (S, S), Count: 135477',
 'Pair: (L, S), Count: 111026',
 'Pair: (L, A), Count: 93549',
 'Pair: (S, L), Count: 115163',
 'Pair: (A, L), Count: 99568',
 'Pair: (E, E), Count: 107112',
 'Pair: (S, G), Count: 82440',
 'Pair: (G, G), Count: 73809',
 'Pair: (A, A), Count: 98339']

Define the merging function.

This function is the core of the algorithm.

In [9]:
def merge_pair(sym_pair):
    # print(f"merging: {str(sym_pair)}")
    add_database = {}
    remove_database = {}
    for sym1, sym2 in sym_pair.positions:
        
        # Due to lazy removing, check if we are trying to merge an already removed pair
        if sym1.next.literal != sym_pair.right: continue
        if sym2.prev.literal != sym_pair.left: continue

        merged_sym = Sym(sym1.literal + sym2.literal)
        merged_sym.prev = sym1.prev
        merged_sym.next = sym2.next
        if sym1.prev is not None:
            pre = sym1.prev
            pre.next = merged_sym
            # print(f"Adding new entries to databases", pre, merged_sym)
            # print(f"Adding new entries to databases", pre, sym1)
            add_entry(add_database, pre, merged_sym)
            # Don't attempt to remove currently processed merge pair
            if SymPair(pre.literal, sym1.literal) != sym_pair: add_entry(remove_database, pre, sym1)
        if sym2.next is not None:
            nex = sym2.next
            nex.prev = merged_sym
            # print(f"Adding new entries to databases", pre, merged_sym)
            # print(f"Adding new entries to databases", pre, sym1)
            add_entry(add_database, merged_sym, nex)
             # Don't attempt to remove currently processed merge pair
            if SymPair(sym2.literal, nex.literal) != sym_pair: add_entry(remove_database, sym2, nex)
    # print(f"Sizes of the new databases add and remove: {len(add_database)}, {len(remove_database)}")
    for val in add_database.values():
        # print(f"adding: {str(val)}")
        merge_heap.push(val)
    for r_val in remove_database.values():
        # print(f"trying to remove: {str(r_val)}")
        inner_val = merge_heap.remove_by_value(r_val)
        inner_val.count -= r_val.count
        merge_heap.push(inner_val)



Add words by vocabulary size

In [10]:
# # %%prun -s cumulative
# init_vocab = alphabet.copy()

# # Add 10000 words
# for i in range(100):
#     best_pair = merge_heap.pop()
    
#     merge_pair(best_pair)
#     init_vocab.append(best_pair.merged())

# print(init_vocab)

Add words by cutoff frequency threshold

In [11]:
# # %%prun -s cumulative
# init_vocab = alphabet.copy()

# # Add 10000 words
# best_pair = merge_heap.pop()
# max_count = best_pair.count
# while best_pair.count > max_count / 1000:
#     merge_pair(best_pair)
#     init_vocab.append(best_pair.merged())
#     best_pair = merge_heap.pop()

# print(init_vocab)

In [12]:
# [print(str(k)) for k in merge_heap.map]
# print(len(init_vocab))

# cnts = Counter(init_vocab)

## Code blocks related to addition of mutated sequences

In [13]:
from itertools import product

# Generates a sorted list of mutations for the given sequence.
# Each resulting element is a tuple (m_sequence, score), where:
# - m_sequence: the mutated sequence
# - score: similarity score to the original sequence, a number between 0 and 1
# Results are sorted from highest to lowest score.
# First element of the results is the original sequence
# Mutations are based only on substitutions with non-negative matrix values.
def generate_mutations(seq, matrix):
    alp = matrix.alphabet
    candidates = []

    max_score = 0
    for aa in seq:
        max_score += matrix[aa][aa]

    # Create mutation candidates for each symbol in the sequence
    for i, aa in enumerate(seq):
        candidates.append([])
        # Consider substitutions with non-negative scores
        for c_aa in alp:
            score = matrix[aa][c_aa]
            if score > 0:
                candidates[i].append((c_aa, score))

    # Generate all possible mutated sequences and their scores
    combinations = product(*candidates)
    final_mutations = []
    for combination in combinations:
        m_sequence = ""
        total_score = 0
        for c_aa, score in combination:
            m_sequence += c_aa
            total_score += score
        # Normalize the total score by max_score
        final_mutations.append((m_sequence, round(total_score / max_score, 3)))
    
    # Sort mutations by score in descending order
    final_mutations.sort(key=lambda x: x[1], reverse=True)

    return final_mutations


Example usage of generate_mutations

In [14]:
from Bio.Align import substitution_matrices

blosum62 = substitution_matrices.load("BLOSUM62")
pam250 = substitution_matrices.load("PAM250")

print(generate_mutations("LLDFLL", blosum62))
print(generate_mutations("LL", pam250))


[('LLDFLL', 1.0), ('ILDFLL', 0.929), ('LIDFLL', 0.929), ('LLDFIL', 0.929), ('LLDFLI', 0.929), ('LLDFLM', 0.929), ('LLDFML', 0.929), ('LLBFLL', 0.929), ('LMDFLL', 0.929), ('MLDFLL', 0.929), ('LLDFLV', 0.893), ('LLDFVL', 0.893), ('LLDYLL', 0.893), ('LVDFLL', 0.893), ('VLDFLL', 0.893), ('IIDFLL', 0.857), ('ILDFIL', 0.857), ('ILDFLI', 0.857), ('ILDFLM', 0.857), ('ILDFML', 0.857), ('ILBFLL', 0.857), ('IMDFLL', 0.857), ('LIDFIL', 0.857), ('LIDFLI', 0.857), ('LIDFLM', 0.857), ('LIDFML', 0.857), ('LIBFLL', 0.857), ('LLDFII', 0.857), ('LLDFIM', 0.857), ('LLDFMI', 0.857), ('LLDFMM', 0.857), ('LLEFLL', 0.857), ('LLBFIL', 0.857), ('LLBFLI', 0.857), ('LLBFLM', 0.857), ('LLBFML', 0.857), ('LMDFIL', 0.857), ('LMDFLI', 0.857), ('LMDFLM', 0.857), ('LMDFML', 0.857), ('LMBFLL', 0.857), ('MIDFLL', 0.857), ('MLDFIL', 0.857), ('MLDFLI', 0.857), ('MLDFLM', 0.857), ('MLDFML', 0.857), ('MLBFLL', 0.857), ('MMDFLL', 0.857), ('ILDFLV', 0.821), ('ILDFVL', 0.821), ('ILDYLL', 0.821), ('IVDFLL', 0.821), ('LIDFLV', 0.

In [15]:
# Given a list of SymPairs and a string sequence seq_str
# Returns the SymPairs whose merge would result in the same string as seq_str
def search_existing_merge(sympair_list , seq_str):
    result = []
    for sympair in sympair_list:
        merged_pair = sympair.merged()
        if merged_pair == seq_str:
            result.append(sympair)
    return result


An example code block for also adding the mutated sequences during the training of BPE tokenizer

Currently, it adds **all** the mutations that are generated from positive values in the substitution matrix.

In [None]:
# %%prun -s cumulative
init_vocab = alphabet.copy()

# Add 100 words + their mutations
for i in range(100):
    best_pair = merge_heap.pop()
    merge_pair(best_pair)
    merged_string = best_pair.merged()
    init_vocab.append(merged_string)

    # For the mutations:
    
    # [1:] ignores the original string
    mutations = generate_mutations(merged_string, blosum62)[1:]
    # Consider only the mutations with a similarity score larger than 0.8
    mutations = [(mut, sc) for mut, sc in mutations if sc > 0.8]
    for mutated_str, score in mutations:
        pairs_to_merge = search_existing_merge(list(merge_heap.map.keys()), mutated_str)
        for pair in pairs_to_merge:
            merge_heap.remove_by_value(pair)
            merge_pair(pair)
        if len(pairs_to_merge) > 0:
            init_vocab.append(mutated_str)

print(init_vocab)

['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'X', 'Y', 'LL', 'SS', 'EE', 'AA', 'SL', 'PP', 'AL', 'VL', 'IL', 'GL', 'SP', 'GG', 'EL', 'RL', 'TL', 'EK', 'SG', 'RR', 'DL', 'KK', 'QL', 'PL', 'PI', 'PM', 'SA', 'PG', 'SV', 'SI', 'ST', 'EA', 'SR', 'FL', 'KL', 'ED', 'PA', 'SQ', 'EV', 'EI', 'EG', 'NL', 'TV', 'TI', 'AV', 'AI', 'SD', 'SK', 'TG', 'ER', 'PV', 'AG', 'HL', 'HI', 'HM', 'PR', 'QQ', 'SF', 'EN', 'RG', 'TT', 'VV', 'IV', 'VI', 'DG', 'EQ', 'RK', 'YL', 'YI', 'YM', 'PQ', 'TA', 'SN', 'CL', 'CI', 'CM', 'KG', 'DV', 'DI', 'ET', 'KA', 'KV', 'KI', 'RA', 'RV', 'RI', 'QG', 'QA', 'PT', 'PD', 'SH', 'QV', 'QI', 'FG', 'CG', 'KD', 'NG', 'SE', 'ML', 'RD', 'FV', 'FI', 'SY', 'KT', 'SC', 'NI', 'NV', 'RT', 'PE', 'PK', 'QD', 'AT', 'QK', 'HG', 'QR', 'YG', 'PF', 'DD', 'SM', 'VG', 'IG', 'FT', 'WL', 'WI', 'WM', 'QT', 'KN', 'AD', 'PN', 'AR', 'KR', 'PY', 'EM', 'EF', 'HV', 'QN', 'TD', 'LLL', 'ILL', 'LIL', 'LLI', 'LLM', 'LML', 'MLL', 'YV', 'RN', 'SW', 'PH', 'CV', 'AF',

         113316287 function calls (113170590 primitive calls) in 50.951 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   50.951   50.951 {built-in method builtins.exec}
        1    1.841    1.841   50.951   50.951 <string>:4(<module>)
      317   24.514    0.077   45.674    0.144 2737463463.py:3(search_existing_merge)
 99643900   21.160    0.000   21.160    0.000 helper_classes.py:23(merged)
      275    0.854    0.003    3.401    0.012 701220899.py:1(merge_pair)
   118086    0.167    0.000    1.096    0.000 helper_classes.py:105(remove_by_value)
  1080016    0.339    0.000    0.709    0.000 598563906.py:6(add_entry)
263883/118186    0.205    0.000    0.470    0.000 helper_classes.py:126(_heapify_down)
   354107    0.155    0.000    0.430    0.000 helper_classes.py:117(_heapify_up)
   478853    0.261    0.000    0.406    0.000 helper_classes.py:74(_swap)
   236021    0.141    0.000    0.367  

371516

In [23]:
# For debugging purposes, to see if any duplicates are added
print(len(init_vocab))

cnts = Counter(init_vocab)
cnts

973


Counter({'A': 1,
         'C': 1,
         'D': 1,
         'E': 1,
         'F': 1,
         'G': 1,
         'H': 1,
         'I': 1,
         'K': 1,
         'L': 1,
         'M': 1,
         'N': 1,
         'P': 1,
         'Q': 1,
         'R': 1,
         'S': 1,
         'T': 1,
         'V': 1,
         'W': 1,
         'X': 1,
         'Y': 1,
         'LL': 1,
         'SS': 1,
         'EE': 1,
         'AA': 1,
         'SL': 1,
         'PP': 1,
         'AL': 1,
         'VL': 1,
         'IL': 1,
         'GL': 1,
         'SP': 1,
         'GG': 1,
         'EL': 1,
         'RL': 1,
         'TL': 1,
         'EK': 1,
         'SG': 1,
         'RR': 1,
         'DL': 1,
         'KK': 1,
         'QL': 1,
         'PL': 1,
         'PI': 1,
         'PM': 1,
         'SA': 1,
         'PG': 1,
         'SV': 1,
         'SI': 1,
         'ST': 1,
         'EA': 1,
         'SR': 1,
         'FL': 1,
         'KL': 1,
         'ED': 1,
         'PA': 1,
         'SQ'