In [71]:
import numpy as np
from tqdm import tqdm
from scipy.special import softmax
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_all_providers
from gpt2_tokenizer import GPT2Tokenizer

In [73]:
def create_model_for_provider(model_path: str, provider: str= 'CPUExecutionProvider') -> InferenceSession:
    assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}"
    # Few properties that might have an impact on performances (provided by MS)
    options = SessionOptions()
    options.intra_op_num_threads = 4
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
    # Load the model as a graph and prepare the CPU backend
    session = InferenceSession(model_path, options, providers=[provider])
    session.disable_fallback()
    return session

In [74]:
cpm = create_model_for_provider('./onnx_q/cpm.onnx')
tokenizer = GPT2Tokenizer(
    'CPM-Generate/bpe_3w_new/vocab.json',
    'CPM-Generate/bpe_3w_new/merges.txt',
    model_file='CPM-Generate/bpe_3w_new/chinese_vocab.model')

In [94]:
# ids = tokenizer.encode('上联：天下太平\n下联：')

ids = tokenizer.encode('励志金句：\n')
max_len = 50
temperature = 1.0
top_p = 0.95
top_k = 50
ban = [
    8,  # 一个空白字符
]

for i in tqdm(range(max_len)):
    logits = cpm.run(None, {
        "input_ids": np.array([ids], dtype=np.int32)
    })[0]
    
    for x in ban:
        logits[:, -1, x] = -9999
    
    logits = logits / temperature
    scores = softmax(logits[:, -1, :])
    next_probs = np.sort(scores)[:, ::-1]
    if top_p > 0.0 and top_p < 1.0:
        next_probs = next_probs[:, :int(next_probs.shape[1] * (1 - top_p))]
    if top_k > 0 and top_k < next_probs.shape[1]:
        next_probs = next_probs[:, :top_k]
    next_probs_1 = next_probs / next_probs.sum(axis=1).reshape((-1, 1))

    next_tokens = np.argsort(scores)[:, ::-1]
    if top_p > 0.0 and top_p < 1.0:
        next_tokens = next_tokens[:, :int(next_tokens.shape[1] * (1 - top_p))]
    if top_k > 0 and top_k < next_tokens.shape[1]:
        next_tokens = next_tokens[:, :top_k]

    next_token = np.random.choice(next_tokens[0], p=next_probs_1[0])
    if tokenizer.eod_id == next_token:
        break
    ids.append(next_token)
print(tokenizer.decode(ids).replace(' ', ''))

100%|██████████| 50/50 [00:11<00:00,  4.25it/s]

励志金句:
我想到一句英语中的俗语来形容我和我的朋友的友谊(我们都是学英语的孩子所以用这个来形容对于学英语的孩子来说是再合适不过了吧hahiahia



