<a href="https://colab.research.google.com/github/jevliu/2022-Machine-Learning-Specialization/blob/main/quantization_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np

In [2]:
from os import terminal_size
# suppres scitntific notation
np.set_printoptions(suppress=True)

In [5]:
# generate random distributed parameters
params = np.random.uniform(low=-50, high=150, size=20)

# make sure importment values are at the begining for better debugging
params[0] = params.max() + 1
params[1] = params.min() - 1
params[2] = 0

# round each number to the second decimal place
params = np.round(params, 2)

# print the parameters
print(params)

[149.28 -17.66   0.    61.84  -0.25 134.6  148.28  68.34  74.51  51.24
  43.13 140.25 142.08 114.53 -16.66  81.33 135.17  -8.6  112.6  103.86]


In [21]:
# define several function for quantization and dequantize according to the
# mathmetical formular
def clamp(param_q:np.array, lower_bound:int, upper_bound:int)->np.array:
  param_q[param_q < lower_bound] = lower_bound
  param_q[param_q > upper_bound] = upper_bound
  return param_q

def asymmetric_quantization(params:np.array,bits:int)->tuple[np.array,float,int]:
  # calulate the scale and zero point
  alpha = np.max(params)
  beta = np.min(params)
  scale = (alpha-beta) / (2**bits-1)
  zero = -1*np.round(beta/scale)
  # unsigned integer
  lower_bound, upper_bound = 0, 2**bits-1
  # quantize the parameters
  quantized = clamp(np.round(params/scale+zero),lower_bound,upper_bound).astype(np.int32)
  return quantized,scale,zero

def symmetric_quantization(params:np.array,bits:int)->tuple[np.array,float]:
  # calculate the scale
  alpha = np.max(np.abs(params))
  scale = alpha / (2**(bits-1)-1)
  lower_bound,upper_bound = -1*(2**(bits-1)),2**(bits-1)-1
  quantized = clamp(np.round(params/scale),lower_bound,upper_bound).astype(np.int32)
  return quantized,scale

def asymmetric_dequantize(params:np.array,scale:float,zero:int)->np.array:
  return scale * (params-zero)

def symmetric_dequantize(params:np.array,scale:float)->np.array:
  return scale * params

def quantization_error(params:np.array, params_q:np.array):
  # calculate the MSE
  return np.mean((params-params_q)**2)


In [22]:
(asymmetric_q, asymmetric_s, asymmetric_z) = asymmetric_quantization(params, 8)
(symmetric_q, symmetric_s) = symmetric_quantization(params, 8)
as_deq_params = asymmetric_dequantize(asymmetric_q, asymmetric_s, asymmetric_z)
sy_deq_params = symmetric_dequantize(symmetric_q, symmetric_s)

print('original parameters:\n',np.round(params,2))
print('parameters after asymmetric quantitation:\n',np.round(asymmetric_q))
print(f'asymmetric_scale: {np.round(asymmetric_s,2)}, asymmetric_zero: {asymmetric_z.round(2)}')
print('parameters after symmetric quantitation:\n',np.round(symmetric_q))
print(f'symmetric_scale: {symmetric_s.round(2)}')
print(f'quantitation error with asymmetric: {quantization_error(params,as_deq_params).round(2)}')
print(f'quantitation error with symmetric: {quantization_error(params,sy_deq_params).round(2)}')

original parameters:
 [149.28 -17.66   0.    61.84  -0.25 134.6  148.28  68.34  74.51  51.24
  43.13 140.25 142.08 114.53 -16.66  81.33 135.17  -8.6  112.6  103.86]
parameters after asymmetric quantitation:
 [255   0  27 121  27 233 253 131 141 105  93 241 244 202   2 151 233  14
 199 186]
asymmetric_scale: 0.65, asymmetric_zero: 27.0
parameters after symmetric quantitation:
 [127 -15   0  53   0 115 126  58  63  44  37 119 121  97 -14  69 115  -7
  96  88]
symmetric_scale: 1.18
quantitation error with asymmetric: 0.04
quantitation error with symmetric: 0.11


## Quantization range:how to choose alpha&beta

### Quantization strategy

**Min-Max:** sensitive to outlier numbers

**Percntile:** only the outlier has big error

**Mean-Square-Error:**It is usually solved using Grid-Search

**Cross-Entropy:**used when the values in the tensor being quantized are not equally importan.
to keep the order in the softmax layer

### Quantization granularity

## Post Tranining Quantization (PTQ)

### PTQ process:

pre-trained model --> attatch observers(calculate the s and z parameter using the observed data) --> calibrate --> quantized model

In [36]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

#### Load the MNIST dataset

In [26]:
# make torch deterministic
_ = torch.manual_seed(0)

In [30]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])

# load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data',train=True,download=True,transform=transform)
# create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset,batch_size=10,shuffle=True)

# load the MNIST test dataset
mnist_testset = datasets.MNIST(root='./data',train=False,download=True,transform=transform)
# create a dataloader for the testing
test_loader = torch.utils.data.DataLoader(mnist_testset,batch_size=10,shuffle=True)

# define the device
device = 'cpu'

### Define the model

In [41]:
class VerySimpleNet(nn.Module):
  def __init__(self,hidden_size_1=100,hidden_size_2=100):
    super(VerySimpleNet,self).__init__()
    self.linear1 = nn.Linear(28*28,hidden_size_1)
    self.linear2= nn.Linear(hidden_size_1,hidden_size_2)
    self.linear3 = nn.Linear(hidden_size_2,10)
    self.relu = nn.ReLU()

  def forward(self,img):
    x = img.view(-1,28*28)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    return x

In [42]:
net = VerySimpleNet().to(device)

### Train the model

In [45]:
def train(train_loader,net,epochs=5,total_iterations_limit=None):
  cross_el = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

  total_iterations = 0

  for epoch in range(epochs):
    net.train()

    loss_sum = 0
    num_iterations = 0

    data_iterator = tqdm(train_loader,desc=f'Epoch {epoch+1}')
    for data in data_iterator:
      num_iterations += 1
      total_iterations += 1
      x, y = data
      x, y = x.to(device), y.to(device)
      optimizer.zero_grad()
      output = net(x.view(-1,28*28))
      loss = cross_el(output,y)
      loss_sum += loss
      avg_loss = loss_sum / num_iterations
      data_iterator.set_postfix(loss=avg_loss)
      loss.backward()
      optimizer.step()

def print_size_of_model(model):
  torch.save(model.state_dict(),"temp_delme.p")
  print("Size (KB):", os.path.getsize("temp_delme.p")/1e3)
  os.remove('temp_delme.p')

MODEL_FILENAME = 'simplenet_ptq_pt'

if Path(MODEL_FILENAME).exists():
  net.load_state_dict(torch.load(MODEL_FILENAME))
  print('Loaded model from disk')
else:
  train(train_loader,net,epochs=1)
  # save the model tp disk
  torch.save(net.state_dict(), MODEL_FILENAME)


Loaded model from disk


### Define the testing loop

In [46]:
def test(model:nn.Module,total_iterations:int=None):
  correct = 0
  total = 0

  iterations = 0

  model.eval()

  with torch.no_grad():
    for data in tqdm(test_loader,desc='Testing'):
      x, y = data
      x, y = x.to(device), y.to(device)
      output = model(x.view(-1,784))
      for idx,i in enumerate(output):
        if torch.argmax(i) == y[idx]:
          correct += 1
        total += 1
      iterations += 1
      if total_iterations is not None and iterations >= total_iterations:
        break
  print(f'Accuracy: {round(correct/total, 3)}')

### Print weights and size of the model before quantization

In [48]:
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[ 0.0476,  0.0066,  0.0505,  ...,  0.0176,  0.0215,  0.0393],
        [ 0.0074, -0.0103, -0.0235,  ...,  0.0176,  0.0246, -0.0121],
        [ 0.0327,  0.0203,  0.0466,  ...,  0.0046,  0.0356,  0.0435],
        ...,
        [-0.0280, -0.0063, -0.0346,  ..., -0.0222, -0.0336, -0.0100],
        [ 0.0332, -0.0112, -0.0090,  ..., -0.0258,  0.0404, -0.0072],
        [ 0.0512,  0.0453,  0.0502,  ..., -0.0042,  0.0484,  0.0278]],
       requires_grad=True)
torch.float32


In [49]:
print('Size of the model before quantization')
print_size_of_model(net)

Size of the model before quantization
Size (KB): 360.998


In [50]:
print(f'Accuracy of the model before quantization')
test(net)

Accuracy of the model before quantization


Testing: 100%|██████████| 1000/1000 [00:03<00:00, 307.16it/s]

Accuracy: 0.95





### Insert Min-Max observers in the model

In [51]:
class QuantizedVerySimpleNet(nn.Module):
  def __init__(self,hidden_size_1=100,hidden_size_2=100):
    super(QuantizedVerySimpleNet,self).__init__()
    self.quant = torch.quantization.QuantStub()
    self.linear1 = nn.Linear(28*28,hidden_size_1)
    self.linear2= nn.Linear(hidden_size_1,hidden_size_2)
    self.linear3 = nn.Linear(hidden_size_2,10)
    self.relu = nn.ReLU()
    self.dequant = torch.quantization.DeQuantStub()

  def forward(self,img):
    x = img.view(-1,28*28)
    x = self.quant(x)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    x = self.dequant(x)

    return x

In [52]:
net_quantized = QuantizedVerySimpleNet().to(device)
# copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized) # insert observers
net_quantized


QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

### Calibrate the model using the test set

In [53]:
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 272.06it/s]

Accuracy: 0.95





In [54]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-46.66990280151367, max_val=30.96890640258789)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-24.895591735839844, max_val=23.977750778198242)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-30.03553581237793, max_val=20.859119415283203)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

### Quantize the model using the statistics collected

In [55]:
net_quantized = torch.ao.quantization.convert(net_quantized)

In [56]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.611329197883606, zero_point=76, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.38482949137687683, zero_point=65, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.40074530243873596, zero_point=75, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

### Print the weights matrix of the model after quantizaion

In [57]:
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights after quantization
tensor([[12,  2, 13,  ...,  4,  5, 10],
        [ 2, -3, -6,  ...,  4,  6, -3],
        [ 8,  5, 12,  ...,  1,  9, 11],
        ...,
        [-7, -2, -9,  ..., -6, -8, -2],
        [ 8, -3, -2,  ..., -6, 10, -2],
        [13, 11, 12,  ..., -1, 12,  7]], dtype=torch.int8)


### Compare the dequantized weights and the original weights

In [58]:
print('Original weights')
print(net.linear1.weight)
print('')
print('Dequantized weights')
print(torch.dequantize(net_quantized.linear1.weight()))

Original weights
Parameter containing:
tensor([[ 0.0476,  0.0066,  0.0505,  ...,  0.0176,  0.0215,  0.0393],
        [ 0.0074, -0.0103, -0.0235,  ...,  0.0176,  0.0246, -0.0121],
        [ 0.0327,  0.0203,  0.0466,  ...,  0.0046,  0.0356,  0.0435],
        ...,
        [-0.0280, -0.0063, -0.0346,  ..., -0.0222, -0.0336, -0.0100],
        [ 0.0332, -0.0112, -0.0090,  ..., -0.0258,  0.0404, -0.0072],
        [ 0.0512,  0.0453,  0.0502,  ..., -0.0042,  0.0484,  0.0278]],
       requires_grad=True)

Dequantized weights
tensor([[ 0.0483,  0.0081,  0.0523,  ...,  0.0161,  0.0201,  0.0403],
        [ 0.0081, -0.0121, -0.0242,  ...,  0.0161,  0.0242, -0.0121],
        [ 0.0322,  0.0201,  0.0483,  ...,  0.0040,  0.0362,  0.0443],
        ...,
        [-0.0282, -0.0081, -0.0362,  ..., -0.0242, -0.0322, -0.0081],
        [ 0.0322, -0.0121, -0.0081,  ..., -0.0242,  0.0403, -0.0081],
        [ 0.0523,  0.0443,  0.0483,  ..., -0.0040,  0.0483,  0.0282]])


### Print the size and accuracy of the quantized model

In [59]:
print('Size of the model after quantization')
print_size_of_model(net_quantized)

Size of the model after quantization
Size (KB): 95.394


In [61]:
print('Testing the model after quantization')
test(net_quantized)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:03<00:00, 325.17it/s]

Accuracy: 0.947





## Quantization Aware Training (QAT)

We insert some fake modules in the computational graph of the model to simulate the effect of the quantization during training

This way,the loss function gets used to update the weights that constantly suffer from the effect of quantizaion,and it usually leads to a more robust model

In [64]:
class VerySimpleNet2(nn.Module):
  def __init__(self,hidden_size_1=100,hidden_size_2=100):
    super(VerySimpleNet2,self).__init__()
    self.quant = torch.quantization.QuantStub()
    self.linear1 = nn.Linear(28*28,hidden_size_1)
    self.linear2= nn.Linear(hidden_size_1,hidden_size_2)
    self.linear3 = nn.Linear(hidden_size_2,10)
    self.relu = nn.ReLU()
    self.dequant = torch.quantization.DeQuantStub()

  def forward(self,img):
    x = img.view(-1,28*28)
    x = self.quant(x)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    x = self.dequant(x)

    return x

### Insert min-max observers in the model

In [65]:
net2 = VerySimpleNet2().to(device)
net2.qconfig = torch.ao.quantization.default_qconfig
net2.train()
net_quantized_qwt = torch.ao.quantization.prepare_qat(net2) # insert observers
net_quantized_qwt

VerySimpleNet2(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

### Train the model

In [66]:
train(train_loader,net_quantized_qwt,epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:56<00:00, 106.18it/s, loss=tensor(0.2204, grad_fn=<DivBackward0>)]


### Check the statistics collected during training

In [67]:
print(f'Check statistics of the various layers')
net_quantized_qwt

Check statistics of the various layers


VerySimpleNet2(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.559889018535614, max_val=0.3510452210903168)
    (activation_post_process): MinMaxObserver(min_val=-49.90431213378906, max_val=32.529666900634766)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.3996035158634186, max_val=0.4016411602497101)
    (activation_post_process): MinMaxObserver(min_val=-28.026037216186523, max_val=19.46849822998047)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.5091035962104797, max_val=0.22812820971012115)
    (activation_post_process): MinMaxObserver(min_val=-33.177005767822266, max_val=21.5955867767334)
  )
  (relu): ReLU()
  (dequant): DeQuantSt

### Quantize the model using the statistics collected

In [68]:
net_quantized_qwt.eval()
net_quantized_qwt = torch.ao.quantization.convert(net_quantized_qwt)

In [70]:
print(f'Check statistics of the various layers')
net_quantized_qwt

Check statistics of the various layers


VerySimpleNet2(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.6490864157676697, zero_point=77, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.37397274374961853, zero_point=75, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.431280255317688, zero_point=77, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

### Print weights and size of the model after quantization

In [71]:
print('Weights after quantization')
print(torch.int_repr(net_quantized_qwt.linear1.weight()))

Weights after quantization
tensor([[-2, -4,  6,  ..., -3, -8,  0],
        [ 1,  4,  1,  ..., -1,  0, -3],
        [ 8, 10,  1,  ...,  5,  1,  8],
        ...,
        [-2,  6, -2,  ...,  6, -1,  3],
        [10,  0, 10,  ...,  3,  9,  4],
        [ 9,  9, -6,  ..., -5,  9, -5]], dtype=torch.int8)


In [72]:
print('Testing the model after quantization')
test(net_quantized_qwt)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:03<00:00, 315.34it/s]

Accuracy: 0.961



