In [1]:
import torch
import torch.nn as nn
import torch.quantization
import brevitas.nn as qnn
import brevitas.onnx as bo
from brevitas.nn import QuantConv2d, QuantLinear, QuantReLU, QuantAvgPool2d
from brevitas.quant import IntBias

from brevitas.core.restrict_val import RestrictValueType
from brevitas.quant import Uint8ActPerTensorFloatMaxInit, Int8ActPerTensorFloatMinMaxInit
from brevitas.quant import Int8WeightPerTensorFloat


class CommonIntWeightPerTensorQuant(Int8WeightPerTensorFloat):
    """
    Common per-tensor weight quantizer with bit-width set to None so that it's forced to be
    specified by each layer.
    """
    scaling_min_val = 2e-16
    bit_width = None


class CommonIntWeightPerChannelQuant(CommonIntWeightPerTensorQuant):
    """
    Common per-channel weight quantizer with bit-width set to None so that it's forced to be
    specified by each layer.
    """
    scaling_per_output_channel = True


class CommonIntActQuant(Int8ActPerTensorFloatMinMaxInit):
    """
    Common signed act quantizer with bit-width set to None so that it's forced to be specified by
    each layer.
    """
    scaling_min_val = 2e-16
    bit_width = None
    min_val = -10.0
    max_val = 10.0
    restrict_scaling_type = RestrictValueType.LOG_FP


class CommonUintActQuant(Uint8ActPerTensorFloatMaxInit):
    """
    Common unsigned act quantizer with bit-width set to None so that it's forced to be specified by
    each layer.
    """
    scaling_min_val = 2e-16
    bit_width = None
    max_val = 6.0
    restrict_scaling_type = RestrictValueType.LOG_FP


### Conv

```python
class QuantizedConvNdFn(Function):

    @staticmethod
    def symbolic(
            g, x, W, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_shape, pads, strides,
            bias, kernel_shape, groups, dilations):
        ret = g.op(
            f'{DOMAIN_STRING}::Conv', x, W,
            weight_qnt_s=w_qnt_type,
            kernel_shape_i=kernel_shape,
            pads_i=pads,
            strides_i=strides,
            group_i=groups,
            dilations_i=dilations)
        if w_qnt_scale is not None:
            ret = g.op('Mul', ret, w_qnt_scale)
        if bias is not None:
            if b_qnt_type is not None:
                assert b_qnt_scale is not None
                ret = g.op('Div', ret, b_qnt_scale)
                ret = g.op('{DOMAIN_STRING}::Add', ret, bias, bias_qnt_s=b_qnt_type)
                ret = g.op('Mul', ret, b_qnt_scale)
            else:
                ret = g.op('Add', ret, bias)
        return ret

    @staticmethod
    def forward(
            ctx, x, W, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_shape, pads, strides,
            bias, kernel_shape, groups, dilations):
        return torch.empty(out_shape, dtype=torch.float, device=x.device)
```
---
- CASE 1: w_qnt_scale(weight_quant):     QuantizedConv =  Conv + Mul
- CASE 2: if(bias and bias_quant):       QuantizedConv =  Conv + Mul + Div + Add + Mul
- CASE 3: if(bias):                      QuantizedConv =  Conv + Mul + Add


In [2]:
class ConvBlock(nn.Module):
    def __init__(self, inp, outp, stride=1):
        super(ConvBlock, self).__init__()
        self.conv = QuantConv2d(inp, outp, kernel_size=3, stride=stride,
                     padding=1, bias=False, return_quant_tensor=True)
    
    def forward(self, x):
        out = self.conv(x)
        return out


model_for_export = "conv.pth"
ready_model_filename = "conv_finn.onnx"

model = ConvBlock(3,64)
print(model)
torch.save(model.state_dict(),model_for_export)
input_shape = (1, 3, 224, 224)
bo.export_finn_onnx(model, input_shape, export_path=ready_model_filename,)
print("Model saved to %s" % ready_model_filename)

ConvBlock(
  (conv): QuantConv2d(
    3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (output_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (weight_quant): WeightQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (tensor_quant): RescalingIntQuant(
        (int_quant): IntQuant(
          (float_to_int_impl): RoundSte()
          (tensor_clamp_impl): TensorClampSte()
          (delay_wrapper): DelayWrapper(
            (delay_impl): _NoDelay()
          )
        )
        (scaling_impl): StatsFromParameterScaling(
          (parameter_list_stats): _ParameterListStats(
            (first_tracked_param): _ViewParameterWrapper(
              (view_shape_impl): OverTensorView()
            )
            (stats): _Stats(
              (stats_impl): AbsMax()
            )
          )
          (stats_scal



In [3]:
from finn.util.visualization import showInNetron

showInNetron(ready_model_filename)

Serving 'conv_finn.onnx' at http://0.0.0.0:8081


### QuantReLU

```python
class QuantReLUFn(Function): 

    @staticmethod
    def symbolic(g, input, qnt_type, thres, bias, scale):
        ret = g.op(f'{DOMAIN_STRING}::MultiThreshold', input, thres, out_dtype_s=qnt_type)
        if scale is not None:
            ret = g.op('Mul', ret, scale)
        return ret

    @staticmethod
    def forward(ctx, input, qnt_type, thres, bias, scale):
        return input.clamp(0)
```
---

- CASE 1: QuantReLU = MultiThreshold 
- CASE 2: if scale， QuantReLU =  MultiThreshold + Mul

In [12]:
class ReLUBlock(nn.Module):
    def __init__(self):
        super(ReLUBlock, self).__init__()
        self.relu = QuantReLU()
    
    def forward(self, x):
        out = self.relu(x)
        return out


model_for_export = "relu.pth"
ready_model_filename = "relu_finn.onnx"

model = ReLUBlock()
print(model)
torch.save(model.state_dict(),model_for_export)
input_shape = (1, 64, 224, 224)
bo.export_finn_onnx(model, input_shape, export_path=ready_model_filename,)
print("Model saved to %s" % ready_model_filename)

ReLUBlock(
  (relu): QuantReLU(
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (act_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (fused_activation_quant_proxy): FusedActivationQuantProxy(
        (activation_impl): ReLU()
        (tensor_quant): RescalingIntQuant(
          (int_quant): IntQuant(
            (float_to_int_impl): RoundSte()
            (tensor_clamp_impl): TensorClamp()
            (delay_wrapper): DelayWrapper(
              (delay_impl): _NoDelay()
            )
          )
          (scaling_impl): ParameterFromRuntimeStatsScaling(
            (stats_input_view_shape_impl): OverTensorView()
            (stats): _Stats(
              (stats_impl): AbsPercentile()
            )
            (restrict_clamp_scaling): _RestrictClampValue(
              (clamp_min_ste): Identity()
              (restrict_value_impl): FloatRestrictValue()
            )
            (restrict_inplac



In [13]:
from finn.util.visualization import showInNetron

showInNetron(ready_model_filename)

Stopping http://0.0.0.0:8081
Serving 'relu_finn.onnx' at http://0.0.0.0:8081


### QuantizedLinear

```python
class QuantizedLinearFn(Function):

    @staticmethod
    def symbolic(g, x, Wt, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_shape, bias):
        ret = g.op(f'{DOMAIN_STRING}::MatMul', x, Wt, weight_qnt_s=w_qnt_type)
        if w_qnt_scale is not None:
            ret = g.op('Mul', ret, w_qnt_scale)
        if bias is not None:
            if b_qnt_type is not None:
                assert b_qnt_scale is not None
                ret = g.op('Div', ret, b_qnt_scale)
                ret = g.op('{DOMAIN_STRING}::Add', ret, bias, bias_qnt_s=b_qnt_type)
                ret = g.op('Mul', ret, b_qnt_scale)
            else:
                ret = g.op('Add', ret, bias)
        return ret

    @staticmethod
    def forward(ctx, x, Wt, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_shape, bias):
        return torch.empty(out_shape, dtype=torch.float, device=x.device)
```
---

- CASE 1:                                        QuantizedLinear = MatMul 
- CASE 2: if(weight_quant):                      QuantizedLinear = MatMul +  Mul
- CASE 3: if(bias):                              QuantizedLinear = MatMul +  Add
- CASE 4: if(bias and bias_quant):               QuantizedLinear = MatMul +  Div  +  Add  + Mul
- CASE 5: if(weight and bias):                   QuantizedLinear = MatMul +  Mul  +  Add
- CASE 8: if(weight and bias and bias_quant):    QuantizedLinear = MatMul +  Mul  +  Div  +  Add  + Mul

In [19]:
class QuantizedLinearBlock(nn.Module):
    def __init__(self, inp, outp):
        super(QuantizedLinearBlock, self).__init__()
        self.linear = QuantLinear(inp, outp, bias=True, return_quant_tensor=True)
    
    def forward(self, x):
        out = self.linear(x)
        return out


model_for_export = "linear.pth"
ready_model_filename = "linear_finn.onnx"

model = QuantizedLinearBlock(1024,10)
print(model)
torch.save(model.state_dict(),model_for_export)
input_shape = (1, 1024)
bo.export_finn_onnx(model, input_shape, export_path=ready_model_filename,)
print("Model saved to %s" % ready_model_filename)

QuantizedLinearBlock(
  (linear): QuantLinear(
    in_features=1024, out_features=10, bias=True
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (output_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (weight_quant): WeightQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (tensor_quant): RescalingIntQuant(
        (int_quant): IntQuant(
          (float_to_int_impl): RoundSte()
          (tensor_clamp_impl): TensorClampSte()
          (delay_wrapper): DelayWrapper(
            (delay_impl): _NoDelay()
          )
        )
        (scaling_impl): StatsFromParameterScaling(
          (parameter_list_stats): _ParameterListStats(
            (first_tracked_param): _ViewParameterWrapper(
              (view_shape_impl): OverTensorView()
            )
            (stats): _Stats(
              (stats_impl): AbsMax()
            )
          )
          (stats_scaling_impl): 



In [20]:
from finn.util.visualization import showInNetron

showInNetron(ready_model_filename)

Stopping http://0.0.0.0:8081
Serving 'linear_finn.onnx' at http://0.0.0.0:8081


### QuantIdentity

- QuantIdentity = MultiThreshold + Add + Mul

In [33]:
class QuantIdentityBlock(nn.Module):
    def __init__(self):
        super(QuantIdentityBlock, self).__init__()
        self.quant_inp = qnn.QuantIdentity(bit_width=8, return_quant_tensor=True)

    def forward(self, x):
        out= self.quant_inp(x)
        return out


model_for_export = "identity.pth"
ready_model_filename = "identity_finn.onnx"

model = QuantIdentityBlock()
print(model)
torch.save(model.state_dict(),model_for_export)
input_shape = (1,3,320,320)
bo.export_finn_onnx(model, input_shape, export_path=ready_model_filename)
print("Model saved to %s" % ready_model_filename)

QuantIdentityBlock(
  (quant_inp): QuantIdentity(
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (act_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (fused_activation_quant_proxy): FusedActivationQuantProxy(
        (activation_impl): Identity()
        (tensor_quant): RescalingIntQuant(
          (int_quant): IntQuant(
            (float_to_int_impl): RoundSte()
            (tensor_clamp_impl): TensorClamp()
            (delay_wrapper): DelayWrapper(
              (delay_impl): _NoDelay()
            )
          )
          (scaling_impl): ParameterFromRuntimeStatsScaling(
            (stats_input_view_shape_impl): OverTensorView()
            (stats): _Stats(
              (stats_impl): AbsPercentile()
            )
            (restrict_clamp_scaling): _RestrictClampValue(
              (clamp_min_ste): Identity()
              (restrict_value_impl): FloatRestrictValue()
            )
      



In [34]:
from finn.util.visualization import showInNetron

showInNetron(ready_model_filename)

Stopping http://0.0.0.0:8081
Serving 'identity_finn.onnx' at http://0.0.0.0:8081


### QuantAvgPool2d

```python
class QuantAvgPool2dFn(Function):

    @staticmethod
    def symbolic(g, x, out_shape, kernel, stride, signed, ibits, obits, scale, qnt_type):
        if scale is not None:
            x = g.op('{DOMAIN_STRING}::Div', x, scale, activation_qnt_s=qnt_type)
        ret = g.op(
            f'{DOMAIN_STRING}::QuantAvgPool2d', x,
            kernel_i=kernel,
            stride_i=stride,
            signed_i=signed,
            ibits_i=ibits,
            obits_i=obits)
        if scale is not None:
            ret = g.op('Mul', ret, scale)
        return ret

    @staticmethod
    def forward(ctx, x, out_shape, kernel, stride, signed, ibits, obits, scale, qnt_type):
        return torch.empty(out_shape, dtype=torch.float, device=x.device)
    
```
---

- CASE 1: if(scale):                             QuantAvgPool2d = Div + QuantAvgPool2d + Mul
- CASE 2:                                        QuantAvgPool2d = QuantAvgPool2d

In [29]:
class QuantAvgPoolBlock(nn.Module):
    def __init__(self):
        super(QuantAvgPoolBlock, self).__init__()
        self.quant_inp = qnn.QuantIdentity(bit_width=8, return_quant_tensor=True)
        self.pool = QuantAvgPool2d(kernel_size=3, stride=1)
    
    def forward(self, x):
        x= self.quant_inp(x)
        out = self.pool(x)
        return out


model_for_export = "pool.pth"
ready_model_filename = "pool_finn.onnx"

model = QuantAvgPoolBlock()
print(model)
torch.save(model.state_dict(),model_for_export)
input_shape = (1,3,320,320)
bo.export_finn_onnx(model, input_shape, export_path=ready_model_filename)
print("Model saved to %s" % ready_model_filename)

QuantAvgPoolBlock(
  (quant_inp): QuantIdentity(
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (act_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (fused_activation_quant_proxy): FusedActivationQuantProxy(
        (activation_impl): Identity()
        (tensor_quant): RescalingIntQuant(
          (int_quant): IntQuant(
            (float_to_int_impl): RoundSte()
            (tensor_clamp_impl): TensorClamp()
            (delay_wrapper): DelayWrapper(
              (delay_impl): _NoDelay()
            )
          )
          (scaling_impl): ParameterFromRuntimeStatsScaling(
            (stats_input_view_shape_impl): OverTensorView()
            (stats): _Stats(
              (stats_impl): AbsPercentile()
            )
            (restrict_clamp_scaling): _RestrictClampValue(
              (clamp_min_ste): Identity()
              (restrict_value_impl): FloatRestrictValue()
            )
       



In [30]:
from finn.util.visualization import showInNetron

showInNetron(ready_model_filename)

Stopping http://0.0.0.0:8081
Serving 'pool_finn.onnx' at http://0.0.0.0:8081
