Export PyTorch model to ONNX format for serving with ONNX Runtime Web 

In [9]:
text1 = "How is Rupee values against Dollar right now?"
text2 = "What is the per month overall cost per subs for FY 2023"

In [10]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("./saved_model/query_intent_model/best_model")
inputs = tokenizer(text1, return_tensors="pt")

# Pass your inputs to the model and return the `logits`
from transformers import AutoModelForSequenceClassification
import torch

model = AutoModelForSequenceClassification.from_pretrained("./saved_model/query_intent_model/best_model")
with torch.no_grad():
    logits = model(**inputs).logits

# Get the class with the highest probability, and use the model's `id2label` mapping to convert it to a text label
import torch.nn.functional as F

probabilities = F.softmax(logits, dim=1)
predicted_class_id = logits.argmax().item()
model.config.id2label[predicted_class_id], probabilities.max().item()

('irrelevant', 0.8557552695274353)

In [11]:
import transformers
import transformers.convert_graph_to_onnx as onnx_convert
from pathlib import Path

In [12]:
pipeline = transformers.pipeline("text-classification",model=model,tokenizer=tokenizer)

In [13]:
model = model.to("cpu")

In [15]:
onnx_convert.convert_pytorch(pipeline, opset=11, output=Path("./public/saved_onnx/classifier.onnx"), use_external_format=False)

Using framework PyTorch: 2.4.0
Found input input_ids with shape: {0: 'batch', 1: 'sequence'}
Found input token_type_ids with shape: {0: 'batch', 1: 'sequence'}
Found input attention_mask with shape: {0: 'batch', 1: 'sequence'}
Found output output_0 with shape: {0: 'batch'}
Ensuring inputs are in correct order
position_ids is not present in the generated input list.
Generated inputs order: ['input_ids', 'attention_mask', 'token_type_ids']


In [17]:
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic("./public/saved_onnx/classifier.onnx", "./public/saved_onnx/classifier_int8.onnx", 
                 weight_type=QuantType.QUInt8)



Evaluate accuracy using ONNX-Runtime inference - validate PyTorch inference versus ONNX-Runtime 

In [18]:
import onnxruntime as ort

In [20]:
session = ort.InferenceSession("./public/saved_onnx/classifier.onnx")
session_int8 = ort.InferenceSession("./public/saved_onnx/classifier_int8.onnx")

In [21]:
import numpy as np

In [23]:
from datasets import load_dataset, DatasetDict
intent_data = DatasetDict.load_from_disk('./data/intent_data')

def map_labels(example):
    if example["label"] == "irrelevant":
        example["label"] = 0
    elif example["label"] == "relevant":
        example["label"] = 1
    return example

intent_data = intent_data.map(map_labels)

def preprocess_function(examples):
    return tokenizer(examples["text"], padding=True, max_length=512, truncation=True)

tokenized_intent_data = intent_data.map(preprocess_function, batched=True)

full_train_dataset = tokenized_intent_data["train"]
full_eval_dataset = tokenized_intent_data["test"]
# reduced_eval_dataset = full_eval_dataset.select(range(500))

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

In [24]:
input_feed = {
    "input_ids": np.array(full_eval_dataset['input_ids']),
    "attention_mask": np.array(full_eval_dataset['attention_mask']),
    "token_type_ids": np.array(full_eval_dataset['token_type_ids'])
}

In [25]:
out = session.run(input_feed=input_feed,output_names=['output_0'])[0]
out_int8 = session_int8.run(input_feed=input_feed,output_names=['output_0'])[0]

In [26]:
predictions = np.argmax(out, axis=-1)
predictions_int8 = np.argmax(out_int8, axis=-1)

In [30]:
from datasets import load_metric
metric = load_metric("accuracy")

  metric = load_metric("accuracy")


Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

In [31]:
metric.compute(predictions=predictions, references=full_eval_dataset['label'])

{'accuracy': 1.0}

In [32]:
metric.compute(predictions=predictions_int8, references=full_eval_dataset['label'])

{'accuracy': 1.0}