The purpose of the notebook is to demonstrate quantization of a deep learning model (ResNet in this example). Quantization is a method to reduce the number of bits used to represent each parameter in the model. There are three main purposes of quantization:
1. Reduce Model Size
Memory Efficiency: Instead of using 32-bit floating-point numbers (FP32), quantization typically reduces this to 16-bit floating-point (FP16). This leads to significant reductions in the model's memory usage.
Storage Savings: Smaller models require less storage space, which is beneficial for deploying models on devices with limited memory, such as embedded systems.
2. Improve Computational Efficiency
Faster Inference: Operations involving lower-bit integers (e.g., 8-bit integers) are typically faster to execute than those involving floating-point numbers. Hardware accelerators like CPUs, GPUs, and specialized AI processors often have optimized instructions for integer arithmetic, making quantized models more efficient in terms of computation.
Reduced Bandwidth: Lower precision data requires less bandwidth, which can be advantageous for data transfer and network communication in distributed systems or edge devices.
3. Lower Power Consumption
Energy Efficiency: Quantized operations consume less power compared to their floating-point counterparts.
Hardware Utilization: Many modern processors are designed to handle lower-precision arithmetic more efficiently, leading to lower overall power usage during model inference.

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init
import torch.optim as optim
from torch.utils.data import random_split
import warnings
import matplotlib.pyplot as plt
import math
import numpy as np
from torchsummary import summary

In [28]:
class STE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w, bit, symmetric=False):
        if bit is None:
            wq = w
        elif bit==0:
            wq = w*0
        else:
            # Build a mask to record position of zero weights
            weight_mask = (w!= 0).int()
            if symmetric == False:
                # Compute alpha (scale) for dynamic scaling
                alpha = torch.max(w) - torch.min(w)
                # Compute beta (bias) for dynamic scaling
                beta = torch.min(w)
                # Scale w with alpha and beta so that all elements in ws are between 0 and 1
                ws = (w-beta)/alpha
                step = 2 ** (bit)-1
                # Quantize ws with a linear quantizer to "bit" bits
                R = torch.round(step*ws)/step
                # Scale the quantized weight R back with alpha and beta
                wq = alpha*R+beta
            else:
                alpha = torch.max(torch.abs(w))
                ws = w/alpha
                step = 2**(bit-1)-1
                R = torch.round(step*ws)/step
                wq = alpha*R
            # Restore zero elements in wq 
            wq = wq*weight_mask
        return wq

    @staticmethod
    def backward(ctx, g):
        return g, None, None

class FP_Linear(nn.Module):
    def __init__(self, in_features, out_features, Nbits=None, symmetric=False):
        super(FP_Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(in_features, out_features)
        self.Nbits = Nbits
        self.symmetric = symmetric
        
        m = self.in_features
        n = self.out_features
        self.linear.weight.data.normal_(0, math.sqrt(2. / (m+n)))

    def forward(self, x):
        return F.linear(x, STE.apply(self.linear.weight, self.Nbits, self.symmetric), self.linear.bias)

class FP_Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False, Nbits=None, symmetric=False):
        super(FP_Conv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.Nbits = Nbits
        self.symmetric = symmetric

        n = self.kernel_size * self.kernel_size * self.out_channels
        m = self.kernel_size * self.kernel_size * self.in_channels
        self.conv.weight.data.normal_(0, math.sqrt(2. / (n+m) ))
        self.sparsity = 1.0

    def forward(self, x):
        return F.conv2d(x, STE.apply(self.conv.weight, self.Nbits, self.symmetric), self.conv.bias, self.conv.stride, self.conv.padding, self.conv.dilation, self.conv.groups)

In [31]:
class ResidualBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride, Nbits=None, symmetric=False):
    super(ResidualBlock, self).__init__()
    self.conv1 = FP_Conv(in_channels, out_channels, 3, stride=stride, padding=1, bias=False, Nbits=Nbits, symmetric=symmetric)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.conv2 = FP_Conv(out_channels, out_channels, 3, stride=1, padding=1, bias=False, Nbits=Nbits, symmetric=symmetric)
    self.bn2 = nn.BatchNorm2d(out_channels)
    self.downsample = None
    if stride==2:
      self.downsample = FP_Conv(in_channels, out_channels, kernel_size=1, stride=stride)

  def forward(self, x):
    res = x
    out = self.conv1(x)
    out = self.bn1(out)
    out = F.relu(out)
    out = self.conv2(out)
    out = self.bn2(out)
    if self.downsample is not None:
      res = self.downsample(x)
    out += res
    out = F.relu(out)
    return out

class ResNet20(nn.Module):
  def __init__(self, resblock, n=3,  Nbits=None, symmetric=False):
    super(ResNet20, self).__init__()
    self.n = n
    self.conv1 = FP_Conv(3, 16, 3, stride=1, padding=1, bias=False, Nbits=Nbits, symmetric=symmetric)
    self.bn1 = nn.BatchNorm2d(16)
    self.avgpool = nn.AvgPool2d(8)
    self.fc1   = FP_Linear(64, 10, Nbits=None)
    self.layer1 = self.create_layer(resblock, 16, 16, stride=1)
    self.layer2 = self.create_layer(resblock, 16, 32, stride=2)
    self.layer3 = self.create_layer(resblock, 32, 64, stride=2)

  def create_layer(self, resblock, in_channels, out_channels, stride):
    layers = nn.ModuleList()
    layers.append(resblock(in_channels, out_channels, stride))
    for i in range(self.n-1):
      layers.append(resblock(out_channels, out_channels, stride=1))
    return nn.Sequential(*layers)

  def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = F.relu(out)
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.avgpool(out)
    out = out.view(out.size(0), -1)
    out = self.fc1(out)
    return out

In [32]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ResNet20(ResidualBlock).to(device)
summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           FP_Conv-1           [-1, 16, 32, 32]               0
       BatchNorm2d-2           [-1, 16, 32, 32]              32
           FP_Conv-3           [-1, 16, 32, 32]               0
       BatchNorm2d-4           [-1, 16, 32, 32]              32
           FP_Conv-5           [-1, 16, 32, 32]               0
       BatchNorm2d-6           [-1, 16, 32, 32]              32
     ResidualBlock-7           [-1, 16, 32, 32]               0
           FP_Conv-8           [-1, 16, 32, 32]               0
       BatchNorm2d-9           [-1, 16, 32, 32]              32
          FP_Conv-10           [-1, 16, 32, 32]               0
      BatchNorm2d-11           [-1, 16, 32, 32]              32
    ResidualBlock-12           [-1, 16, 32, 32]               0
          FP_Conv-13           [-1, 16, 32, 32]               0
      BatchNorm2d-14           [-1, 16,

In [26]:
model.load_state_dict(torch.load("resnet_quantization.pth")['state_dict'])

<All keys matched successfully>

In [None]:
import torch.nn as nn
import torchvision.transforms as transforms

train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomCrop(size=(32, 32), padding=4),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
])

CIFAR10_train = torchvision.datasets.CIFAR10(root='CIFAR10_data/',
                                   train=True,
                                   transform=train_transform,
                                   download=True)

CIFAR10_test = torchvision.datasets.CIFAR10(root='CIFAR10_data/',
                         train=False,
                         transform=test_transform,
                         download=True)
BATCH_SIZE = 100

train_loader = torch.utils.data.DataLoader(
    dataset=CIFAR10_train, batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    dataset=CIFAR10_test, batch_size=BATCH_SIZE, shuffle=False)

def test_CIFAR10(model, test_loader, device, verbose=True):
    criterion = torch.nn.CrossEntropyLoss().to(device)
    model.eval()
    total_examples = 0
    correct_examples = 0
    total_test_loss = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            out = model(inputs)
            loss = criterion(out, targets)
            total_test_loss += loss.item()
            _, predicted = torch.max(out, 1)
            total_examples += targets.size(0)
            correct_examples += (predicted == targets).sum().item()
    test_avg_acc = correct_examples / total_examples
    test_avg_loss = total_test_loss / len(test_loader)
    if verbose:
        print("Test accuracy: %.4f" % (test_avg_acc))
        print("Test loss: %.4f" % (test_avg_loss))
    return test_avg_acc, test_avg_loss
    
def plot_acc(x, acc, x_label, y_label, title):
    plt.figure(figsize=(8, 6))
    plt.plot(x, acc, 'b-')
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.grid(True)
    plt.show()

## Fixed-point quantization

In [None]:
Nbits = 4
model = ResNetCIFAR(num_layers=20, Nbits=Nbits, symmetric=False)
model = model.to(device)
model.load_state_dict(torch.load("resnet_quantization.pth"))
print('Test Accuracy: ', test_CIFAR10(model, test_loader, device)[0])

In [None]:
finetune(net, epochs=20, batch_size=256, lr=0.002, reg=1e-4)   
print('Test Accuracy: ', test_CIFAR10(model, test_loader, device)[0])

## Symmetric quantization

In [None]:
Nbits = 4

model = ResNetCIFAR(num_layers=20, Nbits=Nbits, symmetric=False)
model = model.to(device)
model.load_state_dict(torch.load("resnet_quantization.pth"))
print('Test Accuracy: ', test_CIFAR10(model, test_loader, device)[0])