In [1]:
import os
import jieba
import numpy as np
from scipy.special import softmax
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_all_providers
from tokenization_jieba import JIEBATokenizer

In [2]:
def create_model_for_provider(model_path: str, provider: str= 'CPUExecutionProvider') -> InferenceSession:
    assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}"
    # Few properties that might have an impact on performances (provided by MS)
    options = SessionOptions()
    options.intra_op_num_threads = int(os.environ.get('NUM_THREADS', 4))
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
    # Load the model as a graph and prepare the CPU backend
    session = InferenceSession(model_path, options, providers=[provider])
    session.disable_fallback()
    return session


print('model loading...')
tokenizer = JIEBATokenizer(
    'tokenizer/vocab.vocab',
    'tokenizer/vocab.model')
pangu_kv = create_model_for_provider('./onnx_kv_q/pangu.onnx')
jieba.initialize()
kv_cache_start = np.load('kv_cache.npy')
print('model green')

model loading...


Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.528 seconds.
Prefix dict has been built successfully.


model green


In [3]:
def generate(
    text,
    max_len = 100,
    temperature = 1.0,
    top_p = 0.95,
    top_k = 50,
    eod=None,
    additional_eod=[],
    ban = []
):
    if eod is None:
        eod = [tokenizer.eod_id, tokenizer.eot_id]
    ids = tokenizer.encode(text)
    kv_cache = None

    for i in range(max_len):
        if i == 0:
            logits, kv_cache = pangu_kv.run(None, {
                "input_ids": np.array([ids], dtype=np.int64),
                'kv_cache': kv_cache_start,
            })
        else:
            logits, new_kv = pangu_kv.run(None, {
                "input_ids": np.array([[next_token]], dtype=np.int64),
                'kv_cache': kv_cache,
            })
            kv_cache = np.concatenate([kv_cache, new_kv], axis=-2)

        for x in ban:
            logits[:, -1, x] = -9999

        logits = logits / temperature
        scores = softmax(logits[:, -1, :])
        next_probs = np.sort(scores)[:, ::-1]
        if top_p > 0.0 and top_p < 1.0:
            next_probs = next_probs[:, :int(next_probs.shape[1] * (1 - top_p))]
        if top_k > 0 and top_k < next_probs.shape[1]:
            next_probs = next_probs[:, :top_k]
        next_probs_1 = next_probs / next_probs.sum(axis=1).reshape((-1, 1))

        next_tokens = np.argsort(scores)[:, ::-1]
        if top_p > 0.0 and top_p < 1.0:
            next_tokens = next_tokens[:, :int(next_tokens.shape[1] * (1 - top_p))]
        if top_k > 0 and top_k < next_tokens.shape[1]:
            next_tokens = next_tokens[:, :top_k]

        next_token = np.random.choice(next_tokens[0], p=next_probs_1[0])
        if next_token in eod or next_token in additional_eod:
            break
        ids.append(next_token)
    return tokenizer.decode([int(x) for x in ids]).replace(' ', '')

In [4]:
print(generate('西红柿炒鸡蛋的做法：\n', max_len=200))

西红柿炒鸡蛋的做法:
1.鸡蛋打散,西红柿切成小块。2.炒锅加油烧热,加入鸡蛋,翻炒。3.炒至鸡蛋全部变成小块,然后盛出备用。4.然后加入西红柿碎。5.翻炒均匀后再加盐。6.放入黑胡椒粉。7.大火炒至西红柿7成熟就可以出锅了。


In [5]:
print(generate('上联：天地在我心中\n下联：', max_len=50))

上联:天地在我心中
下联:明月我心灯


In [6]:
print(generate('1+1=2;3+5=8;2+4=', max_len=1))

1+1=2;3+5=8;2+4=16


In [7]:
print(generate('默写古诗：\n白日依山尽，黄河入海流。\n床前明月光，', max_len=5))

默写古诗:
白日依山尽,黄河入海流。
床前明月光,疑是地上霜。


In [8]:
print(generate('李大嘴：“各回各家，各找各妈！” \n佟掌柜：', 20))

李大嘴:“各回各家,各找各妈!”
佟掌柜:“没您说的那么严重。”
二宝和小翠倒是


In [9]:
print(generate('中国的首都是北京\n日本的首都是东京\n美国的首都是'))

中国的首都是北京
日本的首都是东京
美国的首都是华盛顿


In [10]:
print(generate('中国的四大发明有哪些？', 50))

中国的四大发明有哪些?
造纸术


In [11]:
print(generate('''乔布斯曾经说过：“''', 50))

乔布斯曾经说过:“当我们还没有出生的时候,我们不知道自己还能做什么,也不知道未来将从哪里开始,所以,必须去找到自己的位置,并且做自己擅长的事。”而在这个时代,


In [12]:
print(generate('''老子曾经说过：“''', 50))

老子曾经说过:“大难不死必有后福”,是的,当灾难降临到自己身上时,一定会有贵人相助,而这种贵人相助可以让自己更快进入好的命运之中,获得最好的生活。一


In [13]:
print(generate('''老子曾经说过：“''', 50))

老子曾经说过:“大道至简,知易行难。”要知道天下事复杂得多,在我们看来,要走成功之道,要从政的人多,不走成功之道的人少。古往今来无数聪明
