## GPT Tokeniser Development

A token is a sequence of characters in a text that serves as a unit. Furthermore, tokenisation is the process of converting a text into a sequence of tokens. Tokenisation is critical to the correct functioning of transformers and bad tokenisation can cause issues with the models performance irrespective of the model architecture. If tokenisation is not done correctly, transformer models can struggle to spell words, struggle with non-English words, struggle with simple arithmetic, and even produces unintended outputs (see [SolidGoldMagikarp](https://www.lesswrong.com/posts/aPeJE8bSo6rAFoLqg/solidgoldmagikarp-plus-prompt-generation)).

Recall the GPT model in `gpt.py`:
```Python
class GPTLanguageModel(nn.Module):
    """GPT Decoder model. Consists of an embedding layer, transformer blocks, and a linear head."""

    def __init__(self):
        super().__init__()
        self.token_embed_table = nn.Embedding(vocab_size, n_embed) # (B,T) -> (B,T,C)
        # etc.
```
Tokens are the fundamental 'atoms' at the input of transformers. Each token (character) is used as an index to look up the corresponding row in the embedding table, where this row is a trainable vector (of size `n_embed`) representation of the token. Using characters as tokens is a naive approach due to the transformers having a limited context window (1024 tokens for GPT-2) in which tokens can attend to each other. Chunk vocabularies are used to tokenise text into character chunks instead of individual characters. These chunk vocabularies are constructed using the Byte Pair Encoding (BPE) algorithm (popularised in the [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)). Using character chunks as tokens allows the model to attend to a wider portion of the text, which can improve performance.

Note that the tokeniser is completely separate from the transformer model. It has a separate training dataset to train the vocabulary on the BPE algorithm. The tokeniser then encodes/decodes between text and sequences of tokens. The transformer model only sees the tokens and never directly deals with any text.

[Tiktokeniser](https://tiktokenizer.vercel.app/) provides a visualisation into differences between various tokenisers available for GPT models. Use 'gpt2' and 'cl100k_base' as the model names to compare the tokenisation of GPT-2 and GPT-4.

In [None]:
text = 'Hello 你好'

print([ord(x) for x in text]) # Encoding to unicode values
print(list(text.encode('utf-8'))) # Encoding to utf-8 bytes

# The utf-8 encoding is different from the unicode values for non-ASCII characters 
# as it uses a variable number of bytes. 20320 is encoded as 228 189 160 in utf-8.

### Helper Functions

In [None]:
def consecutive_pairs(ints: list[int]) -> dict[tuple[int, int], int]:
    """
    Generate a dictionary of the frequencies of consecutive integers in the list.
    Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    """
    freq = {}
    for pair in zip(ints, ints[1:]):
        freq[pair] = freq.get(pair, 0) + 1
    return freq

text = 'abcab'
tokens = list(text.encode('utf-8'))

freq_pairs = consecutive_pairs(tokens)
print(freq_pairs)

In [None]:
def replace_pair(ints: list[int], pair: tuple[int, int], new_int: int) -> list[int]:
    """
    Replace all consecutive occurrences of a pair of integers in the list with a new integer.
    Example: ints=[1, 2, 3, 1, 2], pair=(1, 2), new_int=4 -> [4, 3, 4]
    """
    new_ints = []
    i = 0
    while i < len(ints):
        # If not at the last position AND the pair matches, replace it
        if (i < len(ints) - 1) and ints[i:i+2] == list(pair):
            new_ints.append(new_int)
            i += 2
        else:
            new_ints.append(ints[i])
            i += 1
    return new_ints

# Replace the most frequent pair with a new token (256)
max_pair = max(freq_pairs, key=freq_pairs.get)
new_tokens = replace_pair(tokens, pair=max_pair, new_int=256)
print(f'{tokens} -> {new_tokens}')

### Training via Byte Pair Encoding (BPE)

In [None]:
# Load new text from a file.
with open('test.txt', 'r', encoding='utf-8') as file:
    text = file.read()

vocab_size = 265 # Desired vocabulary size

assert vocab_size >= 256
n_merges = vocab_size - 256
tokens = list(text.encode('utf-8'))
merges = {} # Dictionary to store the merges
vocab = {i: bytes([i]) for i in range(256)}

# Merge the most frequent pair n_merges times to create new tokens
for i in range(n_merges):
    # Find the most frequent consecutive pair of tokens
    freq_pairs = consecutive_pairs(tokens)
    max_pair = max(freq_pairs, key=freq_pairs.get)
    # Create a new token and assign it to an unused integer
    new_token = 256 + i
    tokens = replace_pair(tokens, max_pair, new_token)
    # Store the merge and the new token in the vocab
    merges[max_pair] = new_token
    vocab[new_token] = vocab[max_pair[0]] + vocab[max_pair[1]]
    print(f'{i+1}/{n_merges}: {max_pair} -> {new_token}')

print(vocab)


In [None]:
# The text is now represented by fewer tokens
print('New token length:', len(tokens))
print(f'Compression ratio: {len(list(text.encode("utf-8"))) / len(tokens):.2f}')

### Decoding and Encoding

In [None]:
def decode(tokens: list[int]) -> str:
    """Decode a sequence of tokens into a string."""
    bytes_ = b''.join(vocab[token] for token in tokens)
    text = bytes_.decode('utf-8', errors='replace') # Replace unknown characters
    return text

# Not every byte sequence is valid utf-8. Replacing unknown characters with '?'
# helps to avoid decoding errors as the language model may generate tokens that
# are not valid utf-8. For example 128 is not a valid utf-8 byte.
print(decode([128]))

In [None]:
def encode(text: str) -> list[int]:
    """Encode a string into a sequence of tokens."""
    tokens = list(text.encode('utf-8'))
    while len(tokens) > 1:
        freq_pairs = consecutive_pairs(tokens)
        # Find the most frequent consecutive pair that has been merged
        most_freq = min(freq_pairs, key=lambda pair: merges.get(pair, float('inf')))
        if most_freq not in merges:
            break # No more merges to apply
        # Merge the pair into a new token
        new_token = merges[most_freq]
        tokens = replace_pair(tokens, most_freq, new_token)
    return tokens

print(encode('the quick brown fox'))

In [None]:
# Check that the encode and decode functions are inverses
text == decode(encode(text))

### Tokenisation Notes

The tokeniser is represented using just the learned `merges` and `vocab` variables and can encode and decode text using the BPE algorithm.

**Splitting Text via RegEx Patterns (GPT-2)**

Instead of directly encoding each string for tokenisation, the string is split up into a list of strings using regular expressions. All the string in this list are processed independently by the tokeniser. Therefore merges can only happen within the same string. The results are then concatenated together to form the final tokenised string. This ensures that some consecutive pairs of characters are not merged together (i.e. 'e ').

The regular expression below is from the [GPT-2 tokeniser](https://github.com/openai/gpt-2/blob/master/src/encoder.py). The patterns `'s|'t|'re|'ve|'m|'ll|'d` match common contractions, however it only considers the ASCII apostrophe (') and not the unicode apostrophe (’). Furthermore, they do not ignore case and so will not match `'S|'T|'RE|'VE|'M|'LL|'D`. That is, `I'm` will be tokenised as `I`, `'m` and `I'M` will be tokenised as `I`, `'`, `M`. These are limitations of the GPT-2 tokeniser.

GPT-2 also uses one special token which denotes the end of text, `<|endoftext|>`.

In [None]:
import regex as re

regex = re.compile(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+")
print(re.findall(regex, 'I\'m. I\'M'))

**Vocabulary Size**

Large vocabulary sizes increase the number of tokens that the model can represent. Therefore, tokens can express more information in a shorter sequence. This allows transformers to attend to more tokens in the sequence and improves the model's ability to learn long-range dependencies. However, larger vocabulary sizes mean that the embedding table is larger and hence more computationally expensive to train. Furthermore, large vocabulary sizes mean that each unique token is less likely to be seen in the training data, hence the vector representation of the token may be under-trained, leading to worse overall performance.