<a target="_blank" href="https://colab.research.google.com/github/holmrenser/deep_learning/blob/main/tokenization.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Tokenization
Machine learning approaches for natural language processing face the problem of representing long stretches of natural language (i.e. words, sentences, paragraphs, chapters, etc.) in a meaningful and computationally efficient way. A trivial approach is to split text on interpunction and whitespace, effectively selecting individual words. The downside of this approach is that semantically similar words are encoded differently. For example, 'small', 'smaller', and 'smallest', would all be encoded as different entities, forcing a model to learn any semantic similarity from data alone. An alternative approach would encode 'small' as one entity, and 'er', and 'est' as separate entities. The benefit of this 'subword' approach is that it is more straightforward to model semantic similarity, the downside is that it is not straightforward to identify an optimal subword selection scheme. 

In this notebook we will explore the [byte pair encoding (BPE)](https://en.wikipedia.org/wiki/Byte_pair_encoding) algorithm for creating subword representations, a.k.a. tokens. Put simply, byte pair encoding iteratively merges the most frequent token pair into a new token, starting from the most simple tokens (e.g. letters), and continuing until the desired number of tokens is reached.

For computational reasons, letters are often represented as integers

In [10]:
# All dependencies for the whole notebook
from collections import Counter
from itertools import chain
from tqdm.auto import trange
import json
from dataclasses import dataclass, field
from typing import Generator
from tokenizers import (
    decoders,
    models,
    trainers,
    Tokenizer,
)

In [14]:
# Download the tiny shakespeare dataset
!wget -nc https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

File ‘input.txt’ already there; not retrieving.



In [24]:
# Load the tiny shakespeare data and show the first 100 characters
with open('input.txt', 'r') as fh:
    data = fh.read()

data[:30]

'First Citizen:\nBefore we proce'

## Naive 'tokens'
Probably the simplest tokenization strategy is to assign integers to individual characters in a given dataset. We'll implement this strategy below with a few lines of code. We create two dictionaries that function as lookup table: one encoding characters to integers, and one decoding integers back to characters. To apply our lookup tables we use a list comprehension.

In [26]:
# Calling 'set' on our data returns all individual characters, which are then lexicographically sorted
chars = sorted(set(data))
# Calling 'enumerate' returns the original iterator and increasing integers, which we'll use as tokens
char_to_token = {char:token for token,char in enumerate(chars)}
# Reverse the mapping to be able to decode
token_to_char = {token:char for char,token in char_to_token.items()}

# Encode the first 30 characters of the tiny shakespeare dataset
tokens = [char_to_token[c] for c in data[:30]]
print(tokens)

# Check that we retrieve our original text when decoding
chars = [token_to_char[t] for t in tokens]
print(chars)

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43]
['F', 'i', 'r', 's', 't', ' ', 'C', 'i', 't', 'i', 'z', 'e', 'n', ':', '\n', 'B', 'e', 'f', 'o', 'r', 'e', ' ', 'w', 'e', ' ', 'p', 'r', 'o', 'c', 'e']


Below we add a bit more functionality and structure to the idea presented above. This allows us to more efficiently use our mappings as tokenizer, pass the tokenizer around more easily, and aligns with code conventions used in many of-the-shelf tokenizer libraries.

In [62]:
class NaiveTokenizer:
    def __init__(self, encoding_dict: dict[str, int]=None):
        if encoding_dict is None:
            self.encoding_dict = dict()
        else:
            self.encoding_dict = encoding_dict

    def __repr__(self):
        return f'NaiveTokenizer(vocab_size={self.vocab_size})'

    @property
    def decoding_dict(self) -> dict[int, str]:
        return {token:char for char,token in self.encoding_dict.items()}

    @property
    def vocab_size(self) -> int:
        return len(self.encoding_dict)

    def train(self, data: str) -> None:
        chars = sorted(set(data))
        self.encoding_dict = {char:token for token,char in enumerate(chars)}

    def encode(self, data: str) -> list[int]:
        return [self.encoding_dict.get(char, -1) for char in data]

    def decode(self, tokens: list[int]) -> str:
        return ''.join(self.decoding_dict.get(token, '<unk>') for token in tokens)

# Initialize an empty NaiveTokenizer
tokenizer = NaiveTokenizer()
print(f'Untrained tokenizer: {tokenizer}')

# 'Train' on tiny shakespeare
tokenizer.train(data)
print(f'Trained tokenizer: {tokenizer}')

# Encode a string (that is not in the training data)
tokens = tokenizer.encode('Hi how are you')
print(tokens)

# Decode the encoding
print(tokenizer.decode(tokens))

Untrained tokenizer: NaiveTokenizer(vocab_size=0)
Trained tokenizer: NaiveTokenizer(vocab_size=65)
[20, 47, 1, 46, 53, 61, 1, 39, 56, 43, 1, 63, 53, 59]
Hi how are you


### Exercise 1
Can you come up with a string that cannot be effectively encoded by our naive tokenizer? What happens to this string? How would you circumvent this issue?

## Byte Pair Encoding (BPE)

In [5]:
text_bytes = data.encode('utf-8')
tokens = list(text_bytes)
tokens[:20]

[70,
 105,
 114,
 115,
 116,
 32,
 67,
 105,
 116,
 105,
 122,
 101,
 110,
 58,
 10,
 66,
 101,
 102,
 111,
 114]

In [7]:
[ord(c) for c in data[:20]]

[70,
 105,
 114,
 115,
 116,
 32,
 67,
 105,
 116,
 105,
 122,
 101,
 110,
 58,
 10,
 66,
 101,
 102,
 111,
 114]

In [44]:
def merge_tokens(tokens: list[int], token_pair: tuple[int,int], new_token: int) -> list[int]:
    """Takes a list of tokens and replaces every occurence of token_pair with new_token"""
    new_tokens = []
    i = 0
    # Iterate in a while loop because we want to jump ahead two steps sometimes
    while i < len(tokens):
        token = tokens[i]
        # Edge case: final individual token
        if i == len(tokens) - 1:
            new_tokens.append(token)
            break
        # Look ahead one token to find a token pair
        next_token = tokens[i+1]
        # On match we should jump ahead two tokens to skip the original pair
        if token_pair == (token, next_token):
            new_tokens.append(new_token)
            i += 2
        else:
            new_tokens.append(token)
            i += 1
    return new_tokens    

class BytePairEncoder:
    def __init__(self, merges: dict[int, tuple[int, int]]=None):
        if merges is None:
            self.merges = dict()
        else:
            self.merges = merges

    def __repr__(self) -> str:
        vocab_size = self.vocab_size
        n_merges = len(self.merges)
        return f'BytePairEncoder({vocab_size=} {n_merges=})'
    
    @property
    def vocab_size(self) -> int:
        return len(self.get_vocab())

    def get_vocab(self) -> dict[int, str]:
        base_vocab = {chr(token): token for token in range(256)}        
        merge_vocab = {self.decode([token]): token for token in self.merges}
        return base_vocab | merge_vocab

    def _token_to_bytes(self, token: int) -> Generator[int, None, None]:
        if token not in self.merges:
            yield token
            return
        for pair_token in self.merges[token]:
            if pair_token >= 256:
                yield from self._token_to_bytes(pair_token)
            else:
                yield pair_token

    def train(self, input: str, vocab_size: int = 512) -> None:
        assert vocab_size > 256, f'Invalid vocab_size: {vocab_size}, must be larger than 256'
        tokens = list(input.encode('utf-8'))
        num_merges = vocab_size - 256
        for i in trange(num_merges):
            pair_counts = Counter(zip(tokens[:-1], tokens[1:]))
            merge_pair = pair_counts.most_common(1)[0][0]
            new_token = 256 + i
            self.merges[new_token] = merge_pair
            tokens = merge_tokens(tokens, merge_pair, new_token)

    def encode(self, input: str) -> list[int]:
        tokens = list(input.encode('utf-8'))
        for new_token, merge_pair in self.merges.items():
            tokens = merge_tokens(tokens, merge_pair, new_token)
        return tokens

    def decode(self, tokens: list[int]) -> str:
        decoded_tokens = chain.from_iterable(map(self._token_to_bytes, tokens))
        return bytes(decoded_tokens).decode('utf-8', errors='replace')

    def save(self, prefix: str) -> None:
        with open(f'{prefix}.vocab', 'w') as fh:
            json.dump(self.vocab, fh)
        with open(f'{prefix}.model', 'w') as fh:
            json.dump(self.merges, fh)

    @classmethod
    def load(cls, model_filename: str) -> 'BytePairEncoder':
        with open(model_filename, 'r') as fh:
            merges = json.load(fh)
        sanitized_merges = {int(k):tuple(v) for k,v in merges.items()}
        return cls(sanitized_merges)

bpe = BytePairEncoder()
bpe

BytePairEncoder(vocab_size=256 n_merges=0)

In [158]:
bpe.train(data, vocab_size=512)
bpe.save('shakespeare_512')

  0%|          | 0/256 [00:00<?, ?it/s]

In [46]:
BytePairEncoder.load('./shakespeare_260.model')

BytePairEncoder(vocab_size=260 n_merges=4)

In [47]:

tokenizer = Tokenizer(models.BPE(byte_fallback=True))
trainer = trainers.BpeTrainer(
    initial_alphabet=[chr(i) for i in range(256)],
    vocab_size=512
)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.train(["input.txt"], trainer=trainer)
tokenizer






<tokenizers.Tokenizer at 0x105d2e4e0>

In [293]:
[tokenizer.decode([i]) for i in tokenizer.encode("Hi how are you 1234").ids]

['H', 'i', ' h', 'ow', ' ', 'are ', 'you', ' ', '1', '2', '3', '4']

In [287]:
tokenizer.decode(tokenizer.encode("Hi how are you 1234").ids)

'Hi how are you 1234'

In [288]:
[bpe.decode([i]) for i in bpe.encode('Hi how are you 1234')]

['H', 'i', ' h', 'ow', ' ', 'are ', 'you', ' ', '1', '2', '3', '4']

In [289]:
bpe.decode(bpe.encode('Hi how are you 1234'))

'Hi how are you 1234'

In [290]:
tokenizer.encode('Hi how are you 1234').ids

[72, 105, 289, 284, 32, 420, 280, 32, 49, 50, 51, 52]

In [291]:
bpe.encode('Hi how are you 1234')

[72, 105, 290, 285, 32, 420, 281, 32, 49, 50, 51, 52]

In [294]:
tokenizer.get_vocab()

{'x': 120,
 'ow': 284,
 'W': 87,
 '[': 91,
 's\n': 469,
 '÷': 247,
 '¿': 191,
 'Â': 194,
 'sir': 497,
 'å': 229,
 'st': 295,
 '\x8b': 139,
 'lea': 508,
 'O:\n': 337,
 'res': 448,
 '-': 45,
 'À': 192,
 'ø': 248,
 '\x91': 145,
 'Ë': 203,
 'ck': 380,
 's, ': 376,
 '©': 169,
 'lo': 379,
 'oun': 442,
 '·': 183,
 'ê': 234,
 'om': 291,
 'im': 314,
 'up': 421,
 'us': 440,
 'z': 122,
 'as ': 352,
 '3': 51,
 '¥': 165,
 'ee': 326,
 '4': 52,
 'T': 84,
 ' and ': 434,
 '\x88': 136,
 'Ã': 195,
 't': 116,
 '#': 35,
 '\x18': 24,
 're': 308,
 'ut ': 335,
 'an': 266,
 'is ': 278,
 'him': 423,
 '\x87': 135,
 'qu': 458,
 'L': 76,
 '\x83': 131,
 'for': 299,
 ' s': 281,
 ':': 58,
 'ù': 249,
 'Á': 193,
 '\n': 10,
 'v': 118,
 'mor': 460,
 ' w': 287,
 'er': 263,
 'ill': 386,
 '\x99': 153,
 '; ': 374,
 ' the': 465,
 'y ': 265,
 '\x9a': 154,
 'et ': 447,
 'ou': 262,
 'ING': 433,
 'un': 346,
 '\x8d': 141,
 '\x80': 128,
 '\xa0': 160,
 'µ': 181,
 'Ö': 214,
 'Th': 306,
 'hat ': 333,
 '^': 94,
 'Ç': 199,
 'at ': 310,
