Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions quantization/image_classification/cpu/resnet50_data_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy
import onnxruntime
import os
from onnxruntime.quantization import CalibrationDataReader
from PIL import Image


def _preprocess_images(images_folder: str, height: int, width: int, size_limit=0):
"""
Loads a batch of images and preprocess them
parameter images_folder: path to folder storing images
parameter height: image height in pixels
parameter width: image width in pixels
parameter size_limit: number of images to load. Default is 0 which means all images are picked.
return: list of matrices characterizing multiple images
"""
image_names = os.listdir(images_folder)
if size_limit > 0 and len(image_names) >= size_limit:
batch_filenames = [image_names[i] for i in range(size_limit)]
else:
batch_filenames = image_names
unconcatenated_batch_data = []

for image_name in batch_filenames:
image_filepath = images_folder + "/" + image_name
pillow_img = Image.new("RGB", (width, height))
pillow_img.paste(Image.open(image_filepath).resize((width, height)))
input_data = numpy.float32(pillow_img) - numpy.array(
[123.68, 116.78, 103.94], dtype=numpy.float32
)
nhwc_data = numpy.expand_dims(input_data, axis=0)
nchw_data = nhwc_data.transpose(0, 3, 1, 2) # ONNX Runtime standard
unconcatenated_batch_data.append(nchw_data)
batch_data = numpy.concatenate(
numpy.expand_dims(unconcatenated_batch_data, axis=0), axis=0
)
return batch_data


class ResNet50DataReader(CalibrationDataReader):
def __init__(self, calibration_image_folder: str, model_path: str):
self.image_folder = calibration_image_folder
self.model_path = model_path
self.preprocess_flag = True
self.enum_data_dicts = []
self.datasize = 0

def get_next(self):
if self.preprocess_flag:
self.preprocess_flag = False
session = onnxruntime.InferenceSession(self.model_path, None)
(_, _, height, width) = session.get_inputs()[0].shape
nhwc_data_list = _preprocess_images(
self.image_folder, height, width, size_limit=0
)
input_name = session.get_inputs()[0].name
self.datasize = len(nhwc_data_list)
self.enum_data_dicts = iter(
[{input_name: nhwc_data} for nhwc_data in nhwc_data_list]
)
return next(self.enum_data_dicts, None)
107 changes: 29 additions & 78 deletions quantization/image_classification/cpu/run.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,10 @@
import os
import sys
import numpy as np
import re
import abc
import subprocess
import json
import argparse
import time
from PIL import Image

import onnx
import numpy as np
import onnxruntime
from onnx import helper, TensorProto, numpy_helper
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantFormat, QuantType


class ResNet50DataReader(CalibrationDataReader):
def __init__(self, calibration_image_folder, augmented_model_path='augmented_model.onnx'):
self.image_folder = calibration_image_folder
self.augmented_model_path = augmented_model_path
self.preprocess_flag = True
self.enum_data_dicts = []
self.datasize = 0

def get_next(self):
if self.preprocess_flag:
self.preprocess_flag = False
session = onnxruntime.InferenceSession(self.augmented_model_path, None)
(_, _, height, width) = session.get_inputs()[0].shape
nhwc_data_list = preprocess_func(self.image_folder, height, width, size_limit=0)
input_name = session.get_inputs()[0].name
self.datasize = len(nhwc_data_list)
self.enum_data_dicts = iter([{input_name: nhwc_data} for nhwc_data in nhwc_data_list])
return next(self.enum_data_dicts, None)


def preprocess_func(images_folder, height, width, size_limit=0):
'''
Loads a batch of images and preprocess them
parameter images_folder: path to folder storing images
parameter height: image height in pixels
parameter width: image width in pixels
parameter size_limit: number of images to load. Default is 0 which means all images are picked.
return: list of matrices characterizing multiple images
'''
image_names = os.listdir(images_folder)
if size_limit > 0 and len(image_names) >= size_limit:
batch_filenames = [image_names[i] for i in range(size_limit)]
else:
batch_filenames = image_names
unconcatenated_batch_data = []
import time
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static

for image_name in batch_filenames:
image_filepath = images_folder + '/' + image_name
pillow_img = Image.new("RGB", (width, height))
pillow_img.paste(Image.open(image_filepath).resize((width, height)))
input_data = np.float32(pillow_img) - \
np.array([123.68, 116.78, 103.94], dtype=np.float32)
nhwc_data = np.expand_dims(input_data, axis=0)
nchw_data = nhwc_data.transpose(0, 3, 1, 2) # ONNX Runtime standard
unconcatenated_batch_data.append(nchw_data)
batch_data = np.concatenate(np.expand_dims(unconcatenated_batch_data, axis=0), axis=0)
return batch_data
import resnet50_data_reader


def benchmark(model_path):
Expand All @@ -87,11 +30,15 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input_model", required=True, help="input model")
parser.add_argument("--output_model", required=True, help="output model")
parser.add_argument("--calibrate_dataset", default="./test_images", help="calibration data set")
parser.add_argument("--quant_format",
default=QuantFormat.QDQ,
type=QuantFormat.from_string,
choices=list(QuantFormat))
parser.add_argument(
"--calibrate_dataset", default="./test_images", help="calibration data set"
)
parser.add_argument(
"--quant_format",
default=QuantFormat.QDQ,
type=QuantFormat.from_string,
choices=list(QuantFormat),
)
parser.add_argument("--per_channel", default=False, type=bool)
args = parser.parse_args()
return args
Expand All @@ -102,21 +49,25 @@ def main():
input_model_path = args.input_model
output_model_path = args.output_model
calibration_dataset_path = args.calibrate_dataset
dr = ResNet50DataReader(calibration_dataset_path)
quantize_static(input_model_path,
output_model_path,
dr,
quant_format=args.quant_format,
per_channel=args.per_channel,
weight_type=QuantType.QInt8)
print('Calibrated and quantized model saved.')

print('benchmarking fp32 model...')
dr = resnet50_data_reader.ResNet50DataReader(
calibration_dataset_path, input_model_path
)
quantize_static(
input_model_path,
output_model_path,
dr,
quant_format=args.quant_format,
per_channel=args.per_channel,
weight_type=QuantType.QInt8,
)
print("Calibrated and quantized model saved.")

print("benchmarking fp32 model...")
benchmark(input_model_path)

print('benchmarking int8 model...')
print("benchmarking int8 model...")
benchmark(output_model_path)


if __name__ == '__main__':
if __name__ == "__main__":
main()