1. 准备语料,初始化vocab

In [43]:
corpus = "Transformers: the model-definition framework for state-of-the-art machine learning models in text, vision, audio, and multimodal models, for both inference and training."

In [44]:
vocab = []
for ch in corpus:
    if ch not in vocab:
        vocab.append(ch)
vocab = ["<|endoftext|>"] + sorted(vocab)
vocab

['<|endoftext|>',
 ' ',
 ',',
 '-',
 '.',
 ':',
 'T',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'k',
 'l',
 'm',
 'n',
 'o',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x']

2. 计算词频（便于后面计算），这里我们先pre_tokenizer（gpt2）

In [45]:
import regex as re
# GPT-2 原始 pre-tokenizer 正则
gpt2_regex = re.compile(
    r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def pre_tokenize_str(text: str) -> list[str]:
    tokens = gpt2_regex.findall(text)
    return tokens
print(pre_tokenize_str(corpus))
# ['Transformers', ':', ' the', ' model', '-', 'definition', ' framework', ' for', ' state', '-', 'of', '-', 'the', '-', 'art', ' machine', ' learning', ' models', ' in', ' text', ',', ' vision', ',', ' audio', ',', ' and', ' multimodal', ' models', ',', ' for', ' both', ' inference', ' and', ' training', '.']

['Transformers', ':', ' the', ' model', '-', 'definition', ' framework', ' for', ' state', '-', 'of', '-', 'the', '-', 'art', ' machine', ' learning', ' models', ' in', ' text', ',', ' vision', ',', ' audio', ',', ' and', ' multimodal', ' models', ',', ' for', ' both', ' inference', ' and', ' training', '.']


In [46]:
from collections import defaultdict

def get_word_freqs(text:str):
    words = pre_tokenize_str(text)
    word_freqs = defaultdict(int)
    
    for word in words:
        word_freqs[tuple(word)] += 1
    
    return word_freqs

print(get_word_freqs(corpus))

defaultdict(<class 'int'>, {('T', 'r', 'a', 'n', 's', 'f', 'o', 'r', 'm', 'e', 'r', 's'): 1, (':',): 1, (' ', 't', 'h', 'e'): 1, (' ', 'm', 'o', 'd', 'e', 'l'): 1, ('-',): 4, ('d', 'e', 'f', 'i', 'n', 'i', 't', 'i', 'o', 'n'): 1, (' ', 'f', 'r', 'a', 'm', 'e', 'w', 'o', 'r', 'k'): 1, (' ', 'f', 'o', 'r'): 2, (' ', 's', 't', 'a', 't', 'e'): 1, ('o', 'f'): 1, ('t', 'h', 'e'): 1, ('a', 'r', 't'): 1, (' ', 'm', 'a', 'c', 'h', 'i', 'n', 'e'): 1, (' ', 'l', 'e', 'a', 'r', 'n', 'i', 'n', 'g'): 1, (' ', 'm', 'o', 'd', 'e', 'l', 's'): 2, (' ', 'i', 'n'): 1, (' ', 't', 'e', 'x', 't'): 1, (',',): 4, (' ', 'v', 'i', 's', 'i', 'o', 'n'): 1, (' ', 'a', 'u', 'd', 'i', 'o'): 1, (' ', 'a', 'n', 'd'): 2, (' ', 'm', 'u', 'l', 't', 'i', 'm', 'o', 'd', 'a', 'l'): 1, (' ', 'b', 'o', 't', 'h'): 1, (' ', 'i', 'n', 'f', 'e', 'r', 'e', 'n', 'c', 'e'): 1, (' ', 't', 'r', 'a', 'i', 'n', 'i', 'n', 'g'): 1, ('.',): 1})


3. 计算相邻pairs频率

In [47]:
def compute_pair_freqs(word_freqs):
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        if len(word) == 1:
            continue
        for i in range(len(word) - 1):
            pair = (word[i], word[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs
compute_pair_freqs(get_word_freqs(corpus))

defaultdict(int,
            {('T', 'r'): 1,
             ('r', 'a'): 3,
             ('a', 'n'): 3,
             ('n', 's'): 1,
             ('s', 'f'): 1,
             ('f', 'o'): 3,
             ('o', 'r'): 4,
             ('r', 'm'): 1,
             ('m', 'e'): 2,
             ('e', 'r'): 2,
             ('r', 's'): 1,
             (' ', 't'): 3,
             ('t', 'h'): 3,
             ('h', 'e'): 2,
             (' ', 'm'): 5,
             ('m', 'o'): 4,
             ('o', 'd'): 4,
             ('d', 'e'): 4,
             ('e', 'l'): 3,
             ('e', 'f'): 1,
             ('f', 'i'): 1,
             ('i', 'n'): 7,
             ('n', 'i'): 3,
             ('i', 't'): 1,
             ('t', 'i'): 2,
             ('i', 'o'): 3,
             ('o', 'n'): 2,
             (' ', 'f'): 3,
             ('f', 'r'): 1,
             ('a', 'm'): 1,
             ('e', 'w'): 1,
             ('w', 'o'): 1,
             ('r', 'k'): 1,
             (' ', 's'): 1,
             ('s', 't'): 1,
   

4. 找到频率最大的pair 并合并

In [48]:
merges = []

In [49]:
def merge_pair(a, b, word_freqs):
    new_word_freqs = {}
    for word, freq in word_freqs.items():  
        if len(word) == 1:
            new_word_freqs[word] = freq
            continue
        word_list = list(word)  
        i = 0
        while i < len(word_list) - 1:
            if word_list[i] == a and word_list[i + 1] == b:
                word_list[i:i+2] = [a + b]  
                i += 1  
            else:
                i += 1
        new_word_freqs[tuple(word_list)] = freq  
    return new_word_freqs


In [50]:
def train_step(word_freqs,pair_freqs):
    best_pair = ""
    max_freq = None
    for pair, freq in pair_freqs.items():
        if max_freq is None or max_freq < freq:
            best_pair = pair
            max_freq = freq
    merges.append(best_pair)
    vocab.append(best_pair[0] + best_pair[1])
    return merge_pair(best_pair[0],best_pair[1],word_freqs)

5. 设定vocab_size,重复3、4训练最后vocab

In [51]:
def train(vocab_size,word_freqs):
    while len(vocab) < vocab_size:
        pair_freqs = compute_pair_freqs(word_freqs)
        word_freqs = train_step(word_freqs,pair_freqs)
        
train(80,get_word_freqs(corpus))

6. 训练好的merges和vocab继续tokenize

In [52]:
def bpe_encode_word(word, merges):
    """对单个word执行BPE编码"""
    word = list(word)
    for merge in merges:  
        a, b = merge
        i = 0
        new_word = []
        while i < len(word):
            if i < len(word) - 1 and word[i] == a and word[i + 1] == b:
                new_word.append(a + b)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        word = new_word  
    
    return word

In [53]:
def tokenize(text:str):
    tokens = pre_tokenize_str(text)
    output_tokens = []
    for token in tokens:
        encoded = bpe_encode_word(token, merges)
        output_tokens.extend(encoded)
    return output_tokens
tokenize(corpus)

['Transformers',
 ':',
 ' the',
 ' model',
 '-',
 'definition',
 ' framework',
 ' for',
 ' state',
 '-',
 'of',
 '-',
 'the',
 '-',
 'art',
 ' mac',
 'h',
 'in',
 'e',
 ' ',
 'l',
 'e',
 'ar',
 'n',
 'ing',
 ' models',
 ' in',
 ' t',
 'e',
 'x',
 't',
 ',',
 ' ',
 'v',
 'i',
 's',
 'ion',
 ',',
 ' a',
 'u',
 'd',
 'io',
 ',',
 ' and',
 ' m',
 'u',
 'l',
 't',
 'i',
 'm',
 'od',
 'a',
 'l',
 ' models',
 ',',
 ' for',
 ' ',
 'b',
 'o',
 't',
 'h',
 ' in',
 'f',
 'e',
 'r',
 'e',
 'n',
 'c',
 'e',
 ' and',
 ' t',
 'ra',
 'in',
 'ing',
 '.']

In [54]:
def encode(text: str):
    """完整编码流程：文本 -> tokens -> token_ids"""
    tokens = tokenize(text)
    token_ids = [vocab.index(token) for token in tokens]
    return token_ids

print(encode(corpus))

[56, 5, 57, 37, 3, 63, 69, 43, 74, 3, 75, 3, 76, 3, 77, 79, 14, 28, 11, 1, 17, 11, 44, 19, 45, 46, 47, 34, 11, 27, 23, 2, 1, 25, 15, 22, 42, 2, 39, 24, 10, 38, 2, 49, 29, 24, 17, 23, 15, 18, 31, 7, 17, 46, 2, 43, 1, 8, 20, 23, 14, 47, 12, 11, 21, 11, 19, 9, 11, 49, 34, 32, 28, 45, 4]


In [None]:
def decode(token_ids: list[int]):
    """将 token_ids 转回原始文本"""
    tokens = [vocab[i] for i in token_ids]
    text = "".join(tokens)
    return text
print(decode(encode(corpus)) == corpus) # True

True
