In [1]:
from collections import defaultdict
from copy import deepcopy
from transformers import AutoTokenizer

# Simple BPE implementation

In [2]:
corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]

# Pre-tokenization

In [3]:
tokenizer_checkpoint = "gpt2"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)

## Compute word frequencies

In [5]:
word_freqs = defaultdict(lambda: 0)
for text in corpus:
    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    new_words = [w for w, _ in words_with_offsets]
    for word in new_words:
        word_freqs[word] += 1

In [6]:
print(word_freqs)

defaultdict(<function <lambda> at 0x7f172e7a9ea0>, {'This': 3, 'Ġis': 2, 'Ġthe': 1, 'ĠHugging': 1, 'ĠFace': 1, 'ĠCourse': 1, '.': 4, 'Ġchapter': 1, 'Ġabout': 1, 'Ġtokenization': 1, 'Ġsection': 1, 'Ġshows': 1, 'Ġseveral': 1, 'Ġtokenizer': 1, 'Ġalgorithms': 1, 'Hopefully': 1, ',': 1, 'Ġyou': 1, 'Ġwill': 1, 'Ġbe': 1, 'Ġable': 1, 'Ġto': 1, 'Ġunderstand': 1, 'Ġhow': 1, 'Ġthey': 1, 'Ġare': 1, 'Ġtrained': 1, 'Ġand': 1, 'Ġgenerate': 1, 'Ġtokens': 1})


## Compute base vocabulary

In [7]:
alphabet = set()
for word in word_freqs.keys():
    for letter in word:
        if letter not in alphabet:
            alphabet.add(letter)
alphabet = sorted(list(alphabet))
end_token = "<|endoftext|>"
vocab = alphabet.copy() + [end_token]

In [8]:
print(alphabet)

[',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z', 'Ġ']


In [9]:
print(vocab)

[',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z', 'Ġ', '<|endoftext|>']


## Split words into characters

In [10]:
splits = {word: [c for c in word] for word in word_freqs.keys()}

## Compute pair frequencies

In [11]:
def compute_pair_freqs(splits, word_freqs):
    pair_freqs = defaultdict(lambda: 0)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs

In [12]:
pair_freqs = compute_pair_freqs(splits, word_freqs)

In [13]:
for i, (k, v) in enumerate(pair_freqs.items()):
    print(f"{k}: {v}")
    if i >= 5:
        break

('T', 'h'): 3
('h', 'i'): 3
('i', 's'): 5
('Ġ', 'i'): 2
('Ġ', 't'): 7
('t', 'h'): 3


## Find max frequency pairs

In [14]:
def max_pair_freq(pair_freqs):
    if len(pair_freqs) == 0:
        return "", None
    best_pair, max_freq = max(pair_freqs.items(), key=lambda d: d[1])
    return best_pair, max_freq

In [15]:
print(*max_pair_freq(pair_freqs), sep="  ")

('Ġ', 't')  7


## Merge target pairs

In [16]:
def merge_pair(a, b, splits, word_freqs):
    for word in word_freqs.keys():
        split = splits[word]
        if len(split) == 1:
            continue

        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                split = split[:i] + [a + b] + split[i + 2:]
            else:
                i += 1
        splits[word] = split
    return splits

## Implement merges up until desired vocabulary length

In [17]:
vocab = alphabet.copy() + [end_token]
vocab_size = 50
merges = {}
splits_copy = deepcopy(splits)

In [18]:
while len(vocab) < vocab_size:
    pair_freqs = compute_pair_freqs(splits_copy, word_freqs)
    (best_a, best_b), max_freq = max_pair_freq(pair_freqs)
    splits_copy = merge_pair(best_a, best_b, splits_copy, word_freqs)
    new_merge = best_a + best_b
    merges[f"{best_a},{best_b}"] = new_merge
    vocab.append(new_merge)

In [19]:
merges

{'Ġ,t': 'Ġt',
 'i,s': 'is',
 'e,r': 'er',
 'Ġ,a': 'Ġa',
 'Ġt,o': 'Ġto',
 'e,n': 'en',
 'T,h': 'Th',
 'Th,is': 'This',
 'o,u': 'ou',
 's,e': 'se',
 'Ġto,k': 'Ġtok',
 'Ġtok,en': 'Ġtoken',
 'n,d': 'nd',
 'Ġ,is': 'Ġis',
 'Ġt,h': 'Ġth',
 'Ġth,e': 'Ġthe',
 'i,n': 'in',
 'Ġa,b': 'Ġab',
 'Ġtoken,i': 'Ġtokeni'}

In [20]:
print(vocab)

[',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z', 'Ġ', '<|endoftext|>', 'Ġt', 'is', 'er', 'Ġa', 'Ġto', 'en', 'Th', 'This', 'ou', 'se', 'Ġtok', 'Ġtoken', 'nd', 'Ġis', 'Ġth', 'Ġthe', 'in', 'Ġab', 'Ġtokeni']


## Tokenize all text

In [21]:
for k, v in merges.items():
    k1, k2 = k.split(',')
    print(k1, k2, v)

Ġ t Ġt
i s is
e r er
Ġ a Ġa
Ġt o Ġto
e n en
T h Th
Th is This
o u ou
s e se
Ġto k Ġtok
Ġtok en Ġtoken
n d nd
Ġ is Ġis
Ġt h Ġth
Ġth e Ġthe
i n in
Ġa b Ġab
Ġtoken i Ġtokeni


In [24]:
def tokenize(text):
    pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)
    pre_tokenized_text = [word for word, offset in pre_tokenize_result]
    splits_t = [[l for l in word] for word in pre_tokenized_text]
    for pair, merge in merges.items():
        pair_a, pair_b = pair.split(',')
        for idx, split in enumerate(splits_t):
            i = 0
            while i < len(split) - 1:
                if split[i] == pair_a and split[i + 1] == pair_b:
                    split = split[:i] + [merge] + split[i + 2:]
                else:
                    i += 1
                splits_t[idx] = split
    return sum(splits_t, [])

In [25]:
tokenize("This is not a token.")

['This', 'Ġis', 'Ġ', 'n', 'o', 't', 'Ġa', 'Ġtoken', '.']