# Manipulate SentencePiece Vocabulary

## References

* [add new vocab](https://github.com/google/sentencepiece/blob/9cf136582d9cce492ba5a0cfb775f9e777fe07ea/python/add_new_vocab.ipynb) from google/sentencepiece
* [reduce vocab](https://github.com/bojone/t5_in_bert4keras/blob/6cf50dbf3ffd3b4e9f36a59ee9f98356cf686de0/tokenizer/reduce.py) from bojone/t5_in_bert4keras

## Get a pretrained tokenizer (mT5-small)

In [None]:
from pathlib import Path
import shutil
Path("cache/").mkdir(exist_ok=True)
if Path("cache/mt5-small").exists():
    shutil.rmtree("cache/mt5-small")

In [2]:
from transformers import MT5Tokenizer

tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small")

In [3]:
tokenizer.save_pretrained("cache/mt5-small")

('cache/mt5-small/tokenizer_config.json',
 'cache/mt5-small/special_tokens_map.json',
 'cache/mt5-small/spiece.model',
 'cache/mt5-small/added_tokens.json')

### Get a Dataset (XNLI)

We want to retain only pieces that are used in this dataset.

In [4]:
from datasets import load_dataset

dataset = load_dataset("xnli", "zh")

Reusing dataset xnli (/home/ceshine/.cache/huggingface/datasets/xnli/zh/1.1.0/51ba3a1091acf33fd7c2a54bcbeeee1b1df3ecb127fdca003d31968fa3a1e6a8)


In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 392702
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 5010
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 2490
    })
})

In [6]:
dataset["train"]["hypothesis"][:10]

['产品 和 地理 是 什么 使 奶油 抹 霜 工作 .',
 '如果 人们 记得 的 话 , 你 就 会 把 事情 弄 丢 了 .',
 '我 团队 的 一个 成员 将 非常 精确 地 执行 你 的 命令',
 '这些 信息 属于 他们 .',
 '网球鞋 有 一 系列 的 价格 .',
 '我 很 难 过 我 的 随身听 坏 了 现在 我 得 把 音响 调 大 一点',
 '大多数 基督教 马赛克 都 被 穆斯林 摧毁 .',
 '石板 对 杰克逊 的 调查 结果 有 意见',
 '异性恋者',
 '孚日 广场 完全 是 用 灰色 大理石 建造 的 .']

In [7]:
tokenizer.batch_encode_plus(dataset["train"]["hypothesis"][:8], return_attention_mask=False)

{'input_ids': [[259, 15104, 259, 1107, 259, 148479, 259, 1543, 259, 16892, 259, 12561, 259, 63749, 10920, 259, 126767, 259, 155228, 259, 6573, 259, 260, 1], [259, 21304, 259, 79316, 259, 177378, 259, 493, 259, 30253, 259, 261, 259, 4235, 259, 3981, 259, 2219, 259, 9803, 259, 52597, 259, 91253, 259, 142089, 259, 1322, 259, 260, 1], [259, 3003, 259, 61105, 259, 493, 259, 8149, 259, 98581, 259, 3661, 259, 25265, 259, 12348, 107310, 259, 2524, 259, 55958, 259, 4235, 259, 493, 259, 129300, 1], [259, 20155, 259, 12359, 259, 80922, 259, 16171, 259, 260, 1], [259, 1758, 8320, 62043, 259, 1637, 259, 1374, 259, 27858, 259, 493, 259, 21919, 259, 260, 1], [259, 3003, 259, 10559, 259, 20481, 259, 6994, 259, 3003, 259, 493, 259, 24470, 7431, 24762, 259, 90707, 259, 1322, 259, 24150, 259, 3003, 259, 5880, 259, 9803, 259, 7647, 42797, 259, 19477, 259, 1146, 259, 39200, 1], [259, 155598, 259, 16746, 55627, 11072, 259, 6890, 16003, 9636, 259, 4794, 259, 3916, 259, 131330, 9684, 6892, 259, 234510, 113286

In [8]:
from itertools import chain
from tqdm import tqdm

def tokenize_data(data, batch_size=1024):
    global seen
    for i in tqdm(range(0, len(data), batch_size)):
        seen = seen.union(
            set(chain.from_iterable(tokenizer.batch_encode_plus(data[i:(i+batch_size)], return_attention_mask=False)["input_ids"]))
        )

In [9]:
seen = set()
for subset in ("train", "test", "validation"):
    print(subset)
    tokenize_data(dataset[subset]["hypothesis"])
    tokenize_data(dataset[subset]["premise"])

  1%|          | 2/384 [00:00<00:27, 13.89it/s]

train


100%|██████████| 384/384 [00:25<00:00, 14.95it/s]
100%|██████████| 384/384 [00:43<00:00,  8.85it/s]
 40%|████      | 2/5 [00:00<00:00, 18.02it/s]

test


100%|██████████| 5/5 [00:00<00:00, 18.68it/s]
100%|██████████| 5/5 [00:00<00:00, 13.53it/s]
100%|██████████| 3/3 [00:00<00:00, 23.82it/s]
  0%|          | 0/3 [00:00<?, ?it/s]

validation


100%|██████████| 3/3 [00:00<00:00, 16.88it/s]


In [10]:
len(seen)

30314

You can also add some additional (meta) tokens:

In [11]:
seen = seen.union(set(tokenizer.encode("mnli premise: hypothesis: <unk>")))

In [12]:
len(seen)

30316

### Load the SentencePiece Model

In [13]:
from sentencepiece import sentencepiece_model_pb2 as model

m = model.ModelProto()
m.ParseFromString(open("cache/mt5-small/spiece.model", 'rb').read())
# There are some reserved places for speical tokens
for i, piece in enumerate(m.pieces[:320]):
    if i % 20 == 0:
        print(i, piece.piece)

0 <pad>
20 <0x11>
40 <0x25>
60 <0x39>
80 <0x4D>
100 <0x61>
120 <0x75>
140 <0x89>
160 <0x9D>
180 <0xB1>
200 <0xC5>
220 <0xD9>
240 <0xED>
260 .
280 l
300 ▁v


In [14]:
m.pieces[258].piece, m.pieces[259].piece

('<0xFF>', '▁')

In [15]:
len(m.pieces)

250100

## Shrink the SentencePiece Model

In [16]:
kept_pieces, i = [], len(m.pieces) - 1
while len(m.pieces):
    piece = m.pieces.pop()
    if i < 259 or i in seen:
        kept_pieces.append(piece)
    i -= 1
kept_pieces = list(reversed(kept_pieces))
len(kept_pieces)

30513

In [17]:
m.pieces.extend(kept_pieces)
len(m.pieces)

30513

Backup the old model and save the new model:

In [18]:
Path("cache/mt5-small/spiece.model").rename("cache/mt5-small/spiece.model.old")
with open("cache/mt5-small/spiece.model", 'wb') as f:
    f.write(m.SerializeToString())

We'll also want to save the list of ids that are kept to trim the embedding matrix later:

In [19]:
import json

kept_ids = sorted(list(seen.union(set(range(259)))))
print(len(kept_ids))
with open("cache/mt5-small/kept_ids.json", 'w') as f:
    json.dump(kept_ids, f)

30513


### Test

First test the dumped `kept_ids`

In [20]:
with open("cache/mt5-small/kept_ids.json") as f:
    tmp = json.load(f)
len(tmp)

30513

In [21]:
tmp[:5], tmp[-5:]

([0, 1, 2, 3, 4], [249716, 249738, 249740, 249753, 249834])

In [22]:
tokenizer = MT5Tokenizer.from_pretrained("cache/mt5-small")

Try one example:

In [23]:
tokenizer.decode(
    tokenizer.encode(dataset["train"]["hypothesis"][0]), skip_special_tokens=True
)

'产品 和 地理 是 什么 使 奶油 抹 霜 工作.'

Try a few more, just to be sure:

In [24]:
import random
for i in random.sample(range(100), k=10):
    converted = tokenizer.decode(
        tokenizer.encode(dataset["train"]["hypothesis"][i]), skip_special_tokens=True
    ).replace(" ", "") # the space placements are slightly different from the original
    assert converted == dataset["train"]["hypothesis"][i].replace(" ", ""), f'{converted}\n{dataset["train"]["hypothesis"][i]}'