<a href="https://colab.research.google.com/github/fornitroll/Object-Detection-with-PyTorch-Kyiv-/blob/master/PyTorch_Quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PyTorch Quantization

Quantization - is a process of mapping values from bigget set to the smaller one. 
In ML this process is done for 2 reason:
1. Lower model size. In most cases size of the model decrese in 2-6 times.
2. Make model to inference faster on CPU, because of limmiting amount of data, and more simpler and faster operations on int8.

In our case we will be mapping values from float32(-3.4E+38 to +3.4E+38) to the int8(-128 to +128). As you can see difference is quite big and that's means that we will get accuracy drops(but not always, sometimes small models can increase accuracy, some sort of regularization)

There are two ways of model quantization:
1. Post-training quantization
2. Quantization-aware training

Today we will cover both of them.

## Post-training quantization

In [0]:
!git clone https://github.com/NVIDIA/apex

In [0]:
cd apex

In [0]:
# for real use please use installation with CUDA and CPP
!pip install -v --no-cache-dir ./

In [0]:
cd ..

In [0]:
import torchvision
import torchvision.transforms as transforms
import os
import time
import torch
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
from torch.quantization import QuantStub, DeQuantStub, fuse_modules

In [0]:
def make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class QuantizableBasicBlock(nn.Module):
    expansion = 1
    __constants__ = ['downsample']

    def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=None, base_width=None, previous_dilation=None, norm_layer=None, dilation=None):
        super(QuantizableBasicBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels, momentum=0.1)
        self.relu = nn.ReLU(inplace=False)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.1)

        self.downsample = downsample
        self.stride = stride
        # used to wrap some simple float operations like add, mul, relu, etc.
        self.skip_add_relu = torch.nn.quantized.FloatFunctional()

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.skip_add_relu.add_relu(out, identity)

        return out

    def fuse_model(self):
        torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
                                               ['conv2', 'bn2']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class QuantizableBottleneck(nn.Module):
    expansion = 4
    __constants__ = ['downsample']

    def __init__(self, in_channels, channels, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, **kwargs):
        super(QuantizableBottleneck, self).__init__()

        width = make_divisible(int(channels * (base_width / 64.)) * groups, 8)
        self.conv1 = nn.Conv2d(in_channels, width, 1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(width, momentum=0.1)

        self.conv2 = nn.Conv2d(width, width, 3, stride=stride, padding=1, groups=groups, dilation=dilation, bias=False)
        self.bn2 = nn.BatchNorm2d(width, momentum=0.1)

        out_channels = make_divisible(channels * self.expansion, 8)
        self.conv3 = nn.Conv2d(width, out_channels, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels, momentum=0.1)

        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)

        self.skip_add_relu = nn.quantized.FloatFunctional()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.skip_add_relu.add_relu(out, identity)

        return out

    def fuse_model(self):
        torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1'],
                                               ['conv2', 'bn2', 'relu2'],
                                               ['conv3', 'bn3']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class QuantizableResNet(ResNet):

    def __init__(self, *args, **kwargs):
        super(QuantizableResNet, self).__init__(*args, **kwargs)

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.relu = nn.ReLU(inplace=False)

    def _forward_impl(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        x = self.quant(x)
        x = self._forward_impl(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        # fuse first layers
        fuse_modules(self, ['conv1', 'bn1', 'relu'], inplace=True)
        for m in self.modules():
            if type(m) == QuantizableBottleneck or type(m) == QuantizableBasicBlock:
                m.fuse_model()


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = QuantizableResNet(block, layers, **kwargs)

    if pretrained:
        model_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
        state_dict = load_state_dict_from_url(model_url,
                                              progress=progress)

        model.load_state_dict(state_dict)
    return model


def resnet18(pretrained=False, progress=True, **kwargs):
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
    

#### Fusing
Fusing is process to replace time-consuming operations with more faster but freezed or approximated. Currently we use fuze only on convolution and batchnorm layers.

You can read more about it here:
http://learnml.today/speeding-up-model-with-fusing-batch-normalization-and-convolution-3

So what we do above is added `fuse_model` method to fuse layers and replaces residual operation with `nn.quantized.FloatFunctional()`.

Also we set all `ReLU(inplace=False)`, this needed for quantizitaion module. 


Now let's try quantization

In [0]:
class MeanMetric(object):
    """Computes accuracy mean"""
    def __init__(self, name):
        self.name = name
        self.clear()

    def clear(self):
        self.val = 0
        self.count = 0

    def update(self, val):
        self.val += val
        self.count += 1

    def __str__(self):
        return f'{self.name}={round(self.val/self.count, 2)}'

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    size = ('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')
    return size

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def evaluate(model, criterion, data_loader, device, eval_steps=10):
    model.eval().to(device)
    top1 = MeanMetric('Acc@1')
    top5 = MeanMetric('Acc@5')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            image, target = image.to(device), target.to(device)
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0].detach().cpu().item())
            top5.update(acc5[0].detach().cpu().item())
            if cnt >= eval_steps:
                 return top1, top5

    return top1, top5

def load_model(model_file, classes=100):
    model = resnet18(num_classes=classes)
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    return model

In [0]:
def get_loaders(data_path, batch_size=8):

    traindir = os.path.join(data_path, 'train')
    valdir = os.path.join(data_path, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    
    # We want to say how we happy to know that in 21st cetury people still 
    # blocking other people from getting data to learn. That definitelly what 
    # will make our world better. Our thanks goes to the guys from ImageNet
    # who locked down public access to the ImageNet Dataset.

    dataset = torchvision.datasets.CIFAR100(
        './cifar100',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ])
    )
        
    dataset_test = torchvision.datasets.CIFAR100(
        './cifar100',
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])
    )

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size,
        sampler=train_sampler)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=batch_size,
        sampler=test_sampler)

    return data_loader, data_loader_test

train, test = get_loaders('imagenet_1k')

0it [00:00, ?it/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./cifar100/cifar-100-python.tar.gz


169009152it [00:04, 37868417.69it/s]                               


Extracting ./cifar100/cifar-100-python.tar.gz to ./cifar100
Files already downloaded and verified


In [0]:
def train_one_epoch(model, criterion, optimizer, data_loader, device, log_steps=30):
    model.train().to(device)
    top1 = MeanMetric('Acc@1')
    top5 = MeanMetric('Acc@5')
    avgloss = MeanMetric('Loss')

    cnt = 0
    for image, target in data_loader:
        start_time = time.time()
        cnt += 1
        
        image, target = image.to(device), target.to(device)
        output = model(image)

        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        top1.update(acc1[0].detach().cpu().item())
        top5.update(acc5[0].detach().cpu().item())
        avgloss.update(loss.detach().cpu().item())

        if cnt % log_steps == 0 and cnt:
            print(f'Training {cnt}: {avgloss} {top1} {top5}')
            
            avgloss.clear()
            top1.clear()
            top5.clear()

                
    top1, top5 = evaluate(model, criterion, data_loader, device)
    print(f'Full imagenet train set: {top1} {top5}')
    return model

Let's train a bit our model on CIFAR100 dataset to get some basic accuracy

In [0]:
model = resnet18(pretrained=False, num_classes=100).to('cuda')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=5e-4)


for i in range(20):
    print(f'Epoch: {i} ######################################################')
    model = train_one_epoch(model, criterion, optimizer, train, device='cuda')

Epoch: 0 ######################################################
Training 30: Loss=4.9 Acc@1=0.83 Acc@5=5.0
Training 60: Loss=4.82 Acc@1=2.08 Acc@5=7.92
Training 90: Loss=4.68 Acc@1=2.5 Acc@5=11.67
Training 120: Loss=4.68 Acc@1=0.83 Acc@5=9.58
Training 150: Loss=4.64 Acc@1=0.42 Acc@5=9.17
Training 180: Loss=4.63 Acc@1=3.75 Acc@5=10.83
Training 210: Loss=4.67 Acc@1=2.92 Acc@5=11.25
Training 240: Loss=4.62 Acc@1=3.75 Acc@5=10.42
Training 270: Loss=4.53 Acc@1=3.75 Acc@5=11.25
Training 300: Loss=4.39 Acc@1=5.42 Acc@5=16.25
Training 330: Loss=4.51 Acc@1=3.33 Acc@5=15.42
Training 360: Loss=4.43 Acc@1=4.17 Acc@5=17.08
Training 390: Loss=4.55 Acc@1=4.17 Acc@5=12.92
Training 420: Loss=4.47 Acc@1=5.42 Acc@5=17.08
Training 450: Loss=4.41 Acc@1=4.58 Acc@5=16.67
Training 480: Loss=4.43 Acc@1=4.17 Acc@5=15.83
Training 510: Loss=4.32 Acc@1=5.42 Acc@5=21.25
Training 540: Loss=4.38 Acc@1=4.58 Acc@5=18.33
Training 570: Loss=4.35 Acc@1=4.17 Acc@5=18.33
Training 600: Loss=4.41 Acc@1=3.33 Acc@5=13.75
Traini

In [0]:
torch.save(model.state_dict(), 'resnet18.pth')


In [0]:
torch.manual_seed(42)
train, test = get_loaders('imagenet_1k')
st = time.time()
top1, top5 = evaluate(model, criterion, test, 'cpu')
print(f'Finall accuracy: {top1} {top5} in {time.time() - st} sec')

### Dynamic quantization
Currently supported only for nn.Linear, nn.LSTM modules

In [0]:
base_model_size = print_size_of_model(model)
print(base_model_size)

In [0]:
import torch.quantization

base = load_model('resnet18.pth').eval().to('cpu')
# we will se DynamicQuantizedLinear module in the end instead of Linear
dyn_quantized_model = torch.quantization.quantize_dynamic(base, {nn.Linear}, dtype=torch.qint8)
print(list(dyn_quantized_model.modules())[-3:])

base_model_size = print_size_of_model(dyn_quantized_model)
print(base_model_size)

torch.manual_seed(42)
train, test = get_loaders('imagenet_1k')
st = time.time()
top1, top5 = evaluate(dyn_quantized_model, criterion, test, 'cpu')
print(f'Finall accuracy: {top1} {top5} in {time.time() - st} sec')


### Static quantization

Quantization is based on Observer class, which main purpose is to find best parameters to fit float32 into int8. So to give some info about weights and biases distribution we need to inference some amount of data through the model. We do this after the `prepare` method

In [0]:
torch.manual_seed(42)
stat_quant = load_model('resnet18.pth').eval().to('cpu')

# Fuse Conv, bn and relu
stat_quant.fuse_model()
# we will not see Batch Norm instead will see replaced it with Identity
print(stat_quant)

stat_quant.qconfig = torch.quantization.get_default_qconfig('fbgemm')
print(stat_quant.qconfig)

# this aware quantization engine thet we are going to inference few inputs 
# so Observer can get enough data for building optimiation maps
torch.quantization.prepare(stat_quant, inplace=True)
train, test = get_loaders('imagenet_1k')
evaluate(stat_quant, criterion, test, 'cpu')
# now we converting our model to int8
torch.quantization.convert(stat_quant, inplace=True)

base_model_size = print_size_of_model(stat_quant)
print(base_model_size)

st = time.time()
top1, top5 = evaluate(stat_quant, criterion, test, 'cpu')
print(f'Finall accuracy: {top1} {top5} in {time.time() - st} sec')

So we have small accuracy drop, but we lower size of our model in 4 times and inference times droped in twice. Not bad for few lines of code :)

## Quantization-aware training
So this works preatty simple. QAT converts all weights and activations to *fake quantized*, so during forward and backward passes float values are rounded to mimic int8 values. Thus during training we optimize values of weights and activations so they better fit to the int8 values distribution. And after training we just fix that.

In [0]:
torch.manual_seed(42)
qat_model = load_model('resnet18.pth')
optimizer = torch.optim.SGD(qat_model.parameters(), momentum=0.9, lr=5e-4)
qat_model.fuse_model()

qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(qat_model, inplace=True)

train, test = get_loaders('imagenet_1k')
# Train and check accuracy after each epoch
for epoch in range(5):
    train_one_epoch(qat_model, criterion, optimizer, train, 'cuda')
    if epoch > 3:
        # Freeze quantizer parameters
        qat_model.apply(torch.quantization.disable_observer)
    if epoch > 2:
        # Freeze batch norm mean and variance estimates
        qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    # Check the accuracy after each epoch
    quantized_model = torch.quantization.convert(qat_model.eval().to('cpu'), inplace=False)
    quantized_model.eval()
    top1, top5 = evaluate(quantized_model, criterion, test, 'cpu')
    print(f'Epoch {epoch}: {top1} {top5}')

base_model_size = print_size_of_model(stat_quant)
print(base_model_size)

st = time.time()
top1, top5 = evaluate(stat_quant, criterion, test, 'cpu')
print(f'Finall accuracy: {top1} {top5} in {time.time() - st} sec')

## Nvidia APEX

This lib is a quite cool stuff to get additional boost just with few lines, or more wih more advanced techniques.
What it do is same Quantized Aware Training. It give you ability to use mixed float precision with float32/float16, or fully float16 precision.

This is usefull because many edge device have much better float16 support than float32, also new Nvidia cards have preatty good optimization on float16. Here is example of Nvidia T4 card:

<img src='https://github.com/learnml-today/object-detection-with-pytorch/blob/master/imgs/nvidiat4.png?raw=true' />

As you see difference betwen float32 and float16 is more than 8x.

In [0]:
try:
    from apex import amp
except ImportError:
    raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")



In [0]:
def train_one_epoch_with_apex(model, criterion, optimizer, data_loader, device, log_steps=30):
    model.train().to(device)
    top1 = MeanMetric('Acc@1')
    top5 = MeanMetric('Acc@5')
    avgloss = MeanMetric('Loss')

    cnt = 0
    for image, target in data_loader:
        start_time = time.time()
        cnt += 1
        
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad()
        
        with amp.scale_loss(loss, optimizer) as scaled_loss:           
            scaled_loss.backward()
        
        optimizer.step()

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        top1.update(acc1[0].detach().cpu().item())
        top5.update(acc5[0].detach().cpu().item())
        avgloss.update(loss.detach().cpu().item())

        if cnt % log_steps == 0 and cnt:
            print(f'Training {cnt}: {avgloss} {top1} {top5}')
            
            avgloss.clear()
            top1.clear()
            top5.clear()

                
    top1, top5 = evaluate(model, criterion, data_loader, device)
    print(f'Full imagenet train set: {top1} {top5}')
    return model

In [0]:
opt_level = 'O1' # mixed float32 and float16 precision
model = resnet18(pretrained=False, num_classes=100).to('cuda')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=5e-4)
# Loss scaling here can be used to preserve small gradient values. if not set used dynamic
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)

st = time.time()
for i in range(1):
    print(f'Epoch: {i} ######################################################')
    model = train_one_epoch_with_apex(model, criterion, optimizer, train, device='cuda')
print('Total time:', time.time()-st)

In [0]:
torch.manual_seed(42)
train, test = get_loaders('imagenet_1k')
st = time.time()
top1, top5 = evaluate(model, criterion, test, 'cuda')
print(f'Finall accuracy: {top1} {top5} in {time.time() - st} sec')

base_model_size = print_size_of_model(model)
print(base_model_size)

So as we see size not changes much, but the inference speed becomes faster.