In [1]:
import os
from pathlib import Path

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [2]:
current_dir = Path().resolve()
while not current_dir.name.endswith("xlm-roberta-base-cls-depression"):
    current_dir = current_dir.parent

os.chdir(current_dir)

input_val_data = current_dir / "data/clean/val.csv"
input_model_dir = current_dir / "data/models/xlm-roberta-base-cls-depression"
output_model_dir = current_dir / "data/dist/xlm-roberta-base-cls-depression"
output_model_base_filename = output_model_dir / "model.onnx"
output_model_optimized_filename = output_model_dir / "model.opt.onnx"
output_model_quantized_filename = output_model_dir / "model.opt.quant.onnx"

os.makedirs(output_model_dir, exist_ok=True)

In [3]:
model = AutoModelForSequenceClassification.from_pretrained(input_model_dir)
model.eval()

tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")

text = "Sample text"
encoding = tokenizer(text, return_tensors="pt", padding="max_length", max_length=512, truncation=True)

symbolic_names = { 0: 'batch_size', 1: 'max_seq_len'}
torch.onnx.export(
    model,
    (encoding["input_ids"], encoding["attention_mask"]),
    output_model_base_filename,
    input_names=['input_ids', 'attention_mask'],
    output_names=['logits'],
    dynamic_axes={
        'input_ids': symbolic_names,
        'attention_mask': symbolic_names,
        'logits': symbolic_names
    },
    opset_version=16,
    do_constant_folding=True
)

In [4]:
# from onnxruntime.transformers import optimizer
# from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader
# from onnxruntime.quantization.preprocess import quant_pre_process

In [5]:
# from random import SystemRandom
# from torch.utils.data import Dataset

In [6]:
# tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")

# class CustomDataset(Dataset):
#     def __init__(self, dataframe, tokenizer):
#         encoded = tokenizer(
#             dataframe['text'].tolist(),
#             padding="max_length",
#             truncation=True,
#             max_length=512,
#             return_tensors="pt"
#         )

#         self.input_ids = encoded['input_ids']
#         self.attention_mask = encoded['attention_mask']
#         self.labels = torch.tensor(dataframe['label'].tolist())

#     def __len__(self):
#         return len(self.labels)

#     def __getitem__(self, idx):
#         return {
#             'input_ids': self.input_ids[idx],
#             'attention_mask': self.attention_mask[idx],
#             'labels': self.labels[idx]
#         }
    
#     def select(self, indices):
#         """Create a new dataset with only the selected indices"""
#         new_input_ids = self.input_ids[indices]
#         new_attention_mask = self.attention_mask[indices]
#         new_labels = self.labels[indices]
        
#         new_dataset = CustomDataset.__new__(CustomDataset)
#         new_dataset.input_ids = new_input_ids
#         new_dataset.attention_mask = new_attention_mask
#         new_dataset.labels = new_labels
#         return new_dataset
    
#     def shuffle(self, seed=None):
#         """Shuffle the dataset securely and return a new shuffled dataset"""
#         indices = list(range(len(self)))
#         SystemRandom().shuffle(indices)
#         return self.select(indices)

In [7]:
# class CustomCalibrationDataReader(CalibrationDataReader):
#     def __init__(self, data_loader):
#         """
#         Initialize with a DataLoader that provides the input data.
#         :param data_loader: A DataLoader instance that yields input data.
#         """
#         self.data_loader = data_loader
#         self.iter = iter(data_loader)

#     def get_next(self):
#         try:
#             batch = next(self.iter)
#             return {
#                 'input_ids': batch['input_ids'].numpy(),
#                 'attention_mask': batch['attention_mask'].numpy()
#             }
#         except StopIteration:
#             return None

In [8]:
# validation_df = pd.read_csv(input_val_data, encoding='utf-8', sep='|')
# validation_dataset = CustomDataset(validation_df, tokenizer)
# data_loader = DataLoader(validation_dataset, batch_size=16, shuffle=False)

In [9]:
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from optimum.onnxruntime import ORTQuantizer, ORTModelForSequenceClassification

onnx_model = ORTModelForSequenceClassification.from_pretrained(input_model_dir, export=True)

In [10]:
quantizer = ORTQuantizer.from_pretrained(onnx_model)

In [11]:
dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)

In [12]:
model_quantized_path = quantizer.quantize(
    save_dir=output_model_quantized_filename,
    quantization_config=dqconfig,
)