# Tokenizers playground

this notebook is mostly following the [tut](https://www.youtube.com/watch?v=zduSFxRajkE)

### Get some text data

<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/9/98/International_Pok%C3%A9mon_logo.svg/800px-International_Pok%C3%A9mon_logo.svg.png" width=30% height=30% />

In [1]:
from wikipediaapi import Wikipedia

wiki_api = Wikipedia("Tokenizers, yay!", "en")
wiki_page = wiki_api.page("Pokémon")
pokemon_text = wiki_page.text

# encode the pokemon text as utf-8
tokens = list(pokemon_text.encode("utf-8"))
print(len(pokemon_text), len(tokens))

85792 86236


In [2]:
ord("😂"), chr(list("😂".encode("utf-8"))[3]), "😂".encode("utf-8").decode()

(128514, '\x82', '😂')

### Basic BPE tokenization

In [3]:
from collections import defaultdict
from typing import Dict, List, Tuple

def get_pairs(tokens: List[int]) -> Dict[Tuple[int, int], int]:
    """Return pairs of tokens from given token list"""
    token_pairs = defaultdict(int)
    for pair in zip(tokens, tokens[1:]):
        token_pairs[pair] += 1
    return token_pairs

token_pairs = get_pairs(tokens)
print(token_pairs)
print(max(token_pairs, key=token_pairs.get))

defaultdict(<class 'int'>, {(80, 111): 408, (111, 107): 414, (107, 195): 275, (195, 169): 275, (169, 109): 267, (109, 111): 543, (111, 110): 1148, (110, 32): 1304, (32, 105): 736, (105, 115): 374, (115, 32): 1333, (32, 97): 1376, (97, 32): 393, (32, 74): 107, (74, 97): 91, (97, 112): 183, (112, 97): 226, (97, 110): 1136, (110, 101): 372, (101, 115): 680, (115, 101): 481, (101, 32): 2295, (32, 109): 343, (109, 101): 672, (101, 100): 955, (100, 105): 235, (105, 97): 140, (32, 102): 425, (102, 114): 92, (114, 97): 414, (110, 99): 276, (99, 104): 360, (104, 105): 449, (32, 99): 652, (99, 111): 373, (110, 115): 237, (115, 105): 251, (115, 116): 456, (116, 105): 529, (105, 110): 1257, (110, 103): 510, (103, 32): 401, (32, 111): 772, (111, 102): 440, (102, 32): 434, (32, 118): 94, (118, 105): 131, (105, 100): 221, (100, 101): 417, (101, 111): 44, (111, 32): 587, (32, 103): 287, (103, 97): 261, (97, 109): 388, (115, 44): 181, (44, 32): 1027, (110, 105): 234, (105, 109): 238, (109, 97): 258, (9

In [4]:
def merge(tokens, pair, idx):
    """
    Replace pair inside of tokens with a new idx.

    Note: simple BPE merging
    """
    new_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and tuple(tokens[i:i+2]) == pair:
            new_tokens.append(idx)
            i += 2
        else:
            new_tokens.append(tokens[i])
            i += 1
    return new_tokens

print(merge([1, 3, 3, 7], (3, 3), 0))


[1, 0, 7]


In [5]:
def get_tokenizer(tokens, num_merges=20, starting_idx=256, verbose=True) -> Tuple[List[int], Dict[Tuple[int, int], int]]:
    """
    Should return a tokenizer class, but for now it returns a dictionary of merge operations: (int, int) -> int

    Args:
        num_merges: how much to extend the vocabulary size
        starting_idx: the index which is used in the first merge op (as the new replacement index)
    """
    merges = {}

    for i in range(num_merges):
        pairs = get_pairs(tokens)
        most_frequent_pair = max(pairs, key=pairs.get)
        new_index = starting_idx + i
        tokens = merge(tokens, most_frequent_pair, new_index)
        merges[most_frequent_pair] = new_index

        if verbose:
            print(f"Merging {most_frequent_pair} -> {new_index}")

    assert len(merges) == num_merges
    return tokens, merges

new_tokens, merges = get_tokenizer(tokens)

print(f"\nOriginal tokens length {len(tokens)} vs compressed length {len(new_tokens)}")
print(f"Compression ratio: {len(tokens) / len(new_tokens):.2f}")

Merging (101, 32) -> 256
Merging (100, 32) -> 257
Merging (116, 104) -> 258
Merging (115, 32) -> 259
Merging (110, 32) -> 260
Merging (101, 114) -> 261
Merging (44, 32) -> 262
Merging (97, 110) -> 263
Merging (105, 110) -> 264
Merging (116, 32) -> 265
Merging (101, 257) -> 266
Merging (258, 256) -> 267
Merging (114, 101) -> 268
Merging (121, 32) -> 269
Merging (111, 32) -> 270
Merging (97, 114) -> 271
Merging (46, 32) -> 272
Merging (111, 260) -> 273
Merging (111, 110) -> 274
Merging (111, 114) -> 275

Original tokens length 86236 vs compressed length 67071
Compression ratio: 1.29


### Decoding & Encoding

In [6]:
def decode(tokens: List[int], merges) -> str:
    """Return a string from a list of token ids"""
    vocab = { idx: bytes([idx]) for idx in range(256) }
    for pair, idx in merges.items():
        vocab[idx] = vocab[pair[0]] + vocab[pair[1]]

    joined = b"".join([vocab[idx] for idx in tokens])
    return joined.decode("utf-8", errors="replace")

decoded_first_260 = []
for i in range(260):
    decoded_first_260.append(decode([i], merges))

print(decoded_first_260)

['\x00', '\x01', '\x02', '\x03', '\x04', '\x05', '\x06', '\x07', '\x08', '\t', '\n', '\x0b', '\x0c', '\r', '\x0e', '\x0f', '\x10', '\x11', '\x12', '\x13', '\x14', '\x15', '\x16', '\x17', '\x18', '\x19', '\x1a', '\x1b', '\x1c', '\x1d', '\x1e', '\x1f', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '\x7f', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', '�', 

In [7]:
def encode(s: str, merges) -> List[int]:
    """Return a list of token ids from a given string (by applying the merge rules)"""
    # encode in utf-8 to get a list of bytes
    tokens = list(s.encode("utf-8"))
    for pair, idx in merges.items():
        tokens = merge(tokens, pair, idx)
    return tokens

print(decode(encode("Pokémon", merges), merges))

Pokémon


## Exercises

[https://github.com/karpathy/minbpe/blob/master/exercise.md](https://github.com/karpathy/minbpe/blob/master/exercise.md)

### Step 1

In [10]:
from typing import Dict, List, Tuple

class BasicTokenizer:
    """Simple tokenizer.

    1. convert string (sequence of Unicode code points) to UTF-8 encoding
    2.
    """

    DEFAULT_VOCAB_SIZE = 256

    def __init__(self):
        self.merges = {}
        self.vocab = {}

    def _get_pairs(self, tokens: List[int]) -> Dict[Tuple[int, int], int]:
        """Return pairs of tokens from given token list"""
        token_pairs = defaultdict(int)
        for pair in zip(tokens, tokens[1:]):
            token_pairs[pair] += 1
        return token_pairs

    def _merge(self, tokens, pair, idx):
        """
        Replace pair inside of tokens with a new idx.

        Note: simple BPE merging
        """
        new_tokens = []
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and tuple(tokens[i:i+2]) == pair:
                new_tokens.append(idx)
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        return new_tokens

    def train(self, text, vocab_size, verbose=False):
        """

        """
        new_merges = {}
        num_merges = vocab_size - self.DEFAULT_VOCAB_SIZE

        tokens = list(text.encode("utf-8"))

        for i in range(num_merges):
            pairs = self._get_pairs(tokens)
            most_frequent_pair = max(pairs, key=pairs.get)
            new_index = self.DEFAULT_VOCAB_SIZE + i
            tokens = self._merge(tokens, most_frequent_pair, new_index)
            new_merges[most_frequent_pair] = new_index

            if verbose:
                print(f"Merging {most_frequent_pair} -> {new_index}")

        assert len(new_merges) == num_merges
        self.merges = new_merges

        # save vocabulary for new merges
        vocab = { idx: bytes([idx]) for idx in range(self.DEFAULT_VOCAB_SIZE) }
        for pair, idx in merges.items():
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
        self.vocab = vocab

        return new_merges, vocab

    def encode(self, text):
        """Return a list of token ids from a given string (by applying the merge rules)"""
        # encode in utf-8 to get a list of bytes
        tokens = list(text.encode("utf-8"))
        for pair, idx in self.merges.items():
            tokens = self._merge(tokens, pair, idx)
        return tokens

    def decode(self, ids):
        """Return a string from a list of token ids"""
        joined = b"".join([self.vocab[idx] for idx in ids])
        return joined.decode("utf-8", errors="replace")

tokenizer = BasicTokenizer()
tokenizer.train(pokemon_text, 280, verbose=True)
tokenizer.decode(tokenizer.encode("Hello world! 🥹"))

Merging (101, 32) -> 256
Merging (100, 32) -> 257
Merging (116, 104) -> 258
Merging (115, 32) -> 259
Merging (110, 32) -> 260
Merging (101, 114) -> 261
Merging (44, 32) -> 262
Merging (97, 110) -> 263
Merging (105, 110) -> 264
Merging (116, 32) -> 265
Merging (101, 257) -> 266
Merging (258, 256) -> 267
Merging (114, 101) -> 268
Merging (121, 32) -> 269
Merging (111, 32) -> 270
Merging (97, 114) -> 271
Merging (46, 32) -> 272
Merging (111, 260) -> 273
Merging (111, 110) -> 274
Merging (111, 114) -> 275
Merging (101, 110) -> 276
Merging (97, 108) -> 277
Merging (263, 257) -> 278
Merging (111, 102) -> 279


'Hello world! 🥹'