In [None]:
!pip install onnx onnxruntime transformers

In [None]:
import os

import torch

from transformers import AutoModel, AutoTokenizer

from catalyst.utils.quantization import quantize

In [None]:
tokenizer = AutoTokenizer.from_pretrained("google/bert_uncased_L-4_H-512_A-8")  
model = AutoModel.from_pretrained("google/bert_uncased_L-4_H-512_A-8")

In [None]:
torch.save(model.state_dict(), "model.pth")

print(f"Model size: {os.path.getsize('model.pth')/2**20:.2f}")
q_model = quantize(model)
torch.save(q_model.state_dict(), "quantized_model.pth")
print(f"Quantized model size: {os.path.getsize('quantized_model.pth')/2**20:.2f}")

In [None]:
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
inputs = {
    'input_ids':      torch.ones(1,128, dtype=torch.int64),
    'attention_mask': torch.ones(1,128, dtype=torch.int64),
}

quantize(
    model,
    backend="onnx",
    onnx_params={
        "inp_shape": (
            inputs["input_ids"],
            inputs["attention_mask"]
        ),
        "file": "model.onnx",
        "opset_version": 11,
        "do_constant_folding": True,
        "input_names": ["input_ids", "attention_mask"],
        "output_names": ["output"],
        "dynamic_axes": {
            "input_ids": symbolic_names,
            "attention_mask": symbolic_names
        }
    }
)

In [None]:
print(f"Model size: {os.path.getsize('model.onnx')/2**20:.2f}")
print(f"Quantized model size: {os.path.getsize('quantized_model.onnx')/2**20:.2f}")