In [4]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
from collections import defaultdict

from Bio import Align
from Bio.Align import substitution_matrices

import json
from sklearn.model_selection import train_test_split
import time

In [5]:
def load_aligner(mode='global', sub_matrix='BLOSUM62', open_gap_score=-11, extend_gap_score=-1):
    '''
    aligner = Align.PairwiseAligner(mode='global', open_gap_score=-9, extend_gap_score=-1, substitution_matrix=substitution_matrices.load("PAM30"))
    aligner = Align.PairwiseAligner(mode='global', open_gap_score=-10, extend_gap_score=-1, substitution_matrix=substitution_matrices.load("PAM70"))
    aligner = Align.PairwiseAligner(mode='global', open_gap_score=-14, extend_gap_score=-2, substitution_matrix=substitution_matrices.load("PAM250"))
    aligner = Align.PairwiseAligner(mode='global', open_gap_score=-10, extend_gap_score=-1, substitution_matrix=substitution_matrices.load("BLOSUM80"))
    aligner = Align.PairwiseAligner(mode='global', open_gap_score=-11, extend_gap_score=-1, substitution_matrix=substitution_matrices.load("BLOSUM62"))
    aligner = Align.PairwiseAligner(mode='global', open_gap_score=-15, extend_gap_score=-2, substitution_matrix=substitution_matrices.load("BLOSUM45"))
    aligner = Align.PairwiseAligner(mode='global', open_gap_score=-13, extend_gap_score=-2, substitution_matrix=substitution_matrices.load("BLOSUM50"))
    aligner = Align.PairwiseAligner(mode='global', open_gap_score=-10, extend_gap_score=-1, substitution_matrix=substitution_matrices.load("BLOSUM90"))
    '''
    return Align.PairwiseAligner(mode=mode, open_gap_score=open_gap_score, extend_gap_score=extend_gap_score, substitution_matrix=substitution_matrices.load(sub_matrix))

In [6]:
def pair_scoring(token1, token2, aligner):
    if (token1, token2) in alignment_scores_cache:
        return alignment_scores_cache[(token1, token2)]
    else:
        score = aligner.score(token1, token2)
        alignment_scores_cache[(token1, token2)] = score
        return score

In [7]:
def genetic_aa_mutation2(ref_pair, alphabet, pair_freqs, mutation_perc, sample_count):
    '''
    genetic_aa_mutation(ref_pair=('AB','CD'), alphabet=alphabet, pair_freqs=pair_freqs, mutation_perc=.33, sample_count=2)
    [(('AB', 'CN'), 0), (('IX', 'CD'), 0)]
    '''
    
    samples = []
    for _ in range(sample_count*2):
        new_sample = list(ref_pair[0]+ref_pair[1])
        for i in range(len(new_sample)):
            if np.random.rand() < mutation_perc:
                alph = list(aligner.alphabet[:-4])
                new_sample[i] = alph[np.random.randint(len(alph))]
        new_sample = ''.join(new_sample)
        if new_sample == ''.join(ref_pair):
            continue
        new_sample_pair_1 = new_sample[:len(ref_pair[0])]
        new_sample_pair_2 = new_sample[len(ref_pair[0]):]
        samples.append(((new_sample_pair_1, new_sample_pair_2), pair_freqs[(new_sample_pair_1, new_sample_pair_2)]))
    return list(set(samples))[:sample_count]

def genetic_aa_mutation(ref_pair, alphabet, pair_freqs, mutation_perc, sample_count):
    '''
    genetic_aa_mutation(ref_pair=('AB','CD'), alphabet=alphabet, pair_freqs=pair_freqs, mutation_perc=.33, sample_count=2)
    [(('AB', 'CN'), 0), (('IX', 'CD'), 0)]
    '''
    
    samples = []
    for _ in range(sample_count*2):
        new_sample = list(ref_pair[0]+ref_pair[1])
        for i in range(len(new_sample)):
            if np.random.rand() < mutation_perc:
                subs = aligner.substitution_matrix[new_sample[i]][:-4]
                subs = subs+abs(min(subs))+1
                new_sample[i] = np.random.choice(list(aligner.alphabet[:-4]), p=subs/sum(subs))
        new_sample = ''.join(new_sample)
        if new_sample == ''.join(ref_pair):
            continue
        new_sample_pair_1 = new_sample[:len(ref_pair[0])]
        new_sample_pair_2 = new_sample[len(ref_pair[0]):]
        samples.append(((new_sample_pair_1, new_sample_pair_2), pair_freqs[(new_sample_pair_1, new_sample_pair_2)]))
    return list(set(samples))[:sample_count]

In [8]:
def choice_of_freqs(pair_freqs_sorted_list, sample_count, search_size=None):
    '''
    choice_of_freqs(pair_freqs_sorted_list, sample_count=2)
    [(('Y', 'Y'), 12), (('C', 'YN'), 10)]
    '''
    sample_count = min(sample_count, len(pair_freqs_sorted_list))
    index_list = list(range(len(pair_freqs_sorted_list)))[:search_size]
    freqs_list = [pf[1] for pf in pair_freqs_sorted_list][:search_size]
    if len(freqs_list) == 0:
        return []
    index_choices = np.random.choice(index_list, size=sample_count, replace=False, p=freqs_list/np.sum(freqs_list))
    return [pair_freqs_sorted_list[ic] for ic in index_choices]

In [9]:
def calculate_alignment_scores(ref_pair, comp_pairs_list, aligner):
    '''
    calculate_alignment_scores((('F', 'Y'), 12), [(('Y', 'Y'), 12), (('C', 'YN'), 10)])
    [((('F', 'Y'), 12), (('Y', 'Y'), 12), 10.0, 240.0),
    ((('F', 'Y'), 12), (('C', 'YN'), 10), -6.0, -132.0)]
    '''
    alignment_scores = [(ref_pair, comp_pair, pair_scoring(''.join(ref_pair[0]), ''.join(comp_pair[0]), aligner),) for comp_pair in comp_pairs_list]
    alignment_scores = [(*alignment_score, (np.log(ref_pair[1])+np.log(alignment_score[1][1]))*alignment_score[2]) for alignment_score in alignment_scores]
    return alignment_scores

In [10]:
def calculate_alignment_scores_gen(ref_pair, comp_pairs_list, aligner):
    '''
    calculate_alignment_scores((('F', 'Y'), 12), [(('Y', 'Y'), 12), (('F', 'C'), 10)])
    [((('F', 'Y'), 12), (('Y', 'Y'), 12), 10.0, 240.0),
    ((('F', 'Y'), 12), (('F', 'C'), 10), 4.0, 88.0)]
    '''
    ref_comp_list = []
    for comp_pair in comp_pairs_list:
        diff_ref = ''
        diff_comp = ''
        for ref_aa, comp_aa in zip(''.join(ref_pair[0]), ''.join(comp_pair[0])):
            if ref_aa != comp_aa:
                diff_ref += ref_aa
                diff_comp += comp_aa
        ref_comp_list.append((ref_pair, comp_pair, diff_ref, diff_comp))

    alignment_scores = [(ref_pair, comp_pair, pair_scoring(diff_ref, diff_comp, aligner),) for ref_pair, comp_pair, diff_ref, diff_comp in ref_comp_list]
    alignment_scores = [(*alignment_score, (np.log(ref_pair[1])+np.log(alignment_score[1][1]))*alignment_score[2]) for alignment_score in alignment_scores]
    return alignment_scores

In [11]:
def compute_pair_freqs(splits):
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs

def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue

        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                split = split[:i] + [a + b] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

In [14]:
with open('uniref_taxonomy_id_9606_AND_identity_2024_09_13.json') as f:
# with open('../../RSRC/ECCB/uniref_taxonomy_id_9606_AND_identity_2024_09_13.json') as f:
    human_proteins_json = json.load(f)['results']
    
human_proteins_df = []
for prot in human_proteins_json:
    human_proteins_df.append({'id': prot['id'], 'sequence': prot['representativeMember']['sequence']['value']})
human_proteins_df = pd.DataFrame(human_proteins_df)
human_proteins_df = human_proteins_df[~human_proteins_df['sequence'].str.contains('U')]

df_ds_train, df_ds_test = train_test_split(human_proteins_df, test_size=0.2, random_state=42)

corpus = df_ds_train['sequence']

In [24]:
word_freqs = defaultdict(int)

for text in corpus:
    word_freqs[text] += 1
    
alphabet = []

for word in word_freqs.keys():
    for letter in word:
        if letter not in alphabet:
            alphabet.append(letter)
alphabet.sort()

print(alphabet)

['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'X', 'Y']


In [25]:
np.random.seed(10)
vocab_size = 50
merges = {}
vocab = alphabet.copy()
splits = {word: list(word) for word in word_freqs.keys()}

method_min_thr = 3
mutation_perc = .33
sample_count = 100
extra_merge_count = 1
mode='global'
sub_matrix='BLOSUM62'
open_gap_score=-11
extend_gap_score=-1
candidate_method = 'genetic' # 'genetic', 'choice', 'freqs'
scoring_method = 'align' # 'align', 'freq_align'

print(f'{vocab_size}_{candidate_method}_{scoring_method}_{sample_count}_{extra_merge_count}_{mutation_perc}_{method_min_thr}_{sub_matrix}')

alignment_scores_cache = defaultdict(int)
aligner = load_aligner(mode=mode, sub_matrix=sub_matrix, open_gap_score=open_gap_score, extend_gap_score=extend_gap_score)

with tqdm(total=vocab_size-len(vocab)) as pbar:
    while len(vocab) < vocab_size:
        t0 = time.time()
        pair_freqs = compute_pair_freqs(splits)

        pair_freqs_sorted_list = sorted(list(pair_freqs.items()), key=lambda x: (-x[1], (-x[1], len(x[0][0]+x[0][1]))))
        best_pair = pair_freqs_sorted_list[0]

        if len(best_pair[0][0])+len(best_pair[0][1]) >= method_min_thr:
            if candidate_method == 'genetic':
                candidate_pairs = genetic_aa_mutation(best_pair[0], alphabet, pair_freqs, mutation_perc, sample_count)
            elif candidate_method == 'choice':
                candidate_pairs = choice_of_freqs(pair_freqs_sorted_list[1:], sample_count)
            else:
                candidate_pairs = pair_freqs_sorted_list[1:1+sample_count]

            candidate_pairs = [candidate_pair for candidate_pair in candidate_pairs if candidate_pair[1]>0]

            if candidate_method == 'genetic':
                alignment_scores = calculate_alignment_scores_gen(best_pair, candidate_pairs, aligner)
            else:
                alignment_scores = calculate_alignment_scores(best_pair, candidate_pairs, aligner)

            alignment_scores = [alignment_score for alignment_score in alignment_scores if alignment_score[2]>0]

            if scoring_method == 'align':
                alignment_scores = sorted(alignment_scores, key=lambda x: (-x[2], -x[3])) # alignment
            else:
                alignment_scores = sorted(alignment_scores, key=lambda x: (-x[3], -x[2])) # frequency * alignment

            new_pairs = [alignment_score[1][0] for alignment_score in alignment_scores[:extra_merge_count]]
            new_pairs.insert(0, best_pair[0])
        else:
            new_pairs = [best_pair[0]]

        for new_pair in new_pairs:
            splits = merge_pair(*new_pair, splits)
            merges[new_pair] = ''.join(new_pair)
            vocab.append(''.join(new_pair))

        if len(vocab) % 1000 in [25]:
            with open(f'vocab_bpe_human_{len(vocab)}_{candidate_method}_{scoring_method}_{sample_count}_{extra_merge_count}_{mutation_perc}_{method_min_thr}_{sub_matrix}.pickle', 'wb') as f:
                pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL)
            with open(f'merges_bpe_human_{len(vocab)}_{candidate_method}_{scoring_method}_{sample_count}_{extra_merge_count}_{mutation_perc}_{method_min_thr}_{sub_matrix}.pickle', 'wb') as f:
                pickle.dump(merges, f, pickle.HIGHEST_PROTOCOL)
            
        pbar.update(len(new_pairs))
        
        
with open(f'vocab_bpe_human_{vocab_size}_{candidate_method}_{scoring_method}_{sample_count}_{extra_merge_count}_{mutation_perc}_{method_min_thr}_{sub_matrix}.pickle', 'wb') as f:
    pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL)
with open(f'merges_bpe_human_{vocab_size}_{candidate_method}_{scoring_method}_{sample_count}_{extra_merge_count}_{mutation_perc}_{method_min_thr}_{sub_matrix}.pickle', 'wb') as f:
    pickle.dump(merges, f, pickle.HIGHEST_PROTOCOL)

50_genetic_freq_align_100_1_0.33_3_BLOSUM62


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [01:40<00:00,  3.45s/it]
