# Byte Pair Encoding (BPE) Tokenizer

BPE是一种常用的子词(subword)分词算法，被广泛应用于现代语言模型(如GPT系列)中。

## 统计相邻token频率

- 功能：统计token序列中所有相邻token对的出现频率
- 示例输入：`[1, 2, 3, 1, 2]`
- 输出：`{(1, 2): 2, (2, 3): 1, (3, 1): 1}` 
- 说明：token对(1,2)出现2次，其他各出现1次

In [1]:
def get_stats(ids, counts=None):
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]): # 遍历连续的token对
        counts[pair] = counts.get(pair, 0) + 1
    return counts

example = [1, 2, 3, 1, 2] # token id 序列
counts = get_stats(example)
print('token id array:')
print(example)
print('get stats:')
print(counts) # 相邻token出现频次

token id array:
[1, 2, 3, 1, 2]
get stats:
{(1, 2): 2, (2, 3): 1, (3, 1): 1}


## 合并token对

- 功能：将序列中所有指定的token对替换为一个新token
- 示例：将`[1, 2, 3, 1, 2]`中的(1,2)替换为4
- 输出：`[4, 3, 4]`

In [2]:
def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            newids.append(idx)  # 匹配到pair则替换为新token
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

ids=[1, 2, 3, 1, 2]
pair=(1, 2)
# 在 ids 中用 pair 匹配，匹配到替换为新 token id 4
newids = merge(ids, pair, 4)
print(newids)

[4, 3, 4]


## BPE Tokenizer实现

### 1 编码阶段（构建词汇表）
**输入**：原始文本 + 预设词汇表大小  
**输出**：包含常见子词的词汇表

**步骤**：
1. **初始化**：
   - 将文本拆分为最小单元（如ASCII字符或字节）
   - 初始词汇表=所有基础字符

2. **迭代合并**：
   - **统计频率**：计算所有相邻字节对的出现频率
   - **合并最高频对**：将最高频的字节对合并为新符号
   - **更新词汇表**：将新符号加入词汇表
   - **替换文本**：用新符号替换所有该字节对的出现

3. **终止条件**：
   - 达到预设词汇表大小
   - 或无可合并的字节对（所有频次=1）

### 2 解码阶段

逆向操作：从最高ID开始，逐步将合并符号替换回原始字节对。

In [3]:
INITIAL_VOCAB_SIZE = 256

class BasicTokenizer():
    def __init__(self):
        self.merges = {}  # 存储合并规则：(token1, token2) -> new_token
        self.vocab = self.build_vocab()  # token_id到字节的映射
        
    def build_vocab(self):
        # 初始词表包含所有单字节(0-255)
        vocab = {idx: bytes([idx]) for idx in range(INITIAL_VOCAB_SIZE)}
        # 添加合并后的token
        for (p0, p1), idx in self.merges.items():
            # bytes 加法等价于字符串拼接
            vocab[idx] = vocab[p0] + vocab[p1]
        return vocab

    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= INITIAL_VOCAB_SIZE
        num_merges = vocab_size - INITIAL_VOCAB_SIZE

        text_bytes = text.encode("utf-8") 
        ids = list(text_bytes)  # 初始化为字节级token

        merges = {} 
        # int -> bytes (初始词表，直接 idx 到字节映射)
        vocab = {idx: bytes([idx]) for idx in range(INITIAL_VOCAB_SIZE)}

        for i in range(num_merges):
            stats = get_stats(ids)  # 统计相邻token频率
            pair = max(stats, key=stats.get)  # 选择最频繁的token对
            new_idx = INITIAL_VOCAB_SIZE + i  # 新token的id从256开始分配
            ids = merge(ids, pair, new_idx)  # 合并token对
            merges[pair] = new_idx  # 记录合并规则
            # 原来的词不会剔除，而是在基础词表上增加
            vocab[new_idx] = vocab[pair[0]] + vocab[pair[1]]  # 更新词表

        self.merges = merges # used in encode()
        self.vocab = vocab   # used in decode()

In [4]:
text = '''   
Cats never fail to fascinate human beings.
They can be friendly and affectionate towards humans, but they lead mysterious lives of their own as well.
They never become submissive like dogs and horses. As a result, humans have learned to respect feline independence.
Most cats remain suspicious of humans all their lives.
One of the things that fascinates us most about cats is the popular belief that they have nine lives.
Apparently, there is a good deal of truth in this idea. A cat's ability to survive falls is based on fact.
'''

text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
print(len(ids))

bpe = BasicTokenizer()
bpe.train(text, vocab_size=266)
for i in range(256 , 266, 1):
    print(bpe.vocab[i])

print(bpe.merges)

534
b's '
b'e '
b' t'
b'at'
b'in'
b' th'
b'an'
b'.\n'
b'li'
b've'
{(115, 32): 256, (101, 32): 257, (32, 116): 258, (97, 116): 259, (105, 110): 260, (258, 104): 261, (97, 110): 262, (46, 10): 263, (108, 105): 264, (118, 101): 265}


In [5]:
# Encode
# utf-8 token ids
text = 'do you like cats'
text_bytes = text.encode("utf-8") # raw bytes
# 首先对数据转成字符的 token id
# 再将 raw token id 按照 merges 表对 raw token id 进行合并 -> token_id

# bpe token ids
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
    stats = get_stats(ids)
    # 结果取min: merge对应idx越小，出现的频率越高
    pair = min(stats, key=lambda p: bpe.merges.get(p, float("inf"))) 
    print('pair:')
    print(pair)
    print(bpe.vocab[pair[0]], bpe.vocab[pair[1]])
    if pair not in bpe.merges:
        break 
    idx = bpe.merges[pair] # (3,4) -> 268
    ids = merge(ids, pair, idx) # (2,3,4,5) -> (2, 268, 5)
print(len(text))
print(len(ids))
print(ids)

pair:
(101, 32)
b'e' b' '
pair:
(97, 116)
b'a' b't'
pair:
(108, 105)
b'l' b'i'
pair:
(100, 111)
b'd' b'o'
16
13
[100, 111, 32, 121, 111, 117, 32, 264, 107, 257, 99, 259, 115]


In [6]:
# Decode
text_bytes = b"".join(bpe.vocab[idx] for idx in ids)
decode_text = text_bytes.decode("utf-8", errors="replace")
print(decode_text)

do you like cats
