In [1]:
import numpy as np
import cv2
import onnx
import onnxruntime as ort
import numpy as np
from PIL import Image
from pathlib import Path
import vai_q_onnx

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10

import onnx
import onnxruntime as ort
from onnxruntime.quantization import CalibrationDataReader, QuantType, QuantFormat, CalibrationMethod, quantize_static

In [2]:
import tarfile
import urllib.request
from resnet_utils import get_directories

_, models_dir, data_dir, _ = get_directories()
if not data_dir.exists():
    data_download_path_python = data_dir / "cifar-10-python.tar.gz"
    data_download_path_bin = data_dir / "cifar-10-binary.tar.gz"
    urllib.request.urlretrieve("https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz", data_download_path_python)
    urllib.request.urlretrieve("https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", data_download_path_bin)
    file_python = tarfile.open(data_download_path_python)
    file_python.extractall(data_dir)
    file_python.close()
    file_bin = tarfile.open(data_download_path_bin)
    file_bin.extractall(data_dir)
    file_bin.close()

In [3]:
class CIFAR10DataSet:
    def __init__(
        self,
        data_dir,
        **kwargs,
    ):
        super().__init__()
        self.train_path = data_dir
        self.vld_path = data_dir
        self.setup("fit")

    def setup(self, stage: str):
        transform = transforms.Compose(
            [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()]
        )
        self.train_dataset = CIFAR10(root=self.train_path, train=True, transform=transform, download=False)
        self.val_dataset = CIFAR10(root=self.vld_path, train=True, transform=transform, download=False)


class PytorchResNetDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        sample = self.dataset[index]
        input_data = sample[0]
        label = sample[1]
        return input_data, label


class ResnetCalibrationDataReader(CalibrationDataReader):
    def __init__(self, data_dir: str, batch_size: int = 16):
        super().__init__()
        cifar10_dataset = CIFAR10DataSet(data_dir)
        _, val_set = torch.utils.data.random_split(cifar10_dataset.val_dataset, [49000, 1000])
        self.iterator = iter(DataLoader(PytorchResNetDataset(val_set), batch_size=batch_size, drop_last=True))

    def get_next(self) -> dict:
        try:
            images, labels = next(self.iterator)
            return {"input": images.numpy()}
        except Exception:
            return None

In [4]:
input_model_path = "models/onnx/ResNet10_1111_4.onnx"
output_model_path = "models/quant/ResNet10_1111_4.U8S8.onnx"
calibration_dataset_path = "data/"

dr = ResnetCalibrationDataReader(calibration_dataset_path, batch_size=16)

vai_q_onnx.quantize_static(
    input_model_path,
    output_model_path,
    dr,
    quant_format=vai_q_onnx.QuantFormat.QDQ,
    calibrate_method=vai_q_onnx.PowerOfTwoMethod.MinMSE,
    activation_type=vai_q_onnx.QuantType.QUInt8,
    weight_type=vai_q_onnx.QuantType.QInt8,
    enable_dpu=True, 
    extra_options={'ActivationSymmetric': True} 
)

print('Calibrated and quantized model saved at:', output_model_path)

INFO:vai_q_onnx.quant_utils:The input ONNX model models/onnx/ResNet10_1111_4.onnx can create InferenceSession successfully
INFO:vai_q_onnx.quant_utils:Obtained calibration data with 62 iters
INFO:vai_q_onnx.quant_utils:The input ONNX model models/onnx/ResNet10_1111_4.onnx can run inference successfully
INFO:vai_q_onnx.quantize:Removed initializers from input
INFO:vai_q_onnx.quantize:Loading model...
INFO:vai_q_onnx.quantize:enable_ipu_cnn is True, optimize the model for better hardware compatibility.


[VAI_Q_ONNX_INFO]: Time information:
2024-04-14 16:56:17.554112
[VAI_Q_ONNX_INFO]: OS and CPU information:
                                        system --- Windows
                                          node --- GEEKOM-AMD
                                       release --- 10
                                       version --- 10.0.22631
                                       machine --- AMD64
                                     processor --- AMD64 Family 25 Model 116 Stepping 1, AuthenticAMD
[VAI_Q_ONNX_INFO]: Tools version information:
                                        python --- 3.9.19
                                          onnx --- 1.16.0
                                   onnxruntime --- 1.15.1
                                    vai_q_onnx --- 1.16.0+69bc4f2
[VAI_Q_ONNX_INFO]: Quantized Configuration information:
                                   model_input --- models/onnx/ResNet10_1111_4.onnx
                                  model_output --- models/quant/ResNet1

INFO:vai_q_onnx.quantize:Start calibration...
INFO:vai_q_onnx.quantize:Start collecting data, runtime depends on your model size and the number of calibration dataset.
INFO:vai_q_onnx.calibrate:Finding optimal threshold for each tensor using PowerOfTwoMethod.MinMSE algorithm ...
INFO:vai_q_onnx.calibrate:Use all calibration data to calculate min mse
Computing range: 100%|██████████| 30/30 [00:05<00:00,  5.63tensor/s]
INFO:vai_q_onnx.qdq_quantizer:Remove QuantizeLinear & DequantizeLinear on certain operations(such as conv-relu).
INFO:vai_q_onnx.simulate_dpu:Rescale AveragePool /AveragePool with factor 1.0 to simulate DPU behavior.


Calibrated and quantized model saved at: models/quant/ResNet10_1111_4.U8S8.onnx


In [5]:
quantized_model_path = r'models/quant/ResNet10_1111_4.U8S8.onnx'
model = onnx.load(quantized_model_path)


providers = ['CPUExecutionProvider']
provider_options = [{}]

use_aie = True
if use_aie:
   providers = ['VitisAIExecutionProvider']
   cache_dir = './'
   provider_options = [{
                'config_file': 'vaip_config.json',
                'cacheDir': str(cache_dir),
                'cacheKey': 'modelcachekey'
            }]

session = ort.InferenceSession(model.SerializeToString(), providers=providers,
                               provider_options=provider_options)


def unpickle(file):
    import pickle
    with open(file,'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
    return dict

datafile = r'./data/cifar-10-batches-py/test_batch'
metafile = r'./data/cifar-10-batches-py/batches.meta'

data_batch_1 = unpickle(datafile) 
metadata = unpickle(metafile)

images = data_batch_1['data']
labels = data_batch_1['labels']
images = np.reshape(images,(10000, 3, 32, 32))

import os
dirname = 'images'
if not os.path.exists(dirname):
   os.mkdir(dirname)


#Extract and dump first 10 images 
for i in range (0,10): 
    im = images[i]
    im  = im.transpose(1,2,0)
    im = cv2.cvtColor(im,cv2.COLOR_RGB2BGR)
    im_name = f'./images/image_{i}.png'
    cv2.imwrite(im_name, im)

#Pick dumped images and predict
for i in range (0,10): 
    image_name = f'./images/image_{i}.png'
    image = Image.open(image_name).convert('RGB')
    # Resize the image to match the input size expected by the model
    image = image.resize((32, 32))  
    image_array = np.array(image).astype(np.float32)
    image_array = image_array/255

    # Reshape the array to match the input shape expected by the model
    image_array = np.transpose(image_array, (2, 0, 1))  

    # Add a batch dimension to the input image
    input_data = np.expand_dims(image_array, axis=0)


    # Run the model
    outputs = session.run(None, {'input': input_data})


    # Process the outputs
    output_array = outputs[0]
    predicted_class = np.argmax(output_array)
    predicted_label = metadata['label_names'][predicted_class]
    label = metadata['label_names'][labels[i]]
    print(f'Image {i}: Actual Label {label}, Predicted Label {predicted_label}')

Image 0: Actual Label cat, Predicted Label cat
Image 1: Actual Label ship, Predicted Label cat
Image 2: Actual Label ship, Predicted Label cat
Image 3: Actual Label airplane, Predicted Label cat
Image 4: Actual Label frog, Predicted Label cat
Image 5: Actual Label frog, Predicted Label cat
Image 6: Actual Label automobile, Predicted Label horse
Image 7: Actual Label frog, Predicted Label cat
Image 8: Actual Label cat, Predicted Label cat
Image 9: Actual Label automobile, Predicted Label cat
