# tokenization_chatglm6b.py


In [5]:
from typing import List, Optional, Union
import os

from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import logging, PaddingStrategy
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
from typing import Dict
import sentencepiece as spm
import numpy as np

In [7]:

class TextTokenizer:
    def __init__(self, model_path):
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(model_path)
        self.num_tokens = self.sp.vocab_size()

    def encode(self, text):
        return self.sp.EncodeAsIds(text)

    def decode(self, ids: List[int]):
        return self.sp.DecodeIds(ids)

    def tokenize(self, text):
        return self.sp.EncodeAsPieces(text)

    def convert_tokens_to_string(self, tokens):
        return self.sp.DecodePieces(tokens)

    def convert_tokens_to_ids(self, tokens):
        return [self.sp.PieceToId(token) for token in tokens]

    def convert_token_to_id(self, token):
        return self.sp.PieceToId(token)

    def convert_id_to_token(self, idx):
        return self.sp.IdToPiece(idx)

    def __len__(self):
        return self.num_tokens

In [8]:
class SPTokenizer:
    def __init__(
            self,
            vocab_file,
            num_image_tokens=20000,
            max_blank_length=80,
            byte_fallback=True,
    ):
        assert vocab_file is not None
        self.vocab_file = vocab_file
        self.num_image_tokens = num_image_tokens
        self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<unused_0>", "<sop>", "<eop>", "<ENC>", "<dBLOCK>"]
        self.max_blank_length = max_blank_length
        self.byte_fallback = byte_fallback
        self.text_tokenizer = TextTokenizer(vocab_file)

    def _get_text_tokenizer(self):
        return self.text_tokenizer

    @staticmethod
    def get_blank_token(length: int):
        assert length >= 2
        return f"<|blank_{length}|>"

    @staticmethod
    def get_tab_token():
        return f"<|tab|>"

    @property
    def num_text_tokens(self):
        return self.text_tokenizer.num_tokens

    @property
    def num_tokens(self):
        return self.num_image_tokens + self.num_text_tokens

    @staticmethod
    def _encode_whitespaces(text: str, max_len: int = 80):
        text = text.replace("\t", SPTokenizer.get_tab_token())
        for i in range(max_len, 1, -1):
            text = text.replace(" " * i, SPTokenizer.get_blank_token(i))
        return text

    def _preprocess(self, text: str, linebreak=True, whitespaces=True):
        if linebreak:
            text = text.replace("\n", "<n>")
        if whitespaces:
            text = self._encode_whitespaces(text, max_len=self.max_blank_length)
        return text

    def encode(
            self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
    ) -> List[int]:
        """
        @param text: Text to encode.
        @param linebreak: Whether to encode newline (\n) in text.
        @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
        @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
        @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
        """
        text = self._preprocess(text, linebreak, whitespaces)
        if not add_dummy_prefix:
            text = "<n>" + text
        tmp = self._get_text_tokenizer().encode(text)
        tokens = [x + self.num_image_tokens for x in tmp]
        return tokens if add_dummy_prefix else tokens[2:]

    def postprocess(self, text):
        text = text.replace("<n>", "\n")
        text = text.replace(SPTokenizer.get_tab_token(), "\t")
        for i in range(2, self.max_blank_length + 1):
            text = text.replace(self.get_blank_token(i), " " * i)
        return text

    def decode(self, text_ids: List[int]) -> str:
        ids = [int(_id) - self.num_image_tokens for _id in text_ids]
        ids = [_id for _id in ids if _id >= 0]
        text = self._get_text_tokenizer().decode(ids)
        text = self.postprocess(text)
        return text

    def decode_tokens(self, tokens: List[str]) -> str:
        text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
        text = self.postprocess(text)
        return text

    def tokenize(
            self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
    ) -> List[str]:
        """
        @param text: Text to encode.
        @param linebreak: Whether to encode newline (\n) in text.
        @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
        @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
        @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
        """
        text = self._preprocess(text, linebreak, whitespaces)
        if not add_dummy_prefix:
            text = "<n>" + text
        tokens = self._get_text_tokenizer().tokenize(text)
        return tokens if add_dummy_prefix else tokens[2:]

    def __getitem__(self, x: Union[int, str]):
        if isinstance(x, int):
            if x < self.num_image_tokens:
                return "<image_{}>".format(x)
            else:
                return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)
        elif isinstance(x, str):
            if x.startswith("<image_") and x.endswith(">") and x[7:-1].isdigit():
                return int(x[7:-1])
            else:
                return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens
        else:
            raise ValueError("The key should be str or int.")



In [9]:
def _encode_whitespaces(text: str, max_len: int = 80):
    text = text.replace("\t", SPTokenizer.get_tab_token())
    for i in range(max_len, 1, -1):
        text = text.replace(" " * i, SPTokenizer.get_blank_token(i))
    return text

In [21]:
print(_encode_whitespaces("An        apple", 5))
print(_encode_whitespaces("An\t        apple", 5))

An<|blank_5|><|blank_3|>apple
An<|tab|><|blank_5|><|blank_3|>apple


In [16]:
print(SPTokenizer.get_tab_token())
print(SPTokenizer.get_blank_token(3))

<|tab|>
<|blank_3|>


In [18]:
for i in range(5, 1, -1):
    print(i)

5
4
3
2


## BPE algorithm


In [41]:

import re, collections

def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

def get_vocab(text):
    vocab = collections.defaultdict(int)
    for word, freq in collections.Counter(text).items():
        vocab[' '.join(word) + ' </w>'] = freq
    return vocab

def bpe(text, num_merges):
    vocab = get_vocab(text)
    for i in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = merge_vocab(best, vocab)
    return vocab

# Test the BPE function
text = ['low', 'low', 'low', 'low', 'low', 'lower', 'lower', 'newest', 'newest', 'newest', 'newest', 'newest', 'newest', 'widest', 'widest', 'widest']
print(bpe(text, 10))



{'low</w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'widest</w>': 3}


In [42]:
print(bpe(text, 1))
print(bpe(text, 2))
print(bpe(text, 3))


{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3}
{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3}
{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}


In [27]:

def get_stats(vocab):
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i], symbols[i+1]] += freq
    return pairs
    



In [29]:
vocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e r </w>': 6, 'w i d e r </w>': 3}
pairs = get_stats(vocab)
print(pairs)
best = max(pairs, key=pairs.get)
print(best)

defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 8, ('e', 'r'): 11, ('r', '</w>'): 11, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3})
('e', 'r')


In [30]:
def merge_vocab(pair, in_vocab):
    out_vocab = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in in_vocab:
        out_word = p.sub(''.join(pair), word)
        out_vocab[out_word] = in_vocab[word]
    return out_vocab

In [31]:
vocab = merge_vocab(best, vocab)
print(vocab)

{'l o w </w>': 5, 'l o w er </w>': 2, 'n e w er </w>': 6, 'w i d er </w>': 3}


In [34]:
bigram = re.escape(' '.join(best))
print(bigram)

e\ r


# HF tokenizer

In [44]:
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("F:\\learn\\AI\\carr001\\learn_ai\\third_party\\ChatGLM-6b\\THUDM\\chatglm-6b", trust_remote_code=True)
tokenizer.tokenize("你好，我是惠成煊，请问你是谁？")

# tokenize更多的方法可以在SentencePiece中的python实现中看到
# 参考https://github.com/google/sentencepiece/blob/master/python/README.md

['▁', '你好', ',', '我是', '惠', '成', '煊', ',', '请问', '你是谁', '?']

In [69]:
vocab = tokenizer.get_vocab()
print(dict(list(vocab.items())[:2]))
print(dict(list(vocab.items())[-400:-395]))
print(tokenizer.num_special_tokens_to_add())
print(tokenizer.prepare_for_tokenization("你好，我是惠成煊，请问你是  谁？"))
print(tokenizer.prepare_for_tokenization("hello, my name is carr, how are you ?"))
print(tokenizer.gmask_token_id)
print(tokenizer.add_tokens)

{'<unk>': 0, '<s>': 1}
{'ŋ': 129944, 'ක': 129945, '작': 129946, '\x98': 129947, 'ය': 129948}
2
('你好，我是惠成煊，请问你是  谁？', {})
('hello, my name is carr, how are you ?', {})
130001
<bound method SpecialTokensMixin.add_tokens of ChatGLMTokenizer(name_or_path='F:\learn\AI\carr001\learn_ai\third_party\ChatGLM-6b\THUDM\chatglm-6b', vocab_size=130344, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<sop>', 'eos_token': '<eop>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	130000: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	130004: AddedToken("<sop>", rstrip=False

- bos_token（Begin Of Sentence token）：句子开头的标记，用来指示一个句子的开始。
- eos_token（End Of Sentence token）：句子结束的标记，用来指示一个句子的结束。
- end_token：这通常与eos_token相同，用来指示一个序列的结束，比如句子的结尾。
- gmask_token：这是一个特定于某些模型（如T5模型）的令牌，用来指示一个序列中的部分内容需要被模型生成或预测。
- mask_token：在屏蔽语言模型（如BERT）中使用的令牌，用来替换文本中的某些词，以训练模型对被屏蔽词的上下文进行理解。
- pad_token：填充令牌，用来将不同长度的文本序列填充到相同的长度，以便可以批量处理。在许多模型中，序列需要被填充到批处理中最长序列的长度。
- unk_token（Unknown token）：未知令牌，用来替换模型词汇表之外的词。当模型遇到训练期间未见过的词时，会用unk_token来表示。

# chat

In [70]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("F:\\learn\\AI\\carr001\\learn_ai\\third_party\\ChatGLM-6b\\THUDM\\chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("F:\\learn\\AI\\carr001\\learn_ai\\third_party\\ChatGLM-6b\\THUDM\\chatglm-6b", trust_remote_code=True).half().cuda()


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [71]:
model = model.eval()
response, history = model.chat(tokenizer, "你好", history=[])
response

The dtype of attention mask (torch.int64) is not bool


'你好👋！我是人工智能助手 ChatGLM-6B，很高兴见到你，欢迎问我任何问题。'

In [73]:
response, history = model.chat(tokenizer, "你喜欢做什么", history=[])
response

'作为一个人工智能助手，我没有个人喜好或情感，因为我只是由计算机程序驱动的。我的目的是尽可能准确地回答你的问题和提供帮助，所以请随时问我问题。'

In [74]:
response, history = model.chat(tokenizer, "你说句英文", history=[])
response

'Sure, I\'d be happy to say a sentence in English. How about "Hello, world!" as a starting point?'