# Byte-Pair Encoding

In [1]:
text = "aaabdaaabac"
print(len(text))

11


In [2]:
tokens = list(text.encode("utf-8"))
print(tokens)

[97, 97, 97, 98, 100, 97, 97, 97, 98, 97, 99]


In [3]:
def calc_bigram_freqs(tokens: list[int]) -> dict[tuple[int, int], int]:
    bigram_freqs = {}
    for bigram in zip(tokens[:-1], tokens[1:]):
        if bigram not in bigram_freqs:
            bigram_freqs[bigram] = 0
        bigram_freqs[bigram] += 1

    return bigram_freqs

bigram_freqs = calc_bigram_freqs(tokens)
print(bigram_freqs)

{(97, 97): 4, (97, 98): 2, (98, 100): 1, (100, 97): 1, (98, 97): 1, (97, 99): 1}


In [4]:
max_bigram = max(bigram_freqs, key=bigram_freqs.get)
print(max_bigram, bigram_freqs[max_bigram])

(97, 97) 4


In [5]:
def merge(tokens: list[int], bigram: tuple[int, int], new_token: int) -> list[int]:
    new_tokens = []

    i = 0
    while i < len(tokens):
        if (i == (len(tokens) - 1)) or ((tokens[i], tokens[i + 1]) != bigram):
            new_tokens.append(tokens[i])
        else:
            new_tokens.append(new_token)
            i += 1
        
        i += 1

    return new_tokens

new_tokens = merge(tokens, max_bigram, 255 + 1)
new_tokens

[256, 97, 98, 100, 256, 97, 98, 97, 99]

## Regex

In [6]:
import re

split_pattern = r"""[ ']?[a-zA-Z]+|\d{1,4}|\s+(?!\S)|.+?"""
compiled_pattern = re.compile(split_pattern)

weird_str = "Hello, my NAme isn'tcool. \n\t  I like may number8834534s; and other things."
compiled_pattern.findall(weird_str)

['Hello',
 ',',
 ' my',
 ' NAme',
 ' isn',
 "'tcool",
 '.',
 ' \n\t ',
 ' I',
 ' like',
 ' may',
 ' number',
 '8834',
 '534',
 's',
 ';',
 ' and',
 ' other',
 ' things',
 '.']

In [7]:
with open("../data/shakespeare.txt") as f:
    shakespeare_text = f.read()

print(shakespeare_text[:100])
compiled_pattern.findall(shakespeare_text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


['First',
 ' Citizen',
 ':',
 'Before',
 ' we',
 ' proceed',
 ' any',
 ' further',
 ',',
 ' hear',
 ' me',
 ' speak',
 '.',
 '\n',
 'All',
 ':',
 'Speak',
 ',',
 ' speak',
 '.',
 '\n',
 'First',
 ' Citizen',
 ':',
 'You']

## Class

In [8]:
class BytePairEncoder:
    """Byte pair encoder."""

    def __init__(self, text: str, vocab_size: int) -> None:
        split_pattern = re.compile(r"""[ ']?[a-zA-Z]+|\d{1,4}|\s+(?!\S)|.+?""")
        text_chunks = split_pattern.findall(text)
        token_chunks = [list(chunk.encode("utf-8")) for chunk in text_chunks]

        self.merges: dict[tuple[int, int], int] = {}
        self.vocab: dict[int, bytes] = {i: bytes([i]) for i in range(256)}

        for i in range(vocab_size - 256):
            bigram_freqs = {}
            for chunk in token_chunks:
                bigram_freqs = self.calc_bigram_freqs(chunk, bigram_freqs)
            
            max_bigram = max(bigram_freqs, key=bigram_freqs.get)
            if bigram_freqs[max_bigram] == 1:
                break

            new_token = 256 + i
            token_chunks = [self.merge(chunk, max_bigram, new_token) for chunk in token_chunks]

            self.merges[max_bigram] = new_token
            self.vocab[new_token] = self.vocab[max_bigram[0]] + self.vocab[max_bigram[1]]

    def encode(self, text: str) -> list[int]:
        tokens = list(text.encode("utf-8"))
        while len(tokens) >= 2:
            bigram_freqs = self.calc_bigram_freqs(tokens)

            bigram = min(bigram_freqs, key=lambda bigram: self.merges.get(bigram, float("inf")))
            if bigram not in self.merges:
                break

            tokens = self.merge(tokens, bigram, self.merges[bigram])

        return tokens
    
    def decode(self, tokens: list[int]) -> str:
        text_bytes = b"".join(self.vocab[i] for i in tokens)
        return text_bytes.decode("utf-8", errors="replace")

    def calc_bigram_freqs(self, tokens: list[int], bigram_freqs: dict = None) -> dict[tuple[int, int], int]:
        if not bigram_freqs:
            bigram_freqs = {}
        for bigram in zip(tokens[:-1], tokens[1:]):
            if bigram not in bigram_freqs:
                bigram_freqs[bigram] = 0
            bigram_freqs[bigram] += 1

        return bigram_freqs
    
    def merge(self, tokens: list[int], bigram: tuple[int, int], new_token: int) -> list[int]:
        new_tokens = []

        i = 0
        while i < len(tokens):
            if (i == (len(tokens) - 1)) or ((tokens[i], tokens[i + 1]) != bigram):
                new_tokens.append(tokens[i])
            else:
                new_tokens.append(new_token)
                i += 1
        
            i += 1

        return new_tokens

In [9]:
bpe = BytePairEncoder(text="aaabdaaabac", vocab_size=256 + 16)
print(bpe.merges)

tokens = bpe.encode("aaabdaaabac")
print(tokens)
print(bpe.decode(tokens))

{(97, 97): 256, (256, 97): 257, (257, 98): 258}
[258, 100, 258, 97, 99]
aaabdaaabac


In [10]:
bpe = BytePairEncoder(text=shakespeare_text, vocab_size=1024)
print(bpe.merges)
print(bpe.vocab)

tokens = bpe.encode(shakespeare_text[:200])
print(tokens)
print(bpe.decode(tokens))

{(32, 116): 256, (104, 101): 257, (32, 97): 258, (111, 117): 259, (32, 115): 260, (32, 109): 261, (105, 110): 262, (32, 119): 263, (114, 101): 264, (104, 97): 265, (110, 100): 266, (256, 257): 267, (32, 98): 268, (105, 115): 269, (111, 114): 270, (32, 102): 271, (101, 114): 272, (108, 108): 273, (105, 116): 274, (111, 110): 275, (32, 100): 276, (32, 99): 277, (101, 115): 278, (101, 110): 279, (32, 110): 280, (32, 108): 281, (32, 121): 282, (256, 104): 283, (97, 114): 284, (32, 104): 285, (32, 111): 286, (256, 111): 287, (282, 259): 288, (32, 112): 289, (265, 116): 290, (32, 73): 291, (32, 257): 292, (115, 116): 293, (118, 101): 294, (111, 116): 295, (258, 266): 296, (111, 119): 297, (262, 103): 298, (97, 110): 299, (286, 102): 300, (111, 109): 301, (32, 103): 302, (97, 116): 303, (268, 101): 304, (115, 101): 305, (261, 121): 306, (32, 262): 307, (99, 101): 308, (32, 265): 309, (108, 101): 310, (97, 121): 311, (108, 100): 312, (105, 114): 313, (101, 116): 314, (101, 100): 315, (117, 116

In [11]:
for i in range(256, len(bpe.vocab)):
    print(bpe.vocab[i])

b' t'
b'he'
b' a'
b'ou'
b' s'
b' m'
b'in'
b' w'
b're'
b'ha'
b'nd'
b' the'
b' b'
b'is'
b'or'
b' f'
b'er'
b'll'
b'it'
b'on'
b' d'
b' c'
b'es'
b'en'
b' n'
b' l'
b' y'
b' th'
b'ar'
b' h'
b' o'
b' to'
b' you'
b' p'
b'hat'
b' I'
b' he'
b'st'
b've'
b'ot'
b' and'
b'ow'
b'ing'
b'an'
b' of'
b'om'
b' g'
b'at'
b' be'
b'se'
b' my'
b' in'
b'ce'
b' ha'
b'le'
b'ay'
b'ld'
b'ir'
b'et'
b'ed'
b'ut'
b' me'
b'im'
b'ith'
b' not'
b'ch'
b' that'
b' is'
b'gh'
b'And'
b' for'
b"'s"
b'ke'
b' u'
b'our'
b' we'
b'oo'
b'ill'
b' e'
b'her'
b' with'
b'ent'
b' it'
b' your'
b'ad'
b'ri'
b' thou'
b' st'
b"'d"
b' k'
b'ome'
b' his'
b'ght'
b'EN'
b'ord'
b'id'
b'as'
b'The'
b' re'
b' have'
b'IN'
b'ly'
b'ra'
b' li'
b' him'
b'ur'
b' this'
b'al'
b'IO'
b' so'
b' as'
b' de'
b' on'
b'ore'
b'ro'
b'AR'
b'hi'
b'ould'
b'ood'
b'ck'
b'ain'
b'ver'
b'est'
b' thy'
b' sha'
b'ess'
b'ea'
b' do'
b' will'
b'am'
b' no'
b' but'
b'us'
b'and'
b'US'
b'if'
b' se'
b'ge'
b'Th'
b' all'
b' su'
b'ake'
b'To'
b' her'
b'ru'
b'ion'
b'th'
b' an'
b'ter'
b'ard'
b' lo'