I just watched [Karpathy's BPE video](https://www.youtube.com/watch?v=zduSFxRajkE) (thanks Andrej for all the fun videos!) and got inspired. Karpathy's code is really nice (as always), and [his implementation](https://github.com/karpathy/minbpe) is obviously not meant to be fast, _but_, we **can** make the code much faster without sacrificing clarity too much. Hopefully, faster code can make experimentation (e.g. with different scores, instead of always taking the most popular pair) easier. So... here's my late night take on simple, clean, but faster BPE!

For example, on my laptop, training/tokenizing the taylorswift.txt file with a vocab size of 10K takes:

|              |  minbpe (Karpathy's)       |   fast_minbpe (this colab)|
|--------------|---------------|--------------|
|Training      |  110.10 secs  | 13.65 secs   |
|Tokenizing    |  190.91 secs  | 0.84 secs    |


In [1]:
text = open("taylorswift.txt", "r", encoding="utf-8").read()
print(f'Source text is of length: {len(text)}')
print(f'First 100 chars: {repr(text[:50])}')

Source text is of length: 185561
First 100 chars: 'Copy paste of the Wikipedia article on Taylor Swif'


I'm borrowing a lot from Karpathy's code. The main differences being:
1. We'll hold the sequence in a *linked list* (`Node` below) so that we can efficiently delete elements in the middle, and
2. We'll compute the stats dict only _once_ at the beginning, and during training/tokenization just _update_ the relevant parts.

In [2]:
class Node:
    def __init__(self, id, prev=None, next=None):
        self.id = id
        self.prev = prev
        self.next = next

def preprocess_text(text):
    # Build a linked list from the bytes of the text and a map from a byte to all of its occurrences.
    tok_to_pos = {}  # This will map from a token to all of its occurrences in the linked list.
    prev = sent = Node(None)  # For simplicity, we add a sentinel token at the start and end.
    for t in text.encode('utf-8'):
        node = Node(t, prev, sent)
        tok_to_pos.setdefault(t, []).append(node)
        prev.next = prev = node  # This is fine as assignments are evaluated left to right!
    return sent.next, tok_to_pos

def to_python_list(ll):
    res = []
    while ll.id is not None:
        res.append(ll.id)
        ll = ll.next
    return res

ll, tok_to_pos = preprocess_text(text[:10])
assert bytes(to_python_list(ll)) == text[:10].encode('utf-8')

In [3]:
def update_stats(stats, pair, delta):
    stats[pair] = stats.get(pair, 0) + delta
    if stats[pair] == 0:
        del stats[pair]

def init_stats(ll):
    stats = {}
    while ll.next.id is not None:
        update_stats(stats, (ll.id, ll.next.id), 1)
        ll = ll.next
    return stats

stats = init_stats(ll)
for pair in sorted(stats, key=stats.get, reverse=True):
    print(pair, stats[pair])

assert sum(stats.values()) == len(text[:10].encode('utf-8')) - 1

(67, 111) 1
(111, 112) 1
(112, 121) 1
(121, 32) 1
(32, 112) 1
(112, 97) 1
(97, 115) 1
(115, 116) 1
(116, 101) 1


In [4]:
def merge(pair, new_id, tok_to_pos, stats=None):
    # Merge pair in our linked list (accessible through tok_to_pos) into new_id, and update the stats struct (if provided).
    pos0_to_delete = set()
    pos1_to_delete = set()

    for pos in tok_to_pos[pair[0]]:
        if pos in pos1_to_delete:
            continue  # Already deleted (can happen if pair[0] == pair[1]).
        if pos.next.id == pair[1]:
            pos0_to_delete.add(pos)
            pos1_to_delete.add(pos.next)
            tok_to_pos.setdefault(new_id, []).append(pos)
            pos.next.next.prev = pos
            pos.next = pos.next.next
            pos.id = new_id
            if stats is not None:
                # When merging (b, c) into x, so "a b c d" becomes "a x d", we need to:
                # - Decrement (a, b), (b, c), (c, d), and
                # - Increment (a, x), (x, d)
                update_stats(stats, pair, -1)
                if pos.prev.id is not None:
                    update_stats(stats, (pos.prev.id, pair[0]), -1)
                    update_stats(stats, (pos.prev.id, new_id), 1)
                if pos.next.id is not None:
                    update_stats(stats, (pair[1], pos.next.id), -1)
                    update_stats(stats, (new_id, pos.next.id), 1)

    # There's probably a better way to do this :)
    tok_to_pos[pair[0]] = [pos for pos in tok_to_pos[pair[0]] if pos not in pos0_to_delete]
    tok_to_pos[pair[1]] = [pos for pos in tok_to_pos[pair[1]] if pos not in pos1_to_delete]

def train(text, vocab_size, verbose=False):
    print(f'Training tokenizer on text of length {len(text):,} with vocab of size {vocab_size:,}.')
    n_merges = vocab_size - 256
    vocab = {i: bytes([i]) for i in range(256)}
    merge_tree = []
    ll, tok_to_pos = preprocess_text(text)
    stats = init_stats(ll)
    for i in range(n_merges):
        if not stats: break
        top_pair = max(stats, key=stats.get)
        new_id = len(vocab)
        merge_tree.append((top_pair, new_id))
        vocab[new_id] = vocab[top_pair[0]] + vocab[top_pair[1]]
        if verbose:
            print(f"Merge {i+1}/{n_merges}: {top_pair} -> {new_id} ({vocab[new_id]}) had {stats[top_pair]} occurrences")
        merge(top_pair, new_id, tok_to_pos, stats)
    
    return merge_tree, vocab

def tokenize(text, merge_tree):
    ll, tok_to_pos = preprocess_text(text)
    for pair, new_id in merge_tree:
        merge(pair, new_id, tok_to_pos, None)
    return to_python_list(ll)

def detokenize(lst, vocab):
    return b''.join((vocab[t] for t in lst)).decode('utf-8')

In [5]:
merge_tree, vocab = train(text, 512, verbose=True)

Training tokenizer on text of length 185,561 with vocab of size 512.
Merge 1/256: (101, 32) -> 256 (b'e ') had 2981 occurrences
Merge 2/256: (44, 32) -> 257 (b', ') had 2961 occurrences
Merge 3/256: (100, 32) -> 258 (b'd ') had 2617 occurrences
Merge 4/256: (46, 32) -> 259 (b'. ') had 2560 occurrences
Merge 5/256: (114, 32) -> 260 (b'r ') had 2428 occurrences
Merge 6/256: (50, 48) -> 261 (b'20') had 2365 occurrences
Merge 7/256: (115, 32) -> 262 (b's ') had 2053 occurrences
Merge 8/256: (105, 110) -> 263 (b'in') had 2006 occurrences
Merge 9/256: (111, 110) -> 264 (b'on') had 1815 occurrences
Merge 10/256: (114, 105) -> 265 (b'ri') had 1805 occurrences
Merge 11/256: (116, 32) -> 266 (b't ') had 1802 occurrences
Merge 12/256: (116, 104) -> 267 (b'th') had 1737 occurrences
Merge 13/256: (101, 258) -> 268 (b'ed ') had 1736 occurrences
Merge 14/256: (257, 261) -> 269 (b', 20') had 1705 occurrences
Merge 15/256: (97, 110) -> 270 (b'an') had 1487 occurrences
Merge 16/256: (97, 114) -> 271 (b'

Timing time!

In [6]:
import time

def timeit(f, name):
    start = time.time()
    res = f()
    print(f'{name} took {time.time() - start:.2f} seconds.')
    return res

for vocab_size in [300, 1000, 10_000]:
    merge_tree, vocab = timeit(lambda: train(text, vocab_size), 'Training')
    tokenized_text = timeit(lambda: tokenize(text, merge_tree), 'Tokenization')
    print(f'Tokenized text has {len(tokenized_text)} tokens.')
    detokenized_text = timeit(lambda: detokenize(tokenized_text, vocab), 'Detokenize')
    assert detokenized_text == text
    print()

Training tokenizer on text of length 185,561 with vocab of size 300.
Training took 0.48 seconds.
Tokenization took 0.41 seconds.
Tokenized text has 128451 tokens.
Detokenize took 0.01 seconds.

Training tokenizer on text of length 185,561 with vocab of size 1,000.
Training took 1.45 seconds.
Tokenization took 0.70 seconds.
Tokenized text has 58341 tokens.
Detokenize took 0.00 seconds.

Training tokenizer on text of length 185,561 with vocab of size 10,000.
Training took 13.65 seconds.
Tokenization took 0.84 seconds.
Tokenized text has 24372 tokens.
Detokenize took 0.01 seconds.



In [7]:
# Let's inspect our tokenized text:
def debug(tokenized_text, vocab):
    print('🍔'.join([vocab[t].decode('utf-8') for t in tokenized_text]))

debug(tokenized_text[:100], vocab)

Cop🍔y p🍔ast🍔e of the 🍔Wikipe🍔dia 🍔article 🍔on 🍔Taylor Swift, 🍔as of 🍔Feb🍔 16🍔, 2024.
🍔--🍔-🍔

🍔Main 🍔m🍔enu🍔

🍔Wikipedia🍔The F🍔ree 🍔Enc🍔yclopedia
🍔
🍔Search🍔
🍔Cre🍔ate 🍔account🍔
L🍔og🍔 in🍔

🍔Personal 🍔tool🍔s
🍔Cont🍔ents 🍔 h🍔ide🍔
(🍔Top🍔)
🍔Life and 🍔career🍔
Toggle 🍔Life and 🍔career 🍔subsection
🍔Artistry🍔
Toggle 🍔Artist🍔ry 🍔subsection
🍔Accolades and achievements
🍔Cultural status🍔
Toggle 🍔Cultural 🍔status 🍔subsection
🍔Wealth🍔
Toggle 🍔Weal🍔th 🍔subsection
🍔Discography
🍔Filmography
🍔Tours
🍔See also🍔
F🍔ootnotes
🍔References
🍔Toggle 🍔Refer🍔ences 🍔subsection
🍔External links
🍔Taylor Swift🍔

🍔13🍔6 🍔l🍔ang🍔uag🍔es
🍔Artic🍔le
🍔Tal🍔k
🍔Read🍔
View 🍔sour🍔ce🍔
View 🍔history🍔

🍔T🍔ool
