## convert

In [1]:
import os 

os.environ['CUDA_DEVICE'] = '0'

In [2]:
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor

In [3]:
quant_desc_input = QuantDescriptor(calib_method='histogram')
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

In [4]:
from pytorch_quantization import quant_modules
quant_modules.initialize()

In [5]:
from transformers import BertModel
import torch

if torch.cuda.is_available():
    device = "cuda:0"
else:
    raise ValueError("No GPU available")
# device = "cpu"
query_model = BertModel.from_pretrained('../mbert-retrieve-qry-base/', add_pooling_layer=False)
query_model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): QuantLinear(
              in_features=768, out_features=768, bias=True
              (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
              (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
            )
            (key): QuantLinear(
              in_features=768, out_features=768, bias=True
              (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax

In [6]:
from transformers import BertModel
import torch

# device = "cuda:1" if torch.cuda.is_available() else "cpu"
ctx_model = BertModel.from_pretrained('../mbert-retrieve-ctx-base/', add_pooling_layer=False)
ctx_model.to(device)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): QuantLinear(
              in_features=768, out_features=768, bias=True
              (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
              (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
            )
            (key): QuantLinear(
              in_features=768, out_features=768, bias=True
              (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax

## Calibrate

### utils

In [7]:
from transformers import DataCollatorWithPadding, AutoTokenizer

query_tokenizer = AutoTokenizer.from_pretrained('../mbert-retrieve-qry-base/')
ctx_tokenizer = AutoTokenizer.from_pretrained('../mbert-retrieve-ctx-base/')

def query_collate_fn(examples):
    query = [example['query'] for example in examples]
    encoded_input = query_tokenizer(
        query, 
        padding='max_length', 
        truncation=True, 
        max_length=512, 
        return_tensors='pt'
    )
    return encoded_input


def ctx_collate_fn(examples):

    concate_passage = []
    for example in examples:
        concate_passage.extend(
            [example['positive']] + example['negatives'][:9]
        )

    # concate_passage = [examples['positive']] + examples['negatives'][:9]
    encoded_input = ctx_tokenizer(
        concate_passage, 
        padding='max_length', 
        truncation=True, 
        max_length=512, 
        return_tensors='pt'
    )
    return encoded_input


In [8]:
from tqdm import tqdm

def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistic"""

    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    for i, (encode_input) in tqdm(enumerate(data_loader), total=num_batches):
        for k, v in encode_input.items():
            encode_input[k] = v.to(device)
            # print(k, v.shape)
        model(**encode_input)
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()
            
def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
#             print(F"{name:40}: {module}")
    model.to(device)


### dataset & dataloaders

In [9]:
import datasets
from datasets import concatenate_datasets 

number_samples = 250 
en = datasets.load_dataset('tiennv/mmarco-passage-vi', split=f'train[:{number_samples}]')
vi = datasets.load_dataset('tiennv/mmarco-passage-vi', split=f'train[:{number_samples}]')

dataset_calib = concatenate_datasets([en, vi])
dataset_calib

Using the latest cached version of the dataset since tiennv/mmarco-passage-vi couldn't be found on the Hugging Face Hub
W0202 08:26:12.571879 139961657057664 load.py:1444] Using the latest cached version of the dataset since tiennv/mmarco-passage-vi couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/tiennv/.cache/huggingface/datasets/tiennv___mmarco-passage-vi/default/0.0.0/5ee2171bc2bc0880d2f35c16063096ec1c4dc4da (last modified on Tue Jan 14 15:38:44 2025).
W0202 08:26:12.579268 139961657057664 cache.py:94] Found the latest cached dataset configuration 'default' at /home/tiennv/.cache/huggingface/datasets/tiennv___mmarco-passage-vi/default/0.0.0/5ee2171bc2bc0880d2f35c16063096ec1c4dc4da (last modified on Tue Jan 14 15:38:44 2025).
Using the latest cached version of the dataset since tiennv/mmarco-passage-vi couldn't be found on the Hugging Face Hub
W0202 08:26:13.274114 139961657057664 load.py:1444] Using the latest cached version

Dataset({
    features: ['query_id', 'query', 'positive_id', 'positive', 'negatives'],
    num_rows: 500
})

In [10]:
import torch

batch_size = 1
num_workers = 4

calib_query_loader = torch.utils.data.DataLoader(
    dataset_calib, 
    batch_size=batch_size,
    collate_fn=query_collate_fn,
    num_workers=num_workers, 
    pin_memory=True
)

calib_ctx_loader = torch.utils.data.DataLoader(
    dataset_calib, 
    batch_size=batch_size,
    collate_fn=ctx_collate_fn,
    num_workers=num_workers, 
    pin_memory=True
)

In [11]:
test = next(iter(calib_ctx_loader))
for k, v in test.items():
    print(k, v.shape)
    break

input_ids torch.Size([10, 512])


### run

In [12]:
# It is a bit slow since we collect histograms on CPU

calib_batches = number_samples // batch_size

with torch.no_grad():
    collect_stats(query_model, calib_query_loader, num_batches=calib_batches)
    compute_amax(query_model, method="percentile", percentile=99.99)
    # compute_amax(query_model, method="mse")
    # compute_amax(query_model, method="entropy")


100%|██████████| 250/250 [00:10<00:00, 24.32it/s]
W0202 08:26:24.218169 139961657057664 tensor_quantizer.py:174] Disable HistogramCalibrator
W0202 08:26:24.218636 139961657057664 tensor_quantizer.py:174] Disable MaxCalibrator
W0202 08:26:24.218971 139961657057664 tensor_quantizer.py:174] Disable HistogramCalibrator
W0202 08:26:24.219309 139961657057664 tensor_quantizer.py:174] Disable MaxCalibrator
W0202 08:26:24.219666 139961657057664 tensor_quantizer.py:174] Disable HistogramCalibrator
W0202 08:26:24.219982 139961657057664 tensor_quantizer.py:174] Disable MaxCalibrator
W0202 08:26:24.220240 139961657057664 tensor_quantizer.py:174] Disable HistogramCalibrator
W0202 08:26:24.220518 139961657057664 tensor_quantizer.py:174] Disable MaxCalibrator
W0202 08:26:24.220786 139961657057664 tensor_quantizer.py:174] Disable HistogramCalibrator
W0202 08:26:24.221031 139961657057664 tensor_quantizer.py:174] Disable MaxCalibrator
W0202 08:26:24.221309 139961657057664 tensor_quantizer.py:174] Disable

In [13]:
# It is a bit slow since we collect histograms on CPU

calib_batches = number_samples // batch_size

with torch.no_grad():
    collect_stats(ctx_model, calib_ctx_loader, num_batches=calib_batches)
    compute_amax(ctx_model, method="percentile", percentile=99.99)
    # compute_amax(ctx_model, method="mse")
    # compute_amax(ctx_model, method="entropy")

100%|██████████| 250/250 [00:41<00:00,  6.09it/s]
W0202 08:27:15.113869 139961657057664 tensor_quantizer.py:174] Disable HistogramCalibrator
W0202 08:27:15.114361 139961657057664 tensor_quantizer.py:174] Disable MaxCalibrator
W0202 08:27:15.114764 139961657057664 tensor_quantizer.py:174] Disable HistogramCalibrator
W0202 08:27:15.115107 139961657057664 tensor_quantizer.py:174] Disable MaxCalibrator
W0202 08:27:15.115456 139961657057664 tensor_quantizer.py:174] Disable HistogramCalibrator
W0202 08:27:15.115826 139961657057664 tensor_quantizer.py:174] Disable MaxCalibrator
W0202 08:27:15.116194 139961657057664 tensor_quantizer.py:174] Disable HistogramCalibrator
W0202 08:27:15.116595 139961657057664 tensor_quantizer.py:174] Disable MaxCalibrator
W0202 08:27:15.116959 139961657057664 tensor_quantizer.py:174] Disable HistogramCalibrator
W0202 08:27:15.117311 139961657057664 tensor_quantizer.py:174] Disable MaxCalibrator
W0202 08:27:15.117656 139961657057664 tensor_quantizer.py:174] Disable

In [14]:
import os

qry_model_path = "../outputs/onnx/mbert-retrieve-qry-onnx/qry_model_calib_percentile.pth"
ctx_model_path = "../outputs/onnx/mbert-retrieve-ctx-onnx/ctx_model_calib_percentile.pth"
if os.path.exists(qry_model_path):
    os.remove(qry_model_path)
if os.path.exists(ctx_model_path):
    os.remove(ctx_model_path)

In [15]:
torch.save(query_model.state_dict(), qry_model_path)
torch.save(ctx_model.state_dict(), ctx_model_path)

### Convert onnx

In [20]:
qry_model_path = "../outputs/onnx/mbert-retrieve-qry-onnx/qry_model_calib_percentile.pth"
ctx_model_path = "../outputs/onnx/mbert-retrieve-ctx-onnx/ctx_model_calib_percentile.pth"

In [20]:
import torch
from transformers import BertModel
import pytorch_quantization.utils
# export onnx
# load the calibrated model
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = "cpu"
query_dummy_model = BertModel.from_pretrained('../mbert-retrieve-qry-base/', add_pooling_layer=False)
ctx_dummy_model = BertModel.from_pretrained('../mbert-retrieve-ctx-base/', add_pooling_layer=False)
query_dummy_model.load_state_dict(torch.load(qry_model_path, map_location="cpu"))
ctx_dummy_model.load_state_dict(torch.load(ctx_model_path, map_location="cpu"))
query_dummy_model.to(device)
ctx_dummy_model.to(device)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): QuantLinear(
              in_features=768, out_features=768, bias=True
              (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=6.4014 calibrator=HistogramCalibrator scale=1.0 quant)
              (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=[0.1138, 0.5235](768) calibrator=MaxCalibrator scale=1.0 quant)
            )
            (key): QuantLinear(
              in_features=768, out_features=768, bias=True
              (_input_quantizer): TensorQuantizer(8bit fake per-tensor

In [19]:
# dummy_input = {
#     'input_ids': torch.randint(0, 100, (1, 512)).to(device),
#     'attention_mask': torch.randint(0, 1, (1, 512)).to(device),
#     'token_type_ids': torch.randint(0, 1, (1, 512)).to(device),
# }

# traced_model = torch.jit.trace(query_dummy_model.to(device), (
#     dummy_input['input_ids'].to(device), 
#     dummy_input['attention_mask'].to(device), 
#     dummy_input['token_type_ids'].to(device)
# ), strict=False)
# traced_model

In [21]:
dummy_input = {
    'input_ids': torch.randint(0, 100, (1, 512)).to(device),
    'attention_mask': torch.randint(0, 1, (1, 512)).to(device),
    'token_type_ids': torch.randint(0, 1, (1, 512)).to(device),
}

input_names = [ "input_ids", "attention_mask", "token_type_ids" ]
output_names = [ "last_hidden_state" ]

In [22]:
quant_nn.TensorQuantizer._enable_onnx_export = True
# enable_onnx_checker needs to be disabled. See notes below.
torch.onnx.export(
    ctx_dummy_model, 
    dummy_input, 
    "../outputs/onnx/mbert-retrieve-ctx-onnx/ctx_quant_percential_calib.onnx",
    verbose=True, 
    opset_version=17,
    input_names=input_names,
    output_names=output_names,
    do_constant_folding=True,
    dynamic_axes={
        'input_ids': {0: 'batch', 1: 'sequence'},
        'attention_mask': {0: 'batch', 1: 'sequence'},
        'token_type_ids': {0: 'batch', 1: 'sequence'},
        'last_hidden_state': {0: 'batch', 1: 'sequence'}
    },
)
quant_nn.TensorQuantizer._enable_onnx_export = False

  if min_amax < 0:
  max_bound = torch.tensor((2.0**(num_bits - 1 + int(unsigned))) - 1.0, device=amax.device)
  if min_amax <= epsilon:  # Treat amax smaller than minimum representable of fp16 0
  if min_amax <= epsilon:


In [23]:
quant_nn.TensorQuantizer._enable_onnx_export = True
torch.onnx.export(
    query_dummy_model, 
    dummy_input, 
    "../outputs/onnx/mbert-retrieve-qry-onnx/qry_quant_percential_calib.onnx",
    verbose=True, 
    opset_version=17,
    input_names=input_names,
    output_names=output_names,
    do_constant_folding=True,
    dynamic_axes={
        'input_ids': {0: 'batch', 1: 'sequence'},
        'attention_mask': {0: 'batch', 1: 'sequence'},
        'token_type_ids': {0: 'batch', 1: 'sequence'},
        'last_hidden_state': {0: 'batch', 1: 'sequence'}
    },
)
quant_nn.TensorQuantizer._enable_onnx_export = False

In [16]:
query_model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): QuantLinear(
              in_features=768, out_features=768, bias=True
              (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=6.3772 calibrator=HistogramCalibrator scale=1.0 quant)
              (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=[0.1111, 0.5229](768) calibrator=MaxCalibrator scale=1.0 quant)
            )
            (key): QuantLinear(
              in_features=768, out_features=768, bias=True
              (_input_quantizer): TensorQuantizer(8bit fake per-tensor

### convert tensorrt
use `tritonserver` image (experiment in version 24.01 or 24.08)
```bash
docker run --gpus all -it --rm -v /home/tiennv/hungnq/rtvserving/outputs/onnx:/onnx nvcr.io/nvidia/tensorrt:24.08-py3
```

`quantize calibration` for query & ctx onnx -> trt with flags fp32 & int8 & dynamic_shape 

In [None]:
trtexec \
  --onnx=/onnx/mbert-retrieve-qry-onnx/qry_quant_percential_calib.onnx \
  --builderOptimizationLevel=4 \
  --saveEngine=/onnx/mbert-retrieve-qry-onnx/model_calib_percential_fp32_int8_dynamic_shape.plan \
  --shapes=input_ids:1x512,attention_mask:1x512,token_type_ids:1x512 \
  --minShapes=input_ids:1x128,attention_mask:1x128,token_type_ids:1x128 \
  --optShapes=input_ids:1x512,attention_mask:1x512,token_type_ids:1x512 \
  --maxShapes=input_ids:1x512,attention_mask:1x512,token_type_ids:1x512 \
  --int8


In [None]:

trtexec \
  --onnx=/onnx/mbert-retrieve-ctx-onnx/ctx_quant_percential_calib.onnx \
  --builderOptimizationLevel=4 \
  --saveEngine=/onnx/mbert-retrieve-ctx-onnx/model_calib_percential_fp32_int8_dynamic_shape.plan \
  --shapes=input_ids:10x512,attention_mask:10x512,token_type_ids:10x512 \
  --minShapes=input_ids:1x128,attention_mask:1x128,token_type_ids:1x128 \
  --optShapes=input_ids:10x512,attention_mask:10x512,token_type_ids:10x512 \
  --maxShapes=input_ids:10x512,attention_mask:10x512,token_type_ids:10x512 \
  --int8

`quantize calibration` for query & ctx onnx -> trt with flags fp32 & dynamic_shape 

In [None]:
trtexec \
  --onnx=/onnx/mbert-retrieve-qry-onnx/qry_quant_percential_calib.onnx \
  --builderOptimizationLevel=4 \
  --saveEngine=/onnx/mbert-retrieve-qry-onnx/model_calib_percential_fp32_dynamic_shape.plan \
  --shapes=input_ids:1x512,attention_mask:1x512,token_type_ids:1x512 \
  --minShapes=input_ids:1x128,attention_mask:1x128,token_type_ids:1x128 \
  --optShapes=input_ids:1x512,attention_mask:1x512,token_type_ids:1x512 \
  --maxShapes=input_ids:1x512,attention_mask:1x512,token_type_ids:1x512 


In [None]:
trtexec \
  --onnx=/onnx/mbert-retrieve-ctx-onnx/ctx_quant_percential_calib.onnx \
  --builderOptimizationLevel=4 \
  --saveEngine=/onnx/mbert-retrieve-ctx-onnx/model_calib_percential_fp32_dynamic_shape.plan \
  --shapes=input_ids:10x512,attention_mask:10x512,token_type_ids:10x512 \
  --minShapes=input_ids:1x128,attention_mask:1x128,token_type_ids:1x128 \
  --optShapes=input_ids:10x512,attention_mask:10x512,token_type_ids:10x512 \
  --maxShapes=input_ids:10x512,attention_mask:10x512,token_type_ids:10x512 


## Evaluate

### utils

In [1]:
from torch import nn
import torch

cross_entropy = nn.CrossEntropyLoss(reduction='mean')

def compute_loss(scores, target):
    return cross_entropy(scores, target)

def compute_similarity(q_reps, p_reps):
    if not isinstance(q_reps, torch.Tensor):
        q_reps = torch.tensor(q_reps)
    if not isinstance(p_reps, torch.Tensor):
        p_reps = torch.tensor(p_reps)
    return torch.matmul(q_reps, p_reps.transpose(0,1))

In [2]:
from tqdm import tqdm
import torch
import time
from typing import Callable
import inspect

def eval_accuracy_trt(
    data, 
    encode_fn = Callable, 
    num_passages=65, 
    model_qry=None, 
    model_ctx=None, 
    tokenizer_query=None,
    tokenizer_ctx=None, 
    device='cpu',
):

    assert model_ctx is not None, "model_ctx is required"
    assert model_qry is not None, "model_qry is required"
    assert tokenizer_ctx is not None, "tokenizer_ctx is required"
    assert tokenizer_query is not None, "tokenizer_query is required"
    assert 'query' in data.column_names, "data must have query column"
    assert 'positive' in data.column_names, "data must have positive column"
    assert 'negatives' in data.column_names, "data must have negatives column"
    # len of arguemtn of encode_fn must be 4
    # print(inspect.getargspec(encode_fn).args)
    assert len(inspect.getargspec(encode_fn).args) == 4, "encode_fn must have 4 arguments"

    accuracy = 0

    if device != "cpu":
        model_ctx = model_ctx.to(device)
        model_qry = model_qry.to(device)

    time_query_total = 0
    time_query_run = 0
    time_passage_total = 0
    time_passage_run = 0

    for i in tqdm(range(len(data))):

        start_time = time.time()
        #! CHANGE HERE
        query_batch = [data[i]['query']]
        query, time_query = encode_fn(query_batch, model_qry, tokenizer_query, len(query_batch))
        end_time = time.time() - start_time
        time_query_total += end_time
        time_query_run += time_query

        # concate 10 passages
        concate_passage = [data[i]['positive']] + data[i]['negatives'][:num_passages-1]
        start_time = time.time()
        #! CHANGE HERE
        encoded_passages, time_ctx = encode_fn(concate_passage, model_ctx, tokenizer_ctx, len(concate_passage))
        end_time = time.time() - start_time
        time_passage_total += end_time
        time_passage_run += time_ctx

        # accuracy
        scores = compute_similarity(query, encoded_passages)
        if scores.argmax(dim=1).detach().numpy() != 0:
            continue
        accuracy += 1

    return accuracy / len(data), time_query_run/ len(data), time_passage_run/ len(data), time_query_total/ len(data), time_passage_total/ len(data)

In [3]:
# becnhmark run onnx model
import tensorrt as trt
import numpy as np
import os

import pycuda.driver as cuda
import pycuda.autoinit


from transformers import AutoTokenizer

class HostDeviceMem(object):
    def __init__(self, host_mem, device_mem):
        self.host = host_mem
        self.device = device_mem

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()

class TrtModel:
    
    def __init__(self,engine_path,max_batch_size=1,dtype=np.float32):
        
        self.engine_path = engine_path
        self.dtype = dtype
        self.logger = trt.Logger(trt.Logger.WARNING)
        self.runtime = trt.Runtime(self.logger)
        self.engine = self.load_engine(self.runtime, self.engine_path)
        self.max_batch_size = max_batch_size
        # self.inputs, self.outputs, self.bindings = self.allocate_buffers()
        self.stream = cuda.Stream()
        self.context = self.engine.create_execution_context()

                
    @staticmethod
    def load_engine(trt_runtime, engine_path):
        trt.init_libnvinfer_plugins(None, "")             
        with open(engine_path, 'rb') as f:
            engine_data = f.read()
        engine = trt_runtime.deserialize_cuda_engine(engine_data)
        return engine
    
    def allocate_buffers(self, binding_shape):
        # Allocate host and device buffers
        inputs, outputs, bindings = [], [], []
        for binding in self.engine:
            # 
            if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
                self.context.set_input_shape(binding, binding_shape)
                
            # print("binding: ", binding)
            size = trt.volume(self.context.get_tensor_shape(binding))
            # print("size: ", size)
            # print("batch_size: ", self.context.get_tensor_shape(binding))
            dtype = trt.nptype(self.engine.get_tensor_dtype(binding))
            # print("dtype: ", dtype)

            host_mem = cuda.pagelocked_empty(size, dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            bindings.append(int(device_mem))

            # if self.engine.binding_is_input(binding):
            if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
                inputs.append(HostDeviceMem(host_mem, device_mem))
            else:
                outputs.append(HostDeviceMem(host_mem, device_mem))

        return inputs, outputs, bindings
       
            
    def __call__(self, inputs_id, attention_mask, token_type_ids, batch_size=2):

        
        x = np.array(inputs_id).astype(self.dtype)
        y = np.array(attention_mask).astype(self.dtype)
        z = np.array(token_type_ids).astype(self.dtype)


        inputs, outputs, bindings = self.allocate_buffers(x.shape)
    
        # Transfer input data to the GPU.
        # print(x.shape)
        np.copyto(inputs[0].host,x.ravel())
        np.copyto(inputs[1].host,y.ravel())
        np.copyto(inputs[2].host,z.ravel())
        
        # after copy -> transfer to device, transer first will error duo to hold old value
        for inp in inputs:
            cuda.memcpy_htod_async(inp.device, inp.host, self.stream)

        # Run inference
        self.context.execute_v2(bindings=bindings)
        
        # Transfer prediction output from the GPU.
        for out in outputs:
            cuda.memcpy_dtoh_async(out.host, out.device, self.stream)
        
        # Synchronize the stream
        self.stream.synchronize()
        return [out.host.reshape(batch_size,-1) for out in outputs]

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import time

def encode_trt(texts, model, tokenizer, batch_size):
    # check if tokenize length is min 128
    encoded_input = tokenizer(
        texts, 
        padding='max_length', 
        truncation=True,
        max_length=128,
        return_tensors='np'
    )

    # encoded_input = tokenizer(
    #     texts, 
    #     padding=True, 
    #     truncation=True,
    #     return_tensors='np'
    # )

    start_time = time.time()
    embeddings = model(
        encoded_input['input_ids'],
        encoded_input['attention_mask'],
        encoded_input['token_type_ids'],
        batch_size
    )[0]
    end_time = time.time() - start_time

    # print(embeddings.reshape(batch_size, -1, 768))
    return embeddings.reshape(batch_size, -1, 768)[:, 0], end_time

### dataset eval

In [5]:
import datasets
from datasets import concatenate_datasets
en_eval = datasets.load_dataset('tiennv/mmarco-passage-vi', split='train[-500:]')
vi_eval = datasets.load_dataset('tiennv/mmarco-passage-vi', split='train[-500:]')

dataset_eval = concatenate_datasets([en_eval, vi_eval])
dataset_eval

Using the latest cached version of the dataset since tiennv/mmarco-passage-vi couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/tiennv/.cache/huggingface/datasets/tiennv___mmarco-passage-vi/default/0.0.0/5ee2171bc2bc0880d2f35c16063096ec1c4dc4da (last modified on Tue Jan 14 15:38:44 2025).
Using the latest cached version of the dataset since tiennv/mmarco-passage-vi couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/tiennv/.cache/huggingface/datasets/tiennv___mmarco-passage-vi/default/0.0.0/5ee2171bc2bc0880d2f35c16063096ec1c4dc4da (last modified on Tue Jan 14 15:38:44 2025).


Dataset({
    features: ['query_id', 'query', 'positive_id', 'positive', 'negatives'],
    num_rows: 1000
})

### eval

In [6]:
import os
os.environ['CUDA_DEVICE'] = '0'

In [7]:
!python3 -c "import tensorrt; print(tensorrt.__version__)"

10.3.0


In [8]:
import numpy as np

## from `onnx_dynamic_quantization.ipynb`
# trt_engine_qry_path = "../outputs/onnx/mbert-retrieve-qry-onnx/model_fp32_dynamic_shape.plan"
# trt_engine_ctx_path = "../outputs/onnx/mbert-retrieve-ctx-onnx/model_fp32_dynamic_shape.plan"

# trt_engine_qry_path = "../outputs/onnx/mbert-retrieve-qry-onnx/model_fp32_int8_dynamic_shape.plan"
# trt_engine_ctx_path = "../outputs/onnx/mbert-retrieve-ctx-onnx/model_fp32_int8_dynamic_shape.plan"

# ## above convert
trt_engine_qry_path = "../outputs/onnx/mbert-retrieve-qry-onnx/model_calib_percential_fp32_dynamic_shape.plan"
trt_engine_ctx_path = "../outputs/onnx/mbert-retrieve-ctx-onnx/model_calib_percential_fp32_dynamic_shape.plan"

# trt_engine_qry_path = "../outputs/onnx/mbert-retrieve-qry-onnx/model_calib_percential_fp32_int8_dynamic_shape.plan"
# trt_engine_ctx_path = "../outputs/onnx/mbert-retrieve-ctx-onnx/model_calib_percential_fp32_int8_dynamic_shape.plan"


model_query = TrtModel(trt_engine_qry_path, max_batch_size=1, dtype=np.int32)
model_ctx = TrtModel(trt_engine_ctx_path, max_batch_size=10, dtype=np.int32)
tokenizer_qry = AutoTokenizer.from_pretrained("../outputs/onnx/mbert-retrieve-qry-onnx/")
tokenizer_ctx = AutoTokenizer.from_pretrained("../outputs/onnx/mbert-retrieve-ctx-onnx/")

In [9]:
# device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
accuracy, time_query_run, time_passage_run, time_query_total, time_passage_total = eval_accuracy_trt(
    dataset_eval, 
    encode_trt,
    num_passages=10, 
    model_ctx=model_ctx,
    model_qry=model_query, 
    tokenizer_ctx=tokenizer_ctx,
    tokenizer_query=tokenizer_qry,
    device="cpu"
)
print(f"Accuracy: {accuracy}")
print(f"Time Query Run: {time_query_run}")
print(f"Time Passage Run: {time_passage_run}")
print(f"Time Query Total: {time_query_total}")
print(f"Time Passage Total: {time_passage_total}")

  assert len(inspect.getargspec(encode_fn).args) == 4, "encode_fn must have 4 arguments"
100%|██████████| 1000/1000 [00:12<00:00, 83.15it/s]

Accuracy: 0.798
Time Query Run: 0.0018478615283966065
Time Passage Run: 0.007641497373580932
Time Query Total: 0.0022516303062438967
Time Passage Total: 0.009141062021255492



