In [1]:
from transformers import BertForSequenceClassification
from transformers import BertTokenizer
import torch

In [2]:
# https://huggingface.co/IDEA-CCNL/Erlangshen-Roberta-330M-Similarity
tokenizer=BertTokenizer.from_pretrained('IDEA-CCNL/Erlangshen-Roberta-330M-Similarity')
model=BertForSequenceClassification.from_pretrained('IDEA-CCNL/Erlangshen-Roberta-330M-Similarity')

Downloading:   0%|          | 0.00/107k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/752 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.21G [00:00<?, ?B/s]

In [5]:
model = model.eval()

In [76]:
x = torch.randint(0, 10, (1, 10), dtype=torch.int64, requires_grad=False)
torch_out = model(x)

# Export the model
torch.onnx.export(
    model,               # model being run
    x,                         # model input (or a tuple for multiple inputs)
    "roberta_sts_300m.onnx",   # where to save the model (can be a file or file-like object)
    export_params=True,        # store the trained parameter weights inside the model file
    opset_version=12,          # the ONNX version to export the model to
    do_constant_folding=True,  # whether to execute constant folding for optimization
    input_names = ['input'],   # the model's input names
    output_names = ['output'], # the model's output names
    dynamic_axes={
        'input' : {
            0 : 'batch_size',
            1 : 'sequence_length',
        }
    }
)

In [77]:
!du -sh roberta_sts_300m.onnx

1.3G	roberta_sts_300m.onnx


In [90]:
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

In [91]:
model_fp32 = 'roberta_sts_300m.onnx'
model_quant = 'roberta_sts_300m.quant.onnx'
quantized_model = quantize_dynamic(model_fp32, model_quant)

Ignore MatMul due to non constant B: /[MatMul_94]
Ignore MatMul due to non constant B: /[MatMul_99]
Ignore MatMul due to non constant B: /[MatMul_188]
Ignore MatMul due to non constant B: /[MatMul_193]
Ignore MatMul due to non constant B: /[MatMul_282]
Ignore MatMul due to non constant B: /[MatMul_287]
Ignore MatMul due to non constant B: /[MatMul_376]
Ignore MatMul due to non constant B: /[MatMul_381]
Ignore MatMul due to non constant B: /[MatMul_470]
Ignore MatMul due to non constant B: /[MatMul_475]
Ignore MatMul due to non constant B: /[MatMul_564]
Ignore MatMul due to non constant B: /[MatMul_569]
Ignore MatMul due to non constant B: /[MatMul_658]
Ignore MatMul due to non constant B: /[MatMul_663]
Ignore MatMul due to non constant B: /[MatMul_752]
Ignore MatMul due to non constant B: /[MatMul_757]
Ignore MatMul due to non constant B: /[MatMul_846]
Ignore MatMul due to non constant B: /[MatMul_851]
Ignore MatMul due to non constant B: /[MatMul_940]
Ignore MatMul due to non constant

In [92]:
!du -sh '{model_quant}'

312M	roberta_sts_300m.quant.onnx


In [89]:
# !pip install onnx

In [88]:
# !pip install https://files.pythonhosted.org/packages/1a/6b/db83264475b60809cf17647ad3a7fcb5a3b94e233eca2f403e82fa5b5861/ort_nightly-1.11.0.dev20220320001-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

In [50]:
def sim(texta, textb):
    output = model(torch.tensor([tokenizer.encode(texta,textb)]))
    logits = output.logits.detach().numpy()[0]
    softmax = torch.nn.functional.softmax(output.logits, dim=-1).detach().numpy()[0]
    label = '相似' if softmax.argmax(-1) == 1 else '不同'
    return logits, softmax, round(softmax[1], 4), label

In [44]:
texta = '民航'
textb = '客机'
sim(texta, textb)

(array([ 1.9031352, -1.5363016], dtype=float32),
 array([0.9689145 , 0.03108544], dtype=float32),
 '不同')

In [65]:
text = '群聊无法同步'
# 0.91188097	自动回复功能不生效	检查以下配置是否正确： 确认回复设置（是否需要@托管微信，同一用户/群聊相同消息间隔） 确认生效范围是否设置，如果设置默认生效范围，默认生效范围是否保存 确认触发回复限制的生效次数，是无限次还是仅一次 确认对应托管微信是否开启休息模式		2022-06-23T16:31:25	2022-06-23T16:31:25
# 0.79395235	自动踢人没有生效	自动踢人只能支持托管微信是群主的群聊		2022-06-23T16:31:21	2022-06-23T16:31:21
# 0.77434516	通过秒回发送出去的小程序异常或无法打开
a = [
    '同步不到群聊数据',
    '视频发送失败',
    '自动回复功能不生效',
    '自动踢人没有生效',
    '通过秒回发送出去的小程序异常或无法打开',
]


In [66]:
for aa in a:
    print(sim(text, aa))

(array([-1.2135514,  1.4098148], dtype=float32), array([0.06764966, 0.9323504 ], dtype=float32), 0.9324, '相似')
(array([ 4.6886926, -4.6623383], dtype=float32), array([9.9991310e-01, 8.6868306e-05], dtype=float32), 1e-04, '不同')
(array([ 4.746388, -4.763688], dtype=float32), array([9.9992585e-01, 7.4095879e-05], dtype=float32), 1e-04, '不同')
(array([ 4.75281 , -4.810149], dtype=float32), array([9.9992967e-01, 7.0279530e-05], dtype=float32), 1e-04, '不同')
(array([ 4.760275, -4.803789], dtype=float32), array([9.999298e-01, 7.020196e-05], dtype=float32), 1e-04, '不同')
