In [1]:
from collections import defaultdict
import re

In [2]:
def initialize_vocabulary(corpus: list[str]):
    vocabulary: defaultdict[str, int] = defaultdict(int)
    charset: set[str] = set()
    for word in corpus:
        word_with_marker = '_' + word
        characters = list(word_with_marker)
        charset.update(characters)
        tokenized_word = ' '.join(characters)
        vocabulary[tokenized_word] += 1
    return vocabulary, charset


def get_pair_counts(vocabulary: dict[str, int]):
    # Map[(String, String), Int]
    pair_counts: defaultdict[tuple[str, str], int] = defaultdict(int)
    for tokenized_word, count in vocabulary.items():
        tokens = tokenized_word.split()
        for i in range(len(tokens) - 1):
            pair = (tokens[i], tokens[i + 1])
            pair_counts[pair] += count
    return pair_counts

In [3]:
vocabulary, charset = initialize_vocabulary(corpus=["Movies", "are", "fun", "for", "everyone", "every", "time", "one"])
print("vocabulary:", vocabulary)
print("charset:", charset)
print("----- get_pair_counts(vocabulary) ----")
for t in get_pair_counts(vocabulary).items(): print(t)

vocabulary: defaultdict(<class 'int'>, {'_ M o v i e s': 1, '_ a r e': 1, '_ f u n': 1, '_ f o r': 1, '_ e v e r y o n e': 1, '_ e v e r y': 1, '_ t i m e': 1, '_ o n e': 1})
charset: {'e', 'M', '_', 'v', 'i', 'a', 'f', 's', 'r', 'o', 'n', 'y', 't', 'm', 'u'}
----- get_pair_counts(vocabulary) ----
(('_', 'M'), 1)
(('M', 'o'), 1)
(('o', 'v'), 1)
(('v', 'i'), 1)
(('i', 'e'), 1)
(('e', 's'), 1)
(('_', 'a'), 1)
(('a', 'r'), 1)
(('r', 'e'), 1)
(('_', 'f'), 2)
(('f', 'u'), 1)
(('u', 'n'), 1)
(('f', 'o'), 1)
(('o', 'r'), 1)
(('_', 'e'), 2)
(('e', 'v'), 2)
(('v', 'e'), 2)
(('e', 'r'), 2)
(('r', 'y'), 2)
(('y', 'o'), 1)
(('o', 'n'), 2)
(('n', 'e'), 2)
(('_', 't'), 1)
(('t', 'i'), 1)
(('i', 'm'), 1)
(('m', 'e'), 1)
(('_', 'o'), 1)


In [4]:
def merge_pair(vocabulary: defaultdict[str, int], pair: tuple[str, str]) -> dict[str, int]:
    """
    Merges a given pair of symbols in the vocabulary.

    Args:
        vocab (dict): A dictionary where keys are tokenized words and values are their frequencies.
        pair (tuple): A tuple containing the pair of symbols to merge.
    Returns:
        dict: Updated vocabulary with the pair merged.
    """
    new_vocabulary: dict[str, int] = {}
    # ('o', 'v') => 'o\\ v'
    bigram = re.escape(' '.join(pair))
    # re.compile(r'(?<!\S)o\ v(?!\S)', re.UNICODE)
    # matches only whole token pairs
    pattern = re.compile(r"(?<!\S)" + bigram + r"(?!\S)")
    # same as:
    #  `vocabulary.map((tokenized_word, count) => merged(tokenized_word) -> count)`
    #
    for tokenized_word, count in vocabulary.items():
        # in tokenized_word, replace ' o v ' ~> ' ov '
        # '_ M o v i e s' ~> '_ M ov i e s'
        new_tokenized_word = pattern.sub("".join(pair), tokenized_word)
        new_vocabulary[new_tokenized_word] = count
    return new_vocabulary

In [5]:
print(merge_pair(vocabulary, ('o', 'v')))

_ M o v i e s ~~> _ M ov i e s
{'_ M ov i e s': 1, '_ a r e': 1, '_ f u n': 1, '_ f o r': 1, '_ e v e r y o n e': 1, '_ e v e r y': 1, '_ t i m e': 1, '_ o n e': 1}


In [5]:
def byte_pair_encoding(corpus: list[str], vocab_size: int):
    vocabulary, charset = initialize_vocabulary(corpus)
    merges = []
    # start with a copy of charset
    tokens = set(charset)
    while len(tokens) < vocab_size:
        pair_counts = get_pair_counts(vocabulary)
        if not pair_counts:
            break
        most_frequent_pair: tuple[str, str] = max(pair_counts, key=pair_counts.get)
        merges.append(most_frequent_pair)
        vocabulary = merge_pair(vocabulary, most_frequent_pair)
        new_token = ''.join(most_frequent_pair)
        tokens.add(new_token)
    return vocabulary, merges, charset, tokens

In [6]:
byte_pair_encoding(["Movies", "are", "fun", "for", "everyone", "every", "time", "one"], 50)

({'_Movies': 1,
  '_are': 1,
  '_fun': 1,
  '_for': 1,
  '_everyone': 1,
  '_every': 1,
  '_time': 1,
  '_one': 1},
 [('_', 'f'),
  ('_', 'e'),
  ('_e', 'v'),
  ('_ev', 'e'),
  ('_eve', 'r'),
  ('_ever', 'y'),
  ('o', 'n'),
  ('on', 'e'),
  ('_', 'M'),
  ('_M', 'o'),
  ('_Mo', 'v'),
  ('_Mov', 'i'),
  ('_Movi', 'e'),
  ('_Movie', 's'),
  ('_', 'a'),
  ('_a', 'r'),
  ('_ar', 'e'),
  ('_f', 'u'),
  ('_fu', 'n'),
  ('_f', 'o'),
  ('_fo', 'r'),
  ('_every', 'one'),
  ('_', 't'),
  ('_t', 'i'),
  ('_ti', 'm'),
  ('_tim', 'e'),
  ('_', 'one')],
 {'M', '_', 'a', 'e', 'f', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u', 'v', 'y'},
 {'M',
  '_',
  '_M',
  '_Mo',
  '_Mov',
  '_Movi',
  '_Movie',
  '_Movies',
  '_a',
  '_ar',
  '_are',
  '_e',
  '_ev',
  '_eve',
  '_ever',
  '_every',
  '_everyone',
  '_f',
  '_fo',
  '_for',
  '_fu',
  '_fun',
  '_one',
  '_t',
  '_ti',
  '_tim',
  '_time',
  'a',
  'e',
  'f',
  'i',
  'm',
  'n',
  'o',
  'on',
  'one',
  'r',
  's',
  't',
  'u',
  'v',
  'y'})

In [20]:
def tokenize_word(word, merges, vocabulary, charset, unk_token="<UNK>"):
    word = '_' + word
    if word in vocabulary:
        return [word]
    tokens = [char if char in charset else unk_token for char in word]
    for left, right in merges:
        i = 0
        while i < len(tokens) - 1:
            if tokens[i:i + 2] == [left, right]:
                tokens[i:i + 2] = [left + right]
            else:
                i += 1
    return tokens

vocabulary, merges, charset, tokens = byte_pair_encoding(["Movies", "are", "fun", "for", "everyone", "every", "time", "one"], 50)

tokenize_word("forum", merges, vocabulary, charset)

['_for', 'u', 'm']

In [18]:
vocabulary, merges, charset, tokens = byte_pair_encoding(["Movies", "are", "fun", "for", "everyone", "every", "time", "one"], 50)

tokenize_word("forum", merges, vocabulary, charset)

['_for', 'u', 'm']