### 构建一个文本二分类模型 并将其转换为onnx格式

In [43]:
import torch
from transformers import BertModel, BertTokenizer
from time import time
import onnxruntime
import numpy as np
import onnx

In [44]:
class BertTextClassification(torch.nn.Module):
    def __init__(self, model_path):
        super(BertTextClassification, self).__init__()
        self.bert = BertModel.from_pretrained(model_path)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 2)

    def forward(self, input_ids, token_type_ids, attention_mask):
        cls = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).pooler_output
        logits = self.classifier(cls)
        return logits
tokenzier = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertTextClassification('bert-base-chinese')

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [45]:
text = '我是中国人'
torch_inputs = tokenzier(text, return_tensors='pt')
dynamic_axes = {
            'input_ids': {0: 'batch', 1: 'seq'},
            'attention_mask': {0: 'batch', 1: 'seq'},
            'token_type_ids': {0: 'batch', 1: 'seq'},
        }
model.eval()
with torch.no_grad():
    torch.onnx.export(
        model=model,
        args=tuple(torch_inputs.values()), 
        f='./model.onnx', 
        input_names=list(torch_inputs.keys()),
        dynamic_axes=dynamic_axes, 
        opset_version=11,
        output_names=['logits'],
        export_params=True)


In [47]:
ort_session = onnxruntime.InferenceSession('./model.onnx')

In [48]:
for i in ort_session.get_inputs():
    print(i.name)
for i in ort_session.get_outputs():
    print(i.name)


input_ids
token_type_ids
attention_mask
logits


In [49]:
model(**torch_inputs)

tensor([[-0.2583,  0.2475]], grad_fn=<AddmmBackward0>)

In [50]:
ort_inputs = {i.name: torch_inputs[i.name].detach().numpy() for i in ort_session.get_inputs()}
ort_session.run(None, ort_inputs)

[array([[-0.25828335,  0.2475062 ]], dtype=float32)]