In [104]:
from collections import defaultdict
from typing import List
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]



In [138]:
class WordPieceTokenizer:
    """Google的WordPiece分词法（未开源，根据原理实现），代码借鉴https://huggingface.co/learn/nlp-course/zh-CN/chapter6/6
    
    Args:
        max_iters (int): 退出条件，最大迭代次数
        vocab_size (int): 退出条件，最大词表长度
        corpus (List[str]): 原始语料库
        punctuations (List[str]): 需要特殊对待的标点符号，默认为[",", ".", "\"", "'"]
    """
    def __init__(
        self,
        max_iters:int=None,
        vocab_size:int=None,
        corpus:List[str]=None,
        punctuations:List[str]=[",", ".", "\"", "'"]
    ):
        self.max_iters = max_iters
        self.vocab_size = vocab_size
        self.vocab_dict = None
        self.punctuations = punctuations

        # 初始的词频统计表（所有词：词频），语料库内所有单词出现的次数，不更新，仅作为后续倍乘基数
        self.word_freqs = self.pre_states(corpus)       
        alphabet = self.get_alphabet(self.word_freqs)
        # 初始的词表，每次合并都需要更新（这里仅append，因此词表大小递增）
        self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + alphabet
        # 初始的词分词方案（所有词：原始分词方案），每次更新
        self.splits = self.get_split_dict(self.word_freqs)
    
    def pre_states(self, corpus):
        """对语料集进行预分词并统计每个词的频数"""
        word_freqs = defaultdict(int)
        for sentence in corpus:
            # 1.手动预分词：不能只根据空格分词，还要考虑标点
            words = sentence.split()
            for word in words:
                for punc in self.punctuations:
                    if word.endswith(punc):
                        word = word.replace(punc, "")
                        word_freqs[punc] += 1
                word_freqs[word] += 1
            '''# 2. 使用现有工具预分词
            words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(sentence)   # 返回一个字符串中所有word和标点及其起止位置
            words = [word for word, offset in words_with_offsets]
            for word in words:
                word_freqs[word] += 1'''
        return word_freqs
    
    def get_alphabet(self, word_freqs):
        """从词频统计表中获取当前词表"""
        alphabet = []
        for word in word_freqs.keys():
            if word[0] not in alphabet: # 首字母
                alphabet.append(word[0])
            for letter in word[1:]: # 后续字母
                if f"##{letter}" not in alphabet:
                    alphabet.append(f"##{letter}")
        alphabet.sort() # 根据ASCII码降序以适应MM算法
        return alphabet
    
    def get_split_dict(self, word_freqs):
        """对词频表中每个词进行原始分割并加入中间分隔符"""
        splits = {
            word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
            for word in word_freqs.keys()
        }
        return splits
    
    def compute_PMI_scores(self):
        """基于整个语料库的词频字典word_freqs的PMI点互信息得分计算：遍历统计pair频数（与BPE一致）和每个letter的频数"""
        letter_freqs = defaultdict(int)
        pair_freqs = defaultdict(int)
        for word, freq in self.word_freqs.items():
            split = self.splits[word]
            if len(split) == 1: # 单独处理只有一个元素无法成对的情况
                letter_freqs[split[0]] += freq
                continue
            for i in range(len(split) - 1):
                pair = (split[i], split[i + 1])
                letter_freqs[split[i]] += freq  # 统计每个词出现的频数
                pair_freqs[pair] += freq    # 统计每个pair出现的频数
            letter_freqs[split[-1]] += freq # 补充统计最后一个词的频数

        # 计算每个pair的点互信息：P(AB) / P(letter_A)*P(letter_B)
        scores = {
            pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])
            for pair, freq in pair_freqs.items()
        }
        return scores
    
    def merge_pair(self, best_pair):
        """子词融合，更新splits分割字典"""
        for word in self.splits.keys():
            split = self.splits[word]
            i = 0
            while i < len(split) - 1:
                if split[i] == best_pair[0] and split[i + 1] == best_pair[1]:
                    merge = best_pair[0] + best_pair[1][2:] if best_pair[1].startswith("##") else best_pair[0] + best_pair[1]
                    split = split[:i] + [merge] + split[i+2:]
                else:
                    i += 1
            self.splits[word] = split
    
    def train(self):
        """训练代码：基于极大似然的合并，以PMI近似"""
        while len(self.vocab) < self.vocab_size:
            print("#"*20)
            print(f"当前词表长度：{len(self.vocab)}")
            pair_scores = self.compute_PMI_scores()
            best_pair = max(pair_scores, key = pair_scores.get)
            max_score = pair_scores[best_pair]
            self.merge_pair(best_pair)
            new_token = (
                best_pair[0] + best_pair[1][2:]
                if best_pair[1].startswith("##")
                else best_pair[0] + best_pair[1]
            )
            self.vocab.append(new_token)
            print(f"当前新合并pair：{best_pair}")
            print(f"当前词表：{self.vocab}")
        self.get_vocab_dict()

    def get_vocab_dict(self):
        """根据id-token的隐射vocab获取反向token-id的映射vocab_dict"""
        vocab_dict = {}
        for i in range(len(self.vocab)):
            vocab_dict[self.vocab[i]] = i
        self.vocab_dict = vocab_dict

    def encode_word(self, word):
        """对单个单词最大正向匹配MM分词"""
        tokens = []
        while len(word) > 0:
            i = len(word)
            # 目标字符串待匹配子串的左端点不变，右侧依次回缩匹配
            while i > 0 and word[:i] not in self.vocab:
                i -= 1
            if i == 0:
                return ["[UNK]"]
            tokens.append(word[:i])
            word = word[i:]
            if len(word) > 0:
                word = f"##{word}"
        return tokens

    def encode(self, sentence):
        """对整句话进行分词"""
        words = sentence.split()

        tokens = []
        for word in words:
            punc_flag = None
            for punc in self.punctuations:
                if word.endswith(punc):
                    punc_flag = punc
                    break
            word = word.replace(punc_flag, "") if punc_flag else word
            tokens.extend(self.encode_word(word))
            if punc_flag:
                tokens.append(punc_flag)
        return [self.vocab_dict[token] for token in tokens]
    
    def decode(self, input_ids):
        tokens = [self.vocab[id] for id in input_ids]
        ret_list = ""
        pre = ""
        for token in tokens:
            if "#" not in token:
                if pre != "":
                    ret_list += f" {pre}" if ret_list != "" else pre
                pre = token
            else:
                pre += token.replace("#", "")
        if pre != "":
            if len(pre) == 1:
                if 'A' <= pre <= 'Z' or 'a' <= pre <= 'z':
                    ret_list += f" {pre}"
                else:
                    ret_list += pre
            else:
                ret_list += pre
        return ret_list


In [139]:
tokenizer = WordPieceTokenizer(vocab_size = 70, corpus = corpus)
tokenizer.train()

input = "This is the Hugging Face Course."
input_ids = tokenizer.encode(input)
output = tokenizer.decode(input_ids)
output

####################
当前词表长度：45
当前新合并pair：('a', '##b')
当前词表：['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y', 'ab']
####################
当前词表长度：46
当前新合并pair：('##f', '##u')
当前词表：['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y', 'ab', '##fu']
####################
当前词表长度：47
当前新合并pair：('F', '##a')
当前词表：['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.

'This is the Hugging Face Course.'

'This is the Hugging Face Course.'