<a href="https://colab.research.google.com/github/Jeremy26/neural_optimization_course/blob/main/static_quant_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports and useful variables

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from tqdm import tqdm
from copy import deepcopy

In [None]:
cpu_device = torch.device('cpu')
randomInput = torch.rand(1,3,9,9)

## Model definition

In [None]:
class demoModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        # feature extractor
        self.fe = nn.Sequential(
                            nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3),
                            nn.BatchNorm2d(2),
                            nn.ReLU(inplace=True)
                        )
        
        # classifier
        self.clf = nn.Sequential(
                            nn.Conv2d(in_channels=2, out_channels=4, kernel_size=1),
                            nn.BatchNorm2d(4),
                            nn.ReLU(inplace=True),
                        )

        self.avgPool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(4,10)


    def forward(self, x : torch.Tensor) -> torch.Tensor:
        feature_extractor_out = self.fe(x)
        classifier_out = self.clf(feature_extractor_out)
        out = self.avgPool(classifier_out)
        out = torch.flatten(out,1)
        out = self.fc(out)
        return out

In [None]:
fp32_model = demoModule()
out = fp32_model(randomInput)
print(out.shape)

torch.Size([1, 10])


## Static Quantization Steps

### Make a copy, move to cpu, set to inference mode

In [None]:
model_to_quantize = deepcopy(fp32_model)
model_to_quantize.eval();
model_to_quantize.to(cpu_device);

### Fuse modules

In [None]:
print(model_to_quantize)

demoModule(
  (fe): Sequential(
    (0): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (clf): Sequential(
    (0): Conv2d(2, 4, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (avgPool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=4, out_features=10, bias=True)
)


In [None]:
modules_to_fuse = [
                    ['fe.0', 'fe.1', 'fe.2'],
                    ['clf.0', 'clf.1', 'clf.2']
                ]
fused_model = torch.quantization.fuse_modules(model_to_quantize, modules_to_fuse, inplace=True)
print(fused_model)

demoModule(
  (fe): Sequential(
    (0): ConvReLU2d(
      (0): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU(inplace=True)
    )
    (1): Identity()
    (2): Identity()
  )
  (clf): Sequential(
    (0): ConvReLU2d(
      (0): Conv2d(2, 4, kernel_size=(1, 1), stride=(1, 1))
      (1): ReLU(inplace=True)
    )
    (1): Identity()
    (2): Identity()
  )
  (avgPool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=4, out_features=10, bias=True)
)


### Create stubs for model input and output

In [None]:
class quantStubModel(nn.Module):
    def __init__(self, model_fp32):
        super(quantStubModel, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()        
        self.model_fp32 = model_fp32

    def forward(self, x):
        x = self.quant(x)
        x = self.model_fp32(x)
        x = self.dequant(x)
        return x

# creating nn.Module with stubs for inputs and outputs
quant_stubbed_model = quantStubModel(model_fp32=fused_model)
print(quant_stubbed_model)

quantStubModel(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (model_fp32): demoModule(
    (fe): Sequential(
      (0): ConvReLU2d(
        (0): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU(inplace=True)
      )
      (1): Identity()
      (2): Identity()
    )
    (clf): Sequential(
      (0): ConvReLU2d(
        (0): Conv2d(2, 4, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
      )
      (1): Identity()
      (2): Identity()
    )
    (avgPool): AdaptiveAvgPool2d(output_size=(1, 1))
    (fc): Linear(in_features=4, out_features=10, bias=True)
  )
)


### Quantization config & quantization.prepare() function

In [None]:
# colab requires fbgemm backend
use_fbgemm = True

if use_fbgemm == True:
    # for fbgemm, histogram observer is default config
    quantization_config = torch.quantization.get_default_qconfig('fbgemm')
    torch.backends.quantized.engine = 'fbgemm'

else:
    # default is minmax observer
    quantization_config = torch.quantization.default_qconfig
    torch.backends.quantized.engine = 'qnnpack'
        
# set the quantization configuration for the model
print('### Preparing for quantization, inserting observers ...')
quant_stubbed_model.qconfig = quantization_config    
torch.quantization.prepare(quant_stubbed_model, inplace=True);

### Preparing for quantization, inserting observers ...


  reduce_range will be deprecated in a future release of PyTorch."


In [None]:
quantization_config

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})

In [None]:
torch.quantization.default_qconfig

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})

### Calibrate Observer parameters on sample dataset

In [None]:
with torch.no_grad():
    for i in range(5):
        _ = quant_stubbed_model(randomInput)

### Call quantization.convert()

In [None]:
quantized_model = torch.quantization.convert(quant_stubbed_model, inplace=True)
print(quantized_model)

  src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1
  src_bin_end // dst_bin_width, 0, self.dst_nbins - 1


quantStubModel(
  (quant): Quantize(scale=tensor([0.0078]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
  (model_fp32): demoModule(
    (fe): Sequential(
      (0): QuantizedConvReLU2d(3, 2, kernel_size=(3, 3), stride=(1, 1), scale=0.0052048638463020325, zero_point=0)
      (1): Identity()
      (2): Identity()
    )
    (clf): Sequential(
      (0): QuantizedConvReLU2d(2, 4, kernel_size=(1, 1), stride=(1, 1), scale=0.007241956889629364, zero_point=0)
      (1): Identity()
      (2): Identity()
    )
    (avgPool): AdaptiveAvgPool2d(output_size=(1, 1))
    (fc): QuantizedLinear(in_features=4, out_features=10, scale=0.010852521285414696, zero_point=74, qscheme=torch.per_channel_affine)
  )
)
