# Post-Training Quantization

## Import the required modules

In [1]:
import numpy as np
import onnx
import onnxruntime as ort
import os
import random
import torch

from onnxruntime.quantization import (
    CalibrationDataReader,
    CalibrationMethod,
    QuantFormat,
    QuantType,
    StaticQuantConfig,
    quantize,
)

from msc_dataset_lab3 import MSCDataset

## Set Deterministic Behaviour

In [2]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

## Create Datasets for Calibration/Test

In [3]:
CLASSES = ['stop', 'up']

calibration_ds = MSCDataset('/tmp/msc-val', torch.nn.Identity(), CLASSES)
test_ds = MSCDataset('/tmp/msc-test', torch.nn.Identity(), CLASSES)

## Define the Model Name

In [4]:
MODEL_NAME = '1753860010'

## Evaluate the Float32 ONNX Model

In [5]:
frontend_float32_file = f'./saved_models/{MODEL_NAME}_frontend.onnx'
model_float32_file = f'./saved_models/{MODEL_NAME}_model.onnx'
ort_frontend = ort.InferenceSession(frontend_float32_file)
ort_model = ort.InferenceSession(model_float32_file)

true_count = 0.0
for sample in test_ds:
    inputs = sample['x']
    label = sample['label']
    inputs = inputs.numpy()
    inputs = np.expand_dims(inputs, 0)
    features = ort_frontend.run(None, {'input': inputs})[0]
    outputs = ort_model.run(None,  {'input': features})[0]
    prediction = np.argmax(outputs, axis=-1).item()
    true_count += prediction == label

float32_accuracy = true_count / len(test_ds) * 100
frontend_size = os.path.getsize(frontend_float32_file)
model_float32_size = os.path.getsize(model_float32_file)
total_float32_size = frontend_size + model_float32_size

print(f'Float32 Accuracy: {float32_accuracy:.2f}%')
print(f'Float32 Frontend Size: {frontend_size / 2**10:.1f}KB')
print(f'Float32 Model Size: {model_float32_size / 2**10:.1f}KB')
print(f'Float32 Total Size: {total_float32_size / 2**10:.1f}KB')


## Create the Calibration Class

In [6]:
class DataReader(CalibrationDataReader):
    def __init__(self, dataset):
        self.dataset = dataset
        self.enum_data = None

        self.datasize = len(self.dataset)

    def get_next(self):
        if self.enum_data is None:
            self.enum_data = iter(self.dataset)

        x = next(self.enum_data, None)

        if x is None:
            return None

        x = x['x']
        x = x.numpy()
        x = np.expand_dims(x, 0)
        x = ort_frontend.run(None, {'input': x})[0]
        x = {'input': x}

        return x

    def rewind(self):
        self.enum_data = None


data_reader = DataReader(calibration_ds)

## Quantize the Model to INT8

In [7]:
conf = StaticQuantConfig(
    calibration_data_reader=data_reader,
    quant_format=QuantFormat.QDQ,
    calibrate_method=CalibrationMethod.MinMax ,
    activation_type=QuantType.QInt8,
    weight_type=QuantType.QInt8,
    per_channel=False,
)

model_int8_file = f'./saved_models/{MODEL_NAME}_INT8.onnx'
quantize(model_float32_file, model_int8_file, conf)

## Evaluate the INT8 Model

In [8]:
ort_model_int8 = ort.InferenceSession(model_int8_file)

true_quant_count = 0.0
for sample in test_ds:
    inputs = sample['x']
    label = sample['label']
    inputs = inputs.numpy()
    inputs = np.expand_dims(inputs, 0)
    features = ort_frontend.run(None, {'input': inputs})[0]
    outputs = ort_model_int8.run(None,  {'input': features})[0]
    prediction = np.argmax(outputs, axis=-1).item()
    true_quant_count += prediction == label

int8_accuracy = true_quant_count / len(test_ds) * 100
frontend_size = os.path.getsize(frontend_float32_file)
model_int8_size = os.path.getsize(model_int8_file)
total_int8_size = frontend_size + model_int8_size

print(f'INT8 Accuracy: {int8_accuracy:.2f}%')
print(f'Float32 Frontend Size: {frontend_size / 2**10:.1f}KB')
print(f'INT8 Model Size: {model_int8_size / 2**10:.1f}KB')
print(f'INT8 Total Size: {total_int8_size / 2**10:.1f}KB')

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=3880e510-b64c-4bb5-b488-c2122d5d9e2d' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>