# Quantization with huggingface, onnx and catalyst

In [None]:
!pip install onnx onnxruntime transformers datasets

In [None]:
import os

import torch

from transformers import BertForSequenceClassification, AutoTokenizer
import datasets
from datasets import load_dataset, load_metric

from catalyst.utils import (
    quantize_model, 
    quantize_onnx_model, 
    convert_to_onnx
)
from catalyst.metrics import ICallbackLoaderMetric
from catalyst.runners import Runner

### Hyperparams

In [None]:
seed = 0xDEAD
task = "rte"
model_name = "google/bert_uncased_L-4_H-512_A-8"
batch_size = 32

## Data

some words about glue ...

In [None]:
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}



def preprocess_dataset(dataset, tokenizer, seed, task):
    def get_input_ids(examples, task):
        if task_to_keys[task][1] is None:
            return (examples[task_to_keys[task][0]],)
        return examples[task_to_keys[task][0]], examples[task_to_keys[task][1]]

    encoded_dataset = dataset.map(
        lambda examples: tokenizer(
            *get_input_ids(examples, task),
            max_length=128,
            truncation=True,
            padding="max_length",
        ),
        batched=True,
    )
    encoded_dataset = encoded_dataset.map(lambda x: {"labels": x["label"]})
    encoded_dataset.set_format(
        type="torch", columns=["input_ids", "attention_mask", "labels"]
    )

    return encoded_dataset.shuffle(seed=seed)

In [None]:
is_regression = task == "sst2"

tokenizer = AutoTokenizer.from_pretrained(model_name)

datasets = load_dataset("glue", task)
for k, v in datasets.items():
    datasets[k] = preprocess_dataset(v, tokenizer, seed, task=task)
    
loaders = {}
for key, dataset in datasets.items():
    loaders[key] = torch.utils.data.DataLoader(
        dataset, 
        batch_size=batch_size,
        shuffle=key=="train",
    )
    
if not is_regression:
    label_list = datasets["train"].features["label"].names
    num_labels = len(label_list)
else:
    num_labels = 1

## Metric

some words about new metrics...

In [None]:
class HFMetricCallback(ICallbackLoaderMetric):
    def __init__(
        self,
        metric,
        regression: bool = False,
        compute_on_call: bool = True, 
        prefix: str = None, 
        suffix: str = None,
    ):
        super().__init__(
            compute_on_call=compute_on_call,
            prefix=prefix,
            suffix=suffix,
        )
        
        self.metric = metric
        self.regression = regression
    
    def reset(self):
        self.metric.compute()
        
    def update(self, logits, labels):
        predictions = logits if self.regression else logits.argmax(-1)
        self.metric.add_batch(predictions=predictions, references=labels)
        
    def compute_key_value(self):
        return self.metric.compute()
    
    def compute(self):
        return self.metric.compute()

In [None]:
metric_fn = load_metric("glue", task, keep_in_memory=True)

catalyst_metric = HFMetricCallback(metric=metric_fn, regression=is_regression)

## Model

some words about BERT

In [None]:
model = BertForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
)

## Runner

In [None]:
class GlueRunner(Runner):
    
    def handle_batch(self, batch):
        
        outputs = model(**batch)

        loss = outputs[0]
        logits = outputs[1]
        runner.batch_output = {"loss": loss, "logits": logits}

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
runner = GlueRunner()

## Training

In [None]:
runner.train(
    model=model,
    optimizer=optimizer,
    main_metric="accuracy",
    loaders=loaders,
    num_epochs=5,
    verbose=True
)

## Quantization

### PyTorch

In [None]:
torch.save(model.state_dict(), "model.pth")

print(f"Model size: {os.path.getsize('model.pth')/2**20:.2f}")
q_model = quantize_model(model)
torch.save(q_model.state_dict(), "quantized_model.pth")
print(f"Quantized model size: {os.path.getsize('quantized_model.pth')/2**20:.2f}")

### ONNX

In [None]:
inputs = {
    "input_ids": torch.ones(1, 128, dtype=torch.long),
    "attention_mask": torch.ones(1, 128, dtype=torch.long)
}

symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
convert_to_onnx(
    model=model,
    inp_shape=(
        inputs["input_ids"],
        inputs["attention_mask"]
    ),
    file="model.onnx",
    opset_version=11,
    do_constant_folding=True,
    input_names=["input_ids", "attention_mask"],
    output_names=["output"],
    dynamic_axes={
        "input_ids": symbolic_names,
        "attention_mask": symbolic_names
    }
)

In [None]:
quantize_onnx_model("model.onnx", "quantized_model.onnx", verbose=True)

## Inference and results