# exercise

Build your own GPT-4 Tokenizer!


### Step 1

Write the `BasicTokenizer` class, with the following three core functions:

- `def train(self, text, vocab_size, verbose=False)`
- `def encode(self, text)`
- `def decode(self, ids)`

Train your tokenizer on whatever text you like and visualize the merged tokens. Do they look reasonable? One default test you may wish to use is the text file `tests/taylorswift.txt`.

In [1]:
# download dataset
# !wget https://raw.githubusercontent.com/karpathy/minbpe/master/tests/taylorswift.txt

In [2]:
def get_stats(ids, counts=None):
    """ Count the frequency of each consecutive pair of tokens. Update counts in-place if provided """
    pair_frequency = counts if counts is not None else defaultdict(int) # (int, int) -> int
    for id1, id2 in zip(ids, ids[1:]):
        pair_frequency[(id1, id2)] += 1
    return pair_frequency

def merge(ids, pair, new_tok):
    """ Given a list of tokens, a pair to merge, and a new token, replace all occurrences of pair with new_tok. """
    new_ids = []
    idx = 0
    while idx < len(ids):
        if idx != len(ids) - 1 and (ids[idx], ids[idx+1]) == pair:
            new_ids.append(new_tok)
            idx += 2
        else:
            new_ids.append(ids[idx])
            idx += 1
    return new_ids

In [3]:
from collections import defaultdict

class BasicTokenizer:
    INITIAL_VOCAB_SIZE = 256
    
    def __init__(self):
        self.vocab = {i : i.to_bytes() for i in range(self.INITIAL_VOCAB_SIZE)} # int -> bytes
        self.merges = {} # (int, int) -> int
    
    def train(self, text, vocab_size, verbose=False):
        """ Given training data as a single string, perform BPE merges until we reach vocab_size tokens,
        and store the information in self.vocab and self.merges.
        """
        ids = list(text.encode("utf-8"))
        for i in range(vocab_size - self.INITIAL_VOCAB_SIZE):
            # find pair to merge
            pair_frequency = get_stats(ids)
            most_frequent_pair = max(pair_frequency, key=lambda p: pair_frequency.get(p, float("-inf")))
        
            # mint new token and add it to our vocab
            next_token = self.INITIAL_VOCAB_SIZE + i
            self.vocab[next_token] = self.vocab[most_frequent_pair[0]] + self.vocab[most_frequent_pair[1]]
            self.merges[most_frequent_pair] = next_token
            if verbose:
                print(f"Merging {most_frequent_pair} -> {next_token}")
            # perform merge on our input for the next round
            ids = merge(ids, most_frequent_pair, next_token)
            
    def encode(self, text):
        """ text -> bytes using self.vocab and self.merges """
        ids = list(text.encode("utf-8"))
        # perform all merges
        while len(ids) > 1:
            pair_frequency = get_stats(ids)
            # since a new token may be merged with a subsequent token, we must process lowest index tokens first
            pair_to_merge = min(pair_frequency, key=lambda p: self.merges.get(p, float("inf")))
            if pair_to_merge not in self.merges:
                break
            ids = merge(ids, pair_to_merge, self.merges[pair_to_merge])
        
        return ids
    
    def decode(self, ids):
        """ bytes -> text using self.vocab and self.merges """
        return b"".join([self.vocab[tok] for tok in ids]).decode("utf-8", errors="replace")


text = open("taylorswift.txt", "r", encoding="utf-8").read()
    
tok = BasicTokenizer()
tok.train(text, 512, verbose=True)

Merging (101, 32) -> 256
Merging (44, 32) -> 257
Merging (100, 32) -> 258
Merging (46, 32) -> 259
Merging (114, 32) -> 260
Merging (50, 48) -> 261
Merging (115, 32) -> 262
Merging (105, 110) -> 263
Merging (111, 110) -> 264
Merging (114, 105) -> 265
Merging (116, 32) -> 266
Merging (116, 104) -> 267
Merging (101, 258) -> 268
Merging (257, 261) -> 269
Merging (97, 110) -> 270
Merging (97, 114) -> 271
Merging (101, 260) -> 272
Merging (121, 32) -> 273
Merging (97, 108) -> 274
Merging (267, 256) -> 275
Merging (118, 268) -> 276
Merging (119, 105) -> 277
Merging (101, 114) -> 278
Merging (264, 32) -> 279
Merging (277, 102) -> 280
Merging (82, 101) -> 281
Merging (83, 280) -> 282
Merging (111, 260) -> 283
Merging (99, 104) -> 284
Merging (269, 49) -> 285
Merging (111, 109) -> 286
Merging (98, 272) -> 287
Merging (32, 275) -> 288
Merging (97, 121) -> 289
Merging (101, 110) -> 290
Merging (111, 114) -> 291
Merging (274, 32) -> 292
Merging (101, 109) -> 293
Merging (46, 10) -> 294
Merging (265

In [4]:
print(tok.encode("hello taylor swift"))
print(tok.decode(tok.encode("hello taylor swift")))

[104, 101, 301, 369, 116, 299, 283, 115, 280, 116]
hello taylor swift


In [5]:
print({k:v for k, v in tok.vocab.items() if k > 255}) # print the merged tokens

{256: b'e ', 257: b', ', 258: b'd ', 259: b'. ', 260: b'r ', 261: b'20', 262: b's ', 263: b'in', 264: b'on', 265: b'ri', 266: b't ', 267: b'th', 268: b'ed ', 269: b', 20', 270: b'an', 271: b'ar', 272: b'er ', 273: b'y ', 274: b'al', 275: b'the ', 276: b'ved ', 277: b'wi', 278: b'er', 279: b'on ', 280: b'wif', 281: b'Re', 282: b'Swif', 283: b'or ', 284: b'ch', 285: b', 201', 286: b'om', 287: b'ber ', 288: b' the ', 289: b'ay', 290: b'en', 291: b'or', 292: b'al ', 293: b'em', 294: b'.\n', 295: b'rie', 296: b'ing', 297: b', 202', 298: b'ti', 299: b'ayl', 300: b'". ', 301: b'll', 302: b'Tayl', 303: b'trie', 304: b'.\n ', 305: b'to', 306: b'. Re', 307: b'. Retrie', 308: b'. Retrieved ', 309: b'Taylor ', 310: b'es', 311: b'Taylor Swif', 312: b'us', 313: b'rom', 314: b'ember ', 315: b'). ', 316: b'Ar', 317: b'from', 318: b'). "', 319: b'and ', 320: b're', 321: b'ou', 322: b'ori', 323: b'of', 324: b'gin', 325: b'ing ', 326: b'chi', 327: b'] ', 328: b'ginal ', 329: b'from the ', 330: b'original

Note we have some cross-category merging without the regex forced split, such as `337: b'. Archived from the original on '` or `412: b'). "Taylor Swift '`.

### Step 2

Convert you `BasicTokenizer` into a `RegexTokenizer`, which takes a regex pattern and splits the text exactly as GPT-4 would. Process the parts separately as before, then concatenate the results. Retrain your tokenizer and compare the results before and after. You should see that you will now have no tokens that go across categories (numbers, letters, punctuation, more than one whitespace). Use the GPT-4 pattern:

```
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
```

In [6]:
from collections import defaultdict
import regex as re

class RegexTokenizer:
    INITIAL_VOCAB_SIZE = 256
    GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
    
    def __init__(self):
        self.vocab = {i : i.to_bytes() for i in range(self.INITIAL_VOCAB_SIZE)} # int -> bytes
        self.merges = {} # (int, int) -> int
    
    def train(self, text, vocab_size, verbose=False):
        """ Given training data as a single string, perform BPE merges until we reach vocab_size tokens,
        and store the information in self.vocab and self.merges.
        """
        chunks = re.findall(self.GPT4_SPLIT_PATTERN, text)
        ids = [list(chunk.encode("utf-8")) for chunk in chunks]
        for i in range(vocab_size - self.INITIAL_VOCAB_SIZE):
            # find most common pair across all chunks
            pair_frequency = defaultdict(int)
            for ids_chunk in ids:
                get_stats(ids_chunk, pair_frequency) # update in-place
            most_frequent_pair = max(pair_frequency, key=lambda p: pair_frequency.get(p, float("-inf")))

            # mint new token and add it to our vocab
            next_token = self.INITIAL_VOCAB_SIZE + i
            self.vocab[next_token] = self.vocab[most_frequent_pair[0]] + self.vocab[most_frequent_pair[1]]
            self.merges[most_frequent_pair] = next_token
            if verbose:
                print(f"Merging {most_frequent_pair} -> {next_token}")
            # perform merge on our input for the next round
            ids = [merge(ids_chunk, most_frequent_pair, next_token) for ids_chunk in ids]

    def encode(self, text):
        """ text -> bytes using self.vocab and self.merges """
        ids = list(text.encode("utf-8"))
        # perform all merges
        while len(ids) > 1:
            pair_frequency = get_stats(ids)
            # since a new token may be merged with a subsequent token, we must process lowest index tokens first
            pair_to_merge = min(pair_frequency, key=lambda p: self.merges.get(p, float("inf")))
            if pair_to_merge not in self.merges:
                break
            ids = merge(ids, pair_to_merge, self.merges[pair_to_merge])
        
        return ids
    
    def decode(self, ids):
        """ bytes -> text using self.vocab and self.merges """
        return b"".join([self.vocab[tok] for tok in ids]).decode("utf-8", errors="replace")


text = open("taylorswift.txt", "r", encoding="utf-8").read()
    
tok = RegexTokenizer()
tok.train(text, 512, verbose=True)

Merging (101, 114) -> 256
Merging (50, 48) -> 257
Merging (111, 114) -> 258
Merging (105, 110) -> 259
Merging (101, 100) -> 260
Merging (32, 116) -> 261
Merging (111, 110) -> 262
Merging (104, 101) -> 263
Merging (32, 83) -> 264
Merging (97, 114) -> 265
Merging (97, 110) -> 266
Merging (32, 65) -> 267
Merging (261, 263) -> 268
Merging (97, 108) -> 269
Merging (114, 105) -> 270
Merging (118, 260) -> 271
Merging (115, 116) -> 272
Merging (119, 105) -> 273
Merging (32, 82) -> 274
Merging (257, 49) -> 275
Merging (32, 102) -> 276
Merging (257, 50) -> 277
Merging (32, 84) -> 278
Merging (102, 116) -> 279
Merging (97, 121) -> 280
Merging (32, 34) -> 281
Merging (273, 279) -> 282
Merging (101, 116) -> 283
Merging (264, 282) -> 284
Merging (99, 104) -> 285
Merging (98, 256) -> 286
Merging (97, 116) -> 287
Merging (111, 109) -> 288
Merging (101, 115) -> 289
Merging (101, 110) -> 290
Merging (101, 109) -> 291
Merging (34, 46) -> 292
Merging (32, 40) -> 293
Merging (46, 10) -> 294
Merging (259, 1

In [7]:
print(tok.encode(". Archived from the original on "))
print(tok.decode(tok.encode(". Archived from the original on ")))

[46, 329, 317, 268, 327, 299, 32]
. Archived from the original on 


We can see that the string `". Archived from the original on "` used to be a single token in the `BasicTokenizer`, but now it is broken up into multiple tokens due to the forced split.

In [8]:
print(text == tok.decode(tok.encode(text))) # sanity check

True



### Step 3

You're now ready to load the merges from the GPT-4 tokenizer and show that your tokenizer produces the identical results for both `encode` and `decode`, matching [tiktoken](https://github.com/openai/tiktoken).

```
# match this
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
text = enc.decode(ids) # get the same text back
```

Unfortunately, you will run into two issues:

1. It is not trivial to recover the raw merges from the GPT-4 tokenizer. You can easily recover what we call `vocab` here, and what they call and store under `enc._mergeable_ranks`. Feel free to copy paste the `recover_merges` function in `minbpe/gpt4.py`, which takes these ranks and returns the raw merges. If you wish to know how this function works, read [this](https://github.com/openai/tiktoken/issues/60) and [this](https://github.com/karpathy/minbpe/issues/11#issuecomment-1950805306). Basically, under some conditions it is enough to only store the parent nodes (and their rank) and get rid of the precise details of which children merged up to any parent.
2. Second, the GPT-4 tokenizer for some reason permutes its raw bytes. It stores this permutation in the first 256 elements of the mergeable ranks, so you can recover this byte shuffle relatively simply as `byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}`. In both your encode and decode, you'll have to shuffle bytes around accordingly. If you're stuck, reference the minbpe/gpt4.py` file for hints.

In [9]:
import tiktoken

# ---------- Copied from `minbpe/gpt4.py` under original repo ----------
def bpe(mergeable_ranks, token, max_rank):
    # helper function used in get_gpt4_merges() to reconstruct the merge forest
    parts = [bytes([b]) for b in token]
    while True:
        min_idx = None
        min_rank = None
        for i, pair in enumerate(zip(parts[:-1], parts[1:])):
            rank = mergeable_ranks.get(pair[0] + pair[1])
            if rank is not None and (min_rank is None or rank < min_rank):
                min_idx = i
                min_rank = rank
        if min_rank is None or (max_rank is not None and min_rank >= max_rank):
            break
        assert min_idx is not None
        parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
    return parts


def recover_merges(mergeable_ranks):
    # the `merges` are already the byte sequences in their merged state.
    # so we have to recover the original pairings. We can do this by doing
    # a small BPE training run on all the tokens, in their order.
    # also see https://github.com/openai/tiktoken/issues/60
    # also see https://github.com/karpathy/minbpe/issues/11#issuecomment-1950805306
    merges = {}
    for token, rank in mergeable_ranks.items():
        if len(token) == 1:
            continue # skip raw bytes
        pair = tuple(bpe(mergeable_ranks, token, max_rank=rank))
        assert len(pair) == 2
        # recover the integer ranks of the pair
        ix0 = mergeable_ranks[pair[0]]
        ix1 = mergeable_ranks[pair[1]]
        merges[(ix0, ix1)] = rank

    return merges
# ------------------------------------------------------------------------

class GPT4Tokenizer:
    INITIAL_VOCAB_SIZE = 256
    
    def __init__(self):
        # get the official tokenizer and its merges
        enc = tiktoken.get_encoding("cl100k_base")
        mergeable_ranks = enc._mergeable_ranks
        self.merges = recover_merges(mergeable_ranks)
        self.vocab = {i:i.to_bytes() for i in range(self.INITIAL_VOCAB_SIZE)}
        for pair, tok in self.merges.items():
            self.vocab[tok] = self.vocab[pair[0]] + self.vocab[pair[1]]
        self.byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}
        self.inverse_byte_shuffle = {v:k for k,v in self.byte_shuffle.items()}

    def train(self, text):
        raise NotImplementedError("This is a pretrained tokenizer, not meant to be trained")
        
    def encode(self, text):
        """ text -> bytes using self.vocab and self.merges """
        ids = [self.byte_shuffle[b] for b in text.encode("utf-8")]
        # perform all merges
        while len(ids) > 1:
            pair_frequency = get_stats(ids)
            # since a new token may be merged with a subsequent token, we must process lowest index tokens first
            pair_to_merge = min(pair_frequency, key=lambda p: self.merges.get(p, float("inf")))
            if pair_to_merge not in self.merges:
                break
            ids = merge(ids, pair_to_merge, self.merges[pair_to_merge])
        
        return ids
    
    def decode(self, ids):
        """ bytes -> text using self.vocab and self.merges """
        text_bytes = b"".join([self.vocab[tok] for tok in ids])
        text_bytes = bytes([self.inverse_byte_shuffle[b] for b in text_bytes])
        return text_bytes.decode("utf-8", errors="replace")
        

In [10]:
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
gpt4_ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
gpt4_text = enc.decode(gpt4_ids) # get the same text back

In [11]:
tok = GPT4Tokenizer()
ids = tok.encode("hello world!!!? (안녕하세요!) lol123 😉")
text = tok.decode(ids)

In [12]:
print(gpt4_ids == ids)
print(gpt4_text == text)
print(text == "hello world!!!? (안녕하세요!) lol123 😉")

True
True
True


### Step 4 (Will not do)

(Optional, irritating, not obviously useful) Add the ability to handle special tokens. You'll then be able to match the output of tiktoken even when special tokens are present, e.g.:

```
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("<|endoftext|>hello world", allowed_special="all")
```

Without `allowed_special` tiktoken will error.

### Step 5 (Will not do)


If you've made it this far, you're now a pro at LLM Tokenization! Sadly, you're not exactly done yet because a lot of LLMs outside of OpenAI (e.g. Llama, Mistral) use [sentencepiece](https://github.com/google/sentencepiece) instead. Primary difference being that sentencepiece runs BPE directly on Unicode code points instead of on UTF-8 encoded bytes. Feel free to explore sentencepiece on your own (good luck, it's not too pretty), and stretch goal if you really experience and suffer from the burden of time, re-write your BPE to be on Unicode code points and match the Llama 2 tokenizer.