Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow narrow_range=1 for FakeQuantWithMinMaxVars #1713

Closed
srihari-humbarwadi opened this issue Sep 13, 2021 · 2 comments
Closed

Allow narrow_range=1 for FakeQuantWithMinMaxVars #1713

srihari-humbarwadi opened this issue Sep 13, 2021 · 2 comments

Comments

@srihari-humbarwadi
Copy link

Describe the bug
Allow narrow_range=1 for FakeQuantWithMinMaxVars. From tf-docs, when narrow_range=1

In case of 8 bits, narrow_range nudges the quantized range to be [-127, 127]
instead of [-128, 127]. This ensures symmetric range has 0 as the centre.

TensorRT converts QuantizeLinear and DeQuantizeLinear if and only if quantization scheme is symmetric and the zero_point = 0
This can be confirmed by setting narrow_range=False while creating the tensorflow model and then converting it to onnx format. The onnx model successfully creates QuantizeLinear and DeQuantizeLinear nodes, but when running with trtexec it halts with
[6] Assertion failed: shiftIsAllZeros(zeroPoint) && "TRT only supports symmetric quantization - zeroPt must be all zeros"

Urgency
Blocked usecase,
QAT TF-2.x ->ONNX model -> TensorRT Engine

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
  • Tensorflow Version: 2.6.0, 2.7 nightly
  • Python version: 3.7

To Reproduce
Here is a minimal example that shows this

import tensorflow as tf
import tf2onnx
from tensorflow_model_optimization.python.core.quantization.keras import (
    quantize_wrapper,
    quantizers,
)
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
    default_8bit_quantize_configs,
)


class QuantizeConfig(default_8bit_quantize_configs.Default8BitOutputQuantizeConfig):
    def get_output_quantizers(self, layer):
        return [quantizers.MovingAverageQuantizer(
            num_bits=8, per_axis=False, symmetric=True, narrow_range=True)]
    
def quantize(layer):
      return quantize_wrapper.QuantizeWrapper(
          layer, quantize_config=QuantizeConfig())
    
    
    
    
def get_model():
    images = tf.keras.Input(shape=(32, 32, 3))
    x = tf.keras.layers.Conv2D(32, (3, 3))(images)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = tf.keras.layers.Conv2D(64, (3, 3))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = quantize(tf.keras.layers.Conv2D(64, (3, 3)))(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(64, activation='relu')(x)
    x = tf.keras.layers.Dense(10)(x)
    
    model = tf.keras.Model(images, x)
    return model

q_model = get_model()
q_model(tf.random.normal((1, 32, 32, 3)))


onnx_model, _ = tf2onnx.convert.from_keras(q_model,
                                           opset=13,  # tried 13/14 
                                           inputs_as_nchw=[q_model.input.name])

with open('test_wrapper.onnx', 'wb') as f:
    f.write(onnx_model.SerializeToString())
@jasdeep06
Copy link

@srihari-humbarwadi Were you able to convert asymmetric onnx to trt?How?

@srihari-humbarwadi
Copy link
Author

@jasdeep06 I created my own quantized weight layers using tf.quantization.quantize_and_dequantize_v2. This issue thread has some related information #1719

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants