In [1]:
import numpy as np
from scipy.special import softmax
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_all_providers
from tokenization_enc_dec import EncDecTokenizer

In [2]:
def create_model_for_provider(model_path: str, provider: str= 'CPUExecutionProvider') -> InferenceSession:
    assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}"
    # Few properties that might have an impact on performances (provided by MS)
    options = SessionOptions()
    options.intra_op_num_threads = 100
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
    # Load the model as a graph and prepare the CPU backend
    session = InferenceSession(model_path, options, providers=[provider])
    session.disable_fallback()
    return session

In [29]:
encoder = create_model_for_provider('./onnx_eva_q/encoder.onnx')
decoder = create_model_for_provider('./onnx_eva_q/decoder.onnx')
lm = create_model_for_provider('./onnx_eva_q/lm.onnx')
tokenizer = EncDecTokenizer('./EVA/src/bpe_dialog_new/vocab.txt')

In [49]:
default_ban = [
    5641,  # 京东
    4087,  # 客服
    2184,  # #
    175,   # [
    12539, # 小妹
    724,   # 客户
    1468,  # 商品
    6111,  # vi
    6454,  # 订单
    3748,  # 商家
    1548,  # 咨询
    6391,  # 发票
    681,   # 单
    5942,  # 上门
    4129,  # 售后
    6756,  # 卖家
]

In [50]:
def talk(
    s=['你好'],
    num_returns=1,
    top_k=50,
    top_p=1.0,
    temperature=1.0,
    max_len=64,
    ban=default_ban
):
    input_ids = []
    for ss in s:
        input_ids += tokenizer.encode(ss) + [tokenizer.sep_id]
    input_ids += [tokenizer.get_sentinel_id(0)]
    input_ids = np.array([input_ids])
    mask = np.ones_like(input_ids)
    encoder_last_hidden_state = encoder.run(['last_hidden_state'], {
        "input_ids": input_ids, "attention_mask": mask
    })[0]
    encoder_last_hidden_state = np.repeat(encoder_last_hidden_state, num_returns, axis=0)
    mask = np.repeat(mask, num_returns, axis=0)
    decoder_input_ids = np.repeat(np.array([[tokenizer.get_sentinel_id(0)]]), num_returns, axis=0)
    decoder_mask = np.ones_like(decoder_input_ids)
    choice_inds = list(range(top_k))
    all_probs = None
    finished = []
    for i in range(max_len):
        decoder_last_hidden_state = decoder.run(['last_hidden_state'], {
            "decoder_input_ids": decoder_input_ids,
            "decoder_attention_mask": decoder_mask,
            'encoder_hidden_states': encoder_last_hidden_state,
            'encoder_attention_mask': mask,
        })[0]
        logits = lm.run(['logits'], { 'decoder_hidden_states': decoder_last_hidden_state, })[0]

        for x in ban:
            logits[:, -1, x] = -9999

        logits = logits / temperature
        scores = softmax(logits[:, -1, :])
        next_probs = np.sort(scores)[:, ::-1]
        if top_p > 0.0 and top_p < 1.0:
            next_probs = next_probs[:, :int(next_probs.shape[1] * (1 - top_p))]
        if top_k > 0 and top_k < next_probs.shape[1]:
            next_probs = next_probs[:, :top_k]
        next_probs_1 = next_probs / next_probs.sum(axis=1).reshape((-1, 1))

        next_tokens = np.argsort(scores)[:, ::-1]
        if top_p > 0.0 and top_p < 1.0:
            next_tokens = next_tokens[:, :int(next_tokens.shape[1] * (1 - top_p))]
        if top_k > 0 and top_k < next_probs.shape[1]:
            next_tokens = next_tokens[:, :top_k]

        inds = np.random.choice(
            choice_inds,
            num_returns,
            p=next_probs_1[0],
            replace=True
        )
        next_tokens = np.array([
            x[i]
            for x, i in zip(next_tokens, inds)
        ])
        next_probs = np.array([
            x[i]
            for x, i in zip(next_probs, inds)
        ])
        if all_probs is None:
            all_probs = np.log(next_probs).reshape(-1, 1)
        else:
            all_probs = np.concatenate([
                all_probs,
                np.log(next_probs).reshape(-1, 1)
            ], axis=1)
        decoder_input_ids = np.concatenate([
            decoder_input_ids,
            np.array(next_tokens).reshape((-1, 1))
        ], axis=1)
        for i in reversed(range(decoder_input_ids.shape[0])):
            if tokenizer.sep_id in decoder_input_ids[i]:
                finished.append((decoder_input_ids[i], all_probs[i]))
                decoder_input_ids = np.concatenate([
                    decoder_input_ids[:i],
                    decoder_input_ids[i + 1:]
                ])
                all_probs = np.concatenate([
                    all_probs[:i],
                    all_probs[i + 1:]
                ])
                encoder_last_hidden_state = np.concatenate([
                    encoder_last_hidden_state[:i],
                    encoder_last_hidden_state[i + 1:]
                ])
                mask = np.concatenate([
                    mask[:i],
                    mask[i + 1:]
                ])
        if len(decoder_input_ids) == 0:
            break
        decoder_mask = np.ones_like(decoder_input_ids)
    rets = [
        tokenizer.decode(x[0])
        for x in finished
    ]
    scores = [
        np.mean(x[1])
        for x in finished
    ]
    return rets, scores

In [51]:
%%time
finished = talk(max_len=64, num_returns=10)

CPU times: user 6min 35s, sys: 14.7 s, total: 6min 49s
Wall time: 6.1 s


In [52]:
finished[0]

['<s_0>请问您是要修改什么信息呢<sep>',
 '<s_0>请问您是要修改什么信息呢<sep>',
 '<s_0>有什么可以帮到您的吗?<sep>',
 '<s_0>您好,请问有什么可以帮您?<sep>',
 '<s_0>您好,请问有什么可以帮您?<sep>',
 '<s_0>您好,请问有什么可以帮您?<sep>',
 '<s_0>您好,请问有什么可以帮您?<sep>',
 '<s_0>请问有什么问题我可以帮您处理或解决呢?<sep>',
 '<s_0>请问有什么问题我可以帮您处理或解决呢?你好<sep>',
 '<s_0>您好,请您提供一下您的姓名和联系方式,我们会在今天下午(周日)下午(周六)为您回电,请您注意接听电话哈<sep>']

In [53]:
finished[1]

[-7.2126923,
 -7.2126923,
 -5.9243674,
 -4.9900384,
 -4.9900384,
 -4.9900384,
 -4.9900384,
 -2.449787,
 -3.0436127,
 -2.7436812]

In [56]:
finished[0][np.argmax(finished[1])]

'<s_0>请问有什么问题我可以帮您处理或解决呢?<sep>'