Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions quantization/image_classification/cpu/ReadMe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# ONNX Runtime Quantization Example

This folder contains example code for quantizing Resnet50 or MobilenetV2 models. The example has
three parts:

1. Pre-processing
2. Quantization
3. Debugging

## Pre-processing

Pre-processing prepares a float32 model for quantization. Run the following command to pre-process
model `mobilenetv2-7.onnx`.

```console
python -m onnxruntime.quantization.shape_inference --input mobilenetv2-7.onnx --output mobilenetv2-7-infer.onnx
```

The pre-processing consists of the following optional steps
- Symbolic Shape Inference. It works best with transformer models.
- ONNX Runtime Model Optimization.
- ONNX Shape Inference.

Quantization requires tensor shape information to perform its best. Model optimization
also improve the performance of quantization. For instance, a Convolution node followed
by a BatchNormalization node can be merged into a single node during optimization.
Currently we can not quantize BatchNormalization by itself, but we can quantize the
merged Convolution + BatchNormalization node.

It is highly recommended to run model optimization in pre-processing instead of in quantization.
To learn more about each of these steps and finer controls, run:
```console
python -m onnxruntime.quantization.shape_inference --help
```

## Quantization

Quantization tool takes the pre-processed float32 model and produce a quantized model.
It's recommended to use Tensor-oriented quantization (QDQ; Quantize and DeQuantize).

```console
python run.py --input_model mobilenetv2-7-infer.onnx --output_model mobilenetv2-7.quant.onnx --calibrate_dataset ./test_images/
```
This will generate quantized model mobilenetv2-7.quant.onnx

The code in `run.py` creates an input data reader for the model, uses these input data to run
the model to calibrate quantization parameters for each tensor, and then produces quantized
model. Last, it runs the quantized model. Of these step, the only part that is specific to
the model is the input data reader, as each model requires different shapes of input data.
All other code can be easily generalized for other models.

For historical reasons, the quantization API performs model optimization by default.
It's highly recommended to turn off model optimization using parameter
`optimize_model=False`. This way, it is easier for the quantization debugger to match
tensors of the float32 model and its quantized model, facilitating the triaging of quantization
loss.

## Debugging

Quantization is not a loss-less process. Sometime it results in significant loss in accuracy.
To help locate the source of these losses, our quantization debugging tool matches up
weight tensors of the float32 model vs those of the quantized model. If a input data reader
is provided, our debugger can also run both models with the same input and compare their
corresponding tensors:

'''console
python run_qdq_debug.py --float_model mobilenetv2-7-infer.onnx --qdq_model mobilenetv2-7.quant.onnx --calibrate_dataset ./test_images/
'''

If you have quantized a model with optimization turned on, and found the debugging tool can not
match certain float32 model tensors with their quantized counterparts, you can try to run the
pre-processor to produce the optimized model, then compare the optimized model with the quantized model.

For instance, you have a model `abc_float32_model.onnx`, and a quantized model
`abc_quantized.onnx`. During quantization process, you had optimization turned on
by default. You can run the following code to produce an optimized float32 model:

```console
python -m onnxruntime.quantization.shape_inference --input abc_float32_model.onnx --output abc_optimized.onnx --skip_symbolic_shape True
```

Then run the debugger comparing `abc_optimized.onnx` with `abc_quantized.onnx`.
2 changes: 0 additions & 2 deletions quantization/image_classification/cpu/ReadMe.txt

This file was deleted.

36 changes: 19 additions & 17 deletions quantization/image_classification/cpu/resnet50_data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,25 @@ def _preprocess_images(images_folder: str, height: int, width: int, size_limit=0

class ResNet50DataReader(CalibrationDataReader):
def __init__(self, calibration_image_folder: str, model_path: str):
self.image_folder = calibration_image_folder
self.model_path = model_path
self.preprocess_flag = True
self.enum_data_dicts = []
self.datasize = 0
self.enum_data = None

# Use inference session to get input shape.
session = onnxruntime.InferenceSession(model_path, None)
(_, _, height, width) = session.get_inputs()[0].shape

# Convert image to input data
self.nhwc_data_list = _preprocess_images(
calibration_image_folder, height, width, size_limit=0
)
self.input_name = session.get_inputs()[0].name
self.datasize = len(self.nhwc_data_list)

def get_next(self):
if self.preprocess_flag:
self.preprocess_flag = False
session = onnxruntime.InferenceSession(self.model_path, None)
(_, _, height, width) = session.get_inputs()[0].shape
nhwc_data_list = _preprocess_images(
self.image_folder, height, width, size_limit=0
if self.enum_data is None:
self.enum_data = iter(
[{self.input_name: nhwc_data} for nhwc_data in self.nhwc_data_list]
)
input_name = session.get_inputs()[0].name
self.datasize = len(nhwc_data_list)
self.enum_data_dicts = iter(
[{input_name: nhwc_data} for nhwc_data in nhwc_data_list]
)
return next(self.enum_data_dicts, None)
return next(self.enum_data, None)

def rewind(self):
self.enum_data = None
4 changes: 4 additions & 0 deletions quantization/image_classification/cpu/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,17 @@ def main():
dr = resnet50_data_reader.ResNet50DataReader(
calibration_dataset_path, input_model_path
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call the pre-precess?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would break the flow described in Readme.md if we call pre-processing here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to make the pre-processing front and center. In the example, highlighting it to be a separated step. If I call it in her, it is kinda hidden. The user must dig into python code to locate it. In current layout, this step is on the command line level instead of python code level. This might be more obvious to the user. What do you think?

# Calibrate and quantize model
# Turn off model optimization during quantization
quantize_static(
input_model_path,
output_model_path,
dr,
quant_format=args.quant_format,
per_channel=args.per_channel,
weight_type=QuantType.QInt8,
optimize_model=False,
)
print("Calibrated and quantized model saved.")

Expand Down
90 changes: 90 additions & 0 deletions quantization/image_classification/cpu/run_qdq_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import argparse
import onnx
from onnxruntime.quantization.qdq_loss_debug import (
collect_activations, compute_activation_error, compute_weight_error,
create_activation_matching, create_weight_matching,
modify_model_output_intermediate_tensors)

import resnet50_data_reader


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--float_model", required=True, help="Path to original floating point model"
)
parser.add_argument("--qdq_model", required=True, help="Path to qdq model")
parser.add_argument(
"--calibrate_dataset", default="./test_images", help="calibration data set"
)
args = parser.parse_args()
return args


def _generate_aug_model_path(model_path: str) -> str:
aug_model_path = (
model_path[: -len(".onnx")] if model_path.endswith(".onnx") else model_path
)
return aug_model_path + ".save_tensors.onnx"


def main():
# Process input parameters and setup model input data reader
args = get_args()
float_model_path = args.float_model
qdq_model_path = args.qdq_model
calibration_dataset_path = args.calibrate_dataset

print("------------------------------------------------\n")
print("Comparing weights of float model vs qdq model.....")

matched_weights = create_weight_matching(float_model_path, qdq_model_path)
weights_error = compute_weight_error(matched_weights)
for weight_name, err in weights_error.items():
print(f"Cross model error of '{weight_name}': {err}\n")

print("------------------------------------------------\n")
print("Augmenting models to save intermediate activations......")

aug_float_model = modify_model_output_intermediate_tensors(float_model_path)
aug_float_model_path = _generate_aug_model_path(float_model_path)
onnx.save(
aug_float_model,
aug_float_model_path,
save_as_external_data=False,
)
del aug_float_model

aug_qdq_model = modify_model_output_intermediate_tensors(qdq_model_path)
aug_qdq_model_path = _generate_aug_model_path(qdq_model_path)
onnx.save(
aug_qdq_model,
aug_qdq_model_path,
save_as_external_data=False,
)
del aug_qdq_model

print("------------------------------------------------\n")
print("Running the augmented floating point model to collect activations......")
input_data_reader = resnet50_data_reader.ResNet50DataReader(
calibration_dataset_path, float_model_path
)
float_activations = collect_activations(aug_float_model_path, input_data_reader)

print("------------------------------------------------\n")
print("Running the augmented qdq model to collect activations......")
input_data_reader.rewind()
qdq_activations = collect_activations(aug_qdq_model_path, input_data_reader)

print("------------------------------------------------\n")
print("Comparing activations of float model vs qdq model......")

act_matching = create_activation_matching(qdq_activations, float_activations)
act_error = compute_activation_error(act_matching)
for act_name, err in act_error.items():
print(f"Cross model error of '{act_name}': {err['xmodel_err']} \n")
print(f"QDQ error of '{act_name}': {err['qdq_err']} \n")


if __name__ == "__main__":
main()