In [1]:
from collections import defaultdict
from copy import deepcopy
from transformers import AutoTokenizer

In [2]:
corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]

In [3]:
model_checkpoint = "bert-base-cased"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# Pre-tokenization

## Compute word frequencies

In [5]:
word_freqs = defaultdict(lambda: 0)
for text in corpus:
    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    new_words = [w for w, _ in words_with_offsets]
    for word in new_words:
        word_freqs[word] += 1

In [6]:
word_freqs

defaultdict(<function __main__.<lambda>()>,
            {'This': 3,
             'is': 2,
             'the': 1,
             'Hugging': 1,
             'Face': 1,
             'Course': 1,
             '.': 4,
             'chapter': 1,
             'about': 1,
             'tokenization': 1,
             'section': 1,
             'shows': 1,
             'several': 1,
             'tokenizer': 1,
             'algorithms': 1,
             'Hopefully': 1,
             ',': 1,
             'you': 1,
             'will': 1,
             'be': 1,
             'able': 1,
             'to': 1,
             'understand': 1,
             'how': 1,
             'they': 1,
             'are': 1,
             'trained': 1,
             'and': 1,
             'generate': 1,
             'tokens': 1})

## Compute alphabet and vocabulary

In [7]:
alphabet = set()
for word in word_freqs.keys():
    if word[0] not in alphabet:
        alphabet.add(word[0])
    for letter in word[1:]:
        if f"##{letter}" not in alphabet:
            alphabet.add(f"##{letter}")
alphabet = sorted(list(alphabet))

# Include special tokens
vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + alphabet.copy()

In [8]:
print(alphabet)

['##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y']


In [9]:
print(vocab)

['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y']


## Split words

In [10]:
splits = {
    word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
    for word in word_freqs.keys()
}

In [11]:
print(*splits.items(), sep='\n')

('This', ['T', '##h', '##i', '##s'])
('is', ['i', '##s'])
('the', ['t', '##h', '##e'])
('Hugging', ['H', '##u', '##g', '##g', '##i', '##n', '##g'])
('Face', ['F', '##a', '##c', '##e'])
('Course', ['C', '##o', '##u', '##r', '##s', '##e'])
('.', ['.'])
('chapter', ['c', '##h', '##a', '##p', '##t', '##e', '##r'])
('about', ['a', '##b', '##o', '##u', '##t'])
('tokenization', ['t', '##o', '##k', '##e', '##n', '##i', '##z', '##a', '##t', '##i', '##o', '##n'])
('section', ['s', '##e', '##c', '##t', '##i', '##o', '##n'])
('shows', ['s', '##h', '##o', '##w', '##s'])
('several', ['s', '##e', '##v', '##e', '##r', '##a', '##l'])
('tokenizer', ['t', '##o', '##k', '##e', '##n', '##i', '##z', '##e', '##r'])
('algorithms', ['a', '##l', '##g', '##o', '##r', '##i', '##t', '##h', '##m', '##s'])
('Hopefully', ['H', '##o', '##p', '##e', '##f', '##u', '##l', '##l', '##y'])
(',', [','])
('you', ['y', '##o', '##u'])
('will', ['w', '##i', '##l', '##l'])
('be', ['b', '##e'])
('able', ['a', '##b', '##l', '##e'])

## Compute pair scores

In [12]:
def compute_pair_scores(splits, word_freqs):
    letter_freqs = defaultdict(lambda: 0)
    pair_freqs = defaultdict(lambda: 0)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            letter_freqs[split[0]] += freq
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            letter_freqs[split[i]] += freq
            pair_freqs[pair] += freq
        letter_freqs[split[-1]] += freq

    scores = {
        (pair_a, pair_b): freq / (letter_freqs[pair_a] * letter_freqs[pair_b])
        for (pair_a, pair_b), freq in pair_freqs.items()
    }
    return scores

In [13]:
pair_scores = compute_pair_scores(splits, word_freqs)

In [14]:
for i, (k, v) in enumerate(pair_scores.items()):
    print(k, v)
    if i > 4:
        break

('T', '##h') 0.125
('##h', '##i') 0.03409090909090909
('##i', '##s') 0.02727272727272727
('i', '##s') 0.1
('t', '##h') 0.03571428571428571
('##h', '##e') 0.011904761904761904


## Find highest score pair

In [15]:
def max_pair_score(pair_scores):
    if len(pair_scores) == 0:
        return "", None
    best_pair, max_score = max(pair_scores.items(), key=lambda d: d[1])
    return best_pair, max_score

In [16]:
max_pair_score(pair_scores)

(('a', '##b'), 0.2)

## Merge pairs

In [17]:
def merge_pair(a, b, splits, word_freqs):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                merge = a + b[2:] if b.startswith("##") else a + b
                split = split[:i] + [merge] + split[i + 2:]
            else:
                i += 1
            splits[word] = split

In [18]:
splits_copy = deepcopy(splits)

In [19]:
merge_pair("a", "##b", splits_copy, word_freqs)

In [20]:
splits_copy["about"]

['ab', '##o', '##u', '##t']

## Implement merges up until desired vocabulary length

In [22]:
vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + alphabet.copy()
vocab_size = 70
merges = {}
splits_copy = deepcopy(splits)

In [24]:
while len(vocab) < vocab_size:
    scores = compute_pair_scores(splits_copy, word_freqs)
    (pair_a, pair_b), max_score = max_pair_score(scores)
    merge_pair(pair_a, pair_b, splits_copy, word_freqs)
    new_token = (
        pair_a + pair_b[2:] if pair_b.startswith("##")
        else pair_a + pair_b
    )
    vocab.append(new_token)

In [25]:
print(vocab)

['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y', 'ab', '##fu', 'Fa', 'Fac', '##ct', '##ful', '##full', '##fully', 'Th', 'ch', '##hm', 'cha', 'chap', 'chapt', '##thm', 'Hu', 'Hug', 'Hugg', 'sh', 'th', 'is', '##thms', '##za', '##zat', '##ut']


# Tokenization

In [31]:
def encode_word(word, vocab):
    tokens = []
    while len(word) > 0:
        l = len(word)
        while l > 0 and word[:l] not in vocab:
            l -= 1
        if l == 0:
            return ["[UNK]"]
        tokens.append(word[:l])
        word = word[l:]
        if len(word) > 0:
            word = f"##{word}"
    return tokens

In [32]:
print(encode_word("Hugging", vocab))
print(encode_word("HOgging", vocab))

['Hugg', '##i', '##n', '##g']
['[UNK]']


In [33]:
def tokenize(text):
    pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)
    pre_tokenized_text = [word for word, _ in pre_tokenize_result]
    encoded_words = [encode_word(word, vocab) for word in pre_tokenized_text]
    return sum(encoded_words, [])

In [35]:
print(tokenize("This is the Hugging Face course!"))

['Th', '##i', '##s', 'is', 'th', '##e', 'Hugg', '##i', '##n', '##g', 'Fac', '##e', 'c', '##o', '##u', '##r', '##s', '##e', '[UNK]']
