## WordPiece
与BPE算法类似，但是有些许不同
1. 合并规则不同，BPE每次合并选取当前pairs中频率最大的pair进行合并，而WordPiece按如下公式选取每次合并的pair
$$ score = \frac{freq\_of\_pair}{(freq\_of\_first\_element x freq\_of\_second\_element)}$$
2. 表示方式略有不同，如 'hugging' -> ("hu", "##gging") 

## Demo

### 一、准备语料，计算word_freqs,初始化vocab

In [388]:
corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]

首先使用pre_tokenize处理字符串，转为word_list

In [389]:
import regex as re
from typing import List

def pre_tokenize(text: str, do_lower_case: bool = False) -> List[str]:
    if do_lower_case:
        text = text.lower()
    # 去控制字符
    text = re.sub(r"[\u0000-\u001F\u007F]", "", text)
    # 给 CJK 字符加空格
    text = re.sub(r"([\p{Han}\p{Hiragana}\p{Katakana}\p{Hangul}])", r" \1 ", text)
    # 给标点加空格（Unicode Punctuation）
    text = re.sub(r"([\p{P}])", r" \1 ", text)
    # 按空格拆分
    tokens = re.split(r"\s+", text.strip())
    return [t for t in tokens if t]

pre_tokenize(corpus[3])

['Hopefully',
 ',',
 'you',
 'will',
 'be',
 'able',
 'to',
 'understand',
 'how',
 'they',
 'are',
 'trained',
 'and',
 'generate',
 'tokens',
 '.']

计算word_freqs

In [390]:
from collections import defaultdict
def word_split(word):
    return [word[0]] + [f"##{ch}" for ch in word[1:]]

def get_word_freqs(corpus):
    word_freqs = defaultdict(int)
    for text in corpus:
        words = pre_tokenize(text)
        for word in words:
            word_freqs[tuple(word_split(word))] += 1
    return word_freqs

print(get_word_freqs(corpus))

defaultdict(<class 'int'>, {('T', '##h', '##i', '##s'): 3, ('i', '##s'): 2, ('t', '##h', '##e'): 1, ('H', '##u', '##g', '##g', '##i', '##n', '##g'): 1, ('F', '##a', '##c', '##e'): 1, ('C', '##o', '##u', '##r', '##s', '##e'): 1, ('.',): 4, ('c', '##h', '##a', '##p', '##t', '##e', '##r'): 1, ('a', '##b', '##o', '##u', '##t'): 1, ('t', '##o', '##k', '##e', '##n', '##i', '##z', '##a', '##t', '##i', '##o', '##n'): 1, ('s', '##e', '##c', '##t', '##i', '##o', '##n'): 1, ('s', '##h', '##o', '##w', '##s'): 1, ('s', '##e', '##v', '##e', '##r', '##a', '##l'): 1, ('t', '##o', '##k', '##e', '##n', '##i', '##z', '##e', '##r'): 1, ('a', '##l', '##g', '##o', '##r', '##i', '##t', '##h', '##m', '##s'): 1, ('H', '##o', '##p', '##e', '##f', '##u', '##l', '##l', '##y'): 1, (',',): 1, ('y', '##o', '##u'): 1, ('w', '##i', '##l', '##l'): 1, ('b', '##e'): 1, ('a', '##b', '##l', '##e'): 1, ('t', '##o'): 1, ('u', '##n', '##d', '##e', '##r', '##s', '##t', '##a', '##n', '##d'): 1, ('h', '##o', '##w'): 1, ('t', '##

 初始化vocab

In [391]:
def get_vocab(corpus : str,special_tokens:List[str] = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]):
    voacb = []
    word_freqs = get_word_freqs(corpus)
    for word in word_freqs.keys():
        if word[0] not in voacb:
            voacb.append(word[0])
        for ch in word[1:]:
            if ch not in voacb:
                voacb.append(ch)
    return special_tokens + sorted(voacb)
print(get_vocab(corpus))    

['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y']


### 二、 计算分数

In [392]:

def compute_pair_scores(word_freqs):
    letter_freqs = defaultdict(int)
    pair_freqs = defaultdict(int)
    for word,freq in word_freqs.items():
        if len(word) == 1:
            letter_freqs[word[0]] += freq
            continue
        for i in range(len(word) - 1):
            pair = (word[i],word[i + 1])
            letter_freqs[word[i]] += freq
            pair_freqs[pair] += freq
        letter_freqs[word[-1]] += freq
    score = {
        pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]]) 
        for pair,freq in pair_freqs.items() 
    }
    return score
compute_pair_scores(get_word_freqs(corpus))

{('T', '##h'): 0.125,
 ('##h', '##i'): 0.03409090909090909,
 ('##i', '##s'): 0.02727272727272727,
 ('i', '##s'): 0.1,
 ('t', '##h'): 0.03571428571428571,
 ('##h', '##e'): 0.011904761904761904,
 ('H', '##u'): 0.1,
 ('##u', '##g'): 0.05,
 ('##g', '##g'): 0.0625,
 ('##g', '##i'): 0.022727272727272728,
 ('##i', '##n'): 0.01652892561983471,
 ('##n', '##g'): 0.022727272727272728,
 ('F', '##a'): 0.14285714285714285,
 ('##a', '##c'): 0.07142857142857142,
 ('##c', '##e'): 0.023809523809523808,
 ('C', '##o'): 0.07692307692307693,
 ('##o', '##u'): 0.046153846153846156,
 ('##u', '##r'): 0.022222222222222223,
 ('##r', '##s'): 0.022222222222222223,
 ('##s', '##e'): 0.004761904761904762,
 ('c', '##h'): 0.125,
 ('##h', '##a'): 0.017857142857142856,
 ('##a', '##p'): 0.07142857142857142,
 ('##p', '##t'): 0.07142857142857142,
 ('##t', '##e'): 0.013605442176870748,
 ('##e', '##r'): 0.026455026455026454,
 ('a', '##b'): 0.2,
 ('##b', '##o'): 0.038461538461538464,
 ('##u', '##t'): 0.02857142857142857,
 ('t',

### 三、合并score最高的pair

In [393]:
def get_bestpair(pair_scores:dict):
    best_pair = ""
    max_score = None
    for pair, score in pair_scores.items():
        if max_score is None or max_score < score:
            best_pair = pair
            max_score = score
    return best_pair,max_score
get_bestpair(compute_pair_scores(get_word_freqs(corpus)))

(('a', '##b'), 0.2)

In [394]:
def merge_pair(a,b,word_freqs):
    new_word_freqs = defaultdict(int)
    for word,freq in word_freqs.items():
        if len(word) == 1:
            new_word_freqs[word] = freq
            continue
        i = 0
        new_word = []
        while i < len(word):
            if i < len(word) - 1 and word[i] == a and word[i + 1] == b:
                merge = a + b[2:] if b.startswith("##") else a + b
                new_word.append(merge)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_word_freqs[tuple(new_word)] = freq
    return new_word_freqs
merge_pair('a',"##b",get_word_freqs(corpus))

defaultdict(int,
            {('T', '##h', '##i', '##s'): 3,
             ('i', '##s'): 2,
             ('t', '##h', '##e'): 1,
             ('H', '##u', '##g', '##g', '##i', '##n', '##g'): 1,
             ('F', '##a', '##c', '##e'): 1,
             ('C', '##o', '##u', '##r', '##s', '##e'): 1,
             ('.',): 4,
             ('c', '##h', '##a', '##p', '##t', '##e', '##r'): 1,
             ('ab', '##o', '##u', '##t'): 1,
             ('t',
              '##o',
              '##k',
              '##e',
              '##n',
              '##i',
              '##z',
              '##a',
              '##t',
              '##i',
              '##o',
              '##n'): 1,
             ('s', '##e', '##c', '##t', '##i', '##o', '##n'): 1,
             ('s', '##h', '##o', '##w', '##s'): 1,
             ('s', '##e', '##v', '##e', '##r', '##a', '##l'): 1,
             ('t', '##o', '##k', '##e', '##n', '##i', '##z', '##e', '##r'): 1,
             ('a',
              '##l',
              '##

### 四、train,定义vocab_size，重复2,3步骤

In [395]:
def train(vocab_size,corpus):
    merges = []
    vocab = get_vocab(corpus)
    word_freqs = get_word_freqs(corpus)
    while len(vocab) < vocab_size:
        pair_scores = compute_pair_scores(word_freqs)
        best_pair,_ = get_bestpair(pair_scores)
        a,b = best_pair[0],best_pair[1]
        vocab.append(a + b[2:] if b.startswith("##") else a + b)
        merges.append(best_pair)
        word_freqs = merge_pair(a,b,word_freqs)
    return vocab,merges
vocab,merges = train(70,corpus)

### 五、使用训练好的merge,vocab进行encode和decode

In [396]:
def encode_word(word):
    tokens = []
    while len(word) > 0:
        i = len(word)
        while i > 0 and word[:i] not in vocab:
            i -= 1
        if i == 0:
            return ["[UNK]"]
        tokens.append(word[:i])
        word = word[i:]
        if len(word) > 0:
            word = f"##{word}"
    return tokens
def tokenize(text):
    tokens = pre_tokenize(text)
    return [tk for token in tokens for tk in encode_word(token) ]
print(tokenize("This is the Hugging Face course!"))

['Th', '##i', '##s', 'is', 'th', '##e', 'Hugg', '##i', '##n', '##g', 'Fac', '##e', 'c', '##o', '##u', '##r', '##s', '##e', '[UNK]']
