In [1]:
import numpy as np
import cv2
import onnx
import onnxruntime as ort
import numpy as np
from PIL import Image
from pathlib import Path
import vai_q_onnx

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10

import onnx
import onnxruntime as ort
from onnxruntime.quantization import CalibrationDataReader, QuantType, QuantFormat, CalibrationMethod, quantize_static

In [2]:
import tarfile
import urllib.request
from resnet_utils import get_directories

_, models_dir, data_dir, _ = get_directories()
data_download_path_python = data_dir / "cifar-10-python.tar.gz"
data_download_path_bin = data_dir / "cifar-10-binary.tar.gz"
if not data_download_path_python.exists() or not data_download_path_bin.exists():
    urllib.request.urlretrieve("https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz", data_download_path_python)
    urllib.request.urlretrieve("https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", data_download_path_bin)
    file_python = tarfile.open(data_download_path_python)
    file_python.extractall(data_dir)
    file_python.close()
    file_bin = tarfile.open(data_download_path_bin)
    file_bin.extractall(data_dir)
    file_bin.close()

In [3]:
class CIFAR10DataSet:
    def __init__(
        self,
        data_dir,
        **kwargs,
    ):
        super().__init__()
        self.train_path = data_dir
        self.vld_path = data_dir
        self.setup("fit")

    def setup(self, stage: str):
        transform = transforms.Compose(
            [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()]
        )
        self.train_dataset = CIFAR10(root=self.train_path, train=True, transform=transform, download=False)
        self.val_dataset = CIFAR10(root=self.vld_path, train=True, transform=transform, download=False)


class PytorchResNetDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        sample = self.dataset[index]
        input_data = sample[0]
        label = sample[1]
        return input_data, label


class ResnetCalibrationDataReader(CalibrationDataReader):
    def __init__(self, data_dir: str, batch_size: int = 16):
        super().__init__()
        cifar10_dataset = CIFAR10DataSet(data_dir)
        _, val_set = torch.utils.data.random_split(cifar10_dataset.val_dataset, [49000, 1000])
        self.iterator = iter(DataLoader(PytorchResNetDataset(val_set), batch_size=batch_size, drop_last=True))

    def get_next(self) -> dict:
        try:
            images, labels = next(self.iterator)
            return {"input": images.numpy()}
        except Exception:
            return None

In [4]:
input_model_dir = 'models/onnx'
quant_model_dir = 'models/quant'

calibration_dataset_path = "data/"

import os
models = os.listdir(input_model_dir)
quantized_models = os.listdir(quant_model_dir)
if not len(models) == len(quantized_models):
    for model in models:
        model_name = model.split('.')[0]
        print(model_name)

        input_model_path = os.path.join(input_model_dir, model_name + '.onnx')
        output_model_path = os.path.join(quant_model_dir, model_name + '.U8S8.onnx')

        dr = ResnetCalibrationDataReader(calibration_dataset_path, batch_size=16)

        vai_q_onnx.quantize_static(
            input_model_path,
            output_model_path,
            dr,
            quant_format=vai_q_onnx.QuantFormat.QDQ,
            calibrate_method=vai_q_onnx.PowerOfTwoMethod.MinMSE,
            activation_type=vai_q_onnx.QuantType.QUInt8,
            weight_type=vai_q_onnx.QuantType.QInt8,
            enable_dpu=True, 
            extra_options={'ActivationSymmetric': True} 
        )

        print('Calibrated and quantized model saved at:', output_model_path)

In [5]:
def unpickle(file):
    import pickle
    with open(file,'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
    return dict

datafile = r'./data/cifar-10-batches-py/test_batch'
metafile = r'./data/cifar-10-batches-py/batches.meta'

data_batch_1 = unpickle(datafile) 
metadata = unpickle(metafile)

images = data_batch_1['data']
labels = data_batch_1['labels']
images = np.reshape(images,(10000, 3, 32, 32))

dirname = 'images'
if not os.path.exists(dirname):
   os.mkdir(dirname)

In [6]:
quantized_model_path = r'./models/resnet.qdq.U8S8.onnx'
model = onnx.load(quantized_model_path)

use_aie = True

providers = ['CPUExecutionProvider']
provider_options = [{}]

if use_aie:
   providers = ['VitisAIExecutionProvider']
   cache_dir = './'
   provider_options = [{
                'config_file': 'vaip_config.json',
                'cacheDir': str(cache_dir),
                'cacheKey': 'modelcachekey'
            }]

session = ort.InferenceSession(model.SerializeToString(), providers=providers,
                               provider_options=provider_options)


: 

In [None]:

#Extract and dump first 10 images
correct = 0
img0 = images[0]
print(img0.shape)
print(img0[0][0])
for i in range(100): 
    im = images[i]
    # im = cv2.cvtColor(im,cv2.COLOR_RGB2BGR)
    # im_name = f'./images/image_{i}.png'
    # cv2.imwrite(im_name, im)

    # image_name = f'./images/image_{i}.png'
    # image = Image.open(image_name).convert('RGB')
    # Resize the image to match the input size expected by the model
    # image = im.resize((32, 32))  
    image_array = np.array(im).astype(np.float32)
    image_array = image_array/255

    # Reshape the array to match the input shape expected by the model

    # Add a batch dimension to the input image
    input_data = np.expand_dims(image_array, axis=0)

    # Run the model
    outputs = session.run(None, {'input': input_data})

    # Process the outputs
    output_array = outputs[0]
    predicted_class = np.argmax(output_array)
    predicted_label = metadata['label_names'][predicted_class]
    label = metadata['label_names'][labels[i]]
    if predicted_class == labels[i]:
        correct += 1
    # print(f'Image {i}: Actual Label {label}, Predicted Label {predicted_label}')

print(correct)