In [1]:
import numpy as np
from scipy.special import softmax
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_all_providers
from tokenization_enc_dec import EncDecTokenizer

In [2]:
def create_model_for_provider(model_path: str, provider: str= 'CPUExecutionProvider') -> InferenceSession:
    assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}"
    # Few properties that might have an impact on performances (provided by MS)
    options = SessionOptions()
    options.intra_op_num_threads = 100
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
    # Load the model as a graph and prepare the CPU backend
    session = InferenceSession(model_path, options, providers=[provider])
    session.disable_fallback()
    return session

In [3]:
encoder = create_model_for_provider('./onnx_eva_q/encoder.onnx')
decoder = create_model_for_provider('./onnx_eva_q/decoder.onnx')
lm = create_model_for_provider('./onnx_eva_q/lm.onnx')

In [4]:
tokenizer = EncDecTokenizer('./EVA/src/bpe_dialog_new/vocab.txt')

In [99]:
def talk(s=['你好'], num_returns=10, top_k=50):
    ban = [
        5641,  # 京东
        4087,  # 客服
        2184,  # #
        175,   # [
        12539, # 小妹
        724,   # 客户
        1468,  # 商品
        6111,  # vi
        6454,  # 订单
        3748,  # 商家
        1548,  # 咨询
        6391,  # 发票
        681,   # 单
        5942,  # 上门
    ]
    input_ids = []
    for ss in s:
        input_ids += tokenizer.encode(ss) + [tokenizer.sep_id]
    input_ids += [tokenizer.get_sentinel_id(0)]
    input_ids = np.array([input_ids])
    mask = np.ones_like(input_ids)
    encoder_last_hidden_state = encoder.run(['last_hidden_state'], {
        "input_ids": input_ids, "attention_mask": mask
    })[0]
    encoder_last_hidden_state = np.repeat(encoder_last_hidden_state, num_returns, axis=0)
    mask = np.repeat(mask, num_returns, axis=0)
    decoder_input_ids = np.repeat(np.array([[tokenizer.get_sentinel_id(0)]]), num_returns, axis=0)
    decoder_mask = np.ones_like(decoder_input_ids)
    choice_inds = list(range(top_k))
    all_probs = []
    for i in range(64):
        decoder_last_hidden_state = decoder.run(['last_hidden_state'], {
            "decoder_input_ids": decoder_input_ids,
            "decoder_attention_mask": decoder_mask,
            'encoder_hidden_states': encoder_last_hidden_state,
            'encoder_attention_mask': mask,
        })[0]
        logits = lm.run(['logits'], { 'decoder_hidden_states': decoder_last_hidden_state, })[0]

        for x in ban:
            logits[:, -1, x] = -9999

        scores = softmax(logits[:, -1, :])
        next_probs = np.sort(scores)[:, ::-1][:, :top_k]
        next_probs_1 = next_probs / next_probs.sum(axis=1).reshape((-1, 1))
        next_tokens = np.argsort(scores)[:, ::-1][:, :top_k]
        inds = [
            np.random.choice(choice_inds, p=next_probs_1[0])
            for _ in range(num_returns)
        ]
        next_tokens = [
            x[i]
            for x, i in zip(next_tokens, inds)
        ]
        next_probs = [
            x[i]
            for x, i in zip(next_probs, inds)
        ]
        all_probs.append(np.log(next_probs))
        decoder_input_ids = np.concatenate([
            decoder_input_ids,
            np.array(next_tokens).reshape((num_returns, 1))
        ], axis=1)
        if np.sum(np.sum(decoder_input_ids == tokenizer.sep_id, axis=1) > 0) >= num_returns:
            break
    all_probs = np.array(all_probs).transpose()
    decoder_input_ids = decoder_input_ids[:, 1:]
    rets = []
    for i, ind in enumerate(np.argmax(decoder_input_ids == 4, axis=1)):
        decoder_input_ids[i, ind:] = 0
        rets.append(tokenizer.decode(decoder_input_ids[i, :ind]))
    final_scores = np.sum((decoder_input_ids > 0) * all_probs, axis=1)
    return rets, final_scores

In [101]:
%%time
for i in range(10):
    r = talk(['你好啊'], 1)
    print(r[0])

['您好,请问有什么可以帮助您的吗?']
['请问您是要换绑手机号还是签收信息']
['我在吃饭']
['亲爱的,请问有什么问题我可以帮您处理或解决呢?您好']
['您好,您可以看下您的地址和电话吗?']
['请问您是要问问题什么吗']
['您好,亲爱的有什么可以为您效劳的呢']
['好久没有跟大家见面,今天跟大家一起分享我这段时间的所感所想。']
['亲亲您好,有什么问题我可以帮您处理或解决呢?']
['您好,请问有什么可以帮助您?']
CPU times: user 4min 31s, sys: 73.6 ms, total: 4min 31s
Wall time: 6.71 s
