From c3421de5e8c04813a33b232b753daca130a3a2fd Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Fri, 19 Aug 2022 09:12:54 -0700 Subject: [PATCH 1/3] Add pre-processing and turn off model optimization during quant --- .../image_classification/cpu/ReadMe.md | 78 ++++++++++++++++ .../image_classification/cpu/ReadMe.txt | 2 - .../cpu/resnet50_data_reader.py | 36 ++++---- quantization/image_classification/cpu/run.py | 4 + .../image_classification/cpu/run_qdq_debug.py | 90 +++++++++++++++++++ 5 files changed, 191 insertions(+), 19 deletions(-) create mode 100644 quantization/image_classification/cpu/ReadMe.md delete mode 100644 quantization/image_classification/cpu/ReadMe.txt create mode 100644 quantization/image_classification/cpu/run_qdq_debug.py diff --git a/quantization/image_classification/cpu/ReadMe.md b/quantization/image_classification/cpu/ReadMe.md new file mode 100644 index 000000000..8b4614c10 --- /dev/null +++ b/quantization/image_classification/cpu/ReadMe.md @@ -0,0 +1,78 @@ +# ONNX Runtime Quantization Example + +This folder contains example code for quantizing Resnet50 or mobilenetv2 models, which consists of 3 steps: + +- Pre-processing +- Quantization +- Debugging + + +## Pre-processing + +Quantization works best with shape inferencing, as not knowing a tensor's shape makes +it harder to quantize it. On the other hand, ONNX shape inferencing works best with +optimized models. So, it is recommended to pre-process the original 32 bit floating +point model with optimization and shape inferencing, before quantization. + +```console +python -m onnxruntime.quantization.shape_inference --input mobilenetv2-7.onnx --output mobilenetv2-7-infer.onnx +``` + +The pre-processing consists of 3 optional sub-steps +- Symbolic Shape Inference. It works best with transformer models. +- ONNX Runtime Model Optimization. +- ONNX Shape Inference + +To learn more about these pre-processing steps and how to skip some of them, run: +```console +python -m onnxruntime.quantization.shape_inference --help +``` + +## Quantization + +Quantization tool takes the pre-processed float32 model and produce a quantized model. +It's recommended to use Tensor-oriented quantization (QDQ; Quantize and DeQuantize). + +```console +python run.py --input_model mobilenetv2-7-infer.onnx --output_model mobilenetv2-7.quant.onnx --calibrate_dataset ./test_images/ +``` +This will generate quantized model mobilenetv2-7.quant.onnx + +The code in ```run.py``` creates an input data reader for the model, uses these input data to run +the model to calibrate quantization parameters for each tensor, and then produces quantized +model. Last, it runs the quantized model. Of these step, the only part that is specific to +the model is the input data reader, as each model requires different shapes of input data. +All other code can be easily generalized for other models. + +For historical reasons, the quantization API performs model optimization by default. +It's highly recommended to turn off model optimization using parameter +```optimize_model=False```. This way, it is easier for the quantization debugger to match +tensors of the float32 model and its quantized model, facilitating the triaging of quantization +loss. + +## Debugging + +Quantization is not a loss-less process. Sometime it results in significiant loss in accuracy. +To help locate the source of these losses, our quantization debugging tool matches and +compare weights of the float32 model vs those of the quantized model. If a input data reader +is provided, our debugger can also run both models with the same input and compare their +corresponding tensors: + +'''console +python run_qdq_debug.py --float_model mobilenetv2-7-infer.onnx --qdq_model mobilenetv2-7.quant.onnx --calibrate_dataset ./test_images/ +''' + +For historical reasons, the quantization API performs model optimization by default. If you +have a quantized model with optimization turned on, and found the debugging tool can not match +certain float32 model tensors with their quantized counterparts, you can try running the +debugger again, comparing the optimized float32 model with the quantized model. + +For instance, you have a model ```abc_float32_model.onnx```, and a quantized model +```abc_quantized.onnx```. During quantization process, you had optimization turned on +by default. You can run the following code to produce an optimized float32 model: + +```console +python -m onnxruntime.quantization.shape_inference --input abc_float32_model.onnx --output abc_optimized.onnx --skip_symbolic_shape True +``` + +Then run the debugger comparing ```abc_optimized.onnx``` with ```abc_quantized.onnx```. diff --git a/quantization/image_classification/cpu/ReadMe.txt b/quantization/image_classification/cpu/ReadMe.txt deleted file mode 100644 index 47fd8858f..000000000 --- a/quantization/image_classification/cpu/ReadMe.txt +++ /dev/null @@ -1,2 +0,0 @@ -call run.py to calibrate, quantize and run the quantized model, e.g.: -python run.py --input_model mobilenetv2-7.onnx --output_model mobilenetv2-7.quant.onnx --calibrate_dataset ./test_images/ \ No newline at end of file diff --git a/quantization/image_classification/cpu/resnet50_data_reader.py b/quantization/image_classification/cpu/resnet50_data_reader.py index bf929bce6..d07198f11 100644 --- a/quantization/image_classification/cpu/resnet50_data_reader.py +++ b/quantization/image_classification/cpu/resnet50_data_reader.py @@ -39,23 +39,25 @@ def _preprocess_images(images_folder: str, height: int, width: int, size_limit=0 class ResNet50DataReader(CalibrationDataReader): def __init__(self, calibration_image_folder: str, model_path: str): - self.image_folder = calibration_image_folder - self.model_path = model_path - self.preprocess_flag = True - self.enum_data_dicts = [] - self.datasize = 0 + self.enum_data = None + + # Use inference session to get input shape. + session = onnxruntime.InferenceSession(model_path, None) + (_, _, height, width) = session.get_inputs()[0].shape + + # Convert image to input data + self.nhwc_data_list = _preprocess_images( + calibration_image_folder, height, width, size_limit=0 + ) + self.input_name = session.get_inputs()[0].name + self.datasize = len(self.nhwc_data_list) def get_next(self): - if self.preprocess_flag: - self.preprocess_flag = False - session = onnxruntime.InferenceSession(self.model_path, None) - (_, _, height, width) = session.get_inputs()[0].shape - nhwc_data_list = _preprocess_images( - self.image_folder, height, width, size_limit=0 + if self.enum_data is None: + self.enum_data = iter( + [{self.input_name: nhwc_data} for nhwc_data in self.nhwc_data_list] ) - input_name = session.get_inputs()[0].name - self.datasize = len(nhwc_data_list) - self.enum_data_dicts = iter( - [{input_name: nhwc_data} for nhwc_data in nhwc_data_list] - ) - return next(self.enum_data_dicts, None) + return next(self.enum_data, None) + + def rewind(self): + self.enum_data = None diff --git a/quantization/image_classification/cpu/run.py b/quantization/image_classification/cpu/run.py index 82088b9da..b7a357b35 100644 --- a/quantization/image_classification/cpu/run.py +++ b/quantization/image_classification/cpu/run.py @@ -52,6 +52,9 @@ def main(): dr = resnet50_data_reader.ResNet50DataReader( calibration_dataset_path, input_model_path ) + + # Calibrate and quantize model + # Turn off model optimization during quantization quantize_static( input_model_path, output_model_path, @@ -59,6 +62,7 @@ def main(): quant_format=args.quant_format, per_channel=args.per_channel, weight_type=QuantType.QInt8, + optimize_model=False, ) print("Calibrated and quantized model saved.") diff --git a/quantization/image_classification/cpu/run_qdq_debug.py b/quantization/image_classification/cpu/run_qdq_debug.py new file mode 100644 index 000000000..6c549f761 --- /dev/null +++ b/quantization/image_classification/cpu/run_qdq_debug.py @@ -0,0 +1,90 @@ +import argparse +import onnx +from onnxruntime.quantization.qdq_loss_debug import ( + collect_activations, compute_activation_error, compute_weight_error, + create_activation_matching, create_weight_matching, + modify_model_output_intermediate_tensors) + +import resnet50_data_reader + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--float_model", required=True, help="Path to original floating point model" + ) + parser.add_argument("--qdq_model", required=True, help="Path to qdq model") + parser.add_argument( + "--calibrate_dataset", default="./test_images", help="calibration data set" + ) + args = parser.parse_args() + return args + + +def _generate_aug_model_path(model_path: str) -> str: + aug_model_path = ( + model_path[: -len(".onnx")] if model_path.endswith(".onnx") else model_path + ) + return aug_model_path + ".save_tensors.onnx" + + +def main(): + # Process input parameters and setup model input data reader + args = get_args() + float_model_path = args.float_model + qdq_model_path = args.qdq_model + calibration_dataset_path = args.calibrate_dataset + + print("------------------------------------------------\n") + print("Comparing weights of float model vs qdq model.....") + + matched_weights = create_weight_matching(float_model_path, qdq_model_path) + weights_error = compute_weight_error(matched_weights) + for weight_name, err in weights_error.items(): + print(f"Cross model error of '{weight_name}': {err}\n") + + print("------------------------------------------------\n") + print("Augmenting models to save intermediate activations......") + + aug_float_model = modify_model_output_intermediate_tensors(float_model_path) + aug_float_model_path = _generate_aug_model_path(float_model_path) + onnx.save( + aug_float_model, + aug_float_model_path, + save_as_external_data=False, + ) + del aug_float_model + + aug_qdq_model = modify_model_output_intermediate_tensors(qdq_model_path) + aug_qdq_model_path = _generate_aug_model_path(qdq_model_path) + onnx.save( + aug_qdq_model, + aug_qdq_model_path, + save_as_external_data=False, + ) + del aug_qdq_model + + print("------------------------------------------------\n") + print("Running the augmented floating point model to collect activations......") + input_data_reader = resnet50_data_reader.ResNet50DataReader( + calibration_dataset_path, float_model_path + ) + float_activations = collect_activations(aug_float_model_path, input_data_reader) + + print("------------------------------------------------\n") + print("Running the augmented qdq model to collect activations......") + input_data_reader.rewind() + qdq_activations = collect_activations(aug_qdq_model_path, input_data_reader) + + print("------------------------------------------------\n") + print("Comparing activations of float model vs qdq model......") + + act_matching = create_activation_matching(qdq_activations, float_activations) + act_error = compute_activation_error(act_matching) + for act_name, err in act_error.items(): + print(f"Cross model error of '{act_name}': {err['xmodel_err']} \n") + print(f"QDQ error of '{act_name}': {err['qdq_err']} \n") + + +if __name__ == "__main__": + main() From b0b2db4c765ec51b1c978b68d140866c45b1f3b5 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Fri, 26 Aug 2022 10:34:50 -0700 Subject: [PATCH 2/3] refine docs --- .../image_classification/cpu/ReadMe.md | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/quantization/image_classification/cpu/ReadMe.md b/quantization/image_classification/cpu/ReadMe.md index 8b4614c10..4518c5102 100644 --- a/quantization/image_classification/cpu/ReadMe.md +++ b/quantization/image_classification/cpu/ReadMe.md @@ -1,29 +1,28 @@ # ONNX Runtime Quantization Example -This folder contains example code for quantizing Resnet50 or mobilenetv2 models, which consists of 3 steps: - -- Pre-processing -- Quantization -- Debugging +This folder contains example code for quantizing Resnet50 or MobilenetV2 models. The example has +three parts: +1. Pre-processing +2. Quantization +3. Debugging ## Pre-processing -Quantization works best with shape inferencing, as not knowing a tensor's shape makes -it harder to quantize it. On the other hand, ONNX shape inferencing works best with -optimized models. So, it is recommended to pre-process the original 32 bit floating -point model with optimization and shape inferencing, before quantization. +Pre-processing prepares a float32 model for quantization. Run the following command to pre-process +model `mobilenetv2-7.onnx`. ```console python -m onnxruntime.quantization.shape_inference --input mobilenetv2-7.onnx --output mobilenetv2-7-infer.onnx ``` -The pre-processing consists of 3 optional sub-steps +The pre-processing consists of the following optional steps - Symbolic Shape Inference. It works best with transformer models. - ONNX Runtime Model Optimization. - ONNX Shape Inference -To learn more about these pre-processing steps and how to skip some of them, run: +It is highly recommended to run model optimization in pre-processing instead of in quantization. +To learn more about each of these steps and finer controls, run: ```console python -m onnxruntime.quantization.shape_inference --help ``` @@ -38,7 +37,7 @@ python run.py --input_model mobilenetv2-7-infer.onnx --output_model mobilenetv2- ``` This will generate quantized model mobilenetv2-7.quant.onnx -The code in ```run.py``` creates an input data reader for the model, uses these input data to run +The code in `run.py` creates an input data reader for the model, uses these input data to run the model to calibrate quantization parameters for each tensor, and then produces quantized model. Last, it runs the quantized model. Of these step, the only part that is specific to the model is the input data reader, as each model requires different shapes of input data. @@ -46,15 +45,15 @@ All other code can be easily generalized for other models. For historical reasons, the quantization API performs model optimization by default. It's highly recommended to turn off model optimization using parameter -```optimize_model=False```. This way, it is easier for the quantization debugger to match +`optimize_model=False`. This way, it is easier for the quantization debugger to match tensors of the float32 model and its quantized model, facilitating the triaging of quantization loss. ## Debugging -Quantization is not a loss-less process. Sometime it results in significiant loss in accuracy. -To help locate the source of these losses, our quantization debugging tool matches and -compare weights of the float32 model vs those of the quantized model. If a input data reader +Quantization is not a loss-less process. Sometime it results in significant loss in accuracy. +To help locate the source of these losses, our quantization debugging tool matches up +weight tensors of the float32 model vs those of the quantized model. If a input data reader is provided, our debugger can also run both models with the same input and compare their corresponding tensors: @@ -62,17 +61,16 @@ corresponding tensors: python run_qdq_debug.py --float_model mobilenetv2-7-infer.onnx --qdq_model mobilenetv2-7.quant.onnx --calibrate_dataset ./test_images/ ''' -For historical reasons, the quantization API performs model optimization by default. If you -have a quantized model with optimization turned on, and found the debugging tool can not match -certain float32 model tensors with their quantized counterparts, you can try running the -debugger again, comparing the optimized float32 model with the quantized model. +If you have quantized a model with optimization turned on, and found the debugging tool can not +match certain float32 model tensors with their quantized counterparts, you can try to run the +pre-processor to produce the optimized model, then compare the optimized model with the quantized model. -For instance, you have a model ```abc_float32_model.onnx```, and a quantized model -```abc_quantized.onnx```. During quantization process, you had optimization turned on +For instance, you have a model `abc_float32_model.onnx`, and a quantized model +`abc_quantized.onnx`. During quantization process, you had optimization turned on by default. You can run the following code to produce an optimized float32 model: ```console python -m onnxruntime.quantization.shape_inference --input abc_float32_model.onnx --output abc_optimized.onnx --skip_symbolic_shape True ``` -Then run the debugger comparing ```abc_optimized.onnx``` with ```abc_quantized.onnx```. +Then run the debugger comparing `abc_optimized.onnx` with `abc_quantized.onnx`. From 66317831df3f6df8daf69fd6881a13f24e052d8a Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Mon, 29 Aug 2022 12:16:57 -0700 Subject: [PATCH 3/3] add doc about optimization --- quantization/image_classification/cpu/ReadMe.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/quantization/image_classification/cpu/ReadMe.md b/quantization/image_classification/cpu/ReadMe.md index 4518c5102..fe22b94ba 100644 --- a/quantization/image_classification/cpu/ReadMe.md +++ b/quantization/image_classification/cpu/ReadMe.md @@ -19,7 +19,13 @@ python -m onnxruntime.quantization.shape_inference --input mobilenetv2-7.onnx -- The pre-processing consists of the following optional steps - Symbolic Shape Inference. It works best with transformer models. - ONNX Runtime Model Optimization. -- ONNX Shape Inference +- ONNX Shape Inference. + +Quantization requires tensor shape information to perform its best. Model optimization +also improve the performance of quantization. For instance, a Convolution node followed +by a BatchNormalization node can be merged into a single node during optimization. +Currently we can not quantize BatchNormalization by itself, but we can quantize the +merged Convolution + BatchNormalization node. It is highly recommended to run model optimization in pre-processing instead of in quantization. To learn more about each of these steps and finer controls, run: