In [1]:
from tiny_plm.model import SelfAttention
from tiny_plm.config import PLMConfig
import torch
import pandas as pd

In [2]:
config = PLMConfig(n_head=8)
attn_module = SelfAttention(config)

x = torch.rand(1, 10, 256)
y = attn_module(x)

In [30]:
AA_to_idx = {}
AA_to_idx['A'] = 0
AA_to_idx['R'] = 1
AA_to_idx['N'] = 2
AA_to_idx['D'] = 3
AA_to_idx['C'] = 4
AA_to_idx['Q'] = 5
AA_to_idx['E'] = 6
AA_to_idx['G'] = 7
AA_to_idx['H'] = 8
AA_to_idx['I'] = 9
AA_to_idx['L'] = 10
AA_to_idx['K'] = 11
AA_to_idx['M'] = 12
AA_to_idx['F'] = 13
AA_to_idx['P'] = 14
AA_to_idx['S'] = 15
AA_to_idx['T'] = 16
AA_to_idx['W'] = 17
AA_to_idx['Y'] = 18
AA_to_idx['V'] = 19


In [4]:
seqs = pd.read_csv('data/prok.csv', nrows=10)

In [35]:
seq = seqs.iloc[0].sequence

In [36]:
toks = [AA_to_idx[a] for a in seq]

In [37]:
toks

[12, 11, 1, 9, 15, 16, 16, 9, 16, 16, 16, 9, 16, 9, 16, 16, 7, 2, 7, 0, 7]

In [38]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(toks)
print(sorted( ((v, k) for k, v in stats.items()), reverse=True))

[(4, (16, 16)), (3, (16, 9)), (3, (9, 16)), (1, (16, 7)), (1, (15, 16)), (1, (12, 11)), (1, (11, 1)), (1, (9, 15)), (1, (7, 2)), (1, (7, 0)), (1, (2, 7)), (1, (1, 9)), (1, (0, 7))]


In [39]:
top_pair = max(stats, key=stats.get)
top_pair

(16, 16)

In [40]:
len(seqs.iloc[0].sequence)

21

In [41]:
def merge(ids, pair, idx):
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

In [42]:
merge(toks, top_pair, 20)

[12, 11, 1, 9, 15, 20, 9, 20, 16, 9, 16, 9, 20, 7, 2, 7, 0, 7]

In [44]:
def bpe(tokens, vocab_size=10):
    num_merges = 20 - vocab_size
    ids = list(tokens)
    merges = {}
    for i in range(num_merges):
        stats = get_stats(ids)
        pair = max(stats, key=stats.get)
        idx = 21 + i
        print(f"merging {pair} into {idx}")
        ids = merge(ids, pair, idx)
        merges[pair] = idx

In [45]:
bpe(toks)

merging (16, 16) into 21
merging (9, 21) into 22
merging (12, 11) into 23
merging (23, 1) into 24
merging (24, 9) into 25
merging (25, 15) into 26
merging (26, 21) into 27
merging (27, 22) into 28
merging (28, 16) into 29
merging (29, 9) into 30


## decoding bpe