from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor

In [None]:
quant_desc_input = QuantDescriptor(calib_method='histogram')
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

In [None]:
from pytorch_quantization import quant_modules
quant_modules.initialize()

In [None]:
from transformers import AutoModel
import torch

device = "cuda:1" if torch.cuda.is_available() else "cpu"
# device = "cpu"
query_model = AutoModel.from_pretrained('mbert-retrieve-qry-base/')
query_model.to(device)

In [None]:
from transformers import AutoModel
import torch

# device = "cuda:1" if torch.cuda.is_available() else "cpu"
ctx_model = AutoModel.from_pretrained('mbert-retrieve-ctx-base/')
ctx_model.to(device)

## Calibrate

### utils

In [None]:
from transformers import DataCollatorWithPadding, AutoTokenizer

query_tokenizer = AutoTokenizer.from_pretrained('mbert-retrieve-qry-base/')
ctx_tokenizer = AutoTokenizer.from_pretrained('mbert-retrieve-ctx-base/')

def query_collate_fn(examples):
    query = [example['query'] for example in examples]
    encoded_input = query_tokenizer(
        query, 
        padding='max_length', 
        truncation=True, 
        max_length=512, 
        return_tensors='pt'
    )
    return encoded_input


def ctx_collate_fn(examples):

    concate_passage = []
    for example in examples:
        concate_passage.extend(
            [example['positive']] + example['negatives'][:9]
        )

    # concate_passage = [examples['positive']] + examples['negatives'][:9]
    encoded_input = ctx_tokenizer(
        concate_passage, 
        padding='max_length', 
        truncation=True, 
        max_length=512, 
        return_tensors='pt'
    )
    return encoded_input


In [None]:
from tqdm import tqdm

def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistic"""

    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    for i, (encode_input) in tqdm(enumerate(data_loader), total=num_batches):
        for k, v in encode_input.items():
            encode_input[k] = v.to(device)
            # print(k, v.shape)
        model(**encode_input)
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()
            
def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
#             print(F"{name:40}: {module}")
    model.to(device)


### dataset & dataloaders

In [None]:
import datasets
from datasets import concatenate_datasets 

number_samples = 250 
en = datasets.load_dataset('tiennv/mmarco-passage-vi', split=f'train[:{number_samples}]')
vi = datasets.load_dataset('tiennv/mmarco-passage-vi', split=f'train[:{number_samples}]')

dataset_calib = concatenate_datasets([en, vi])
dataset_calib

In [None]:
import torch

batch_size = 1
num_workers = 4

calib_query_loader = torch.utils.data.DataLoader(
    dataset_calib, 
    batch_size=batch_size,
    collate_fn=query_collate_fn,
    num_workers=num_workers, 
    pin_memory=True
)

calib_ctx_loader = torch.utils.data.DataLoader(
    dataset_calib, 
    batch_size=batch_size,
    collate_fn=ctx_collate_fn,
    num_workers=num_workers, 
    pin_memory=True
)

In [None]:
test = next(iter(calib_ctx_loader))
for k, v in test.items():
    print(k, v.shape)
    break

### run

In [None]:
# It is a bit slow since we collect histograms on CPU

calib_batches = number_samples // batch_size

with torch.no_grad():
    collect_stats(query_model, calib_query_loader, num_batches=calib_batches)
    compute_amax(query_model, method="percentile", percentile=99.99)
    # compute_amax(query_model, method="mse")
    # compute_amax(query_model, method="entropy")


In [None]:
# It is a bit slow since we collect histograms on CPU

calib_batches = number_samples // batch_size

with torch.no_grad():
    collect_stats(ctx_model, calib_ctx_loader, num_batches=calib_batches)
    compute_amax(ctx_model, method="percentile", percentile=99.99)
    # compute_amax(ctx_model, method="mse")
    # compute_amax(ctx_model, method="entropy")

## Evaluate

### utils

In [None]:
from tqdm import tqdm
import torch
import time
from typing import Callable
import inspect

def eval_accuracy_trt(
    data, 
    encode_fn = Callable, 
    num_passages=65, 
    model_qry=None, 
    model_ctx=None, 
    tokenizer_query=None,
    tokenizer_ctx=None, 
    device='cpu',
):

    assert model_ctx is not None, "model_ctx is required"
    assert model_qry is not None, "model_qry is required"
    assert tokenizer_ctx is not None, "tokenizer_ctx is required"
    assert tokenizer_query is not None, "tokenizer_query is required"
    assert 'query' in data.column_names, "data must have query column"
    assert 'positive' in data.column_names, "data must have positive column"
    assert 'negatives' in data.column_names, "data must have negatives column"
    # len of arguemtn of encode_fn must be 4
    # print(inspect.getargspec(encode_fn).args)
    assert len(inspect.getargspec(encode_fn).args) == 4, "encode_fn must have 4 arguments"

    accuracy = 0

    if device != "cpu":
        model_ctx = model_ctx.to(device)
        model_qry = model_qry.to(device)

    time_query_total = 0
    time_query_run = 0
    time_passage_total = 0
    time_passage_run = 0

    for i in tqdm(range(len(data))):

        start_time = time.time()
        #! CHANGE HERE
        query_batch = [data[i]['query']]
        query, time_query = encode_fn(query_batch, model_qry, tokenizer_query, len(query_batch))
        end_time = time.time() - start_time
        time_query_total += end_time
        time_query_run += time_query

        # concate 10 passages
        concate_passage = [data[i]['positive']] + data[i]['negatives'][:num_passages-1]
        start_time = time.time()
        #! CHANGE HERE
        encoded_passages, time_ctx = encode_fn(concate_passage, model_ctx, tokenizer_ctx, len(concate_passage))
        end_time = time.time() - start_time
        time_passage_total += end_time
        time_passage_run += time_ctx

        # accuracy
        scores = compute_similarity(query, encoded_passages)
        if scores.argmax(dim=1).detach().numpy() != 0:
            continue
        accuracy += 1

    return accuracy / len(data), time_query_run/ len(data), time_passage_run/ len(data), time_query_total/ len(data), time_passage_total/ len(data)

In [None]:
# becnhmark run onnx model
import tensorrt as trt
import numpy as np
import os

import pycuda.driver as cuda
import pycuda.autoinit


from transformers import AutoTokenizer

class HostDeviceMem(object):
    def __init__(self, host_mem, device_mem):
        self.host = host_mem
        self.device = device_mem

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()

class TrtModel:
    
    def __init__(self,engine_path,max_batch_size=1,dtype=np.float32):
        
        self.engine_path = engine_path
        self.dtype = dtype
        self.logger = trt.Logger(trt.Logger.WARNING)
        self.runtime = trt.Runtime(self.logger)
        self.engine = self.load_engine(self.runtime, self.engine_path)
        self.max_batch_size = max_batch_size
        # self.inputs, self.outputs, self.bindings = self.allocate_buffers()
        self.stream = cuda.Stream()
        self.context = self.engine.create_execution_context()

                
    @staticmethod
    def load_engine(trt_runtime, engine_path):
        trt.init_libnvinfer_plugins(None, "")             
        with open(engine_path, 'rb') as f:
            engine_data = f.read()
        engine = trt_runtime.deserialize_cuda_engine(engine_data)
        return engine
    
    def allocate_buffers(self, binding_shape):
        # Allocate host and device buffers
        inputs, outputs, bindings = [], [], []
        for binding in self.engine:
            # 
            binding_idx = self.engine.get_binding_index(binding)
            # Set input shape based on image dimensions for inference
            # conditional: binding is input
            if self.engine.binding_is_input(binding):
                self.context.set_binding_shape(binding_idx, binding_shape)

            size = trt.volume(self.context.get_binding_shape(binding_idx))
            # print("size: ", size)
            # print("batch_size: ", batch_size)
            dtype = trt.nptype(self.engine.get_binding_dtype(binding))

            host_mem = cuda.pagelocked_empty(size, dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            bindings.append(int(device_mem))

            if self.engine.binding_is_input(binding):
                inputs.append(HostDeviceMem(host_mem, device_mem))
            else:
                outputs.append(HostDeviceMem(host_mem, device_mem))

        return inputs, outputs, bindings
       
            
    def __call__(self, inputs_id, attention_mask, token_type_ids, batch_size=2):

        
        x = np.array(inputs_id).astype(self.dtype)
        y = np.array(attention_mask).astype(self.dtype)
        z = np.array(token_type_ids).astype(self.dtype)


        inputs, outputs, bindings = self.allocate_buffers(x.shape)
    
        # Transfer input data to the GPU.
        # print(x.shape)
        np.copyto(inputs[0].host,x.ravel())
        np.copyto(inputs[1].host,y.ravel())
        np.copyto(inputs[2].host,z.ravel())
        
        # after copy -> transfer to device, transer first will error duo to hold old value
        for inp in inputs:
            cuda.memcpy_htod_async(inp.device, inp.host, self.stream)

        # Run inference
        self.context.execute_async_v2(bindings=bindings, stream_handle=self.stream.handle)
        
        # Transfer prediction output from the GPU.
        for out in outputs:
            cuda.memcpy_dtoh_async(out.host, out.device, self.stream)
        
        # Synchronize the stream
        self.stream.synchronize()
        return [out.host.reshape(batch_size,-1) for out in outputs]

In [None]:
import time

def encode_trt(texts, model, tokenizer, batch_size):
    # check if tokenize length is min 128
    encoded_input = tokenizer(
        texts, 
        padding='max_length', 
        truncation=True,
        max_length=128,
        return_tensors='np'
    )

    # encoded_input = tokenizer(
    #     texts, 
    #     padding=True, 
    #     truncation=True,
    #     return_tensors='np'
    # )

    start_time = time.time()
    embeddings = model(
        encoded_input['input_ids'],
        encoded_input['attention_mask'],
        encoded_input['token_type_ids'],
        batch_size
    )[0]
    end_time = time.time() - start_time

    # print(embeddings.reshape(batch_size, -1, 768))
    return embeddings.reshape(batch_size, -1, 768)[:, 0], end_time

### dataset eval

In [None]:
import datasets
from datasets import concatenate_datasets
en_eval = datasets.load_dataset('tiennv/mmarco-passage-vi', split='train[-500:]')
vi_eval = datasets.load_dataset('tiennv/mmarco-passage-vi', split='train[-500:]')

dataset_eval = concatenate_datasets([en_eval, vi_eval])
dataset_eval

### eval

In [None]:
import os
os.environ['CUDA_DEVICE'] = '0'

In [None]:
!python3 -c "import tensorrt; print(tensorrt.__version__)"

In [None]:
import numpy as np
# trt_engine_qry_path = "onnx_convert_outputs/mbert-retrieve-qry-onnx/model_fp32_dynamic_shape.engine"
# trt_engine_ctx_path = "onnx_convert_outputs/mbert-retrieve-ctx-onnx/model_fp32_dynamic_shape.engine"

# trt_engine_qry_path = "onnx_convert_outputs/mbert-retrieve-qry-onnx/model_fp32_int8_dynamic_shape.engine"
# trt_engine_ctx_path = "onnx_convert_outputs/mbert-retrieve-ctx-onnx/model_fp32_int8_dynamic_shape.engine"

# trt_engine_qry_path = "onnx_convert_outputs/mbert-retrieve-qry-onnx/model_calib_percential_fp32_dynamic_shape.engine"
# trt_engine_ctx_path = "onnx_convert_outputs/mbert-retrieve-ctx-onnx/model_calib_percential_fp32_dynamic_shape.engine"

# trt_engine_qry_path = "onnx_convert_outputs/mbert-retrieve-qry-onnx/model_calib_percential_fp32_int8_dynamic_shape.engine"
# trt_engine_ctx_path = "onnx_convert_outputs/mbert-retrieve-ctx-onnx/model_calib_percential_fp32_int8_dynamic_shape.engine"


# trtversion 10.3.0

trt_engine_qry_path = "./mbert-retrieve-qry-base_float16_tllm_checkpoint/BertModel_float16_tp1_rank0.engine"
trt_engine_ctx_path = "./mbert-retrieve-ctx-base_float16_tllm_checkpoint/BertModel_float16_tp1_rank0.engine"

# trt_engine_qry_path = "./mbert-retrieve-qry-base_float32_tllm_checkpoint/BertModel_float32_tp1_rank0.engine"
# trt_engine_ctx_path = "./mbert-retrieve-ctx-base_float32_tllm_checkpoint/BertModel_float32_tp1_rank0.engine"

model_query = TrtModel(trt_engine_qry_path, max_batch_size=1, dtype=np.int32)
model_ctx = TrtModel(trt_engine_ctx_path, max_batch_size=10, dtype=np.int32)
tokenizer_qry = AutoTokenizer.from_pretrained("onnx_convert_outputs/mbert-retrieve-qry-onnx/")
tokenizer_ctx = AutoTokenizer.from_pretrained("onnx_convert_outputs/mbert-retrieve-ctx-onnx/")

In [None]:
# device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
accuracy, time_query_run, time_passage_run, time_query_total, time_passage_total = eval_accuracy_trt(
    dataset_eval, 
    encode_trt,
    num_passages=10, 
    model_ctx=model_ctx,
    model_qry=model_query, 
    tokenizer_ctx=tokenizer_ctx,
    tokenizer_query=tokenizer_qry,
    device="cpu"
)
print(f"Accuracy: {accuracy}")
print(f"Time Query Run: {time_query_run}")
print(f"Time Passage Run: {time_passage_run}")
print(f"Time Query Total: {time_query_total}")
print(f"Time Passage Total: {time_passage_total}")