In [53]:
# reference to static quantization
# https://rachitsingh.com/deep-learning-model-compression/
# https://chowdera.com/2021/02/20210203170434983m.html
# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/

In [54]:
import torch
import numpy as np
from torch import nn
from torch.quantization import QuantStub, DeQuantStub

In [67]:
class SampleNet(nn.Module):
    def __init__(self, quantize_statically=False):
        super(SampleNet, self).__init__()
        self.quantize_statically = quantize_statically
        in_channels = 112
        out_channels = 112
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=1, bias=False)
        self.fc = nn.Linear(3, 2, bias=False)
        self.relu = nn.ReLU(inplace=False)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        if self.quantize_statically:
            x = self.quant(x)
        x = self.conv(x)
        x = self.fc(x)
        x = self.relu(x)
        
        if self.quantize_statically:
            x = self.dequant(x)
        return x

In [68]:
b = 1
c = 3
w = 112
h = 112

## Dynamic Quantization

In [69]:
model = SampleNet(quantize_statically=False)

In [113]:
model

SampleNet(
  (conv): Conv2d(112, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): Linear(in_features=3, out_features=2, bias=False)
  (relu): ReLU()
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

In [71]:
model_int8_dynamic = torch.quantization.quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False)

In [72]:
model_int8_dynamic

SampleNet(
  (conv): Conv2d(112, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): DynamicQuantizedLinear(in_features=3, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

In [73]:
np.random.seed(100)
x = torch.from_numpy(np.random.random((b, w, h, c))).float()

In [74]:
o = model_int8_dynamic(x)
o

tensor([[[[0.0000, 0.0000],
          [0.0000, 0.0877],
          [0.0000, 0.0000],
          ...,
          [0.0006, 0.0115],
          [0.0000, 0.0000],
          [0.0256, 0.0000]],

         [[0.0000, 0.0000],
          [0.0000, 0.0000],
          [0.1690, 0.0965],
          ...,
          [0.0682, 0.1329],
          [0.0277, 0.0000],
          [0.0479, 0.0000]],

         [[0.0216, 0.0000],
          [0.0691, 0.1039],
          [0.0000, 0.0412],
          ...,
          [0.0613, 0.1398],
          [0.3052, 0.2448],
          [0.0796, 0.0810]],

         ...,

         [[0.1627, 0.0929],
          [0.0718, 0.1530],
          [0.0535, 0.0000],
          ...,
          [0.0332, 0.0236],
          [0.1541, 0.1157],
          [0.0699, 0.0733]],

         [[0.1018, 0.0143],
          [0.0479, 0.0000],
          [0.0000, 0.0000],
          ...,
          [0.0000, 0.0000],
          [0.0000, 0.0000],
          [0.0464, 0.0498]],

         [[0.1849, 0.1061],
          [0.2382, 0.1566],
    

In [108]:
o1_dynamic = model_int8_dynamic.conv(x)
o2_dynamic = model_int8_dynamic.fc(o1_dynamic)
o3_dynamic = model_int8_dynamic.relu(o2_dynamic)

In [109]:
o1_dynamic

tensor([[[[ 0.2791,  0.3040,  0.0697],
          [ 0.1874,  0.1361, -0.1421],
          [ 0.2033,  0.1973,  0.0379],
          ...,
          [-0.1547,  0.1216,  0.0254],
          [-0.0715,  0.2331,  0.1403],
          [-0.1369, -0.0557,  0.0515]],

         [[ 0.0974, -0.0910,  0.1341],
          [ 0.2933,  0.1376,  0.0847],
          [-0.3050, -0.1590, -0.1934],
          ...,
          [-0.3086,  0.2196, -0.1547],
          [-0.2770, -0.0225,  0.1100],
          [ 0.1098, -0.3398,  0.0224]],

         [[-0.0608, -0.2606,  0.1479],
          [-0.0619,  0.0180, -0.1743],
          [ 0.0579,  0.0359, -0.0711],
          ...,
          [ 0.0391,  0.0548, -0.2481],
          [-0.5632, -0.1450, -0.4495],
          [ 0.1083, -0.1608, -0.2039]],

         ...,

         [[-0.1689, -0.2217, -0.2255],
          [-0.1602,  0.1495, -0.2251],
          [ 0.0772, -0.2318, -0.0521],
          ...,
          [-0.2581,  0.0943,  0.0039],
          [-0.1571, -0.1597, -0.2587],
          [-0.0871, -0

In [110]:
o2_dynamic

tensor([[[[-0.1515, -0.0011],
          [-0.0271,  0.0877],
          [-0.1077, -0.0004],
          ...,
          [ 0.0006,  0.0115],
          [-0.0833, -0.0354],
          [ 0.0256, -0.0264]],

         [[-0.0432, -0.0909],
          [-0.1264, -0.0389],
          [ 0.1690,  0.0965],
          ...,
          [ 0.0682,  0.1329],
          [ 0.0277, -0.0543],
          [ 0.0479, -0.0695]],

         [[ 0.0216, -0.1219],
          [ 0.0691,  0.1039],
          [-0.0032,  0.0412],
          ...,
          [ 0.0613,  0.1398],
          [ 0.3052,  0.2448],
          [ 0.0796,  0.0810]],

         ...,

         [[ 0.1627,  0.0929],
          [ 0.0718,  0.1530],
          [ 0.0535, -0.0135],
          ...,
          [ 0.0332,  0.0236],
          [ 0.1541,  0.1157],
          [ 0.0699,  0.0733]],

         [[ 0.1018,  0.0143],
          [ 0.0479, -0.0334],
          [-0.0816, -0.0951],
          ...,
          [-0.0134, -0.0858],
          [-0.0544, -0.0953],
          [ 0.0464,  0.0498]],



In [111]:
o3_dynamic

tensor([[[[0.0000, 0.0000],
          [0.0000, 0.0877],
          [0.0000, 0.0000],
          ...,
          [0.0006, 0.0115],
          [0.0000, 0.0000],
          [0.0256, 0.0000]],

         [[0.0000, 0.0000],
          [0.0000, 0.0000],
          [0.1690, 0.0965],
          ...,
          [0.0682, 0.1329],
          [0.0277, 0.0000],
          [0.0479, 0.0000]],

         [[0.0216, 0.0000],
          [0.0691, 0.1039],
          [0.0000, 0.0412],
          ...,
          [0.0613, 0.1398],
          [0.3052, 0.2448],
          [0.0796, 0.0810]],

         ...,

         [[0.1627, 0.0929],
          [0.0718, 0.1530],
          [0.0535, 0.0000],
          ...,
          [0.0332, 0.0236],
          [0.1541, 0.1157],
          [0.0699, 0.0733]],

         [[0.1018, 0.0143],
          [0.0479, 0.0000],
          [0.0000, 0.0000],
          ...,
          [0.0000, 0.0000],
          [0.0000, 0.0000],
          [0.0464, 0.0498]],

         [[0.1849, 0.1061],
          [0.2382, 0.1566],
    

## Static Quantization

In [120]:
model = SampleNet(quantize_statically=True)

In [121]:
model

SampleNet(
  (conv): Conv2d(112, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): Linear(in_features=3, out_features=2, bias=False)
  (relu): ReLU()
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

#### STEP 1. layer fusion

In [82]:
fused_model = torch.quantization.fuse_modules(model, [['fc', 'relu']], inplace=False)

#### STEP 2. setup config

In [83]:
fused_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

In [84]:
fused_model

SampleNet(
  (conv): Conv2d(112, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): LinearReLU(
    (0): Linear(in_features=3, out_features=2, bias=False)
    (1): ReLU()
  )
  (relu): Identity()
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

#### STEP 3. Insert Observer

In [85]:
fused_model_with_observer = torch.quantization.prepare(fused_model)

In [86]:
fused_model_with_observer

SampleNet(
  (conv): Conv2d(
    112, 112, kernel_size=(1, 1), stride=(1, 1), bias=False
    (activation_post_process): HistogramObserver()
  )
  (fc): LinearReLU(
    (0): Linear(in_features=3, out_features=2, bias=False)
    (1): ReLU()
    (activation_post_process): HistogramObserver()
  )
  (relu): Identity()
  (quant): QuantStub(
    (activation_post_process): HistogramObserver()
  )
  (dequant): DeQuantStub()
)

In [87]:
fused_model_with_observer.fc.activation_post_process.min_val

tensor(inf)

In [88]:
fused_model_with_observer.fc.activation_post_process.max_val

tensor(-inf)

#### STEP 4. calibration

In [89]:
for _ in range(100):
    inputs = torch.rand(b, w, h, c)
    fused_model_with_observer(inputs)

In [90]:
fused_model_with_observer.fc.activation_post_process.min_val

tensor(0.)

In [91]:
fused_model_with_observer.fc.activation_post_process.max_val

tensor(1.1778)

#### STEP 5. convert

In [92]:
model_int8_static = torch.quantization.convert(fused_model_with_observer)

In [93]:
model_int8_static

SampleNet(
  (conv): QuantizedConv2d(112, 112, kernel_size=(1, 1), stride=(1, 1), scale=0.020792240276932716, zero_point=68, bias=False)
  (fc): QuantizedLinearReLU(in_features=3, out_features=2, scale=0.008626477792859077, zero_point=0, qscheme=torch.per_channel_affine)
  (relu): Identity()
  (quant): Quantize(scale=tensor([0.0081]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
)

In [94]:
np.random.seed(100)
x = torch.from_numpy(np.random.random((b, w, h, c))).float()

In [95]:
o = model_int8_static(x)
o

tensor([[[[0.0604, 0.0000],
          [0.0000, 0.1208],
          [0.0000, 0.0000],
          ...,
          [0.1639, 0.0259],
          [0.0431, 0.0000],
          [0.0000, 0.0000]],

         [[0.1812, 0.0000],
          [0.0431, 0.0776],
          [0.1553, 0.0000],
          ...,
          [0.1208, 0.0000],
          [0.0000, 0.0776],
          [0.1380, 0.0000]],

         [[0.0000, 0.0000],
          [0.0604, 0.1121],
          [0.0000, 0.0949],
          ...,
          [0.0000, 0.0000],
          [0.0000, 0.0000],
          [0.0000, 0.1121]],

         ...,

         [[0.0518, 0.1035],
          [0.0086, 0.0000],
          [0.0000, 0.0000],
          ...,
          [0.0086, 0.0604],
          [0.2070, 0.0259],
          [0.1639, 0.1553]],

         [[0.0000, 0.0086],
          [0.0000, 0.0000],
          [0.0000, 0.0259],
          ...,
          [0.0000, 0.0000],
          [0.0000, 0.1121],
          [0.0000, 0.0000]],

         [[0.6470, 0.0000],
          [0.2502, 0.0863],
    

In [96]:
o1_static = model_int8_static.quant(x)
o2_static = model_int8_static.conv(o1_static)
o3_static = model_int8_static.fc(o2_static)
o4_static = model_int8_static.relu(o3_static)
o5_static = model_int8_static.dequant(o4_static)

In [97]:
o1_static.int_repr()

tensor([[[[ 67,  34,  52],
          [104,   1,  15],
          [ 83, 102,  17],
          ...,
          [108,  14,  21],
          [ 58,  96, 105],
          [ 26,   9,  97]],

         [[ 67,  97, 113],
          [ 59,  57,  74],
          [ 74,  62,  38],
          ...,
          [123,  31,  91],
          [115,  72, 109],
          [ 85,  43, 101]],

         [[ 29, 114, 113],
          [ 78,  63,  37],
          [ 56,  88,  36],
          ...,
          [ 74,   3,  47],
          [117,  27,  38],
          [101,  71,   3]],

         ...,

         [[ 38,  53,  18],
          [  5,   5, 111],
          [ 64,  72,  82],
          ...,
          [ 65,  98,  19],
          [ 31,  83,  20],
          [122,  20, 117]],

         [[114,  84,  27],
          [ 82,  26,  87],
          [ 35,  31,  13],
          ...,
          [ 90,  95,  98],
          [ 91,  48, 107],
          [  8, 120,  10]],

         [[ 73,  52, 123],
          [ 26,  26,  89],
          [ 95,  44,  25],
         

In [98]:
o2_static.int_repr()

tensor([[[[ 70,  60,  65],
          [ 61,  73,  64],
          [ 69,  76,  86],
          ...,
          [ 53,  69,  75],
          [ 67,  70,  85],
          [ 73,  75,  77]],

         [[ 62,  57,  68],
          [ 63,  67,  61],
          [ 74,  50,  73],
          ...,
          [ 68,  53,  52],
          [ 78,  76,  63],
          [ 68,  58,  78]],

         [[ 75,  72,  74],
          [ 56,  74,  69],
          [ 80,  70,  53],
          ...,
          [ 89,  64,  65],
          [ 77,  66,  71],
          [ 62,  80,  72]],

         ...,

         [[ 62,  66,  58],
          [ 72,  61,  58],
          [ 82,  64,  74],
          ...,
          [ 66,  68,  63],
          [ 61,  51,  51],
          [ 47,  72,  67]],

         [[ 80,  84,  78],
          [ 85,  81, 100],
          [ 73,  89,  85],
          ...,
          [ 95,  81,  83],
          [ 73,  86,  73],
          [ 85,  80,  88]],

         [[ 43,  27,  48],
          [ 57,  49,  44],
          [ 51,  35,  57],
         

In [99]:
o3_static.int_repr()

tensor([[[[ 7,  0],
          [ 0, 14],
          [ 0,  0],
          ...,
          [19,  3],
          [ 5,  0],
          [ 0,  0]],

         [[21,  0],
          [ 5,  9],
          [18,  0],
          ...,
          [14,  0],
          [ 0,  9],
          [16,  0]],

         [[ 0,  0],
          [ 7, 13],
          [ 0, 11],
          ...,
          [ 0,  0],
          [ 0,  0],
          [ 0, 13]],

         ...,

         [[ 6, 12],
          [ 1,  0],
          [ 0,  0],
          ...,
          [ 1,  7],
          [24,  3],
          [19, 18]],

         [[ 0,  1],
          [ 0,  0],
          [ 0,  3],
          ...,
          [ 0,  0],
          [ 0, 13],
          [ 0,  0]],

         [[75,  0],
          [29, 10],
          [59,  0],
          ...,
          [37, 28],
          [53,  5],
          [42, 32]]]], dtype=torch.uint8)

In [100]:
o4_static.int_repr()

tensor([[[[ 7,  0],
          [ 0, 14],
          [ 0,  0],
          ...,
          [19,  3],
          [ 5,  0],
          [ 0,  0]],

         [[21,  0],
          [ 5,  9],
          [18,  0],
          ...,
          [14,  0],
          [ 0,  9],
          [16,  0]],

         [[ 0,  0],
          [ 7, 13],
          [ 0, 11],
          ...,
          [ 0,  0],
          [ 0,  0],
          [ 0, 13]],

         ...,

         [[ 6, 12],
          [ 1,  0],
          [ 0,  0],
          ...,
          [ 1,  7],
          [24,  3],
          [19, 18]],

         [[ 0,  1],
          [ 0,  0],
          [ 0,  3],
          ...,
          [ 0,  0],
          [ 0, 13],
          [ 0,  0]],

         [[75,  0],
          [29, 10],
          [59,  0],
          ...,
          [37, 28],
          [53,  5],
          [42, 32]]]], dtype=torch.uint8)

In [101]:
o5_static

tensor([[[[0.0604, 0.0000],
          [0.0000, 0.1208],
          [0.0000, 0.0000],
          ...,
          [0.1639, 0.0259],
          [0.0431, 0.0000],
          [0.0000, 0.0000]],

         [[0.1812, 0.0000],
          [0.0431, 0.0776],
          [0.1553, 0.0000],
          ...,
          [0.1208, 0.0000],
          [0.0000, 0.0776],
          [0.1380, 0.0000]],

         [[0.0000, 0.0000],
          [0.0604, 0.1121],
          [0.0000, 0.0949],
          ...,
          [0.0000, 0.0000],
          [0.0000, 0.0000],
          [0.0000, 0.1121]],

         ...,

         [[0.0518, 0.1035],
          [0.0086, 0.0000],
          [0.0000, 0.0000],
          ...,
          [0.0086, 0.0604],
          [0.2070, 0.0259],
          [0.1639, 0.1553]],

         [[0.0000, 0.0086],
          [0.0000, 0.0000],
          [0.0000, 0.0259],
          ...,
          [0.0000, 0.0000],
          [0.0000, 0.1121],
          [0.0000, 0.0000]],

         [[0.6470, 0.0000],
          [0.2502, 0.0863],
    

In [102]:
# torch.quantize_per_tensor(o2_static.data.cpu(), scale=0.02094164490699768, zero_point=61, dtype=torch.quint8).int_repr()