Notebook prepared by Henrique Lopes Cardoso (hlc@fe.up.pt), based on [minbpe](https://github.com/karpathy/minbpe) by Andrej Karpathy. He also has a very nice introduction to GPT tokenizers [here](https://www.youtube.com/watch?v=zduSFxRajkE).

# BYTE PAIR ENCODING

Byte Pair Encoding is a well-known tokenization algorithm that proceeds by making merges of the most frequent adjacent token pairs. It starts with a vocabulary including every single character (in our case, we will use UTF-8 encoding), and then adds new tokens to the vocabulary by merging adjacent ones, as found in the training data.

In this notebook, you will implement the Byte Pair Encoding algorithm based on three main functions:

- `def train(text, vocab_size, verbose=False)`: the function where you will train the BPE tokenizer on some data, and that will give us as an output the *merges* that were made and the trained *vocabulary*.
- `def encode(text)`: the function that, given some text, tokenizes it based on the merges obtained from the training step of the tokenizer.
- `def decode(ids)`: the function that, given some sequence of tokens, reconstructs the original text.

> Note that GPT tokenizers, such as GPT-4's, includes additional complexity than what is covered in this notebook, including:
> - Splitting the text into chunks using a regex pattern before tokenization, meant to ensure that no merges will happen across category boundaries (letters, numbers, punctuation).
> - Including special tokens that are not part of the text and are specific to the functioning of the language model.

## Auxiliary functions

We'll start by defining two auxiliary functions that will be helpful later on.

The first of these functions is `merge`: given a list of integers `ids` (token identifiers) and a `pair`, it replaces every consecutive occurrence of the pair in the list with some other identifier `idx`.

In [3]:
def merge(ids, pair, idx):
    """
    In the list of integers ids, replace all consecutive occurrences of pair with the new integer token idx
    Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
    """
    newids = []
    i = 0

    while i < len(ids):
        if i < len(ids) - 1 and (ids[i], ids[i + 1]) == pair:
            newids.append(idx)
            i += 2  # Skip the next element as it's part of the pair
        else:
            newids.append(ids[i])
            i += 1  # Move to the next element

    return newids


Make sure to test your `merge` function:

In [4]:
assert merge(ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4) == [4,3,4]

# other tests here


The second function, `get_pair_counts`, counts the number of times each consecutive pair occurs in a sequence of integers.

In [5]:
def get_pair_counts(ids, counts=None):
    """
    Given a list of integers ids, return a dictionary of counts of consecutive pairs.
    Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    Optionally allows updating an existing dictionary of counts.
    """
    counts = {} if counts is None else counts

    for i in range(len(ids) - 1):
        pair = (ids[i], ids[i + 1])
        counts[pair] = counts.get(pair, 0) + 1

    return counts

Make sure to test your `get_pair_counts` function:

In [6]:
assert get_pair_counts([1, 2, 3, 1, 2]) == {(1, 2): 2, (2, 3): 1, (3, 1): 1}

# other tests here


## Train

The `train` function will build the tokenizer. That is, given some training corpus and the desired vocabulary size, it will determine what merges to make based on the frequency of adjacent tokens. It will also provide us with the resulting vocabulary.

In [8]:
def train(text, vocab_size, verbose=False):
    assert vocab_size >= 256  # We should do at least one merge...

    vocab = {idx: bytes([idx]) for idx in range(256)}  # int -> bytes
    num_merges = vocab_size - 256

    # Input text preprocessing: encode text using UTF-8
    text_bytes = text.encode("utf-8")  # Raw bytes
    ids = list(text_bytes)  # List of integers in range 0..255

    # Iteratively merge the most common pairs to create new tokens
    merges = {}  # (int, int) -> int
    for i in range(num_merges):
        # a) Count the number of times every consecutive pair appears
        counts = get_pair_counts(ids)  # Use previously defined function

        # b) Find the pair with the highest count
        pair = max(counts, key=counts.get)

        # Use the next available id for the new token
        idx = 256 + i

        # Replace all occurrences of pair in ids with idx
        ids = merge(ids, pair, idx)

        # Save the merge
        merges[pair] = idx
        vocab[idx] = vocab[pair[0]] + vocab[pair[1]]  # Update vocabulary

        # Print progress if verbose mode is enabled
        if verbose:
            print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {counts[pair]} occurrences")

    return merges, vocab  # Return the merges and vocabulary

Let's train the tokenizer in some simple text:

In [9]:
text = "aaabdaaabac"
merges, vocab = train(text, 256 + 3, verbose=True)

merge 1/3: (97, 97) -> 256 (b'aa') had 4 occurrences
merge 2/3: (256, 97) -> 257 (b'aaa') had 2 occurrences
merge 3/3: (257, 98) -> 258 (b'aaab') had 2 occurrences


And let's see what merges were made and what has been added to the vocabulary:

In [10]:
merges

{(97, 97): 256, (256, 97): 257, (257, 98): 258}

In [11]:
for key, value in vocab.items():
    if key > 255:
        print(key, value)

256 b'aa'
257 b'aaa'
258 b'aaab'


Train the tokenizer in some longer text, for a vocabulary size of 512.

In [13]:
def train(text, vocab_size, verbose=False):
    assert vocab_size >= 256  # We should do at least one merge...

    vocab = {idx: bytes([idx]) for idx in range(256)}  # int -> bytes
    num_merges = vocab_size - 256

    # Input text preprocessing: encode text using UTF-8
    text_bytes = text.encode("utf-8")  # Raw bytes
    ids = list(text_bytes)  # List of integers in range 0..255

    # Iteratively merge the most common pairs to create new tokens
    merges = {}  # (int, int) -> int
    for i in range(num_merges):
        # Count the number of times every consecutive pair appears
        counts = get_pair_counts(ids)  # Use previously defined function

        # Find the pair with the highest count
        if not counts:
            break  # Stop if no more pairs can be merged
        pair = max(counts, key=counts.get)

        # Use the next available id for the new token
        idx = 256 + i

        # Replace all occurrences of pair in ids with idx
        ids = merge(ids, pair, idx)

        # Save the merge
        merges[pair] = idx
        vocab[idx] = vocab[pair[0]] + vocab[pair[1]]  # Update vocabulary

        # Print progress if verbose mode is enabled
        if verbose:
            print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {counts[pair]} occurrences")

    return merges, vocab  # Return the merges and vocabulary


# Load a longer text (for example, from a book)
text = """In the beginning, the universe was created. This has made a lot of people very angry and been widely regarded as a bad move.
It is said that if someone discovers exactly what the Universe is for and why it is here, it will instantly disappear and be replaced by something even more bizarre and inexplicable.
There is another theory which states that this has already happened."""

# Train the tokenizer with a vocabulary size of 512
vocab_size = 512
merges, vocab = train(text, vocab_size, verbose=True)


merge 1/256: (101, 32) -> 256 (b'e ') had 12 occurrences
merge 2/256: (115, 32) -> 257 (b's ') had 12 occurrences
merge 3/256: (116, 104) -> 258 (b'th') had 9 occurrences
merge 4/256: (121, 32) -> 259 (b'y ') had 9 occurrences
merge 5/256: (100, 32) -> 260 (b'd ') had 8 occurrences
merge 6/256: (101, 114) -> 261 (b'er') had 7 occurrences
merge 7/256: (116, 32) -> 262 (b't ') had 7 occurrences
merge 8/256: (97, 110) -> 263 (b'an') had 7 occurrences
merge 9/256: (105, 257) -> 264 (b'is ') had 6 occurrences
merge 10/256: (105, 110) -> 265 (b'in') had 5 occurrences
merge 11/256: (118, 261) -> 266 (b'ver') had 4 occurrences
merge 12/256: (97, 257) -> 267 (b'as ') had 4 occurrences
merge 13/256: (114, 101) -> 268 (b're') had 4 occurrences
merge 14/256: (263, 260) -> 269 (b'and ') had 4 occurrences
merge 15/256: (110, 32) -> 270 (b'n ') had 3 occurrences
merge 16/256: (258, 256) -> 271 (b'the ') had 3 occurrences
merge 17/256: (264, 104) -> 272 (b'is h') had 3 occurrences
merge 18/256: (112, 

## Encode

The `encode` function will convert a text into a sequence of tokens, based on the merges imposed by the tokenizer.

In [14]:
def encode(text):
    # given a string text, return the token ids
    text_bytes = text.encode("utf-8") # raw bytes
    ids = list(text_bytes) # list of integers in range 0..255
    while len(ids) >= 2:
        # find the pair with the lowest merge index
        counts = get_pair_counts(ids)   # note that here we don't really care about the counts, we just need the observed pairs
        pair = min(counts.keys(), key=lambda p: merges.get(p, float("inf")))
        # subtle: if there are no more merges available, the key will
        # result in an inf for every single pair, and the min will be
        # just the first pair in the list, arbitrarily
        # we can detect this terminating case by a membership check
        if pair not in merges:
            break # nothing else can be merged anymore
        # otherwise let's merge the best pair (lowest merge index)
        idx = merges[pair]
        ids = merge(ids, pair, idx)
    
    return ids

Try encoding some piece of text. Compare the number of tokens obtained with the length of the text (or, more precisely, its UTF-8 encoding). This can be seen as a compression ratio (the original version of the [BPE algorithm](https://en.wikipedia.org/wiki/Byte_pair_encoding) focused on compression).
You can also compare the number of tokens with the number of words (e.g., as provided by NLTK's word tokenizer).

In [None]:
# your code here


Print individually the textual entry in the vocabulary corresponding to each of the first 25 tokens.

In [None]:
# your code here


## Decode

The `decode` function will convert sequence of tokens back to text, based on the tokenizer's vocabulary.

In [None]:
def decode(ids):
    # given ids (list of integers), return Python string
    text_bytes = b"".join(vocab[idx] for idx in ids)
    text = text_bytes.decode("utf-8", errors="replace")
    return text

Decode back the text you used in the previous step. Check that you get back the original text.

In [None]:
# your code here
