### Prerequisites

In [2]:
import torch
import torch.nn.functional as F
from torchvision import models, transforms, datasets
from copy import deepcopy
import requests
from PIL import Image
from resnet_cifar import Trainer, cifar_dataloader


def load_img(url):
    IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    transform = transforms.Compose([
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
        ])
    if url.startswith("https"):
        img = Image.open(requests.get(url, stream=True).raw)
    else:
        img = Image.open(url)
    img = transform(img).unsqueeze(0)
    return img


def get_predictions(outp):
    cls_idx = {
        0: 'airplane',
        1: 'automobile',
        2: 'bird',
        3: 'cat',
        4: 'deer',
        5: 'dog',
        6: 'frog',
        7: 'horse',
        8: 'ship',
        9: 'truck'}
    outp = F.softmax(outp, dim=1)
    score, idx = torch.topk(outp, 1)
    idx.squeeze_()
    predicted_label = cls_idx[idx.item()]
    print(predicted_label, '(', score.squeeze().item(), ')')


def print_sizeof(model):
    total = 0
    for p in model.parameters():
        total += p.numel() * p.element_size()
    total /= 1e6
    print("Model size: ", total, " MB")


In [None]:
print(torch.__version__)

## Flowchart for using Quantization in PyTorch

<img src="./img/quantization-flowchart.png" width="700" />

## 10M+ Parameters?

<img src="./img/flowchart-check1.png" width="300" />

Quantization works best on models with 10M+ parameters. [[1](https://arxiv.org/pdf/1806.08342.pdf)]

Large models are more robust to quantization error. Overparameterized models generally have more degrees of freedom and can afford the precision drops with quantization.

As with most thumb rules, YMMV. Quantization is an active area of research, and this might become more permissive.

In [None]:
def is_large_enough(model):
    n_params = sum([p.numel() for p in model.parameters()])
    return n_params > 1e7, n_params // 1e6

print("resnet18: ", is_large_enough(models.resnet18()))
print("resnet50: ", is_large_enough(models.resnet50()))
print("mobilenet_large: ", is_large_enough(models.mobilenet_v3_large()))
print()

## FP32-pretrained checkpoint?

<img src="./img/flowchart-check2.png" width="300" />

Quantized inference works best on models that were originally trained in FP32 (like all pretrained models in PyTorch (vision, audio and text)).

This allows the model to learn many fine-grained parameters, that can later be quantized for inference.

Even Quantization-Aware Training (more on this below) uses FP32 arithmetic to train the parameters.

In [None]:
# load model from checkpoint 

# weights = torch.hub.load_state_dict_from_url("https://quantization-workshop.s3.amazonaws.com/resnet50_cifar_weights.pth", map_location="cpu")
# resnet = models.resnet50(pretrained=False, num_classes=10)

weights = torch.hub.load_state_dict_from_url("https://quantization-workshop.s3.amazonaws.com/resnet18_cifar_weights.pth", map_location="cpu")
resnet = models.resnet18(pretrained=False, num_classes=10)
resnet.load_state_dict(weights)

param_0 = next(iter(resnet.parameters()))
print("Model precision: ", param_0.dtype)

In [None]:
resnet.eval()

## Using a supported backend?

<img src="./img/flowchart-check3.png" width="300" />

Backend refers to the hardware-specific kernels that support quantization. This controls the numerics engine that does the integer arithmetic.

`torch.backends.quantized.engine` specifies the backend to be used.

Using an incorrect backend engine for your hardware will result in (much) slower inference.

In [None]:
import platform
chip = platform.processor()

if chip == 'arm':
    backend = 'qnnpack'
elif chip in ['x86_64', 'i386']:
    backend = 'fbgemm'
else:
    raise SystemError("Backend is not supported")

print(f"Using {backend} backend engine for {chip} CPU")

torch.backends.quantized.engine = backend

## Profile FP32 model inference

<img src="./img/flowchart-check4.png" width="300" />

Let's establish a baseline for model size, inference latency and accuracy

In [45]:
import os
    
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

def profile(model):
    print_size_of_model(model)
    print("="*20)
    Trainer(model, -1).evaluate(max_batch=30)  # latency + accuracy on CIFAR test set 

In [None]:
print("Resnet FP32 profile:")
profile(resnet)

## What are the predominant layers in the model?

<img src="./img/flowchart-check5.png" width="300" />

While dynamic quantization has more overhead than static quantization, some operators (like recurrent layers) aren't supported by static quantization. (See [Operator coverage](https://pytorch.org/docs/stable/quantization.html#:~:text=these%20quantization%20types.-,Operator%20coverage,-varies%20between%20dynamic)).

Knowing which layers are in our model can inform our quantization strategy.


#### Thumb rule

* For recurrent and transformer layers, use Dynamic quantization.
* For linear layers, you can use either Dynamic or Static quantization.
* For everything else, use Static quantization.

In [None]:
def optimal_quant_strategy(model):
    from collections import Counter
    layer_counts = Counter([type(x).__name__ for x in model.modules()])
    print("Model consists of: ", layer_counts)
    
    dyn = [0, 0]
    stat = [0, 0]

    for m in model.modules():
        if hasattr(m, 'weight'):    
            name = type(m).__name__
            params = m.weight.numel()
            if name in ['RNN', 'LSTM', 'GRU', 'LSTMCell', 'RNNCell', 'GRUCell', 'Linear']:
                dyn[0] += 1
                dyn[1] += params
            if 'Conv' in name or name == 'Linear':
                stat[0] += 1
                stat[1] += params
    print()
    print("Dynamic quantization")
    print("====================")
    print(f"Layers: {dyn[0]} || Parameters: {format(dyn[1], 'g')}")
    print()
    print("Static quantization")
    print("====================")
    print(f"Layers: {stat[0]} || Parameters: {format(stat[1], 'g')}")
    

optimal_quant_strategy(resnet)

## Try Dynamic Quantization

<img src="./img/flowchart-check6.png" width="300" />

[Dynamic Quantization API](https://pytorch.org/docs/stable/generated/torch.quantization.quantize_dynamic.html?highlight=quantize_dynamic#torch.quantization.quantize_dynamic)

[Dynamic Quantization Tutorial](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html)

In [None]:
from torch.quantization.quantize_fx import prepare_fx, convert_fx

dynamic_qconfig = torch.quantization.default_dynamic_qconfig
qconfig_dict = {
    # Global Config
    "": dynamic_qconfig
}

model_prepared = prepare_fx(resnet, qconfig_dict)
dynamic_resnet = convert_fx(model_prepared)

### Evaluate performance of dynamic-quantized Resnet model

In [None]:
print("Resnet Dynamic-Quant Profile:")
profile(dynamic_resnet)

## Try Static Quantization

<img src="./img/flowchart-check7_1.png" width="300" />
<br>
<img src="./img/flowchart-check7_2.png" width="300" />

Static quantization has faster inference than dynamic quantization because it eliminates the float<->int conversion costs between layers

### Manual approach - using Eager Mode

Explicitly perform the following steps:

<img src="./img/ptq-flowchart.png" width="300" />

* Manually identify sequence of fusable modules
* Manually insert stubs to quantize and dequantize activations
* Functional ops (eg: `torch.nn.functional.linear`) aren't supported

[Module Fusion Tutorial](https://pytorch.org/tutorials/recipes/fuse.html)

[Static Quantization (Eager Mode) Tutorial](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html)

### Easier approach - using FX Graph Mode

<img src="./img/ptq-fx-flowchart.png" width="300" />

* Just 2 function calls: `prepare_fx` and `convert_fx`
* Automates all the above steps under the hood using `torch.fx`

[`prepare_fx` API](https://pytorch.org/docs/stable/generated/torch.quantization.quantize_fx.prepare_fx.html#torch.quantization.quantize_fx.prepare_fx)

[`convert_fx` API](https://pytorch.org/docs/stable/generated/torch.quantization.quantize_fx.convert_fx.html#torch.quantization.quantize_fx.convert_fx)

[Static Quantization (FX Graph Mode) Tutorial](https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html)

#### QConfig

In FX Quantization, the `qconfig_dict` offers fine-grained control of the model's quantization process.
Setting a `qconfig=None` skips quantization for that module.

```python
qconfig_dict = {
    # Global Config
    "": qconfig,

    # Module-specific config (by class)
    "object_type": [
        (torch.nn.Conv2d, qconfig),
        (torch.nn.functional.add, None),  # skips quantization for this module
        ...,
        ],
    
    # Module-specific config (by name)
    "module_name": [
        ("foo.bar", qconfig)
        ...,
    ],
}
```

In [32]:
static_qconfig = torch.quantization.get_default_qconfig(backend)
qconfig_dict = {
    # Global Config
    "": static_qconfig,
}

In [None]:
from torchvision import datasets, transforms
from torch.quantization.quantize_fx import prepare_fx, convert_fx


def static_quantize_vision_model(model, qconfig_dict):
        _, data = cifar_dataloader()
        mp = prepare_fx(model, qconfig_dict)

        for c, (x, y) in enumerate(data):
                if c == 30:
                        break
                mp(x)
        
        mc = convert_fx(mp)
        return mc

static_resnet = static_quantize_vision_model(resnet, qconfig_dict)

### Evaluate performance of static-quantized model

In [None]:
print("Resnet Static-Quant Profile:")
profile(static_resnet)

### Sensitivity Analysis - Which quantized layers affect accuracy the most?

<img src="./img/flowchart-check8.png" width="300" />
<br>

Some layers are more sensitive to precision drops than others. PyTorch provides tools to help with this analysis under the Numeric Suite.

[Numeric Suite Tutorial](https://pytorch.org/tutorials/prototype/numeric_suite_tutorial.html)


In [35]:
import torch.quantization._numeric_suite as ns

def SNR(x, y):
    # Higher is better
    Ps = torch.norm(x)
    Pn = torch.norm(x-y)
    return 20 * torch.log10(Ps/Pn)

def compare_model_weights(float_model, quant_model):
    snr_dict = {}
    wt_compare_dict = ns.compare_weights(float_model.state_dict(), quant_model.state_dict())
    for param_name, weight in wt_compare_dict.items():
        snr = SNR(weight['float'], weight['quantized'].dequantize())
        snr_dict[param_name] = snr

    return snr_dict

Layer-by-layer comparison of model weights 

<img src="./img/ns.png" width="300" />

In [None]:
snrd = compare_model_weights(resnet, static_resnet)
print(snrd)

In [None]:
def topk_sensitive_layers(snr_dict, k):
    snr_dict = dict(sorted(snr_dict.items(), key=lambda x:x[1]))
    snr_dict = {k.replace('.weight', ''):v for k,v in list(snr_dict.items())[:k]}
    return snr_dict
    
sensitive_layers = topk_sensitive_layers(snrd, 5).keys()
print(sensitive_layers)

## Selective Static Quantization

In [None]:
sensitive_layers = topk_sensitive_layers(snrd, 5).keys()

qconfig_dict = {
    # Global Config
    "": static_qconfig,

    # Disable for sensitive modules
    "module_name": [(m, None) for m in sensitive_layers],
}

sel_static_resnet = static_quantize_vision_model(resnet, qconfig_dict)

### Evaluate performance of selective static-quantized model

In [None]:
print("Resnet Selective-Static-Quant Profile:")
profile(sel_static_resnet)

## Quantization-Aware Training

<img src="./img/flowchart-check9.png" width="300" />

In [40]:
from torch.quantization.quantize_fx import prepare_qat_fx
from resnet_cifar import Trainer, cifar_dataloader

sensitive_layers = topk_sensitive_layers(snrd, 5).keys()

qat_qconfig = torch.quantization.get_default_qat_qconfig(backend)
qconfig_dict = {
    # Global Config
    "": qat_qconfig,
}

def qat_vision_model(model, qconfig):
    model.train()
    mp = prepare_qat_fx(model, qconfig)

    # training loop
    trainer = Trainer(mp, epochs=20)  
    trainer.run_epoch()

    mc = convert_fx(mp)
    return mc

qat_resnet = qat_vision_model(resnet, qconfig_dict)

Files already downloaded and verified
Files already downloaded and verified
Epoch: 0
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7
Epoch: 8
Epoch: 9
Epoch: 10
Epoch: 11
Epoch: 12
Epoch: 13
Epoch: 14
Epoch: 15
Epoch: 16
Epoch: 17
Epoch: 18
Epoch: 19


### Evaluate performance of QAT model

In [42]:
print("Resnet QAT Profile:")
profile(qat_resnet)

Resnet QAT Profile:
Size (MB): 11.310369
Files already downloaded and verified
Files already downloaded and verified
Loss: 0.4549180895090103 
Accuracy: 0.8427083333333333
Time taken (1920 CIFAR test samples): 0.3876371383666992
