In [1]:
import numpy as np
from tqdm import tqdm
from scipy.special import softmax
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_all_providers
from tokenization_jieba import JIEBATokenizer

In [2]:
def create_model_for_provider(model_path: str, provider: str= 'CPUExecutionProvider') -> InferenceSession:
    assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}"
    # Few properties that might have an impact on performances (provided by MS)
    options = SessionOptions()
    options.intra_op_num_threads = 4
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
    # Load the model as a graph and prepare the CPU backend
    session = InferenceSession(model_path, options, providers=[provider])
    session.disable_fallback()
    return session

In [3]:
tokenizer = JIEBATokenizer(
    'PanGu-Alpha-GPU/panguAlpha_pytorch/megatron/tokenizer/bpe_4w_pcl/vocab.vocab',
    'PanGu-Alpha-GPU/panguAlpha_pytorch/megatron/tokenizer/bpe_4w_pcl/vocab.model')

In [4]:
pangu = create_model_for_provider('./onnx_q/pangu.onnx')
pangu_kv = create_model_for_provider('./onnx_kv_q/pangu.onnx')

In [5]:
# %%time
# ids = tokenizer.encode('什么鬼')

# input_ids = np.array([ids], dtype=np.int64)

# logits, new_kv = pangu.run(None, {
#     "input_ids": input_ids,
# })
# kv_cache = new_kv

# next_token = logits[0, -1, :].argmax()

# outputs = [next_token]
# for i in range(10):
#     input_ids = np.array([[next_token]], dtype=np.int64)
#     logits, new_kv = pangu_kv.run(None, {
#         "input_ids": np.array([[next_token]], dtype=np.int64),
#         'kv_cache': kv_cache,
#     })
#     next_token = logits[0, -1, :].argmax()
#     kv_cache = np.concatenate([kv_cache, new_kv], axis=-2)
#     # kv_cache = kv_cache[:, :, :, :, -50:, :]
#     outputs.append(next_token)
# print(len(outputs), tokenizer.decode([int(x) for x in outputs]))

In [15]:
# ids = tokenizer.encode('上联：天下太平\n下联：')

def generate(
    text,
    max_len = 100,
    temperature = 1.0,
    top_p = 0.95,
    top_k = 50,
    eod=(tokenizer.eod_id, tokenizer.eot_id),
    ban = [
        8,  # 一个空白字符
    ]):

    ids = tokenizer.encode(text)

    kv_cache = None

    for i in range(max_len):

        if i == 0:
            logits, kv_cache = pangu.run(None, {
                "input_ids": np.array([ids], dtype=np.int64)
            })
        else:
            logits, new_kv = pangu_kv.run(None, {
                "input_ids": np.array([[next_token]], dtype=np.int64),
                'kv_cache': kv_cache,
            })
            kv_cache = np.concatenate([kv_cache, new_kv], axis=-2)

        for x in ban:
            logits[:, -1, x] = -9999

        logits = logits / temperature
        scores = softmax(logits[:, -1, :])
        next_probs = np.sort(scores)[:, ::-1]
        if top_p > 0.0 and top_p < 1.0:
            next_probs = next_probs[:, :int(next_probs.shape[1] * (1 - top_p))]
        if top_k > 0 and top_k < next_probs.shape[1]:
            next_probs = next_probs[:, :top_k]
        next_probs_1 = next_probs / next_probs.sum(axis=1).reshape((-1, 1))

        next_tokens = np.argsort(scores)[:, ::-1]
        if top_p > 0.0 and top_p < 1.0:
            next_tokens = next_tokens[:, :int(next_tokens.shape[1] * (1 - top_p))]
        if top_k > 0 and top_k < next_tokens.shape[1]:
            next_tokens = next_tokens[:, :top_k]

        next_token = np.random.choice(next_tokens[0], p=next_probs_1[0])
        if eod is not None:
            if next_token in eod:
                break
        ids.append(next_token)
    return tokenizer.decode([int(x) for x in ids]).replace(' ', '')

In [22]:
print(generate('西红柿炒鸡蛋的做法：\n', max_len=200))

西红柿炒鸡蛋的做法:
1.鸡蛋打入碗中打散,加入1/2小勺盐和3克料酒搅拌均匀备用。
2.锅中倒入少许油,6成热后放入打散的鸡蛋炒熟出锅,盛出备用。
3.锅中留底油,7成热放入鸡尾,用锅内的余温将鸡尾炸制微黄捞出。(如图)
4.锅中留少量底油,8成热倒入西红柿块,用筷子将西红柿炒熟至软烂出锅。
5.锅中留底油,9成热放入葱姜煸出香味。
6.倒入处理好的鸡尾,用锅内的余温将鸡尾炸制微黄捞出。
7.锅中留底油,10成热倒入葱姜煸出香味。
8.倒入处理好的西


In [24]:
print(generate('上联：天地在我心中\n下联：', max_len=50))

上联:天地在我心中
下联:地水风山花为证。


In [45]:
print(generate('1+1=2;3+5=8;2+4=', max_len=1))

1+1=2;3+5=8;2+4=?


In [59]:
print(generate('默写古诗：\n白日依山尽，黄河入海流。\n床前明月光，', max_len=5))

默写古诗:
白日依山尽,黄河入海流。
床前明月光,疑是地上霜。


In [60]:
print(generate('李大嘴：“各回各家，各找各妈！” \n佟掌柜：', 20))

李大嘴:“各回各家,各找各妈!”
佟掌柜:“你就是这样的人!”
李大嘴又笑着说:“我不


In [61]:
print(generate('中国的首都是北京\n日本的首都是东京\n美国的首都是'))

中国的首都是北京
日本的首都是东京
美国的首都是华盛顿


In [65]:
print(generate('中国的四大发明有哪些？', 50))

中国的四大发明有哪些?
四大发明指我国古代四大发明(造纸术、指南针、火药、印刷术)。


In [74]:
print(generate('''乔布斯曾经说过：“''', 50))

乔布斯曾经说过:“我的理想是创造一个伟大的公司,能够创造出具有影响力的产品。”这的确是在乔布斯的理想中实现的理想,而这个伟大的产品要靠怎样的团队领导才能领导,就靠乔布斯


In [76]:
print(generate('''老子曾经说过：“''', 50))

老子曾经说过:“士农工商,天下之通义也”。在这些传统的经济模式中,商人最初被看成是商人的同义词,而“商”字和“人”字在历史上都曾被用来


In [77]:
print(generate('''老子曾经说过：“''', 50))

老子曾经说过:“天下大道,其犹张弓与?高者抑之,下者举之。天下莫柔弱于水,而攻坚强者莫之能胜。”如果你想让自己的生命变得强大,那么请在你的日常
