<a target="_blank" href="https://colab.research.google.com/github/holmrenser/deep_learning/blob/main/tokenization.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Tokenization
Machine learning approaches for natural language processing face the problem of representing long stretches of natural language (i.e. words, sentences, paragraphs, chapters, etc.) in a meaningful and computationally efficient way. A trivial approach is to split text on interpunction and whitespace, effectively selecting individual words. The downside of this approach is that semantically similar words are encoded differently. For example, 'small', 'smaller', and 'smallest', would all be encoded as different entities, forcing a model to learn any semantic similarity from data alone. An alternative approach would encode 'small' as one entity, and 'er', and 'est' as separate entities. The benefit of this 'subword' approach is that it is more straightforward to model semantic similarity, the downside is that it is not straightforward to identify an optimal subword selection scheme. 

In this notebook we will explore the [byte pair encoding (BPE)](https://en.wikipedia.org/wiki/Byte_pair_encoding) algorithm for creating subword representations, a.k.a. tokens. Put simply, byte pair encoding iteratively merges the most frequent token pair into a new token, starting from the most simple tokens (e.g. letters), and continuing until the desired number of tokens is reached.

For various computational reasons, letters are often represented as integers.

In [95]:
# All dependencies for the whole notebook
from collections import Counter
from itertools import chain
from tqdm.auto import trange
import json
from dataclasses import dataclass, field
from typing import Generator
from tokenizers import (
    decoders,
    models,
    trainers,
    Tokenizer,
)

In [96]:
# Download the tiny shakespeare dataset
!wget -nc https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

File ‘input.txt’ already there; not retrieving.



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [97]:
# Load the tiny shakespeare data and show the first 100 characters
with open('input.txt', 'r') as fh:
    data = fh.read()

data[:30]

'First Citizen:\nBefore we proce'

## Naive 'tokens'
Probably the simplest tokenization strategy is to assign integers to individual characters in a given dataset. We'll implement this strategy below with a few lines of code. We create two dictionaries that function as lookup table: one encoding characters to integers, and one decoding integers back to characters. To apply our lookup tables we use a list comprehension.

In [98]:
# Calling 'set' on our data returns all individual characters, which are then lexicographically sorted
chars = sorted(set(data))
# Calling 'enumerate' returns the original iterator and increasing integers, which we'll use as tokens
char_to_token = {char:token for token,char in enumerate(chars)}
# Reverse the mapping to be able to decode
token_to_char = {token:char for char,token in char_to_token.items()}

some_text = data[:30]
print(f'{some_text = }')

# Encode the first 30 characters of the tiny shakespeare dataset
tokens = [char_to_token[c] for c in some_text]
print(f'{tokens = }')

# Check that we retrieve our original text when decoding
chars = [token_to_char[t] for t in tokens]
print(f'{chars = }')

some_text = 'First Citizen:\nBefore we proce'
tokens = [18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43]
chars = ['F', 'i', 'r', 's', 't', ' ', 'C', 'i', 't', 'i', 'z', 'e', 'n', ':', '\n', 'B', 'e', 'f', 'o', 'r', 'e', ' ', 'w', 'e', ' ', 'p', 'r', 'o', 'c', 'e']


Below we add a bit more functionality and structure to the idea presented above. This allows us to more efficiently use our mappings as tokenizer, pass the tokenizer around more easily, and aligns with code conventions used in many of-the-shelf tokenizer libraries.

In [99]:
class NaiveTokenizer:
    """Character level tokenizer that enumerates unique characters in a training text"""
    def __init__(self, encoding_dict: dict[str, int]=None):
        if encoding_dict is None:
            self.encoding_dict = dict()
        else:
            self.encoding_dict = encoding_dict

    def __repr__(self):
        return f'NaiveTokenizer(vocab_size={self.vocab_size})'

    @property
    def decoding_dict(self) -> dict[int, str]:
        """Decoding dict is implemented as property to automatically sync with changed encoding dict"""
        return {token:char for char,token in self.encoding_dict.items()}

    @property
    def vocab_size(self) -> int:
        return len(self.encoding_dict)

    def train(self, data: str) -> None:
        """Train on a piece of text by enumerating unique characters"""
        chars = sorted(set(data))
        self.encoding_dict = {char:token for token,char in enumerate(chars)}

    def encode(self, data: str) -> list[int]:
        """Convert text to tokens"""
        return [self.encoding_dict.get(char, -1) for char in data]

    def decode(self, tokens: list[int]) -> str:
        """Convert tokens to text"""
        return ''.join(self.decoding_dict.get(token, '<unk>') for token in tokens)

# Initialize an empty NaiveTokenizer
tokenizer = NaiveTokenizer()
print(f'Untrained tokenizer: {tokenizer}')

# 'Train' on tiny shakespeare
tokenizer.train(data)
print(f'Trained tokenizer: {tokenizer}')

# Encode a string that is not in the training data
tokens = tokenizer.encode('Hi how are you')
print(f'{tokens = }')

# Decode the encoding
print(tokenizer.decode(tokens))

Untrained tokenizer: NaiveTokenizer(vocab_size=0)
Trained tokenizer: NaiveTokenizer(vocab_size=65)
tokens = [20, 47, 1, 46, 53, 61, 1, 39, 56, 43, 1, 63, 53, 59]
Hi how are you


### Exercise 1
Investigate the vocabulary of the trained tokenizer by printing `tokenizer.encoding_dict`. Can you come up with a string that cannot be effectively encoded by our naive tokenizer? What happens to this string? How would you circumvent this issue?

## Byte Pair Encoding (BPE)

#### Converting to- and from bytes
Apart from not being able to process unseen characters, tokenizing by enumerating characters in a training text has another issue: different training datasets can assign different tokens to the same character. The most common solution to these problems is to encode characters using [Unicode](https://en.wikipedia.org/wiki/Unicode) codepoints, specifically [UTF-8](https://en.wikipedia.org/wiki/UTF-8). Without going into too much detail, UTF-8 uses up to 4 bytes to encode individual characters, where every codepoint (i.e. byte or byte sequence) can be interpreted as an integer. The byte pair encoding algorithm iteratively merges these unicode bytes.

In python, converting individual characters to unicode codepoints and back can be done with the built in `ord` and `chr` functions respectively.

In [100]:
# Single character to unicode
ord('H')

72

In [101]:
# Single unicode codepoint to character
chr(64)

'@'

In [102]:
# Converting a string to unicode and back
some_text = 'Deep learning is awesome'

unicode_codepoints = [ord(letter) for letter in some_text]
print(f'{unicode_codepoints = }')

characters = [chr(codepoint) for codepoint in unicode_codepoints]
print(f'{characters = }')

unicode_codepoints = [68, 101, 101, 112, 32, 108, 101, 97, 114, 110, 105, 110, 103, 32, 105, 115, 32, 97, 119, 101, 115, 111, 109, 101]
characters = ['D', 'e', 'e', 'p', ' ', 'l', 'e', 'a', 'r', 'n', 'i', 'n', 'g', ' ', 'i', 's', ' ', 'a', 'w', 'e', 's', 'o', 'm', 'e']


The same principles outlined above can be applied to multi-character strings with a slightly different syntax:

In [103]:
some_text = 'Deep learning is awesome'

# Using the 'encode' method on a string converts to bytes, note the leading 'b' when printing
text_bytes = some_text.encode('utf-8')
print(f'{text_bytes = }')

# The list constructor iterates over the bytes, automatically converting to integers
tokens = list(text_bytes)
print(f'{tokens = }')

# Turning a list of integers into bytes and subsequently 'decoding' into text
reconstructed_text = bytes(tokens).decode('utf-8')
print(f'{reconstructed_text = }')

text_bytes = b'Deep learning is awesome'
tokens = [68, 101, 101, 112, 32, 108, 101, 97, 114, 110, 105, 110, 103, 32, 105, 115, 32, 97, 119, 101, 115, 111, 109, 101]
reconstructed_text = 'Deep learning is awesome'


#### Counting pairs
Byte pair encoding iteratively merges the most frequent pairs of bytes. We use some built in python functionality to count pairs in an iterator (the example uses characters, BPE  uses bytes).

In [104]:
example_sequence = 'A text with some repetition somesome reprepreprepetition'

# zip together original and shifted sequence
pairs = zip(example_sequence[:-1], example_sequence[1:])

# count using inbuilt counter from collections module (part of python standard lib)
pair_counts = Counter(pairs)

# select most common pair
most_common_pair = pair_counts.most_common(1)
print(f'{most_common_pair = }')

most_common_pair = [(('r', 'e'), 5)]


#### BPE implementation
Below we implement a function to merge token pairs and some functionality to train the tokenizer, encode strings, decode tokens, and save and load trained models.

In [153]:
def merge_tokens(tokens: list[int], token_pair: tuple[int,int], new_token: int) -> list[int]:
    """Takes a list of tokens and replaces every occurence of token_pair with new_token"""
    new_tokens = []
    i = 0
    # Iterate in a while loop because we want to jump ahead two steps sometimes
    while i < len(tokens):
        token = tokens[i]
        # Edge case: final individual token
        if i == len(tokens) - 1:
            new_tokens.append(token)
            break
        # Look ahead one token to find a token pair
        next_token = tokens[i+1]
        # On match we should jump ahead two tokens to skip the original pair
        if token_pair == (token, next_token):
            new_tokens.append(new_token)
            i += 2
        else:
            new_tokens.append(token)
            i += 1
    return new_tokens    

class BytePairEncoder:
    """Bytepair encoder with a base vocabulary of the first 256 utf-8 codepoints (this captures all 'normal' alphanumeric characters)"""
    def __init__(self, merges: dict[int, tuple[int, int]]=None):
        if merges is None:
            self.merges = dict()
        else:
            self.merges = merges

    def __repr__(self) -> str:
        vocab_size = self.vocab_size
        n_merges = len(self.merges)
        return f'BytePairEncoder({vocab_size=} {n_merges=})'
    
    @property
    def vocab_size(self) -> int:
        return len(self.get_vocab())

    def get_vocab(self) -> dict[int, str]:
        # Base vocabulary of first 256 utf-8 characters
        base_vocab = {chr(token): token for token in range(128)}
        # Additional vocabulary is determined by trained merges
        merge_vocab = {self.decode([token]): token for token in self.merges}
        # Total vocabulary is the union of base and merge vocabs
        vocab = base_vocab | merge_vocab
        return vocab

    def _get_parent_tokens(self, token: int) -> Generator[int, None, None]:
        """Recursively identify whether a token is made up of parent tokens"""
        if token not in self.merges:
            yield token
            return
        for pair_token in self.merges[token]:
            yield from self._get_parent_tokens(pair_token)

    def train(self, input: str, vocab_size: int = 512) -> None:
        """Training proceeds by iteratively merging the most frequent token pair until the desired number of tokens is reached"""
        assert vocab_size > 128, f'Invalid vocab_size: {vocab_size}, must be larger than 256'
        tokens = list(input.encode('utf-8'))
        num_merges = vocab_size - 128
        for i in trange(num_merges):
            pair_counts = Counter(zip(tokens[:-1], tokens[1:]))
            merge_pair = pair_counts.most_common(1)[0][0]
            new_token = 128 + i
            self.merges[new_token] = merge_pair
            tokens = merge_tokens(tokens, merge_pair, new_token)

    def encode(self, input: str) -> list[int]:
        """Convert text to tokens"""
        tokens = list(input.encode('utf-8'))
        for new_token, merge_pair in self.merges.items():
            tokens = merge_tokens(tokens, merge_pair, new_token)
        return tokens

    def decode(self, tokens: list[int]) -> str:
        """Convert tokens to text"""
        decoded_tokens = chain.from_iterable(map(self._get_parent_tokens, tokens))
        return bytes(decoded_tokens).decode('utf-8', errors='replace')

    def save(self, prefix: str) -> None:
        """Save a trained model"""
        with open(f'{prefix}.vocab', 'w') as fh:
            json.dump(self.get_vocab(), fh)
        with open(f'{prefix}.model', 'w') as fh:
            json.dump(self.merges, fh)

    @classmethod
    def load(cls, model_filename: str) -> 'BytePairEncoder':
        """Load a pretrained model from a .model file"""
        assert model_filename.endswith('.model'), f'{model_filename} is not a valid model file, must end with .model'
        with open(model_filename, 'r') as fh:
            merges = json.load(fh)
        # The json fileformat does not accept integers as dict keys, and does not have tuples
        sanitized_merges = {int(k):tuple(v) for k,v in merges.items()}
        return cls(sanitized_merges)

bpe = BytePairEncoder()
bpe.get_vocab()

{'\x00': 0,
 '\x01': 1,
 '\x02': 2,
 '\x03': 3,
 '\x04': 4,
 '\x05': 5,
 '\x06': 6,
 '\x07': 7,
 '\x08': 8,
 '\t': 9,
 '\n': 10,
 '\x0b': 11,
 '\x0c': 12,
 '\r': 13,
 '\x0e': 14,
 '\x0f': 15,
 '\x10': 16,
 '\x11': 17,
 '\x12': 18,
 '\x13': 19,
 '\x14': 20,
 '\x15': 21,
 '\x16': 22,
 '\x17': 23,
 '\x18': 24,
 '\x19': 25,
 '\x1a': 26,
 '\x1b': 27,
 '\x1c': 28,
 '\x1d': 29,
 '\x1e': 30,
 '\x1f': 31,
 ' ': 32,
 '!': 33,
 '"': 34,
 '#': 35,
 '$': 36,
 '%': 37,
 '&': 38,
 "'": 39,
 '(': 40,
 ')': 41,
 '*': 42,
 '+': 43,
 ',': 44,
 '-': 45,
 '.': 46,
 '/': 47,
 '0': 48,
 '1': 49,
 '2': 50,
 '3': 51,
 '4': 52,
 '5': 53,
 '6': 54,
 '7': 55,
 '8': 56,
 '9': 57,
 ':': 58,
 ';': 59,
 '<': 60,
 '=': 61,
 '>': 62,
 '?': 63,
 '@': 64,
 'A': 65,
 'B': 66,
 'C': 67,
 'D': 68,
 'E': 69,
 'F': 70,
 'G': 71,
 'H': 72,
 'I': 73,
 'J': 74,
 'K': 75,
 'L': 76,
 'M': 77,
 'N': 78,
 'O': 79,
 'P': 80,
 'Q': 81,
 'R': 82,
 'S': 83,
 'T': 84,
 'U': 85,
 'V': 86,
 'W': 87,
 'X': 88,
 'Y': 89,
 'Z': 90,
 '[': 91,


In [154]:
bpe.train(data, vocab_size=256)
bpe.save('shakespeare_256')
bpe = BytePairEncoder.load('./shakespeare_256.model')
bpe.get_vocab()

  0%|          | 0/128 [00:00<?, ?it/s]

### Exercise 2
Train a byte pair encoder on the tiny shakespeare dataset with a vocab_size of 256 and inspect the vocabulary. Can you identify tokens that encode some semantically meaningful identity?

## Optimized BPE using the transformers library
As you can imagine, our python implementation is not optimized to be fast. Several optimized tokenizers are commonly used, most of which have python bindings for ease of use. Below we will reproduce the configuration of our python BPE tokenizer using the transformers tokenizers library. This allows us to train larger vocabularies is a shorter amount of time.

In [160]:
tokenizer = Tokenizer(models.BPE(byte_fallback=True))
trainer = trainers.BpeTrainer(
    initial_alphabet=[chr(i) for i in range(128)],
    vocab_size=1024
)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.train(["input.txt"], trainer=trainer)






In [144]:
[tokenizer.decode([i]) for i in tokenizer.encode("Hi how are you 1234").ids]

['H', 'i', ' h', 'ow', ' ', 'ar', 'e ', 'you', ' ', '1', '2', '3', '4']

In [145]:
tokenizer.decode(tokenizer.encode("Hi how are you 1234").ids)

'Hi how are you 1234'

In [146]:
[bpe.decode([i]) for i in bpe.encode('Hi how are you 1234')]

['H', 'i', ' h', 'ow', ' ', 'ar', 'e ', 'you', ' ', '1', '2', '3', '4']

In [147]:
bpe.decode(bpe.encode('Hi how are you 1234'))

'Hi how are you 1234'

In [148]:
tokenizer.encode('Hi how are you 1234').ids

[72, 105, 161, 156, 32, 143, 128, 152, 32, 49, 50, 51, 52]

In [149]:
bpe.encode('Hi how are you 1234')

[72, 105, 162, 157, 32, 144, 128, 153, 32, 49, 50, 51, 52]

In [159]:
sorted(tokenizer.get_vocab().items(), key=lambda x: x[1])

[('\x00', 0),
 ('\x01', 1),
 ('\x02', 2),
 ('\x03', 3),
 ('\x04', 4),
 ('\x05', 5),
 ('\x06', 6),
 ('\x07', 7),
 ('\x08', 8),
 ('\t', 9),
 ('\n', 10),
 ('\x0b', 11),
 ('\x0c', 12),
 ('\r', 13),
 ('\x0e', 14),
 ('\x0f', 15),
 ('\x10', 16),
 ('\x11', 17),
 ('\x12', 18),
 ('\x13', 19),
 ('\x14', 20),
 ('\x15', 21),
 ('\x16', 22),
 ('\x17', 23),
 ('\x18', 24),
 ('\x19', 25),
 ('\x1a', 26),
 ('\x1b', 27),
 ('\x1c', 28),
 ('\x1d', 29),
 ('\x1e', 30),
 ('\x1f', 31),
 (' ', 32),
 ('!', 33),
 ('"', 34),
 ('#', 35),
 ('$', 36),
 ('%', 37),
 ('&', 38),
 ("'", 39),
 ('(', 40),
 (')', 41),
 ('*', 42),
 ('+', 43),
 (',', 44),
 ('-', 45),
 ('.', 46),
 ('/', 47),
 ('0', 48),
 ('1', 49),
 ('2', 50),
 ('3', 51),
 ('4', 52),
 ('5', 53),
 ('6', 54),
 ('7', 55),
 ('8', 56),
 ('9', 57),
 (':', 58),
 (';', 59),
 ('<', 60),
 ('=', 61),
 ('>', 62),
 ('?', 63),
 ('@', 64),
 ('A', 65),
 ('B', 66),
 ('C', 67),
 ('D', 68),
 ('E', 69),
 ('F', 70),
 ('G', 71),
 ('H', 72),
 ('I', 73),
 ('J', 74),
 ('K', 75),
 ('L', 7

In [151]:
bpe.get_vocab()

{'\x00': 0,
 '\x01': 1,
 '\x02': 2,
 '\x03': 3,
 '\x04': 4,
 '\x05': 5,
 '\x06': 6,
 '\x07': 7,
 '\x08': 8,
 '\t': 9,
 '\n': 10,
 '\x0b': 11,
 '\x0c': 12,
 '\r': 13,
 '\x0e': 14,
 '\x0f': 15,
 '\x10': 16,
 '\x11': 17,
 '\x12': 18,
 '\x13': 19,
 '\x14': 20,
 '\x15': 21,
 '\x16': 22,
 '\x17': 23,
 '\x18': 24,
 '\x19': 25,
 '\x1a': 26,
 '\x1b': 27,
 '\x1c': 28,
 '\x1d': 29,
 '\x1e': 30,
 '\x1f': 31,
 ' ': 32,
 '!': 33,
 '"': 34,
 '#': 35,
 '$': 36,
 '%': 37,
 '&': 38,
 "'": 39,
 '(': 40,
 ')': 41,
 '*': 42,
 '+': 43,
 ',': 44,
 '-': 45,
 '.': 46,
 '/': 47,
 '0': 48,
 '1': 49,
 '2': 50,
 '3': 51,
 '4': 52,
 '5': 53,
 '6': 54,
 '7': 55,
 '8': 56,
 '9': 57,
 ':': 58,
 ';': 59,
 '<': 60,
 '=': 61,
 '>': 62,
 '?': 63,
 '@': 64,
 'A': 65,
 'B': 66,
 'C': 67,
 'D': 68,
 'E': 69,
 'F': 70,
 'G': 71,
 'H': 72,
 'I': 73,
 'J': 74,
 'K': 75,
 'L': 76,
 'M': 77,
 'N': 78,
 'O': 79,
 'P': 80,
 'Q': 81,
 'R': 82,
 'S': 83,
 'T': 84,
 'U': 85,
 'V': 86,
 'W': 87,
 'X': 88,
 'Y': 89,
 'Z': 90,
 '[': 91,


In [132]:
tokenizer.pre_tokenizer