## Quantization

**Tutorial objective**:  
The tutorial shows how to do 
- Post-training static quantization
- Two more advanced techniques
    - Per-channel quantization
    - Quantization-aware training

**Task**
- Classify MNIST digits with a simple LeNet architecture.

## Table of Contents

- [Initial Setup](#intial-setup)
- [Train CNN](#train-cnn)
- [Post-training quantization](#post-training-quantization)
- [Quantization aware training](#quantization-aware-training)

## Intial Setup

Import the MNIST dataset, and train a simple convolutional neural network (CNN) to classify it.

In [1]:
import torch
import torchvision
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
import torch.quantization
from torch.quantization import QuantStub, DeQuantStub
from collections import OrderedDict
import os

  from .autonotebook import tqdm as notebook_tqdm


Data load functions

In [2]:
batch_size = 16
num_workers = 2

def load_data(train=True):
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,), (0.5,))]
    )
    dataset = torchvision.datasets.MNIST(
        root="~/data",
        train=train,
        download=True, 
        transform=transform
    )
    # Shuffle train dataset only
    shuffle = train
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True
    )

    return dataloader

Define some helper functions and classes that help us to track the statistics and accuracy wrt to the train/test data.

In [3]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=":f") -> None:
        self.name = name
        self.fmt = fmt
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
    
    def __str__(self) -> str:
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

In [4]:
def accuracy(output, target):
    """Computes the top 1 accuracy"""
    with torch.no_grad():
        batch_size = target.size(0)

        _, pred = output.topk(1, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        correct_one = correct[:1].view(-1).float().sum(0, keepdim=True)

        return correct_one.mul_(100.0/batch_size).item()

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print("Size (MB): ", os.path.getsize("temp.p")/1e6)
    os.remove("temp.p")

def load_model(quantized_model, model):
    """Loads in the weights into an object meant for quantization"""
    state_dict = model.state_dict()
    model = model.to("cpu")
    '''
    new_state_dict = OrderedDict()
    for key, value in state_dict:
        key_split = key.split(".")
        if len(key_split) == 3:
            key = ".".join([key_split[0], key_split[2]])
        new_state_dict[key] = value
    quantized_model.load_state_dict(new_state_dict)
    '''
    quantized_model.load_state_dict(state_dict)

def fuse_modules(model):
    """Fuse together convolutions/linear layers and ReLU"""
    torch.quantization.fuse_modules(
        model=model,
        modules_to_fuse=[["conv1", "relu1"],
                         ["conv2", "relu2"],
                         ["fc1", "relu3"],
                         ["fc2", "relu4"]],
        inplace=True)

Define a simple CNN to classify MNIST images

In [5]:
class Net(nn.Module):
    def __init__(self, q=False) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=256, out_features=120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(in_features=84, out_features=10)
        self.q = q
        if q:
            self.quant = QuantStub()
            self.dequant = DeQuantStub()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.q:
            x = self.quant(x)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        # Flatten image tensor
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.relu4(x)
        x = self.fc3(x)
        if self.q:
            x = self.dequant(x)
        return x

Check for availability of GPU

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

Check the model size without quantization.

In [7]:
net = Net(q=False)
if torch.cuda.is_available():
    net = net.to(device)
print_size_of_model(model=net)

Size (MB):  0.181134


## Train CNN

In [8]:
lr = 0.001
momentum = 0.9
n_epoch = 20

def train(model: nn.Module, dataloader: DataLoader, cuda: bool=False, q: bool=False):
    # Define a loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(params=model.parameters(), lr=lr, momentum=momentum)

    # Set train mode
    model.train()

    # Loop over the dataset
    for epoch in range(n_epoch):
        running_loss = AverageMeter(name="loss")
        acc = AverageMeter(name="train_acc")

        for i, data in enumerate(dataloader, 0):
            inputs, labels = data

            if cuda:
                inputs = inputs.to(device)
                labels = labels.to(device)
            
            optimizer.zero_grad()

            if q and epoch >= 3:
                # Freeze quantizer parameters
                model.apply(torch.quantization.disable_observer)
            
            outputs = model(inputs)
            # Compute loss
            loss = criterion(outputs, labels)
            # Compute gradient and update weights
            loss.backward()
            optimizer.step()

            # Update statistics
            running_loss.update(val=loss.item(), n=outputs.shape[0])
            acc.update(accuracy(output=outputs, target=labels), n=outputs.shape[0])

            # Print every 100th mini-batch
            if i % 100 == 99:
                print("[%d, %5d] " % (epoch+1, i+1), running_loss, acc)
    
    print("Finished training")

def test(model: nn.Module, dataloader: DataLoader, cuda: bool=False) -> float:
    correct = 0
    total = 0

    # Set eval mode
    model.eval()

    with torch.no_grad():
        for data in dataloader:
            inputs, labels = data

            if cuda:
                inputs = inputs.to(device)
                labels = labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            # update count
            total += outputs.shape[0]
            correct += torch.sum(predicted == labels).item()
    
    return 100*correct/total       


Train and test without quantization

In [9]:
trainloader = load_data(train=True)

train(model=net, dataloader=trainloader, cuda=torch.cuda.is_available(), q=False)

[1,   100]  loss 2.301452 (2.300698) train_acc 6.250000 (10.312500)
[1,   200]  loss 2.287138 (2.297951) train_acc 18.750000 (10.281250)
[1,   300]  loss 2.307468 (2.294516) train_acc 6.250000 (14.854167)
[1,   400]  loss 2.276615 (2.290724) train_acc 37.500000 (19.703125)
[1,   500]  loss 2.220406 (2.284420) train_acc 56.250000 (24.037500)
[1,   600]  loss 2.199768 (2.273662) train_acc 43.750000 (26.791667)
[1,   700]  loss 2.012336 (2.249737) train_acc 37.500000 (29.392857)
[1,   800]  loss 1.320231 (2.182185) train_acc 56.250000 (32.304688)
[1,   900]  loss 0.652150 (2.059625) train_acc 81.250000 (36.381944)
[1,  1000]  loss 0.620984 (1.924953) train_acc 68.750000 (40.487500)
[1,  1100]  loss 0.763521 (1.799508) train_acc 81.250000 (44.471591)
[1,  1200]  loss 0.261930 (1.686431) train_acc 87.500000 (47.979167)
[1,  1300]  loss 0.247878 (1.591323) train_acc 87.500000 (50.908654)
[1,  1400]  loss 0.285040 (1.501963) train_acc 81.250000 (53.683036)
[1,  1500]  loss 0.270009 (1.423879)

Now test on the test dataset

In [10]:
testloader = load_data(train=False)

score = test(model=net, dataloader=testloader, cuda=torch.cuda.is_available())
print(f"Accuracy of the network on the test images: {score}% - FP32")

Accuracy of the network on the test images: 98.88% - FP32


## Post-training quantization

Define a new quantized network architecture. Next we'll "fuse models", this can both make the model faster by saving on memory access while also improving numercial accuracy.

In [11]:
net.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[ 0.4428,  0.2118,  0.2768,  0.0157,  0.1945],
                        [ 0.1524,  0.5733,  0.5056,  0.5796,  0.4472],
                        [ 0.2082,  0.2719,  0.3475,  0.1671,  0.1293],
                        [-0.6373, -0.5048, -0.3163, -0.5148, -0.1139],
                        [-0.5653, -0.6353, -0.5980, -0.4194, -0.1856]]],
              
              
                      [[[-0.1060, -0.1101, -0.2690, -0.2815, -0.3092],
                        [-0.0744, -0.3842, -0.5379, -0.5686, -0.1313],
                        [-0.2139, -0.4757, -0.2699, -0.2235, -0.0287],
                        [-0.0538, -0.2168, -0.3401, -0.0597,  0.3398],
                        [ 0.0278,  0.1061,  0.4902,  0.4629,  0.4011]]],
              
              
                      [[[-0.2204, -0.2499, -0.2926,  0.0900, -0.0302],
                        [ 0.0470,  0.0654,  0.1472,  0.1214, -0.1123],
                        [ 0.3897,  0.1482,  0.4610,  0

In [12]:
net.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

In [13]:
qnet = Net(q=True)
load_model(quantized_model=qnet, model=net)

Compare qnet state_dict before and after fuse modules.

In [14]:
qnet.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

In [15]:
fuse_modules(model=qnet)

In [16]:
qnet.state_dict().keys()

odict_keys(['conv1.0.weight', 'conv1.0.bias', 'conv2.0.weight', 'conv2.0.bias', 'fc1.0.weight', 'fc1.0.bias', 'fc2.0.weight', 'fc2.0.bias', 'fc3.weight', 'fc3.bias'])

In [17]:
print_size_of_model(model=qnet)
score = test(model=qnet, dataloader=testloader, cuda=False)
print("Accuracy of the fused network on the test images: {score}% - FP32")

Size (MB):  0.181262
Accuracy of the fused network on the test images: {score}% - FP32


In [18]:
net.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

In [19]:
qnet.qconfig = torch.quantization.default_qconfig
print(qnet.qconfig)
torch.quantization.prepare(model=qnet, inplace=True)
print("Post training Quantization Prepare: Inserting Observers")
print("\nConv1: After observer insertion\n\n", qnet.conv1)

# Calibrate with the training set
test(model=qnet, dataloader=trainloader, cuda=False)
print("Post Training Quantization: Calibration done")
torch.quantization.convert(qnet, inplace=True)
print("Post Training Quantization: Convert done")
print("Conv1: After fusion and quantization\n\n", qnet.conv1)
print("Size of model after quantization")
print_size_of_model(model=qnet)


QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
Post training Quantization Prepare: Inserting Observers

Conv1: After observer insertion

 ConvReLU2d(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
)
Post Training Quantization: Calibration done
Post Training Quantization: Convert done
Conv1: After fusion and quantization

 QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.0730319693684578, zero_point=0)
Size of model after quantization
Size (MB):  0.052102


Check the accuracy of the quantized model

In [20]:
score = test(model=qnet, dataloader=testloader, cuda=False)
print(f"Accuracy of the fused and quantized network on the test images: {score}% - INT8")

Accuracy of the fused and quantized network on the test images: 98.88% - INT8


Define a custom quantization configuration.
Replace the default observers and instead of quantizing wrt to max/min we take average of the observed max/min.

In [21]:
from torch.quantization.observer import MovingAverageMinMaxObserver

qnet = Net(q=True)
load_model(quantized_model=qnet, model=net)
fuse_modules(model=qnet)

qnet.qconfig = torch.quantization.QConfig(
    activation=MovingAverageMinMaxObserver.with_args(reduce_range=True),
    weight=MovingAverageMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
)
print(qnet.qconfig)
torch.quantization.prepare(model=qnet, inplace=True)
print("Post Training Quantization Prepare: Inserting Observers")
print("\nConv1: After observer insertion\n\n", qnet.conv1)

test(model=qnet, dataloader=trainloader, cuda=False)
print("Post Training Quantization: Calibration done")
torch.quantization.convert(module=qnet, inplace=True)
print("Post Training Quantization: Convert done")
print("\nConv1: After fusion and quantization\n\n", qnet.conv1)
print("Size of model after quantization")
print_size_of_model(model=qnet)
score = test(model=qnet, dataloader=testloader, cuda=False)
print(f"Accuracy of the fused and quantized network on the test images: {score}% - INT8")

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
Post Training Quantization Prepare: Inserting Observers

Conv1: After observer insertion

 ConvReLU2d(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
)




Post Training Quantization: Calibration done
Post Training Quantization: Convert done

Conv1: After fusion and quantization

 QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.07102137058973312, zero_point=0)
Size of model after quantization
Size (MB):  0.052102
Accuracy of the fused and quantized network on the test images: 98.88% - INT8


## Quantization aware training

In [22]:
qnet = Net(q=True)
fuse_modules(model=qnet)

# Specify quantization config for QAT
qnet.qconfig = torch.quantization.get_default_qat_qconfig(backend="fbgemm")

# Prepare QAT
torch.quantization.prepare_qat(model=qnet, inplace=True)

print("\nConv1: After fusion and quantization\n\n", qnet.conv1)

qnet = qnet.to(device=device)


Conv1: After fusion and quantization

 ConvReLU2d(
  1, 6, kernel_size=(5, 5), stride=(1, 1)
  (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
    fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False
    (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (activation_post_process): FusedMovingAvgObsFakeQuantize(
    fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  )
)


In [23]:
train(model=qnet, dataloader=trainloader, cuda=torch.cuda.is_available())

  return torch.fused_moving_avg_obs_fake_quant(
  return torch.fused_moving_avg_obs_fake_quant(


[1,   100]  loss 2.277035 (2.302786) train_acc 12.500000 (9.250000)
[1,   200]  loss 2.261745 (2.296432) train_acc 37.500000 (10.625000)
[1,   300]  loss 2.268169 (2.288824) train_acc 37.500000 (16.270833)
[1,   400]  loss 2.255547 (2.280266) train_acc 25.000000 (21.531250)
[1,   500]  loss 2.210626 (2.268720) train_acc 50.000000 (28.175000)
[1,   600]  loss 2.121386 (2.251543) train_acc 68.750000 (33.604167)
[1,   700]  loss 2.051449 (2.227287) train_acc 68.750000 (37.571429)
[1,   800]  loss 1.730208 (2.190522) train_acc 81.250000 (41.171875)
[1,   900]  loss 1.581596 (2.135772) train_acc 56.250000 (44.229167)
[1,  1000]  loss 1.472379 (2.062680) train_acc 56.250000 (46.881250)
[1,  1100]  loss 0.928114 (1.973150) train_acc 75.000000 (49.323864)
[1,  1200]  loss 0.857227 (1.871526) train_acc 56.250000 (51.812500)
[1,  1300]  loss 0.911844 (1.776487) train_acc 68.750000 (53.942308)
[1,  1400]  loss 0.374884 (1.684849) train_acc 93.750000 (56.174107)
[1,  1500]  loss 0.233549 (1.601560

[2,  3100]  loss 0.119964 (0.112254) train_acc 93.750000 (96.622984)
[2,  3200]  loss 0.001713 (0.110740) train_acc 100.000000 (96.664062)
[2,  3300]  loss 0.086472 (0.110087) train_acc 100.000000 (96.685606)
[2,  3400]  loss 0.022003 (0.109245) train_acc 100.000000 (96.704044)
[2,  3500]  loss 0.068793 (0.109076) train_acc 93.750000 (96.700000)
[2,  3600]  loss 0.033700 (0.108560) train_acc 100.000000 (96.722222)
[2,  3700]  loss 0.029272 (0.107796) train_acc 100.000000 (96.743243)
[3,   100]  loss 0.108049 (0.076380) train_acc 93.750000 (97.375000)
[3,   200]  loss 0.229120 (0.078433) train_acc 93.750000 (97.531250)
[3,   300]  loss 0.079726 (0.075637) train_acc 93.750000 (97.562500)
[3,   400]  loss 0.013925 (0.075444) train_acc 100.000000 (97.546875)
[3,   500]  loss 0.041924 (0.074523) train_acc 100.000000 (97.575000)
[3,   600]  loss 0.007033 (0.074922) train_acc 100.000000 (97.604167)
[3,   700]  loss 0.023965 (0.074106) train_acc 100.000000 (97.598214)
[3,   800]  loss 0.005594

In [24]:
qnet = qnet.to(device=torch.device("cpu"))
torch.quantization.convert(module=qnet, inplace=True)
print("Size of model after quantization")
print_size_of_model(model=qnet)

score = test(model=qnet, dataloader=testloader, cuda=False)
print(f"Accuracy of the fused and quantized network (trianed quantized) on the test images: {score}% - INT8")

Size of model after quantization
Size (MB):  0.05793
Accuracy of the fused and quantized network (trianed quantized) on the test images: 99.11% - INT8


Now run the training loop by freezing the quantizer parameters (scale and zero-point) and finetune the weights.

In [25]:
qnet = Net(q=True)

fuse_modules(model=qnet)

# Specify quantization config for QAT
qnet.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")

# Prepare QAT
torch.quantization.prepare_qat(model=qnet, inplace=True)

qnet = qnet.to(device)

train(model=qnet, dataloader=trainloader, cuda=torch.cuda.is_available(), q=True)

qnet = qnet.to(torch.device("cpu"))

torch.quantization.convert(module=qnet, inplace=True)

print("Size of model after quantization")
print_size_of_model(model=qnet)

score = test(model=qnet, dataloader=testloader, cuda=False)
print(f"Accuracy of the fused and quantized network (trianed quantized) on the test images: {score}% - INT8")

[1,   100]  loss 2.288820 (2.298613) train_acc 6.250000 (10.187500)
[1,   200]  loss 2.292001 (2.294692) train_acc 6.250000 (11.687500)
[1,   300]  loss 2.261161 (2.286984) train_acc 25.000000 (16.229167)
[1,   400]  loss 2.262180 (2.277621) train_acc 37.500000 (20.750000)
[1,   500]  loss 2.228750 (2.265902) train_acc 18.750000 (26.200000)
[1,   600]  loss 2.139659 (2.249977) train_acc 56.250000 (31.125000)
[1,   700]  loss 1.966706 (2.226303) train_acc 62.500000 (35.821429)
[1,   800]  loss 1.803692 (2.191116) train_acc 87.500000 (39.820312)
[1,   900]  loss 1.603228 (2.137844) train_acc 62.500000 (43.208333)
[1,  1000]  loss 1.280798 (2.060221) train_acc 68.750000 (46.431250)
[1,  1100]  loss 0.863921 (1.963603) train_acc 68.750000 (49.386364)
[1,  1200]  loss 0.851023 (1.858339) train_acc 81.250000 (52.031250)
[1,  1300]  loss 0.779770 (1.759009) train_acc 81.250000 (54.418269)
[1,  1400]  loss 0.899344 (1.669670) train_acc 62.500000 (56.517857)
[1,  1500]  loss 0.260186 (1.590217)

[1,  2400]  loss 0.304696 (1.115528) train_acc 87.500000 (70.190104)
[1,  2500]  loss 0.574006 (1.080056) train_acc 87.500000 (71.087500)
[1,  2600]  loss 0.097433 (1.047876) train_acc 100.000000 (71.923077)
[1,  2700]  loss 0.127464 (1.016685) train_acc 93.750000 (72.712963)
[1,  2800]  loss 0.264521 (0.987400) train_acc 87.500000 (73.470982)
[1,  2900]  loss 0.455620 (0.960537) train_acc 87.500000 (74.146552)
[1,  3000]  loss 0.318743 (0.935052) train_acc 87.500000 (74.802083)
[1,  3100]  loss 0.050227 (0.911382) train_acc 100.000000 (75.411290)
[1,  3200]  loss 0.116966 (0.889287) train_acc 100.000000 (75.982422)
[1,  3300]  loss 0.046975 (0.867586) train_acc 100.000000 (76.534091)
[1,  3400]  loss 0.057824 (0.847340) train_acc 100.000000 (77.051471)
[1,  3500]  loss 0.341893 (0.828908) train_acc 87.500000 (77.539286)
[1,  3600]  loss 0.059921 (0.810712) train_acc 100.000000 (78.024306)
[1,  3700]  loss 0.096222 (0.793593) train_acc 100.000000 (78.471284)
[2,   100]  loss 0.253829 (