In [1]:
from typing import List
from collections import defaultdict, OrderedDict
import json

class BPETokenizer:
    """BPE (Byte-Pair-Encoding)分词器

    根据传入参数初始化分词器，设置最大迭代次数，并将原始语料库初始化为词表字典格式，如{'l o w </w>': 5, '}

    Args:
        max_iter (int): 最大迭代合并次数，也直接决定最终词表大小
        max_size (int): 最大词表大小
        corpus (List[str]): 原始语料库，可以是文章等。以字符串列表形式输入，当为None时表示已经有预训练的词表了。
        merge_file (str): txt格式的合并规则表路径，每行是一个合并字符对，优先级递减。
        vocab_file (str): json格式的词表路径
        unknown (str): 未知词的解码形式
        cased (bool): 是否区分大小写，默认为cased=False即不区分大小写
    """
    def __init__(
        self, 
        max_iter:int=None, 
        max_size:int=None,
        merge_file:str="merges.txt",
        vocab_file:str="vocab.json",
        corpus:List[str]=None, 
        unknown:str="<unk>",
        cased:bool=False
    ):
        self.max_iter = max_iter
        self.max_size = max_size
        self.merge_file = merge_file
        self.vocab_file = vocab_file
        self.unknown = unknown
        self.cased = cased
        
        if corpus != None:  # 如果存在原始语料库，说明需要进行训练
            self.vocab = {unknown: 0}
            self.bpe_ranks = None
            self.vocab_freqs = self.init_vocab_from_corpus(corpus)
            print(f"BPE分词器初始化完成，max_iter为{max_iter}，corpus长度为{len(corpus)}，暂无可用词表和合并规则表，请进行训练！")
        else:   # 否则直接加载
            with open(merge_file, "r", encoding="utf-8") as merges_hander:
                merges = merges_hander.read().split("\n")
            merges = [tuple(merge.split()) for merge in merges]
            self.bpe_ranks = OrderedDict(zip(merges, range(len(merges))))  # 有序的合并规则字典如{('sc', 'ssax'): 1}
            with open(vocab_file, "r", encoding="utf-8") as vocab_hander:
                self.vocab = json.load(vocab_hander)    # 字符到整数编码字典如{"like": 0}
            self.get_reverse_vocab()
            print(f"BPE分词器初始化完成，max_iter为{max_iter}，corpus为None，已有可用词表和合并规则表，请直接进行编解码应用！")
        
    def init_vocab_from_corpus(self, corpus):
        """从原始语料库初始化词表，包括空格分词、统计、添加尾部符号并进行内部字符划分四个步骤
        """
        # 1. 空格分词
        word_list = []
        for line in corpus:
            line = line.lower() if not self.cased else line
            for word in line.split():
                word_list.append(word)  
        
        # 2. 统计，添加尾部符号，并内部字符划分
        vocab_freqs = {}
        for item in set(word_list):
            new_item = " ".join(list(item)) + " </w>"
            vocab_freqs[new_item] = word_list.count(item)
        
        return vocab_freqs

    def get_stats(self):
        """统计全局词元对频数"""
        pairs = defaultdict(int)
        for word in self.vocab_freqs:
            subwords = word.split()
            for i in range(len(subwords)-1):
                pair = (subwords[i], subwords[i+1])
                pairs[pair] += self.vocab_freqs[word]   # 单词本身出现的次数
        return pairs

    def merges(self, pair):
        "合并最高频词元对并修改词频字典"
        new_dict = {}
        for word in self.vocab_freqs:
            if " ".join(pair) in word:
                new_word = word.replace( " ".join(pair), "".join(pair))
                new_dict[new_word] = self.vocab_freqs[word]
            else:
                new_dict[word] = self.vocab_freqs[word]
        self.vocab_freqs = new_dict
    
    def get_tokens(self):
        """从现有词频表中获取词表"""
        subword_list = []
        for word in self.vocab_freqs:
            subwords = word.split()
            for i in range(len(subwords)):
                subword_list.append(subwords[i])
        subword_list = set(subword_list)
        id = 1
        new_vocab = {self.unknown:0}
        for subword in subword_list:
            new_vocab[subword] = id
            id += 1
        self.vocab = new_vocab
    
    def get_reverse_vocab(self):
        """获取self.vocab的反向索引{0:'<unk>'}"""
        self.reverse_vocab = {}
        for token, id in self.vocab.items():
            self.reverse_vocab[id] = token
    
    def save(self):
        """保存两个文件"""
        with open(self.merge_file, "w", encoding="utf-8") as merge_hander:
            for pair in self.bpe_ranks.keys():
                merge_hander.write(" ".join(pair) + "\n")
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_hander:
            json.dump(self.vocab, vocab_hander, indent=4)

    def train(self):
        """训练器
        
        从初始化的词频表中迭代，每次合并最高频词元对并更新词频字典。
        词表的长度通常会经历先增大后减小的规律，即最开始pair出现的数量大于单个词元被替换的速度，
        例如sub和word组合为subword后sub和word还是单独在词表中；但后期有些固定组合出现例如be和st就会组合为best，且两者不单独出现。
        """
        merges= []

        iter_count = 0
        while True:
            # 0. 检查退出条件：迭代次数
            if self.max_iter and iter_count == self.max_iter:
                print(f"达到最大迭代次数{self.max_iter}，退出")
                break
            print('=' * 50)
            print('Iter: {}'.format(iter_count))
            
            # 1. 统计词元对频数
            pairs = self.get_stats()
            try:
                best_pair = max(pairs, key = pairs.get)
            except:
                print(f"不可再划分词元对，退出！")
                break
            
            # 2. 检查退出条件：是否最大频数为1
            if pairs[best_pair] == 1:
                print(f"最大频数为1，退出！")
                break

            # 3. 合并最高频词元对并修改词频字典
            merges.append(best_pair)
            self.merges(best_pair)
            self.get_tokens()
            
            # 4. 检查退出条件：最大词表大小
            vocab_size = len(self.vocab)
            if self.max_size and vocab_size == self.max_size:
                print(f"达到最大词表size{self.max_size}，退出")
                break
            
            iter_count += 1
            print('当前Iter出现频率最大的连续字符对为：{}'.format(best_pair))
            print('当前Iter新生成的词表为：{}'.format(self.vocab))
            print('当前Iter新生成的词表长度为: {}'.format(len(self.vocab)))


        self.bpe_ranks = OrderedDict(zip(merges, range(len(merges))))
        self.get_tokens()
        self.get_reverse_vocab()
        self.save()

    def encode(self, input):
        """编码器：基于合并规则进行分词"""
        # 1. 将输入先按空格分隔（单词末尾符号），再分隔为单字符
        input = input.lower() if not self.cased else input
        word_list = [" ".join(list(word)) + " </w>" for word in input.split()]

        # 2. 基于最高频合并规则的迭代分词 
        for i in range(len(word_list)):
            word = word_list[i]
            while True: 
                subword = word.split()
                pair_appear = {}
                for j in range(len(subword)-1):
                    pair = (subword[j], subword[j+1])
                    if pair in self.bpe_ranks:
                        pair_appear[pair] = self.bpe_ranks[pair]
                try:
                    best = max(pair_appear, key=pair_appear.get)
                    word = word.replace(" ".join(best), "".join(best))
                except:
                    break
            word_list[i] = word

        # 3. 将词元映射为id
        ret_list = []
        for word in word_list:
            subwords = word.split()
            for subword in subwords:
                if subword in self.vocab:
                    ret_list.append(self.vocab[subword])
                else:
                    ret_list.append(self.vocab["<unk>"])
        return ret_list

    def decode(self, token_ids):
        """解码器：将id映射为token并组装，如果token中含有</w>表示是单词末尾，需替换为空格，最后一个空格也要抹除"""
        ret_list = [self.reverse_vocab[id] for id in token_ids]
        ret_str = ""
        for word in ret_list:
            if "</w>" in word:
                ret_str += word.replace("</w>", " ")
            else:
                ret_str += word
        return ret_str[:-1]
        

In [2]:
if __name__ == "__main__":
    corpus = [
        "Machine learning has become very popular in recent years and many researchers are focusing on machine learning techniques.",
        "Artificial intelligence is a field that overlaps with machine learning and is often used in various machine learning applications.",
        "Deep learning, a subset of machine learning, has been shown to outperform traditional machine learning algorithms in many tasks.",
        "Data science relies heavily on machine learning algorithms to analyze and interpret complex data for making decisions.",
        "Reinforcement learning is a type of machine learning where an agent learns by interacting with an environment to maximize rewards.",
        "Supervised learning is one of the most common types of machine learning, where models are trained on labeled data.",
        "Unsupervised learning, unlike supervised learning, deals with data that is not labeled, often finding hidden patterns in data.",
        "In machine learning, feature engineering plays a crucial role in improving the performance of learning models and algorithms.",
        "The application of machine learning in healthcare has the potential to revolutionize the industry by improving diagnostic accuracy.",
        "Natural language processing, a subfield of artificial intelligence, often uses machine learning techniques to analyze and generate human language."
    ]
    tokenizer = BPETokenizer(max_iter=100, max_size=70, corpus=corpus, cased=True)

    tokenizer.train()

    input = "Machine learning has become very popular in recent years and many researchers are focusing on machine learning techniques."
    input_ids = tokenizer.encode(input)
    output = tokenizer.decode(input_ids)
    
    # 不传入corpus参数，直接从本地加载
    tokenizer_pre = BPETokenizer()

BPE分词器初始化完成，max_iter为100，corpus长度为10，暂无可用词表和合并规则表，请进行训练！
Iter: 0
当前Iter出现频率最大的连续字符对为：('i', 'n')
当前Iter新生成的词表为：{'<unk>': 0, 'm': 1, 'h': 2, 'n': 3, 'z': 4, 'w': 5, 'N': 6, 'v': 7, 'e': 8, 'g': 9, 'f': 10, 'A': 11, 'R': 12, 'c': 13, 'p': 14, 'o': 15, 'q': 16, 'r': 17, '.': 18, 'M': 19, 'b': 20, '</w>': 21, 'd': 22, 'T': 23, 'a': 24, 'l': 25, 't': 26, 'u': 27, 'k': 28, 'S': 29, 'i': 30, 'U': 31, 'I': 32, 'x': 33, 's': 34, 'D': 35, 'y': 36, ',': 37, 'in': 38}
当前Iter新生成的词表长度为: 39
Iter: 1
当前Iter出现频率最大的连续字符对为：('e', '</w>')
当前Iter新生成的词表为：{'<unk>': 0, 'm': 1, 'h': 2, 'n': 3, 'z': 4, 'w': 5, 'N': 6, 'e</w>': 7, 'v': 8, 'e': 9, 'g': 10, 'f': 11, 'A': 12, 'R': 13, 'c': 14, 'p': 15, 'o': 16, 'q': 17, 'r': 18, '.': 19, 'M': 20, 'b': 21, '</w>': 22, 'd': 23, 'T': 24, 'a': 25, 'l': 26, 't': 27, 'u': 28, 'k': 29, 'S': 30, 'i': 31, 'U': 32, 'I': 33, 'x': 34, 's': 35, 'D': 36, 'y': 37, ',': 38, 'in': 39}
当前Iter新生成的词表长度为: 40
Iter: 2
当前Iter出现频率最大的连续字符对为：('a', 'r')
当前Iter新生成的词表为：{'<unk>': 0, 'm': 1, 'h': 2,

## BBPE：Byte-level BPE

In [5]:
initial_vocab = [bytes([byte]) for byte in range(256)]
initial_vocab

[b'\x00',
 b'\x01',
 b'\x02',
 b'\x03',
 b'\x04',
 b'\x05',
 b'\x06',
 b'\x07',
 b'\x08',
 b'\t',
 b'\n',
 b'\x0b',
 b'\x0c',
 b'\r',
 b'\x0e',
 b'\x0f',
 b'\x10',
 b'\x11',
 b'\x12',
 b'\x13',
 b'\x14',
 b'\x15',
 b'\x16',
 b'\x17',
 b'\x18',
 b'\x19',
 b'\x1a',
 b'\x1b',
 b'\x1c',
 b'\x1d',
 b'\x1e',
 b'\x1f',
 b' ',
 b'!',
 b'"',
 b'#',
 b'$',
 b'%',
 b'&',
 b"'",
 b'(',
 b')',
 b'*',
 b'+',
 b',',
 b'-',
 b'.',
 b'/',
 b'0',
 b'1',
 b'2',
 b'3',
 b'4',
 b'5',
 b'6',
 b'7',
 b'8',
 b'9',
 b':',
 b';',
 b'<',
 b'=',
 b'>',
 b'?',
 b'@',
 b'A',
 b'B',
 b'C',
 b'D',
 b'E',
 b'F',
 b'G',
 b'H',
 b'I',
 b'J',
 b'K',
 b'L',
 b'M',
 b'N',
 b'O',
 b'P',
 b'Q',
 b'R',
 b'S',
 b'T',
 b'U',
 b'V',
 b'W',
 b'X',
 b'Y',
 b'Z',
 b'[',
 b'\\',
 b']',
 b'^',
 b'_',
 b'`',
 b'a',
 b'b',
 b'c',
 b'd',
 b'e',
 b'f',
 b'g',
 b'h',
 b'i',
 b'j',
 b'k',
 b'l',
 b'm',
 b'n',
 b'o',
 b'p',
 b'q',
 b'r',
 b's',
 b't',
 b'u',
 b'v',
 b'w',
 b'x',
 b'y',
 b'z',
 b'{',
 b'|',
 b'}',
 b'~',
 b'\x7f',
 b'\x80',
