In [None]:
"""
Refs:
    https://towardsdatascience.com/byte-pair-encoding-subword-based-tokenization-algorithm-77828a70bee0
"""

import numpy as np
from spacy.lang.en import English


__author__ = "__Girish_Hegde__"

In [None]:
def get_token_freq(words):
    tokens = dict()
    for w, f in words.items():
        for tk in w:
            if tk not in tokens:
                tokens[tk] = f
            else:
                tokens[tk] += f
    return tokens
# tokens = get_token_freq(words)
# print(tokens)

In [None]:
def get_pairs(words):
    pairs = dict()
    split_pairs = dict()
    for w, f in words.items():
        prev = w[0]
        for nxt in w[1:]:
            if (prev + nxt) not in pairs:
                pairs[prev + nxt] = f
                split_pairs[(prev, nxt)] = f
            else:
                pairs[prev + nxt] += f
                split_pairs[(prev, nxt)] += f
            prev = nxt
    return pairs, split_pairs
# pairs, split_pairs = get_pairs(words)
# pairs, split_pairs

In [None]:
def get_frequent(pairs, split_pairs):
    max_pair, max_split = None, None
    fmax = 0
    for w, s, f in zip(pairs.keys(), split_pairs.keys(), pairs.values()):
        if f > fmax:
            fmax = f
            max_pair = w
            max_split = s
    return max_pair, max_split, fmax
# max_pair, max_split, fmax = get_frequent(pairs, split_pairs)
# max_pair, max_split, fmax

In [None]:
def update_tokens(tokens, pair, split, f):
    tokens[pair] = fmax
    for tk in split:
        if tokens[tk] == f:
            tokens.pop(tk)
        else:
            tokens[tk] -= f
    return tokens
# tokens = update_tokens(tokens, max_pair, max_split, fmax)
# tokens

In [None]:
def update_words(words, pair):
    out = dict()
    for w, f in words.items():
        tmp, skip = [], False
        for i, tk in enumerate(w[:-1]):
            if skip:
                skip = False
                continue
            if (tk + w[i + 1]) == pair:
                tmp.append(pair)
                skip = True
            else:
                tmp.append(tk)
                skip = False
        if not skip:
            tmp.append(w[-1])
        out[tuple(tmp)] = f
    return out
# words = update_words(words, max_pair)
# words


In [None]:
words = {'old': 7, 'older': 3, 'finest': 9, 'lowest': 4}
words = {'old': 1, 'older': 1, 'finest': 1, 'lowest': 1}

In [None]:
text = "BPE ensures that the most common words are represented in the vocabulary as a single token while the rare words are broken down into two or more subword tokens and this is in agreement with what a subword-based tokenization algorithm does."
tokenizer = English()
words = [str(w) for w in tokenizer(text)]
words = {w:words.count(w) for w in set(words)}
words

In [None]:
tokens = get_token_freq(words)
print(tokens)

In [None]:
total_words, max_itr = len(words), 50
itr = 0
while (itr < max_itr):
    itr += 1
    pairs, split_pairs = get_pairs(words)
    max_pair, max_split, fmax = get_frequent(pairs, split_pairs)
    tokens = update_tokens(tokens, max_pair, max_split, fmax)
    words = update_words(words, max_pair)
    print(f'{itr = }, n = {len(tokens)}, tokens = {tuple(tokens.keys())}')
    # print(words)