In [1]:
import regex as re
from collections import defaultdict

In [60]:
sample_corpus = ["hug "*10+" bug"*20+" mug"*13, "bug "*30 + " burger"*20]

In [62]:
sample_corpus

['hug hug hug hug hug hug hug hug hug hug  bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug mug mug mug mug mug mug mug mug mug mug mug mug mug',
 'bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug bug  burger burger burger burger burger burger burger burger burger burger burger burger burger burger burger burger burger burger burger burger']

In [63]:
GPT4_WORD_SPLIT_REGEX = "|".join(
        [
            r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
            r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
            r"""\p{N}{1,3}""",
            r""" ?[^\s\p{L}\p{N}]+[\r\n/]*""",
            r"""\s*[\r\n]+""",
            r"""\s+(?!\S)""",
            r"""\s+""",
        ]
    )

In [64]:
word_splitter = re.compile(GPT4_WORD_SPLIT_REGEX)

In [65]:
list(zip([1, 2, 3], [2, 3]))

[(1, 2), (2, 3)]

In [102]:
train_dict = defaultdict(lambda: 0)

idx = 0
for corpus in sample_corpus:
    for _, words in enumerate(word_splitter.findall(corpus)):
        train_dict[words.encode("utf-8")] += 1

In [103]:
train_corpus = defaultdict(list)
pairs_dict = {}

for idx, (word, count) in enumerate(train_dict.items()):
    train_corpus[idx].append(count)
    pairs = list(zip(word, word[1:]))
    train_corpus[idx].append(pairs)

    for pair in pairs:
        pair_entry = pairs_dict.get(pair, [0, set([])])
        pair_entry[0] = pair_entry[0] + count
        positioning = pair_entry[1]
        positioning.add(idx)
        pairs_dict[pair] = [counter, positioning]

In [104]:
train_dict

defaultdict(<function __main__.<lambda>()>,
            {b'hug': 1,
             b' hug': 9,
             b' ': 2,
             b' bug': 49,
             b' mug': 13,
             b'bug': 1,
             b' burger': 20})

In [105]:
 pairs_dict

{(104, 117): [10, {0, 1}],
 (117, 103): [10, {0, 1, 3, 4, 5}],
 (32, 104): [10, {1}],
 (32, 98): [10, {3, 6}],
 (98, 117): [10, {3, 5, 6}],
 (32, 109): [10, {4}],
 (109, 117): [10, {4}],
 (117, 114): [10, {6}],
 (114, 103): [10, {6}],
 (103, 101): [10, {6}],
 (101, 114): [10, {6}]}

In [128]:
MAX_VOCAB_NUM = 300

merges = {}
id_to_vocab = {idx: bytes([idx]) for idx in range(256)}
vocab = {v:k for k, v in id_to_vocab.items()}

In [115]:
num_merges = MAX_VOCAB_NUM - len(vocab)

In [116]:
base_vocab_size = len(vocab)

In [130]:
for i in range(num_merges):
    best = max(pairs_dict, key=lambda pair: pairs_dict[pair][0])
    print(best)
    vocab[b"".join(id_to_vocab[i] for i in best)] = base_vocab_size + i
    print(vocab)
    break
    # merge_pairs(train_dict, pairs_dict, best)

(104, 117)
{b'\x00': 0, b'\x01': 1, b'\x02': 2, b'\x03': 3, b'\x04': 4, b'\x05': 5, b'\x06': 6, b'\x07': 7, b'\x08': 8, b'\t': 9, b'\n': 10, b'\x0b': 11, b'\x0c': 12, b'\r': 13, b'\x0e': 14, b'\x0f': 15, b'\x10': 16, b'\x11': 17, b'\x12': 18, b'\x13': 19, b'\x14': 20, b'\x15': 21, b'\x16': 22, b'\x17': 23, b'\x18': 24, b'\x19': 25, b'\x1a': 26, b'\x1b': 27, b'\x1c': 28, b'\x1d': 29, b'\x1e': 30, b'\x1f': 31, b' ': 32, b'!': 33, b'"': 34, b'#': 35, b'$': 36, b'%': 37, b'&': 38, b"'": 39, b'(': 40, b')': 41, b'*': 42, b'+': 43, b',': 44, b'-': 45, b'.': 46, b'/': 47, b'0': 48, b'1': 49, b'2': 50, b'3': 51, b'4': 52, b'5': 53, b'6': 54, b'7': 55, b'8': 56, b'9': 57, b':': 58, b';': 59, b'<': 60, b'=': 61, b'>': 62, b'?': 63, b'@': 64, b'A': 65, b'B': 66, b'C': 67, b'D': 68, b'E': 69, b'F': 70, b'G': 71, b'H': 72, b'I': 73, b'J': 74, b'K': 75, b'L': 76, b'M': 77, b'N': 78, b'O': 79, b'P': 80, b'Q': 81, b'R': 82, b'S': 83, b'T': 84, b'U': 85, b'V': 86, b'W': 87, b'X': 88, b'Y': 89, b'Z': 90

In [129]:
[id_to_vocab[i] for i in best]

[b'h', b'u']