Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantizing Yolov5 model from ultralytics is giving poor mAP #10787

Closed
Hrayo712 opened this issue Mar 6, 2022 · 14 comments
Closed

Quantizing Yolov5 model from ultralytics is giving poor mAP #10787

Hrayo712 opened this issue Mar 6, 2022 · 14 comments
Labels
quantization issues related to quantization

Comments

@Hrayo712
Copy link

Hrayo712 commented Mar 6, 2022

Describe the bug
I am trying to quantize yolov5s (from ultralytic's repo) model using ONNXRuntime quantization. Quantization runs succesfully. However, when evaluating the quantized model using Ultralytics eval.py, I am obtaining 0 mAP (FP32 ONNX model gives correct results).

The process I am following is as follows:

  1. Export the pytorch model to ONNX using Ultralytics export.py script
  2. Quantize the FP32 model using quantize_static
  3. Evaluate the mAP using ultralytics implementation

Note that I am re-using ultralytics normalization strategy to ensure the data is preprocessed correctly.
I have also tried quantizing both per tensor and per channel, in both cases using MinMax calibration, obtaining the same results. Calibration is done using 128 images from the COCOtrain2017 dataset.

Urgency
Need to know whether this is a limitation of ORT or if I am doing something wrong ASAP

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
  • ONNX Runtime installed from (source or binary): Source
  • ONNX Runtime version: 1.10.0
  • Python version: 3.6.9
  • Visual Studio version (if applicable): N/A
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory:

To Reproduce

  • clone the yolov5 git repo and install the requirements
    git clone https://github.com/ultralytics/yolov5.git
    cd yolov5
    pip install -r requirements.txt
  • Export the pytorch model to onnx.
    python3 export.py --weights yolov5s.pt --include onnx
  • Install the onnxruntime and execute the provided script to quantize the FP32 model
    pip install onnxruntime-gpu==1.10.0
    python quantize_yolov5.py
  • Afterwards a quick evaluation can be made by running inference on COCO128 (The first 128 images of train2017)
    python val.py --weights yolov5s_ort_quant.onnx

The result of this evaluation is stored in the runs folder, where a comparison of the ground truth images vs the inference can be seen.

A more complete evaluation can be made by running validation on the complete validation2017 set by running
python val.py --weights yolov5s_ort_quant.onnx --data coco.yaml

Note that this script downloads the required data by default. Full evaluation will then trigger the complete COCO dataset to be downloaded.

Attaching the script to repro
quantize_yolov5.zip
duce this:

Expected behavior
I expect the model to still retain some accuracy. Ultralytic's implementation supports quantization via TFLite (PTQ), and
the resulting model is still able to retain some accuracy (FP32 - 37.4mAP INT8 - 32.2 mAP), evaluating with the same process as described above.

https://github.com/ultralytics/yolov5/blob/8a66ebad44e8ecf90c7d27757c832579398d4baf/export.py#L313-L342

@harshithapv harshithapv added the quantization issues related to quantization label Mar 7, 2022
@yufenglee
Copy link
Member

yufenglee commented Mar 8, 2022

Thanks for reporting the issue and the through repro steps. Tensor values are scaled to range ~[-8.7, 735.8] at the end of the model by multiplying large integers. 8 bits is not able to express such a big range. I get reasonable result by excluding to quantize nodes which taking in those large tensors. Will add warning messages for large range.

           Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 5000/5000 [08:27<00:00,  9.84it/s]
             all       5000      36335      0.624      0.506      0.536      0.316

Here is the code snippet:

# Quantize the exported model
quantize_static(
    f'{model_path}.onnx',
    f'{model_path}_ort_quant.u8s8.exclude.bigscale.onnx',
    calibration_data_reader=data_reader,
    activation_type=QuantType.QUInt8,
    weight_type=QuantType.QInt8,
    nodes_to_exclude=['Mul_214', 'Mul_225', 'Mul_249', 'Mul_260', 'Mul_284', 'Mul_295', 'Concat_231', 'Concat_266', 'Concat_301', 'Concat_303'],
    per_channel=True,
    reduce_range=True,
    calibrate_method=CalibrationMethod.MinMax
        )

@yufenglee
Copy link
Member

yufenglee commented Mar 8, 2022

You can get better accuracy with weight_type=QuantType.QUInt8 but worse performance.
U8U8 result:

               Class     Images     Labels          P          R     mAP@.5 mAP@.5:.95: 100%|██████████| 5000/5000 [10:45<00:00,  7.74it/s]
                 all       5000      36335      0.637      0.514      0.547      0.327

@Hrayo712
Copy link
Author

Hrayo712 commented Mar 9, 2022

Hi @yufenglee

I just quantized the model excluding the nodes you suggest and it works. Thanks for the investigation.

However, I wonder then how TFLite PTQ deals with this. As I mentioned above, Yolov5's repo provides means to quantize the model through TFLite.

Code to where this happens in their repo: https://github.com/ultralytics/yolov5/blob/8a66ebad44e8ecf90c7d27757c832579398d4baf/export.py#L313-L342

Inspecting the model, it seems all those operations which we are excluding from the quantization, are still computed on INT8 in the TFLite implementation. Would you have any intuition/idea/hypothesis then as to how they could possibly deal with this ?

The base FP32 model is in theory the same, so I suppose they should experience similar issues (?)

TFLite INT8 route:
Pytorch -> TF -> TFLite INT8 (through TFLite's PTQ)
ONNX INT8 route:
Pytorch-> ONNX -> ONNX INT8 model (through ONNXRT PTQ)

@yufenglee
Copy link
Member

@Hrayo712 , tf has special processing to normalize pixel index to 0-1: https://github.com/ultralytics/yolov5/blob/cba4303d323352fc6d76730e15459b14498a9e34/models/tf.py#L231.

If you do similar thing to pytorch/onnx, ort quant should also work.

BTW, we support to convert quantized tflite to ONNX and run it directly with ORT from 1.11, which will be released later this month.

@Hrayo712
Copy link
Author

Hrayo712 commented Mar 9, 2022

Hi @yufenglee , thanks for the quick and informative answer. Really appreciate it.

How would the workflow for this new feature look like ? - You mean that on v1.11, ORT will support deploying a quantized TFLite model directly ? I.e, you just pass the TFLite INT8 model and ORT takes care of it ? E.g., by internally converting the TFLite INT8 model to quantized ONNX (Qops?) via TF2ONNX ?

@yufenglee
Copy link
Member

yufenglee commented Mar 9, 2022

The scenario is for the case that you already a quantized tflite model. Then the workflow is:

  • using tf2onnx to convert quantized tflite to onnx
  • run converted onnx model with ORT directly.

In this way, you don't need to quantize the model again with ORT quantization tool.

@Hrayo712
Copy link
Author

Hrayo712 commented Mar 9, 2022

I see. But if TFLite (INT8) to ONNX conversion is already available via tf2onnx, what is then the added functionality on ORT 1.11 ? I.e., what is currently missing on 1.10 that makes running these converted models unfeasible ? Also, is the opposite conversion possible ? I.e, quantized onnx to quantized tflite, via onnx2tf ?

@yufenglee
Copy link
Member

ORT 1.10 can run the converted model too, but 1.11 has better performance. The opposite conversion is not possible.

@zishui-wu
Copy link

I according the Reproduce but int8 onnx can't get right result when i test coco2017 val images,can you provide the yolov5 onnx int8 model,Thank you!

@Bombex
Copy link

Bombex commented Oct 20, 2022

I according the Reproduce but int8 onnx can't get right result when i test coco2017 val images,can you provide the yolov5 onnx int8 model,Thank you!

You found how to get right result?

@ExSogazu
Copy link

ExSogazu commented Oct 28, 2022

The scenario is for the case that you already a quantized tflite model. Then the workflow is:

  • using tf2onnx to convert quantized tflite to onnx
  • run converted onnx model with ORT directly.

In this way, you don't need to quantize the model again with ORT quantization tool.

So, I tried to follow this work flow and got a setback. In case of tfllite models, there's input detail that specifies quantization parameters such as scale factor and zero point. How do you find those values on the converted onnx models?

edit: I could do it using onnx.graph, looping through the whole initializer instances while looking for the ones with the same name to the 2nd and the 3rd elements of the first and the last nodes' inputs.

@snehashis1997
Copy link

@yufenglee how did you find those layers' names?

@katia-katkat
Copy link

@snehashis1997 inspect the .onnx model on netron and check the last layer's nodes, check this blog it explains it at the end : https://medium.com/@abdelsalam.h.a.a/boosting-yolov5-performance-on-cpu-with-quantization-techniques-for-raspberry-pi4-too-dc2e24f68269

@katia-katkat
Copy link

@Hrayo712, did you do full int8 quantization with TfLite or just export to int8 using export.py? because when I do full quantization I get an error "Unexpected input data type. Actual: (tensor(float)) , expected: (tensor(int8))", it expects the input image to be int8 aswell. if you faced a similar issue please let me know how you fixed it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
quantization issues related to quantization
Projects
None yet
Development

No branches or pull requests

8 participants