Heavily inspired by Karpathy's, here's my late night take on a simple, clean, and fast BPE implementation.

I'm borrowing a lot from Karpathy's code, but we'll use more efficient data-structures:
1. We'll hold the sequence in a [*Leap*](leap.py) so we can efficiently _find_ the pairs to merge and _perform_ the merge.
2. We'll compute the stats only _once_ at the beginning, and during training just _update_ the relevant parts.
3. We'll maintain the counts in a [slightly-modified max-heap](heapykiyay.py), so finding the max element would be logarithmic.

If we are to perform N merges, and the length of the training text is L, Karpathy's original impl does (I think):
```python
for i in range(N):
    calc_stats()        # O(L)
    find_max()          # O(L)
    do_merges()         # O(L)
```
For a total complexity of O(N*L) (maybe I'm neglecting some factors).

With (1), (2), and (3) above, we instead get:
```python
stats = calc_stats()            # O(L)
for i in range(N):
    find_max()                  # O(1)
    do_merges_and_update_stats  # O(Mi + log(L))
```
Where Mi denotes the actual number of merges we perform at the ith iteration. Note that M1+M2+...+Mn <= L - 1, so the overall complexity of evertyhing (again neglecting logarithmic factors) is O(L)!

Unfortunately, we'll have to give up this lovely line from Karpathy's code:
```python
pair = max(stats, key=stats.get)
```
Karpathy's concise code is so nice, but this linear `max` operation is costing us a lot :)

Note that I only implement the functionality of Karpathy's `BasicTokenizer`.

You can find some more details in [this post](https://yanivle.github.io/ai/2024/02/23/fast_minbpe.html).

In [1]:
import time
from leap import Leap
from heapykiyay import HeapyKiYay

In [2]:
text = open("taylorswift.txt", "r", encoding="utf-8").read()
print(f'Source text is of length: {len(text):,}')
print(f'First 100 chars: {repr(text[:50])}')

Source text is of length: 185,561
First 100 chars: 'Copy paste of the Wikipedia article on Taylor Swif'


In [3]:
def build_leap(text):
    return Leap((t for t in text.encode('utf-8')))

leap = build_leap(text[:10])
assert bytes(leap.to_python_list()) == text[:10].encode('utf-8')

In [4]:
class Stats:
    def __init__(self, leap):
        self.d, self.heapy = {}, HeapyKiYay()
        for x in leap:
            if x.next is not None:
                self.inc((x.val, x.next.val))

    def inc(self, pair):
        if pair not in self.d:
            self.d[pair] = self.heapy.insert(pair)
        else:
            self.heapy.increase(self.d[pair])

    def dec(self, pair):
        self.heapy.decrease(self.d[pair])
        if self.d[pair][0] == 0:
            self.heapy.delete(self.d[pair])
            del self.d[pair]

    @property
    def __bool__(self):
        return bool(self.d)

    @property
    def max(self):
        return self.heapy.max

stats = Stats(leap)
for pair in sorted(stats.d):
    print(stats.d[pair])

assert sum((n[0] for n in stats.d.values())) == len(text[:10].encode('utf-8')) - 1

[1, (32, 112), 4]
[1, (67, 111), 7]
[1, (97, 115), 6]
[1, (111, 112), 5]
[1, (112, 97), 2]
[1, (112, 121), 8]
[1, (115, 116), 3]
[1, (116, 101), 1]
[1, (121, 32), 0]


In [5]:
def timeit(f, name):
    start = time.time()
    res = f()
    print(f'{name} took {time.time() - start:.2f} seconds.')
    return res

In [6]:
def merge(pair, new_id, leap, stats=None):
    # Merge occurrences of pair to new_id, and update stats (if provided).
    prev_deleted = None
    for node in [node for node in leap.occurrences(pair[0])]:
        if node is prev_deleted:  # Can happen is pair[0] == pair[1]
            continue
        if node.next is not None and node.next.val == pair[1]:  # Merge!
            prev_deleted = node.next
            leap.delete(node.next)
            leap.set_value(node, new_id)
            if stats is not None:  # Update stats:
                stats.dec(pair)
                if node.prev is not None:
                    stats.dec((node.prev.val, pair[0]))
                    stats.inc((node.prev.val, new_id))
                if node.next is not None:
                    stats.dec((pair[1], node.next.val))
                    stats.inc((new_id, node.next.val))

def train(text, vocab_size, verbose=False):
    print(f'Training tokenizer on text of length {len(text):,} with vocab of size {vocab_size:,}.')
    n_merges = vocab_size - 256
    vocab = {i: bytes([i]) for i in range(256)}
    merge_tree = []
    leap = timeit(lambda: build_leap(text), 'build_leap')
    stats = timeit(lambda: Stats(leap), 'init_stats')
    for i in range(n_merges):
        if not stats.d: break
        top_pair = stats.max
        new_id = len(vocab)
        merge_tree.append((top_pair, new_id))
        vocab[new_id] = vocab[top_pair[0]] + vocab[top_pair[1]]
        if verbose:
            print(f"Merge {i+1}/{n_merges}: {top_pair} -> {new_id} ({vocab[new_id]}) had {stats.d[top_pair][0]} occurrences")
        merge(top_pair, new_id, leap, stats)
    
    return merge_tree, vocab

def tokenize(text, merge_tree):
    leap = build_leap(text)
    for pair, new_id in merge_tree:
        merge(pair, new_id, leap)
    return leap.to_python_list()

def detokenize(lst, vocab):
    return b''.join((vocab[t] for t in lst)).decode('utf-8')

In [7]:
merge_tree, vocab = train(text, 512, verbose=True)

Training tokenizer on text of length 185,561 with vocab of size 512.
build_leap took 0.21 seconds.
init_stats took 0.09 seconds.
Merge 1/256: (101, 32) -> 256 (b'e ') had 2981 occurrences
Merge 2/256: (44, 32) -> 257 (b', ') had 2961 occurrences
Merge 3/256: (100, 32) -> 258 (b'd ') had 2617 occurrences
Merge 4/256: (46, 32) -> 259 (b'. ') had 2560 occurrences
Merge 5/256: (114, 32) -> 260 (b'r ') had 2428 occurrences
Merge 6/256: (50, 48) -> 261 (b'20') had 2365 occurrences
Merge 7/256: (115, 32) -> 262 (b's ') had 2053 occurrences
Merge 8/256: (105, 110) -> 263 (b'in') had 2006 occurrences
Merge 9/256: (111, 110) -> 264 (b'on') had 1815 occurrences
Merge 10/256: (114, 105) -> 265 (b'ri') had 1805 occurrences
Merge 11/256: (116, 32) -> 266 (b't ') had 1802 occurrences
Merge 12/256: (116, 104) -> 267 (b'th') had 1737 occurrences
Merge 13/256: (101, 258) -> 268 (b'ed ') had 1736 occurrences
Merge 14/256: (257, 261) -> 269 (b', 20') had 1705 occurrences
Merge 15/256: (97, 110) -> 270 (b'

Timing time!

In [8]:
for vocab_size in [300, 1000, 10_000]:
    merge_tree, vocab = timeit(lambda: train(text, vocab_size), 'Training')
    tokenized_text = timeit(lambda: tokenize(text, merge_tree), 'Tokenization')
    print(f'Tokenized text has {len(tokenized_text)} tokens.')
    detokenized_text = timeit(lambda: detokenize(tokenized_text, vocab), 'Detokenize')
    assert detokenized_text == text
    print()

Training tokenizer on text of length 185,561 with vocab of size 300.
build_leap took 0.25 seconds.
init_stats took 0.10 seconds.
Training took 0.74 seconds.
Tokenization took 0.54 seconds.
Tokenized text has 128451 tokens.
Detokenize took 0.01 seconds.

Training tokenizer on text of length 185,561 with vocab of size 1,000.
build_leap took 0.28 seconds.
init_stats took 0.11 seconds.
Training took 1.49 seconds.
Tokenization took 0.69 seconds.
Tokenized text has 58298 tokens.
Detokenize took 0.00 seconds.

Training tokenizer on text of length 185,561 with vocab of size 10,000.
build_leap took 0.21 seconds.
init_stats took 0.10 seconds.
Training took 1.83 seconds.
Tokenization took 0.85 seconds.
Tokenized text has 24338 tokens.
Detokenize took 0.00 seconds.



In [9]:
# Let's inspect our tokenized text:
def debug(tokenized_text, vocab):
    print('🍔'.join([vocab[t].decode('utf-8') for t in tokenized_text]))

debug(tokenized_text[:100], vocab)

C🍔op🍔y p🍔ast🍔e of the 🍔Wikipe🍔dia 🍔article 🍔on 🍔Taylor Swift, 🍔as of 🍔F🍔eb🍔 🍔16🍔, 2024🍔.
🍔--🍔-🍔

🍔Main 🍔m🍔enu🍔

🍔Wikipedia🍔The F🍔ree 🍔Enc🍔yclopedia🍔

🍔Search🍔
🍔C🍔re🍔ate 🍔account🍔
🍔L🍔og🍔 🍔in🍔

🍔Personal 🍔tool🍔s
🍔Cont🍔ents 🍔 🍔h🍔ide🍔
🍔(🍔Top🍔)
🍔Life and career
🍔Toggle 🍔Life and 🍔career 🍔subsection
🍔Artistry
🍔Toggle 🍔Artist🍔ry 🍔subsection
🍔Accolades and achievements
🍔Cultural status
🍔Toggle 🍔Cultural stat🍔us 🍔subsection
🍔Wealth
🍔Toggle 🍔W🍔ealth 🍔subsection
🍔Discography
🍔Filmography
🍔Tours
🍔See also🍔
F🍔ootnotes
🍔References
🍔Toggle 🍔Referenc🍔es 🍔subsection
🍔External links
🍔Taylor Swift

🍔13🍔6 🍔l🍔ang🍔u🍔ag🍔es
🍔Ar🍔tic🍔le
🍔Tal🍔k


In [10]:
# What about a GPT-4-like vocabulary with 100K tokens?
merge_tree, vocab = timeit(lambda: train(text, 100_000), 'Training')
tokenized_text = timeit(lambda: tokenize(text, merge_tree), 'Tokenization')
print(f'Tokenized text has {len(tokenized_text)} tokens.')
detokenized_text = timeit(lambda: detokenize(tokenized_text, vocab), 'Detokenize')
assert detokenized_text == text

Training tokenizer on text of length 185,561 with vocab of size 100,000.
build_leap took 0.20 seconds.
init_stats took 0.11 seconds.
Training took 3.14 seconds.
Tokenization took 0.90 seconds.
Tokenized text has 1 tokens.
Detokenize took 0.00 seconds.


With a vocabulary of 100K, only 1 token remains :) Need to run on a longer text :)