The purpose of the notebook is to demonstrate quantization of a deep learning model (ResNet in this example). Quantization is a method to reduce the number of bits used to represent each parameter in the model. There are three main purposes of quantization:
1. Reduce Model Size
Memory Efficiency: Instead of using 32-bit floating-point numbers (FP32), quantization typically reduces this to 16-bit floating-point (FP16). This leads to significant reductions in the model's memory usage.
Storage Savings: Smaller models require less storage space, which is beneficial for deploying models on devices with limited memory, such as embedded systems.
2. Improve Computational Efficiency
Faster Inference: Operations involving lower-bit integers (e.g., 8-bit integers) are typically faster to execute than those involving floating-point numbers. Hardware accelerators like CPUs, GPUs, and specialized AI processors often have optimized instructions for integer arithmetic, making quantized models more efficient in terms of computation.
Reduced Bandwidth: Lower precision data requires less bandwidth, which can be advantageous for data transfer and network communication in distributed systems or edge devices.
3. Lower Power Consumption
Energy Efficiency: Quantized operations consume less power compared to their floating-point counterparts.
Hardware Utilization: Many modern processors are designed to handle lower-precision arithmetic more efficiently, leading to lower overall power usage during model inference.

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init
import torch.optim as optim
from torch.utils.data import random_split
import warnings
import matplotlib.pyplot as plt
import numpy as np
from torchsummary import summary

In [13]:
class ResidualBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride):
    super(ResidualBlock, self).__init__()
    self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
    self.bn2 = nn.BatchNorm2d(out_channels)
    self.downsample = None
    if stride==2:
      self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)

  def forward(self, x):
    res = x
    # print(f"RES Shape: {res.shape}")
    out = self.conv1(x)
    # print(f"CONV1 Shape: {out.shape}")
    out = self.bn1(out)
    out = F.relu(out)
    out = self.conv2(out)
    # print(f"CONV2 Shape: {out.shape}")
    out = self.bn2(out)
    if self.downsample is not None:
      res = self.downsample(x)
      # print(f"NEW RES Shape: {res.shape}")

    out += res
    out = F.relu(out)
    return out

class ResNet20(nn.Module):
  def __init__(self, resblock, n=3):
    super(ResNet20, self).__init__()
    self.n = n
    self.conv1 = nn.Conv2d(3, 16, 3, stride=1, padding=1)
    self.bn1 = nn.BatchNorm2d(16)
    self.avgpool = nn.AvgPool2d(8)
    self.fc1   = nn.Linear(64, 10)

    self.layer1 = self.create_layer(resblock, 16, 16, stride=1)
    self.layer2 = self.create_layer(resblock, 16, 32, stride=2)
    self.layer3 = self.create_layer(resblock, 32, 64, stride=2)

  def create_layer(self, resblock, in_channels, out_channels, stride):
    layers = nn.ModuleList()
    layers.append(resblock(in_channels, out_channels, stride))
    for i in range(self.n-1):
      layers.append(resblock(out_channels, out_channels, stride=1))
    return nn.Sequential(*layers)

  def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = F.relu(out)
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.avgpool(out)
    out = out.view(out.size(0), -1)
    out = self.fc1(out)
    return out

In [14]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ResNet20(ResidualBlock).to(device)
summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             448
       BatchNorm2d-2           [-1, 16, 32, 32]              32
            Conv2d-3           [-1, 16, 32, 32]           2,320
       BatchNorm2d-4           [-1, 16, 32, 32]              32
            Conv2d-5           [-1, 16, 32, 32]           2,320
       BatchNorm2d-6           [-1, 16, 32, 32]              32
     ResidualBlock-7           [-1, 16, 32, 32]               0
            Conv2d-8           [-1, 16, 32, 32]           2,320
       BatchNorm2d-9           [-1, 16, 32, 32]              32
           Conv2d-10           [-1, 16, 32, 32]           2,320
      BatchNorm2d-11           [-1, 16, 32, 32]              32
    ResidualBlock-12           [-1, 16, 32, 32]               0
           Conv2d-13           [-1, 16, 32, 32]           2,320
      BatchNorm2d-14           [-1, 16,

In [26]:
model.load_state_dict(torch.load("resnet.pth")['state_dict'])

<All keys matched successfully>

Fixed-point quantization

Quantize pruned model

Symmetric quantization