# Sentence Transformer ONNX model export

This notebook demonstrates export of a sentence-transformers model to ONNX format and also quantization of the model

See https://huggingface.co/sentence-transformers/msmarco-MiniLM-L-6-v3
    

In [None]:
!pip install transformers torch onnx onnxruntime 

In [None]:
from transformers import AutoModel, AutoTokenizer, BertTokenizer, BertPreTrainedModel, BertModel
import transformers
import torch 
from pathlib import Path

We create a wrapper model so that we can compute the mean pooling over the output inside ONNX. Almost all sentence-transformer models uses mean pooling. 


In [None]:
class MeanPoolingEncoderONNX(BertPreTrainedModel):

    def __init__(self,config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.init_weights()
        
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        token_embeddings = self.bert(input_ids,attention_mask=attention_mask)[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

In [None]:
encoder = MeanPoolingEncoderONNX.from_pretrained("sentence-transformers/msmarco-MiniLM-L-6-v3")
tokenizer = BertTokenizer.from_pretrained("sentence-transformers/msmarco-MiniLM-L-6-v3")

In [None]:
encoder = encoder.eval()

In [None]:
pipeline = transformers.Pipeline(model=encoder, tokenizer=tokenizer)

In [None]:
import transformers.convert_graph_to_onnx as onnx_convert
onnx_convert.convert_pytorch(pipeline, opset=11, output=Path("sentence-encoder.onnx"), use_external_format=False)

In [None]:
onnx_convert.quantize(Path("sentence-encoder.onnx"))

Now we can test the model using ONNX Runtime 


In [None]:
import onnxruntime as rt
import numpy

In [None]:
session = rt.InferenceSession("sentence-encoder-quantized.onnx")

In [None]:
session.get_outputs()[0].shape

In [None]:
session.get_outputs()[0].name

In [None]:
inputs = tokenizer("this is a test", return_tensors="np") #Tokenization outside of ONNX model..

In [None]:
onnx_input_dict = {
    "input_ids": inputs['input_ids'],
    "attention_mask": inputs['attention_mask'],
}
embedding = session.run(input_feed=onnx_input_dict, output_names=['output_0'])[0][0]