## Convert onnx model
1. query_retrieve
2. context_retrieve
3. rerank

Requirements:
1. optinum
2. onnx

In [11]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained('../mbert-rerank-base')
tokenizer = AutoTokenizer.from_pretrained('../mbert-rerank-base')
print(model)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../mbert-rerank-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

In [8]:
!optimum-cli export onnx \
    --library transformers \
    --task text-classification \
    -m ../mbert-rerank-base \
    --optimize O1 ../outputs/onnx/mbert-rerank-onnx \
    --opset 17

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../mbert-rerank-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
!optimum-cli export onnx \
    --library transformers \
    --task feature-extraction \
    -m ../mbert-retrieve-qry-base \
    --optimize O1 ../outputs/onnx/mbert-retrieve-qry-onnx \
    --opset 17



In [3]:
!optimum-cli export onnx \
    --library transformers \
    --task feature-extraction \
    -m ../mbert-retrieve-ctx-base \
    --optimize O1 ../outputs/onnx/mbert-retrieve-ctx-onnx \
    --opset 17



NOTE:
* `--optimize` ranges from [01, 02, 03, 04]. Increasing value will lead to accuracy dropping.
 

## Quantize Dynamic Onnx Model
Requirements:
* onnxruntime

In [5]:
import os
from onnxruntime.quantization import quantize_dynamic, QuantType

# convert
model_fp32 = '../outputs/onnx/mbert-retrieve-qry-onnx/model.onnx'
model_quant = '../outputs/onnx/mbert-retrieve-qry-onnx/model.quant.onnx'
if os.path.exists(model_quant):
    os.remove(model_quant)
    os.remove(model_quant + '.data')
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8, use_external_data_format=True)



In [6]:
import os
from onnxruntime.quantization import quantize_dynamic, QuantType

# convert
model_fp32 = '../outputs/onnx/mbert-retrieve-ctx-onnx/model.onnx'
model_quant = '../outputs/onnx/mbert-retrieve-ctx-onnx/model.quant.onnx'
if os.path.exists(model_quant):
    os.remove(model_quant)
    os.remove(model_quant + '.data')
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8, use_external_data_format=True)



In [2]:
# NOTE: rerank
import os
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

# convert
# model_fp32 = '../outputs/onnx/mbert-rerank-base-onnx/model.onnx'
# model_quant = '../outputs/onnx/mbert-rerank-base--onnx/model.quant.onnx'

# use reference
model_fp32 = '../outputs/onnx/mbert-rerank-onnx_reference/model.onnx'
model_quant = '../outputs/onnx/mbert-rerank-onnx_reference/model.quant.onnx'


if os.path.exists(model_quant):
    os.remove(model_quant)
    os.remove(model_quant + '.data')

extra_options = {'DefaultTensorType': onnx.TensorProto.FLOAT}
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8, use_external_data_format=True, extra_options=extra_options)



## Evaluating

### Utils

In [1]:
from torch import nn
import torch

cross_entropy = nn.CrossEntropyLoss(reduction='mean')

def compute_loss(scores, target):
    return cross_entropy(scores, target)

def compute_similarity(q_reps, p_reps):
    if not isinstance(q_reps, torch.Tensor):
        q_reps = torch.tensor(q_reps)
    if not isinstance(p_reps, torch.Tensor):
        p_reps = torch.tensor(p_reps)
    return torch.matmul(q_reps, p_reps.transpose(0,1))

In [2]:
import torch
import time

# CLS Pooling - Take output from first token
def cls_pooling(model_output):
    return model_output.last_hidden_state[:,0].detach().cpu()

def onnx_predict(onnx_model, encoded_input: dict):
    # encoded_input = {key: tensor.numpy() for key, tensor in encoded_input.items()}
    # Move input to device
    start_time = time.time()
    model_output = onnx_model.run(None, encoded_input)
    end_time = time.time() - start_time

    # Perform pooling
    embeddings = model_output[0][:, 0] # cls embeddings
    return embeddings, end_time

In [3]:
from tqdm import tqdm
import torch
import time
from typing import Callable
import inspect

def eval_accuracy(
    data, 
    encode_fn = Callable, 
    num_passages=65, 
    model_ctx=None, 
    model_qry=None, 
    tokenizer_ctx=None, 
    tokenizer_query=None,
    device='cpu'
):

    assert model_ctx is not None, "model_ctx is required"
    assert model_qry is not None, "model_qry is required"
    assert tokenizer_ctx is not None, "tokenizer_ctx is required"
    assert tokenizer_query is not None, "tokenizer_query is required"
    assert 'query' in data.column_names, "data must have query column"
    assert 'positive' in data.column_names, "data must have positive column"
    assert 'negatives' in data.column_names, "data must have negatives column"
    # len of arguemtn of encode_fn must be 4
    # print(inspect.getargspec(encode_fn).args)
    assert len(inspect.getargspec(encode_fn).args) == 4, "encode_fn must have 4 arguments"

    accuracy = 0

    if device != "cpu":
        model_ctx = model_ctx.to(device)
        model_qry = model_qry.to(device)

    time_query_total = 0
    time_query_run = 0
    time_passage_total = 0
    time_passage_run = 0

    for i in tqdm(range(len(data))):

        start_time = time.time()
        query, time_query = encode_fn([data[i]['query']], model_qry, tokenizer_query, device)
        end_time = time.time() - start_time
        time_query_total += end_time
        time_query_run += time_query

        # concate 10 passages
        concate_passage = [data[i]['positive']] + data[i]['negatives'][:num_passages-1]
        start_time = time.time()
        encoded_passages, time_ctx = encode_fn(concate_passage, model_ctx, tokenizer_ctx, device)
        end_time = time.time() - start_time
        time_passage_total += end_time
        time_passage_run += time_ctx

        # accuracy
        scores = compute_similarity(query, encoded_passages)
        if scores.argmax(dim=1).detach().numpy() != 0:
            continue
        accuracy += 1

    return accuracy / len(data), time_query_run/ len(data), time_passage_run/ len(data), time_query_total/ len(data), time_passage_total/ len(data)

In [4]:
import time
import numpy as np
# Encode text
def encode_onnx(texts, model, tokenizer, device='cpu'):
    # Tokenize sentences
    # encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='np')
    # for key in encoded_input:
    #     assert isinstance(encoded_input[key], np.ndarray), f"{key} is not numpy array"
    #     print(encoded_input[key].dtype
    # print(encoded_input)

    encoded_input2 = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
    encoded_input2 = {key: tensor.numpy() for key, tensor in encoded_input2.items()}
    # print(encoded_input2)

    # for key in encoded_input2:
    #     # assert isinstance(encoded_input2[key], np.ndarray), f"{key} is not numpy array"
    #     print(encoded_input2[key].dtype)
    
    # Move input to device

    #
    embeddings, end_time = onnx_predict(model, encoded_input2)
    #
    return embeddings, end_time

### prepare datasets

In [5]:
import datasets
from datasets import concatenate_datasets
en_eval = datasets.load_dataset('tiennv/mmarco-passage-vi', split='train[-500:]')
vi_eval = datasets.load_dataset('tiennv/mmarco-passage-vi', split='train[-500:]')

dataset_eval = concatenate_datasets([en_eval, vi_eval])
dataset_eval

  from .autonotebook import tqdm as notebook_tqdm
Using the latest cached version of the dataset since tiennv/mmarco-passage-vi couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/tiennv/.cache/huggingface/datasets/tiennv___mmarco-passage-vi/default/0.0.0/5ee2171bc2bc0880d2f35c16063096ec1c4dc4da (last modified on Tue Jan 14 15:38:44 2025).
Using the latest cached version of the dataset since tiennv/mmarco-passage-vi couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/tiennv/.cache/huggingface/datasets/tiennv___mmarco-passage-vi/default/0.0.0/5ee2171bc2bc0880d2f35c16063096ec1c4dc4da (last modified on Tue Jan 14 15:38:44 2025).


Dataset({
    features: ['query_id', 'query', 'positive_id', 'positive', 'negatives'],
    num_rows: 1000
})

### Run


In [6]:
import os
os.environ['CUDA_DEVICE'] = '1'

In [None]:
import onnxruntime
from transformers import AutoTokenizer
import torch

tokenizer_query = AutoTokenizer.from_pretrained("../mbert-retrieve-qry-base")
tokenizer_ctx = AutoTokenizer.from_pretrained("../mbert-retrieve-ctx-base")

# # raw
# query_path = "../outputs/onnx/mbert-retrieve-qry-onnx/model.onnx"
# ctx_path = "../outputs/onnx/mbert-retrieve-ctx-onnx/model.onnx"

# quant dynamic
# query_path = "../outputs/onnx/mbert-retrieve-qry-onnx/model.quant.onnx"
# ctx_path = "../outputs/onnx/mbert-retrieve-ctx-onnx/model.quant.onnx"

# checkpoint from pytorch_quantization_calib.ipynb
query_path = "../outputs/onnx/mbert-retrieve-qry-onnx/qry_quant_percential_calib.onnx"
ctx_path = "../outputs/onnx/mbert-retrieve-ctx-onnx/ctx_quant_percential_calib.onnx"



providers = [("CUDAExecutionProvider", {"device_id": 1,
                                        "user_compute_stream": str(torch.cuda.current_stream().cuda_stream),
                                        "cudnn_conv_algo_search": "DEFAULT",
                                        })]

# providers = ["CPUExecutionProvider"]

sess_options = onnxruntime.SessionOptions()
# sess_options.log_severity_level=1
# sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel
query_session = onnxruntime.InferenceSession(query_path, sess_options, providers=providers)
ctx_session = onnxruntime.InferenceSession(ctx_path, sess_options, providers=providers)
query_session, ctx_session

[0;93m2025-02-07 13:44:28.232912653 [W:onnxruntime:, transformer_memcpy.cc:74 ApplyImpl] 12 Memcpy nodes are added to the graph main_graph for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message.[m
[0;93m2025-02-07 13:44:28.239656076 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.[m
[0;93m2025-02-07 13:44:28.239670515 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.[m
[0;93m2025-02-07 13:44:31.569509976 [W:onnxruntime:, transformer_memcpy.cc:74 ApplyImpl] 12 Memcpy nodes are added to the graph main_graph for CUDAExecutionProvi

(<onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x7f24d6007790>,
 <onnxruntime.capi.onnxruntime_inference_collection.InferenceSession at 0x7f26180bac80>)

In [8]:
# test dummy
encode_onnx([dataset_eval[0]['query']], query_session, tokenizer_query)[0].shape

(1, 768)

In [9]:
# device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
accuracy, time_query_run, time_passage_run, time_query_total, time_passage_total = eval_accuracy(
    dataset_eval, 
    encode_onnx,
    num_passages=10, 
    model_ctx=ctx_session,
    model_qry=query_session, 
    tokenizer_ctx=tokenizer_ctx, 
    tokenizer_query=tokenizer_query, 
    device='cpu'
)
print(f"Accuracy: {accuracy}")
print(f"Time Query Run: {time_query_run}")
print(f"Time Passage Run: {time_passage_run}")
print(f"Time Query Total: {time_query_total}")
print(f"Time Passage Total: {time_passage_total}")

  assert len(inspect.getargspec(encode_fn).args) == 4, "encode_fn must have 4 arguments"
100%|██████████| 1000/1000 [00:41<00:00, 24.32it/s]

Accuracy: 0.81
Time Query Run: 0.003777590751647949
Time Passage Run: 0.03482776665687561
Time Query Total: 0.004183323621749878
Time Passage Total: 0.036229599475860595





## Appendix

### convert tensorRT - eval in file pytorch_quantization_calib.ipynb

use `tensorrt` image (experiment in version 24.08 with TensorRT 10.3.0)
```bash
docker run --gpus "device=1" -it --rm -v /home/tiennv/hungnq/rtvserving/outputs/onnx:/onnx nvcr.io/nvidia/tensorrt:24.08-py3 bash
```

query & context model onnx -> trt model with fp32 & dynamic shape

In [None]:
trtexec \
  --onnx=/onnx/mbert-retrieve-qry-onnx/model.onnx \
  --builderOptimizationLevel=4 \
  --saveEngine=/onnx/mbert-retrieve-qry-onnx/model_fp32_dynamic_shape.plan \
  --minShapes=input_ids:1x1,attention_mask:1x1,token_type_ids:1x1 \
  --optShapes=input_ids:1x256,attention_mask:1x256,token_type_ids:1x256 \
  --maxShapes=input_ids:1x512,attention_mask:1x512,token_type_ids:1x512


In [None]:
trtexec \
  --onnx=/onnx/mbert-retrieve-ctx-onnx/model.onnx \
  --builderOptimizationLevel=4 \
  --saveEngine=/onnx/mbert-retrieve-ctx-onnx/model_fp32_dynamic_shape.plan \
  --minShapes=input_ids:1x1,attention_mask:1x1,token_type_ids:1x1 \
  --optShapes=input_ids:5x256,attention_mask:5x256,token_type_ids:5x256 \
  --maxShapes=input_ids:10x512,attention_mask:10x512,token_type_ids:10x512

query & context model onnx -> trt model with fp32 & int8 & dynamic shape

In [None]:
trtexec \
  --onnx=/onnx/mbert-retrieve-qry-onnx/model.onnx \
  --builderOptimizationLevel=4 \
  --saveEngine=/onnx/mbert-retrieve-qry-onnx/model_fp32_int8_dynamic_shape.plan \
  --minShapes=input_ids:1x1,attention_mask:1x1,token_type_ids:1x1 \
  --optShapes=input_ids:1x256,attention_mask:1x256,token_type_ids:1x256 \
  --maxShapes=input_ids:1x512,attention_mask:1x512,token_type_ids:1x512 \
  --int8

In [None]:
trtexec \
  --onnx=/onnx/mbert-retrieve-ctx-onnx/model.onnx \
  --builderOptimizationLevel=4 \
  --saveEngine=/onnx/mbert-retrieve-ctx-onnx/model_fp32_int8_dynamic_shape.plan \
  --minShapes=input_ids:1x1,attention_mask:1x1,token_type_ids:1x1 \
  --optShapes=input_ids:5x256,attention_mask:5x256,token_type_ids:5x256 \
  --maxShapes=input_ids:10x512,attention_mask:10x512,token_type_ids:10x512 \
  --int8

query & context `quantize dynamic` model onnx -> trt model with fp32 & int8 & dynamic shape


[STATUS] - FAIL due to not support operation

In [None]:
trtexec \
  --onnx=/onnx/mbert-retrieve-qry-onnx/model.quant.onnx \
  --builderOptimizationLevel=4 \
  --saveEngine=/onnx/mbert-retrieve-qry-onnx/model_quant_fp32_int8_dynamic_shape.engine \
  --minShapes=input_ids:1x1,attention_mask:1x1,token_type_ids:1x1 \
  --optShapes=input_ids:1x256,attention_mask:1x256,token_type_ids:1x256 \
  --maxShapes=input_ids:1x512,attention_mask:1x512,token_type_ids:1x512 \
  --int8

In [None]:
trtexec \
  --onnx=/onnx/mbert-retrieve-ctx-onnx/model.quant.onnx \
  --builderOptimizationLevel=4 \
  --saveEngine=/onnx/mbert-retrieve-ctx-onnx/model_quant_fp32_int8_dynamic_shape.engine \
  --minShapes=input_ids:1x1,attention_mask:1x1,token_type_ids:1x1 \
  --optShapes=input_ids:5x256,attention_mask:5x256,token_type_ids:5x256 \
  --maxShapes=input_ids:10x512,attention_mask:10x512,token_type_ids:10x512 \
  --int8
