## Convert onnx model
1. query_retrieve
2. context_retrieve
3. rerank

Requirements:
1. optinum
2. onnx

In [None]:
!optimum-cli export onnx \
    --library transformers \
    --task text-classification \
    -m ./mbert-rerank-base \
    --optimize O1 ./onnx_convert_outputs/mbert-rerank-onnx \
    --opset 17

In [None]:
!optimum-cli export onnx \
    --library transformers \
    --task feature-extraction \
    -m ./mbert-retrieve-qry-base \
    --optimize O1 ./onnx_convert_outputs/mbert-retrieve-qry-onnx \
    --opset 17

In [None]:
!optimum-cli export onnx \
    --library transformers \
    --task feature-extraction \
    -m ./mbert-retrieve-ctx-base \
    --optimize O1 ./onnx_convert_outputs/mbert-retrieve-ctx-onnx \
    --opset 17

NOTE:
* `--optimize` ranges from [01, 02, 03, 04]. Increasing value will lead to accuracy dropping.
 

## Quantize Dynamic Onnx Model
Requirements:
* onnxruntime

In [None]:
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

# convert
model_fp32 = './onnx_convert_outputs/mbert-retrieve-qry-onnx/model.onnx'
model_quant = './onnx_convert_outputs/mbert-retrieve-qry-onnx/model.quant.onnx'
if os.path.exists(model_quant):
    os.remove(model_quant)
    os.remove(model_quant + '.data')
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8, use_external_data_format=True)

In [None]:
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

# convert
model_fp32 = './onnx_convert_outputs/mbert-retrieve-ctx-onnx/model.onnx'
model_quant = './onnx_convert_outputs/mbert-retrieve-ctx-onnx/model.quant.onnx'
if os.path.exists(model_quant):
    os.remove(model_quant)
    os.remove(model_quant + '.data')
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8, use_external_data_format=True)

In [None]:
NOTE: missing rerank

## Evaluating

### Utils

In [None]:
from torch import nn
import torch

cross_entropy = nn.CrossEntropyLoss(reduction='mean')

def compute_loss(scores, target):
    return cross_entropy(scores, target)

def compute_similarity(q_reps, p_reps):
    if not isinstance(q_reps, torch.Tensor):
        q_reps = torch.tensor(q_reps)
    if not isinstance(p_reps, torch.Tensor):
        p_reps = torch.tensor(p_reps)
    return torch.matmul(q_reps, p_reps.transpose(0,1))

In [None]:
import torch
import time

# CLS Pooling - Take output from first token
def cls_pooling(model_output):
    return model_output.last_hidden_state[:,0].detach().cpu()

def onnx_predict(onnx_model, encoded_input: dict):
    encoded_input = {key: tensor.numpy() for key, tensor in encoded_input.items()}
    # Move input to device
    start_time = time.time()
    model_output = onnx_model.run(None, encoded_input)
    end_time = time.time() - start_time

    # Perform pooling
    embeddings = model_output[0][:, 0] # cls embeddings
    return embeddings, end_time

In [None]:
from tqdm import tqdm
import torch
import time
from typing import Callable
import inspect

def eval_accuracy(
    data, 
    encode_fn = Callable, 
    num_passages=65, 
    model_ctx=None, 
    model_qry=None, 
    tokenizer_ctx=None, 
    tokenizer_query=None,
    device='cpu'
):

    assert model_ctx is not None, "model_ctx is required"
    assert model_qry is not None, "model_qry is required"
    assert tokenizer_ctx is not None, "tokenizer_ctx is required"
    assert tokenizer_query is not None, "tokenizer_query is required"
    assert 'query' in data.column_names, "data must have query column"
    assert 'positive' in data.column_names, "data must have positive column"
    assert 'negatives' in data.column_names, "data must have negatives column"
    # len of arguemtn of encode_fn must be 4
    # print(inspect.getargspec(encode_fn).args)
    assert len(inspect.getargspec(encode_fn).args) == 4, "encode_fn must have 4 arguments"

    accuracy = 0

    if device != "cpu":
        model_ctx = model_ctx.to(device)
        model_qry = model_qry.to(device)

    time_query_total = 0
    time_query_run = 0
    time_passage_total = 0
    time_passage_run = 0

    for i in tqdm(range(len(data))):

        start_time = time.time()
        query, time_query = encode_fn([data[i]['query']], model_qry, tokenizer_query, device)
        end_time = time.time() - start_time
        time_query_total += end_time
        time_query_run += time_query

        # concate 10 passages
        concate_passage = [data[i]['positive']] + data[i]['negatives'][:num_passages-1]
        start_time = time.time()
        encoded_passages, time_ctx = encode_fn(concate_passage, model_ctx, tokenizer_ctx, device)
        end_time = time.time() - start_time
        time_passage_total += end_time
        time_passage_run += time_ctx

        # accuracy
        scores = compute_similarity(query, encoded_passages)
        if scores.argmax(dim=1).detach().numpy() != 0:
            continue
        accuracy += 1

    return accuracy / len(data), time_query_run/ len(data), time_passage_run/ len(data), time_query_total/ len(data), time_passage_total/ len(data)

In [None]:
import time
# Encode text
def encode_onnx(texts, model, tokenizer, device='cpu'):
    # Tokenize sentences
    encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
    #
    embeddings, end_time = onnx_predict(model, encoded_input)
    #
    return embeddings, end_time

### prepare datasets

In [None]:
import datasets
from datasets import concatenate_datasets
en_eval = datasets.load_dataset('tiennv/mmarco-passage-vi', split='train[-500:]')
vi_eval = datasets.load_dataset('tiennv/mmarco-passage-vi', split='train[-500:]')

dataset_eval = concatenate_datasets([en_eval, vi_eval])
dataset_eval

### Run


In [None]:
import onnxruntime
from transformers import AutoTokenizer

tokenizer_query = AutoTokenizer.from_pretrained("mbert-retrieve-qry-base")
tokenizer_ctx = AutoTokenizer.from_pretrained("mbert-retrieve-ctx-base")

# raw
# query_path = "onnx_convert_outputs/mbert-retrieve-qry-onnx/model.onnx"
# ctx_path = "onnx_convert_outputs/mbert-retrieve-ctx-onnx/model.onnx"

# quant dynamic
# query_path = "onnx_convert_outputs/mbert-retrieve-qry-onnx/model.quant.onnx"
# ctx_path = "onnx_convert_outputs/mbert-retrieve-ctx-onnx/model.quant.onnx"

# quantize calib
query_path = "onnx_convert_outputs/mbert-retrieve-qry-onnx/qry_quant_percential_calib.onnx"
ctx_path = "onnx_convert_outputs/mbert-retrieve-ctx-onnx/ctx_quant_percential_calib.onnx"


providers = [("CUDAExecutionProvider", {"device_id": 0,
                                        "user_compute_stream": str(torch.cuda.current_stream().cuda_stream),
                                        "cudnn_conv_algo_search": "DEFAULT",
                                        })]

# providers = ["CPUExecutionProvider"]

sess_options = onnxruntime.SessionOptions()
# sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel
query_session = onnxruntime.InferenceSession(query_path, sess_options, providers=providers)
ctx_session = onnxruntime.InferenceSession(ctx_path, sess_options, providers=providers)
query_session, ctx_session

In [None]:
# test dummy
encode_onnx([dataset_eval[0]['query']], query_session, tokenizer_query)[0].shape

In [None]:
# device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
accuracy, time_query_run, time_passage_run, time_query_total, time_passage_total = eval_accuracy(
    dataset_eval, 
    encode_onnx,
    num_passages=10, 
    model_ctx=ctx_session,
    model_qry=query_session, 
    tokenizer_ctx=tokenizer_ctx, 
    tokenizer_query=tokenizer_query, 
    device='cpu'
)
print(f"Accuracy: {accuracy}")
print(f"Time Query Run: {time_query_run}")
print(f"Time Passage Run: {time_passage_run}")
print(f"Time Query Total: {time_query_total}")
print(f"Time Passage Total: {time_passage_total}")