In [1]:
# reference to static quantization
# https://rachitsingh.com/deep-learning-model-compression/
# https://chowdera.com/2021/02/20210203170434983m.html
# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/

In [2]:
import torch
import numpy as np
from torch import nn
from torch.quantization import QuantStub, DeQuantStub

In [3]:
class SampleNet(nn.Module):
    def __init__(self, quantize_statically=False):
        super(SampleNet, self).__init__()
        self.quantize_statically = quantize_statically
        in_channels = 112
        out_channels = 112
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=1, bias=False)
        self.fc = nn.Linear(3, 2, bias=False)
        self.relu = nn.ReLU(inplace=False)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        if self.quantize_statically:
            x = self.quant(x)
        x = self.conv(x)
        x = self.fc(x)
        x = self.relu(x)
        
        if self.quantize_statically:
            x = self.dequant(x)
        return x

In [4]:
b = 1
c = 3
w = 112
h = 112

## Dynamic Quantization

In [5]:
model = SampleNet(quantize_statically=False)

In [6]:
model

SampleNet(
  (conv): Conv2d(112, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): Linear(in_features=3, out_features=2, bias=False)
  (relu): ReLU()
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

In [7]:
model_int8_dynamic = torch.quantization.quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False)

In [8]:
model_int8_dynamic

SampleNet(
  (conv): Conv2d(112, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): DynamicQuantizedLinear(in_features=3, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

In [9]:
np.random.seed(100)
x = torch.from_numpy(np.random.random((b, w, h, c))).float()

In [10]:
o = model_int8_dynamic(x)
o

tensor([[[[0.0231, 0.0000],
          [0.0197, 0.0119],
          [0.0000, 0.1756],
          ...,
          [0.0000, 0.0000],
          [0.1062, 0.0000],
          [0.0000, 0.0000]],

         [[0.0000, 0.0675],
          [0.1083, 0.0508],
          [0.0409, 0.0000],
          ...,
          [0.0000, 0.0000],
          [0.0667, 0.0000],
          [0.0000, 0.2212]],

         [[0.0000, 0.0000],
          [0.0000, 0.0029],
          [0.0000, 0.0000],
          ...,
          [0.0000, 0.0000],
          [0.0000, 0.0000],
          [0.0000, 0.0000]],

         ...,

         [[0.0000, 0.0000],
          [0.0000, 0.0000],
          [0.0000, 0.0000],
          ...,
          [0.0000, 0.0000],
          [0.0000, 0.0000],
          [0.0000, 0.0000]],

         [[0.2278, 0.0000],
          [0.0030, 0.0321],
          [0.0000, 0.0170],
          ...,
          [0.0880, 0.0000],
          [0.1663, 0.1369],
          [0.0858, 0.0644]],

         [[0.0000, 0.0000],
          [0.0000, 0.0000],
    

In [11]:
o1_dynamic = model_int8_dynamic.conv(x)
o2_dynamic = model_int8_dynamic.fc(o1_dynamic)
o3_dynamic = model_int8_dynamic.relu(o2_dynamic)

In [12]:
o1_dynamic

tensor([[[[-2.7936e-01,  1.3857e-01,  3.3904e-02],
          [-6.0530e-02, -3.7147e-02,  1.1386e-01],
          [ 1.5183e-01, -2.6337e-01, -5.2935e-02],
          ...,
          [-4.2800e-01,  1.5016e-01, -2.2047e-01],
          [-3.1539e-02,  3.9916e-01, -1.0503e-01],
          [-4.6136e-02,  2.1197e-01, -2.8246e-01]],

         [[ 9.8018e-02, -9.7160e-02, -9.3736e-02],
          [ 1.0046e-01,  6.6256e-02,  1.8406e-01],
          [-1.7397e-01,  1.8249e-01, -1.6140e-02],
          ...,
          [-1.4145e-01,  2.7818e-02, -2.2079e-01],
          [-9.5429e-02,  1.8721e-01,  1.9807e-02],
          [ 2.8730e-01, -2.2815e-01, -4.2962e-02]],

         [[-5.8040e-01, -5.6354e-01, -5.6660e-01],
          [-4.5702e-01, -7.6515e-01, -7.3578e-01],
          [-8.6858e-01, -4.6947e-01, -6.8603e-01],
          ...,
          [-5.3551e-01, -4.0491e-01, -8.7254e-01],
          [-1.0092e+00, -1.0788e+00, -4.9155e-01],
          [-6.6138e-01, -7.2263e-01, -5.5642e-01]],

         ...,

         [[-3.74

In [13]:
o2_dynamic

tensor([[[[ 0.0231, -0.1790],
          [ 0.0197,  0.0119],
          [-0.1020,  0.1756],
          ...,
          [-0.0900, -0.2941],
          [ 0.1062, -0.2146],
          [-0.0323, -0.1607]],

         [[-0.0585,  0.0675],
          [ 0.1083,  0.0508],
          [ 0.0409, -0.1639],
          ...,
          [-0.0988, -0.1083],
          [ 0.0667, -0.1233],
          [-0.0600,  0.2212]],

         [[-0.5227, -0.1017],
          [-0.6440,  0.0029],
          [-0.5746, -0.2994],
          ...,
          [-0.5654, -0.2151],
          [-0.7590, -0.0544],
          [-0.5888, -0.0730]],

         ...,

         [[-0.0921, -0.2450],
          [-0.0027, -0.1251],
          [-0.2206, -0.1084],
          ...,
          [-0.1166, -0.0880],
          [-0.2057, -0.0914],
          [-0.1192, -0.0389]],

         [[ 0.2278, -0.0978],
          [ 0.0030,  0.0321],
          [-0.0563,  0.0170],
          ...,
          [ 0.0880, -0.0068],
          [ 0.1663,  0.1369],
          [ 0.0858,  0.0644]],



In [14]:
o3_dynamic

tensor([[[[0.0231, 0.0000],
          [0.0197, 0.0119],
          [0.0000, 0.1756],
          ...,
          [0.0000, 0.0000],
          [0.1062, 0.0000],
          [0.0000, 0.0000]],

         [[0.0000, 0.0675],
          [0.1083, 0.0508],
          [0.0409, 0.0000],
          ...,
          [0.0000, 0.0000],
          [0.0667, 0.0000],
          [0.0000, 0.2212]],

         [[0.0000, 0.0000],
          [0.0000, 0.0029],
          [0.0000, 0.0000],
          ...,
          [0.0000, 0.0000],
          [0.0000, 0.0000],
          [0.0000, 0.0000]],

         ...,

         [[0.0000, 0.0000],
          [0.0000, 0.0000],
          [0.0000, 0.0000],
          ...,
          [0.0000, 0.0000],
          [0.0000, 0.0000],
          [0.0000, 0.0000]],

         [[0.2278, 0.0000],
          [0.0030, 0.0321],
          [0.0000, 0.0170],
          ...,
          [0.0880, 0.0000],
          [0.1663, 0.1369],
          [0.0858, 0.0644]],

         [[0.0000, 0.0000],
          [0.0000, 0.0000],
    

## Static Quantization

In [15]:
model = SampleNet(quantize_statically=True)

In [16]:
model

SampleNet(
  (conv): Conv2d(112, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): Linear(in_features=3, out_features=2, bias=False)
  (relu): ReLU()
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

#### STEP 1. layer fusion

In [17]:
fused_model = torch.quantization.fuse_modules(model, [['fc', 'relu']], inplace=False)

#### STEP 2. setup config

In [18]:
fused_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

In [19]:
fused_model

SampleNet(
  (conv): Conv2d(112, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (fc): LinearReLU(
    (0): Linear(in_features=3, out_features=2, bias=False)
    (1): ReLU()
  )
  (relu): Identity()
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

#### STEP 3. Insert Observer

In [20]:
fused_model_with_observer = torch.quantization.prepare(fused_model)



In [21]:
fused_model_with_observer

SampleNet(
  (conv): Conv2d(
    112, 112, kernel_size=(1, 1), stride=(1, 1), bias=False
    (activation_post_process): HistogramObserver()
  )
  (fc): LinearReLU(
    (0): Linear(in_features=3, out_features=2, bias=False)
    (1): ReLU()
    (activation_post_process): HistogramObserver()
  )
  (relu): Identity()
  (quant): QuantStub(
    (activation_post_process): HistogramObserver()
  )
  (dequant): DeQuantStub()
)

In [22]:
fused_model_with_observer.fc.activation_post_process.min_val

tensor(inf)

In [23]:
fused_model_with_observer.fc.activation_post_process.max_val

tensor(-inf)

#### STEP 4. calibration

In [24]:
for _ in range(100):
    inputs = torch.rand(b, w, h, c)
    fused_model_with_observer(inputs)

In [25]:
fused_model_with_observer.fc.activation_post_process.min_val

tensor(0.)

In [26]:
fused_model_with_observer.fc.activation_post_process.max_val

tensor(0.6285)

#### STEP 5. convert

In [27]:
model_int8_static = torch.quantization.convert(fused_model_with_observer)

  src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1
  src_bin_end // dst_bin_width, 0, self.dst_nbins - 1


In [28]:
model_int8_static

SampleNet(
  (conv): QuantizedConv2d(112, 112, kernel_size=(1, 1), stride=(1, 1), scale=0.02259795553982258, zero_point=56, bias=False)
  (fc): QuantizedLinearReLU(in_features=3, out_features=2, scale=0.004586605820804834, zero_point=0, qscheme=torch.per_channel_affine)
  (relu): Identity()
  (quant): Quantize(scale=tensor([0.0084]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
)

In [29]:
np.random.seed(100)
x = torch.from_numpy(np.random.random((b, w, h, c))).float()

In [30]:
o = model_int8_static(x)
o

tensor([[[[0.0596, 0.0000],
          [0.0000, 0.0780],
          [0.1376, 0.0550],
          ...,
          [0.0780, 0.0734],
          [0.0734, 0.0367],
          [0.0183, 0.0000]],

         [[0.2293, 0.1514],
          [0.2431, 0.2706],
          [0.1514, 0.2614],
          ...,
          [0.0826, 0.1238],
          [0.0871, 0.0917],
          [0.1147, 0.0871]],

         [[0.0000, 0.0688],
          [0.1651, 0.1330],
          [0.0000, 0.0183],
          ...,
          [0.0046, 0.1468],
          [0.0413, 0.0321],
          [0.0550, 0.0138]],

         ...,

         [[0.0642, 0.0000],
          [0.1193, 0.0505],
          [0.0917, 0.1514],
          ...,
          [0.1055, 0.0596],
          [0.0963, 0.0092],
          [0.1009, 0.0550]],

         [[0.0963, 0.1514],
          [0.1055, 0.0688],
          [0.1881, 0.1514],
          ...,
          [0.1743, 0.2018],
          [0.2752, 0.1651],
          [0.0550, 0.0000]],

         [[0.0000, 0.0000],
          [0.0000, 0.0000],
    

In [31]:
o1_static = model_int8_static.quant(x)
o2_static = model_int8_static.conv(o1_static)
o3_static = model_int8_static.fc(o2_static)
o4_static = model_int8_static.relu(o3_static)
o5_static = model_int8_static.dequant(o4_static)

In [32]:
o1_static.int_repr()

tensor([[[[ 65,  33,  51],
          [101,   1,  15],
          [ 80,  99,  16],
          ...,
          [104,  13,  20],
          [ 56,  93, 102],
          [ 25,   9,  94]],

         [[ 65,  94, 110],
          [ 57,  55,  72],
          [ 72,  60,  37],
          ...,
          [119,  30,  88],
          [111,  70, 105],
          [ 83,  42,  98]],

         [[ 28, 110, 110],
          [ 75,  61,  36],
          [ 54,  86,  35],
          ...,
          [ 71,   3,  46],
          [114,  27,  36],
          [ 98,  69,   3]],

         ...,

         [[ 37,  51,  17],
          [  5,   5, 107],
          [ 62,  70,  80],
          ...,
          [ 63,  95,  18],
          [ 30,  80,  19],
          [118,  20, 113]],

         [[111,  82,  26],
          [ 80,  25,  85],
          [ 34,  30,  13],
          ...,
          [ 87,  92,  95],
          [ 88,  46, 104],
          [  8, 117,  10]],

         [[ 71,  50, 119],
          [ 25,  25,  87],
          [ 93,  43,  25],
         

In [33]:
o2_static.int_repr()

tensor([[[[61, 50, 65],
          [69, 68, 58],
          [67, 62, 75],
          ...,
          [60, 63, 66],
          [72, 63, 68],
          [64, 54, 60]],

         [[67, 70, 86],
          [73, 84, 88],
          [82, 88, 78],
          ...,
          [85, 76, 71],
          [76, 70, 70],
          [84, 71, 75]],

         [[65, 67, 49],
          [46, 63, 74],
          [62, 61, 50],
          ...,
          [55, 71, 55],
          [57, 59, 61],
          [46, 53, 61]],

         ...,

         [[67, 58, 66],
          [47, 56, 69],
          [53, 69, 66],
          ...,
          [53, 59, 68],
          [62, 57, 69],
          [60, 61, 69]],

         [[67, 73, 69],
          [73, 66, 72],
          [73, 73, 82],
          ...,
          [81, 81, 81],
          [75, 73, 93],
          [93, 63, 70]],

         [[51, 50, 51],
          [45, 50, 50],
          [54, 36, 53],
          ...,
          [50, 45, 54],
          [61, 46, 48],
          [46, 61, 45]]]], dtype=torch.uint8)

In [34]:
o3_static.int_repr()

tensor([[[[13,  0],
          [ 0, 17],
          [30, 12],
          ...,
          [17, 16],
          [16,  8],
          [ 4,  0]],

         [[50, 33],
          [53, 59],
          [33, 57],
          ...,
          [18, 27],
          [19, 20],
          [25, 19]],

         [[ 0, 15],
          [36, 29],
          [ 0,  4],
          ...,
          [ 1, 32],
          [ 9,  7],
          [12,  3]],

         ...,

         [[14,  0],
          [26, 11],
          [20, 33],
          ...,
          [23, 13],
          [21,  2],
          [22, 12]],

         [[21, 33],
          [23, 15],
          [41, 33],
          ...,
          [38, 44],
          [60, 36],
          [12,  0]],

         [[ 0,  0],
          [ 0,  0],
          [ 0,  0],
          ...,
          [ 0,  0],
          [ 0,  0],
          [ 0, 14]]]], dtype=torch.uint8)

In [35]:
o4_static.int_repr()

tensor([[[[13,  0],
          [ 0, 17],
          [30, 12],
          ...,
          [17, 16],
          [16,  8],
          [ 4,  0]],

         [[50, 33],
          [53, 59],
          [33, 57],
          ...,
          [18, 27],
          [19, 20],
          [25, 19]],

         [[ 0, 15],
          [36, 29],
          [ 0,  4],
          ...,
          [ 1, 32],
          [ 9,  7],
          [12,  3]],

         ...,

         [[14,  0],
          [26, 11],
          [20, 33],
          ...,
          [23, 13],
          [21,  2],
          [22, 12]],

         [[21, 33],
          [23, 15],
          [41, 33],
          ...,
          [38, 44],
          [60, 36],
          [12,  0]],

         [[ 0,  0],
          [ 0,  0],
          [ 0,  0],
          ...,
          [ 0,  0],
          [ 0,  0],
          [ 0, 14]]]], dtype=torch.uint8)

In [36]:
o5_static

tensor([[[[0.0596, 0.0000],
          [0.0000, 0.0780],
          [0.1376, 0.0550],
          ...,
          [0.0780, 0.0734],
          [0.0734, 0.0367],
          [0.0183, 0.0000]],

         [[0.2293, 0.1514],
          [0.2431, 0.2706],
          [0.1514, 0.2614],
          ...,
          [0.0826, 0.1238],
          [0.0871, 0.0917],
          [0.1147, 0.0871]],

         [[0.0000, 0.0688],
          [0.1651, 0.1330],
          [0.0000, 0.0183],
          ...,
          [0.0046, 0.1468],
          [0.0413, 0.0321],
          [0.0550, 0.0138]],

         ...,

         [[0.0642, 0.0000],
          [0.1193, 0.0505],
          [0.0917, 0.1514],
          ...,
          [0.1055, 0.0596],
          [0.0963, 0.0092],
          [0.1009, 0.0550]],

         [[0.0963, 0.1514],
          [0.1055, 0.0688],
          [0.1881, 0.1514],
          ...,
          [0.1743, 0.2018],
          [0.2752, 0.1651],
          [0.0550, 0.0000]],

         [[0.0000, 0.0000],
          [0.0000, 0.0000],
    

In [37]:
# torch.quantize_per_tensor(o2_static.data.cpu(), scale=0.02094164490699768, zero_point=61, dtype=torch.quint8).int_repr()