diff --git a/quantization/image_classification/cpu/ReadMe.md b/quantization/image_classification/cpu/ReadMe.md new file mode 100644 index 000000000..fe22b94ba --- /dev/null +++ b/quantization/image_classification/cpu/ReadMe.md @@ -0,0 +1,82 @@ +# ONNX Runtime Quantization Example + +This folder contains example code for quantizing Resnet50 or MobilenetV2 models. The example has +three parts: + +1. Pre-processing +2. Quantization +3. Debugging + +## Pre-processing + +Pre-processing prepares a float32 model for quantization. Run the following command to pre-process +model `mobilenetv2-7.onnx`. + +```console +python -m onnxruntime.quantization.shape_inference --input mobilenetv2-7.onnx --output mobilenetv2-7-infer.onnx +``` + +The pre-processing consists of the following optional steps +- Symbolic Shape Inference. It works best with transformer models. +- ONNX Runtime Model Optimization. +- ONNX Shape Inference. + +Quantization requires tensor shape information to perform its best. Model optimization +also improve the performance of quantization. For instance, a Convolution node followed +by a BatchNormalization node can be merged into a single node during optimization. +Currently we can not quantize BatchNormalization by itself, but we can quantize the +merged Convolution + BatchNormalization node. + +It is highly recommended to run model optimization in pre-processing instead of in quantization. +To learn more about each of these steps and finer controls, run: +```console +python -m onnxruntime.quantization.shape_inference --help +``` + +## Quantization + +Quantization tool takes the pre-processed float32 model and produce a quantized model. +It's recommended to use Tensor-oriented quantization (QDQ; Quantize and DeQuantize). + +```console +python run.py --input_model mobilenetv2-7-infer.onnx --output_model mobilenetv2-7.quant.onnx --calibrate_dataset ./test_images/ +``` +This will generate quantized model mobilenetv2-7.quant.onnx + +The code in `run.py` creates an input data reader for the model, uses these input data to run +the model to calibrate quantization parameters for each tensor, and then produces quantized +model. Last, it runs the quantized model. Of these step, the only part that is specific to +the model is the input data reader, as each model requires different shapes of input data. +All other code can be easily generalized for other models. + +For historical reasons, the quantization API performs model optimization by default. +It's highly recommended to turn off model optimization using parameter +`optimize_model=False`. This way, it is easier for the quantization debugger to match +tensors of the float32 model and its quantized model, facilitating the triaging of quantization +loss. + +## Debugging + +Quantization is not a loss-less process. Sometime it results in significant loss in accuracy. +To help locate the source of these losses, our quantization debugging tool matches up +weight tensors of the float32 model vs those of the quantized model. If a input data reader +is provided, our debugger can also run both models with the same input and compare their +corresponding tensors: + +'''console +python run_qdq_debug.py --float_model mobilenetv2-7-infer.onnx --qdq_model mobilenetv2-7.quant.onnx --calibrate_dataset ./test_images/ +''' + +If you have quantized a model with optimization turned on, and found the debugging tool can not +match certain float32 model tensors with their quantized counterparts, you can try to run the +pre-processor to produce the optimized model, then compare the optimized model with the quantized model. + +For instance, you have a model `abc_float32_model.onnx`, and a quantized model +`abc_quantized.onnx`. During quantization process, you had optimization turned on +by default. You can run the following code to produce an optimized float32 model: + +```console +python -m onnxruntime.quantization.shape_inference --input abc_float32_model.onnx --output abc_optimized.onnx --skip_symbolic_shape True +``` + +Then run the debugger comparing `abc_optimized.onnx` with `abc_quantized.onnx`. diff --git a/quantization/image_classification/cpu/ReadMe.txt b/quantization/image_classification/cpu/ReadMe.txt deleted file mode 100644 index 47fd8858f..000000000 --- a/quantization/image_classification/cpu/ReadMe.txt +++ /dev/null @@ -1,2 +0,0 @@ -call run.py to calibrate, quantize and run the quantized model, e.g.: -python run.py --input_model mobilenetv2-7.onnx --output_model mobilenetv2-7.quant.onnx --calibrate_dataset ./test_images/ \ No newline at end of file diff --git a/quantization/image_classification/cpu/resnet50_data_reader.py b/quantization/image_classification/cpu/resnet50_data_reader.py index bf929bce6..d07198f11 100644 --- a/quantization/image_classification/cpu/resnet50_data_reader.py +++ b/quantization/image_classification/cpu/resnet50_data_reader.py @@ -39,23 +39,25 @@ def _preprocess_images(images_folder: str, height: int, width: int, size_limit=0 class ResNet50DataReader(CalibrationDataReader): def __init__(self, calibration_image_folder: str, model_path: str): - self.image_folder = calibration_image_folder - self.model_path = model_path - self.preprocess_flag = True - self.enum_data_dicts = [] - self.datasize = 0 + self.enum_data = None + + # Use inference session to get input shape. + session = onnxruntime.InferenceSession(model_path, None) + (_, _, height, width) = session.get_inputs()[0].shape + + # Convert image to input data + self.nhwc_data_list = _preprocess_images( + calibration_image_folder, height, width, size_limit=0 + ) + self.input_name = session.get_inputs()[0].name + self.datasize = len(self.nhwc_data_list) def get_next(self): - if self.preprocess_flag: - self.preprocess_flag = False - session = onnxruntime.InferenceSession(self.model_path, None) - (_, _, height, width) = session.get_inputs()[0].shape - nhwc_data_list = _preprocess_images( - self.image_folder, height, width, size_limit=0 + if self.enum_data is None: + self.enum_data = iter( + [{self.input_name: nhwc_data} for nhwc_data in self.nhwc_data_list] ) - input_name = session.get_inputs()[0].name - self.datasize = len(nhwc_data_list) - self.enum_data_dicts = iter( - [{input_name: nhwc_data} for nhwc_data in nhwc_data_list] - ) - return next(self.enum_data_dicts, None) + return next(self.enum_data, None) + + def rewind(self): + self.enum_data = None diff --git a/quantization/image_classification/cpu/run.py b/quantization/image_classification/cpu/run.py index 82088b9da..b7a357b35 100644 --- a/quantization/image_classification/cpu/run.py +++ b/quantization/image_classification/cpu/run.py @@ -52,6 +52,9 @@ def main(): dr = resnet50_data_reader.ResNet50DataReader( calibration_dataset_path, input_model_path ) + + # Calibrate and quantize model + # Turn off model optimization during quantization quantize_static( input_model_path, output_model_path, @@ -59,6 +62,7 @@ def main(): quant_format=args.quant_format, per_channel=args.per_channel, weight_type=QuantType.QInt8, + optimize_model=False, ) print("Calibrated and quantized model saved.") diff --git a/quantization/image_classification/cpu/run_qdq_debug.py b/quantization/image_classification/cpu/run_qdq_debug.py new file mode 100644 index 000000000..6c549f761 --- /dev/null +++ b/quantization/image_classification/cpu/run_qdq_debug.py @@ -0,0 +1,90 @@ +import argparse +import onnx +from onnxruntime.quantization.qdq_loss_debug import ( + collect_activations, compute_activation_error, compute_weight_error, + create_activation_matching, create_weight_matching, + modify_model_output_intermediate_tensors) + +import resnet50_data_reader + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--float_model", required=True, help="Path to original floating point model" + ) + parser.add_argument("--qdq_model", required=True, help="Path to qdq model") + parser.add_argument( + "--calibrate_dataset", default="./test_images", help="calibration data set" + ) + args = parser.parse_args() + return args + + +def _generate_aug_model_path(model_path: str) -> str: + aug_model_path = ( + model_path[: -len(".onnx")] if model_path.endswith(".onnx") else model_path + ) + return aug_model_path + ".save_tensors.onnx" + + +def main(): + # Process input parameters and setup model input data reader + args = get_args() + float_model_path = args.float_model + qdq_model_path = args.qdq_model + calibration_dataset_path = args.calibrate_dataset + + print("------------------------------------------------\n") + print("Comparing weights of float model vs qdq model.....") + + matched_weights = create_weight_matching(float_model_path, qdq_model_path) + weights_error = compute_weight_error(matched_weights) + for weight_name, err in weights_error.items(): + print(f"Cross model error of '{weight_name}': {err}\n") + + print("------------------------------------------------\n") + print("Augmenting models to save intermediate activations......") + + aug_float_model = modify_model_output_intermediate_tensors(float_model_path) + aug_float_model_path = _generate_aug_model_path(float_model_path) + onnx.save( + aug_float_model, + aug_float_model_path, + save_as_external_data=False, + ) + del aug_float_model + + aug_qdq_model = modify_model_output_intermediate_tensors(qdq_model_path) + aug_qdq_model_path = _generate_aug_model_path(qdq_model_path) + onnx.save( + aug_qdq_model, + aug_qdq_model_path, + save_as_external_data=False, + ) + del aug_qdq_model + + print("------------------------------------------------\n") + print("Running the augmented floating point model to collect activations......") + input_data_reader = resnet50_data_reader.ResNet50DataReader( + calibration_dataset_path, float_model_path + ) + float_activations = collect_activations(aug_float_model_path, input_data_reader) + + print("------------------------------------------------\n") + print("Running the augmented qdq model to collect activations......") + input_data_reader.rewind() + qdq_activations = collect_activations(aug_qdq_model_path, input_data_reader) + + print("------------------------------------------------\n") + print("Comparing activations of float model vs qdq model......") + + act_matching = create_activation_matching(qdq_activations, float_activations) + act_error = compute_activation_error(act_matching) + for act_name, err in act_error.items(): + print(f"Cross model error of '{act_name}': {err['xmodel_err']} \n") + print(f"QDQ error of '{act_name}': {err['qdq_err']} \n") + + +if __name__ == "__main__": + main()