# Quantization

> Good one: https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-quantization

* https://lightning.ai/blog/4-bit-quantization-with-lightning-fabric/
* https://lightning.ai/blog/8-bit-quantization-with-lightning-fabric/
* https://lightning.ai/docs/fabric/latest/fundamentals/precision.html#quantization-via-bitsandbytes
* https://pytorch.org/blog/introduction-to-quantization-on-pytorch/
* https://huggingface.co/blog/hf-bitsandbytes-integration
* https://pytorch.org/blog/quantization-in-practice/


# 1- Introduction

It's important to make efficient use of both server-side and on-device compute resources when developing machine learning. To support more efficient deployment on servers and edge devices, we can use Quantization.

Quantization leverages 8bit integer (int8) instructions to **reduce the model size and run the inference faster** (reduced latency) and can be the difference when fitting a model into the available resources. In addition, even when resources aren't quite so constrained, it may enable you to deploy a larger and more accurate model.

Therefore, by doing both computations and memory accesses with lower precision data (usually int8 compared to floating point implementations), it enables performance gains in several important areas:

* 4x reduction in model size;
* 2-4x reduction in memory bandwith;
* 2-4x faster inference due to savings in memory bandwith and faster compute with int8 arithmetic (the exact speed up vaires depending on the hardware, the runtime, and the model)

Quantization does however come with some costs. Fundamentally, **quantization means introducing approximations and the resulting networks have slightly less accuracy**. These techniques attempt to minimize the gap between the full floating point accuracy and the quantized accuracy.

For information about INT4 Quantization: https://arxiv.org/pdf/2301.12017

# 2 - Quantization in PyTorch

Quantization was designed to fit into the PyTorch framework. This means that:

1. PyTorch has data types corresponding to [quantized tensors](https://github.com/pytorch/pytorch/wiki/Introducing-Quantized-Tensor), which share many of the features of "normal" tensors.

2. One can write kernels with quantized tensors, much like kernels for floating point tensors to customize their implementation. PyTorch supports quantized modules for common operations as part of the `torch.nn.quantized` and `torch.nn.quantized.dynamic` name-space.

3. Quantization is compatible with the rest of PyTorch: quantized models are traceable and scriptable. The quantization method is virtually identical for both server and mobile backends. One can easily mix qunatized and floating point operations in a model.

4. Mapping of floating point tensors to quantized tensors is customizable with user defined observer/fake-quantization blocks. PyTorch provides default implementations that should work for most use cases.

<table>
    <tr>
        <td><img src="./images_1/torch_quantization.png" width="700"/></td>
    </tr>
</table>

## 2.1 - The Three Modes of Quantization Supported in PyTorch

### 2.1.1 - Dynamic Quantization

The easiest method of quantization PyTorch supports is called **dynamic quantization**. This involves not just converting the weights to `int8` - as happens in all quantization variants -  but also converting the activations to `int8` on the fly, just before doing the computation (hence "dynamic"). 

The computations will thus be performed using efficient `int8` matrix multiplication and convolution implementations, resulting in faster compute. However, the activations are read and written to memory in floating point format.

#### PyTorch API

We have a simple API for dynamic quantization in PyTorch. torch.quantization.quantize_dynamic takes in a model, as well as a couple other arguments, and produces a quantized model!

[PyTorch documentation contains an end-to-end tutorial that illustrates how to do it for a BERT model.](https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html)

Neverhtless, the part that quantizes the model is simply:

In [None]:
import torch.quantization

quantized_model = torch.quantization.quantize_dynamic(
    model, 
    {torch.nn.Linear}, 
    dtype=torch.qint8
)

### 2.1.2 - Post-Training Static Quantization

One can further improve the performance (latency) by converting networks to use both integer arithmetic and `int8` memory accesses.

Static quantization performs the additional step of first feeding batches of data through the network and computing the resulting distributions of the different activations (specifically, this is done by inserting "observer" modules at different points that record these distributions). **This information is used to determine how specifically the different activations should be quantized at inference time.**

A simple technique for static quantization would be to divide the entire range of activations into 256 levels, but there are more sophisticated methods as well.

Importantly, this additional step allows us to pass quantized values between operations instead of converting these values to floats - and then back to ints - between every operation, resulting in a significant speed-up.

**The Process (Simplified)**

* **Calibration:** Feed a calibration dataset through the model. Analyze how values are distributed.

* **Determining Scaling Factors:** Calculate how to transform the range of floating-point values into the narrower range of integers.

* **Quantization:** Convert weights and activations into integers using the scaling factors.

#### PyTorch API

PyTorch offers several features that allow users to optimize their static quantization:

1. **Observers** (`torch.quantization.prepare`): we can customize observer modules, which specify how statistics are collected prior to quantization to try out more advanced methods to quantize your data.

2. **Operator fusion** (`torch.quantization.fuse_modules`): we can fuse multiple operations into a single operation, saving on memory access while also improving the operation's numerical accuracy.

3. Per-channel quantization: we can independently quantize weights for each output channel in a convolution/linear layer, which can lead to higher accuracy with almost the same speed.

[This tutorial shows how to do post-training static quantization, as well as illustrating two more advanced techniques - per-channel quantization and quantization-aware training - to further improve the model’s accuracy.](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html)

Finally, quantization itself is done using `torch.quantization.convert`. 

In [None]:
# set quantization config for server (x86)
# 'fbgemm' for server, 'qnnpack' for mobile
deploymentmyModel.qconfig = torch.quantization.get_default_config('fbgemm')

# insert observers, calibrate the model, and collect statistics
torch.quantization.prepare(myModel, inplace=True)

# convert to quantized version
torch.quantization.convert(myModel, inplace=True)

### 2.1.3 - Quantization Aware Training

**Quantization-aware training (QAT)** is the third method, and the one that typically results in highest accuracy of these three.

With QAT, all weights and activations are "fake quantized" during both the forward and backward passes of training: that is, float values are rounded to mimic `int8` values, but all computations are still done with floating point numbers. Thus, **all the weight adjustments during training are made while "aware" of the fact that the model will ultimately be quantized**; after quantizing, therefore, this method usually yields higher accuracy than the other two methods.

[This tutorial shows how to do post-training static quantization, as well as illustrating two more advanced techniques - per-channel quantization and quantization-aware training - to further improve the model’s accuracy.](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html)

#### PyTorch API

* `torch.quantization.prepare_qat` inserts fake quantization modules to model quantization.

* Mimicking the static quantization API, `torch.quantization.convert `actually quantizes the model once training is complete.

In [None]:
# specify quantization config for QAT
# 'fbgemm' for server, 'qnnpack' for mobile
qat_model.qconfig=torch.quantization.get_default_qat_qconfig('fbgemm')

# prepare QAT
torch.quantization.prepare_qat(qat_model, inplace=True)

# convert to quantized version, removing dropout, to check for accuracy on each
epochquantized_model=torch.quantization.convert(qat_model.eval(), inplace=False)

## 2.2 - Device and Operator Support in PyTorch

Quantization support is restricted to a subset of available operators, depending on the method being used, for a list of supported operators, please see the documentation at https://pytorch.org/docs/stable/quantization.html

The set of available operators and the quantization numerics also depend on the backend being used to run quantized models. Currently quantized operator are supported only for CPU inference in the following backends: 

* x86 (server).
* ARM (mobile).

Both the quantization configuration (how tensors should be quantized) and the quantized kernels (arithmetic with quantized tensors) are backend dependent. One can specify the backend by doing:

In [None]:
import torchbackend='fbgemm'

# 'fbgemm' for server, 'qnnpack' for mobile
my_model.qconfig = torch.quantization.get_default_qconfig(backend)

# prepare and convert model
# Set the backend on which the quantized kernels need to be run
torch.backends.quantized.engine=backend


However, quantization aware training occurs in full floating point and can run on either GPU or CPU. Quantization aware training is typically only used in CNN models when post training static or dynamic quantization doesn’t yield sufficient accuracy. This can occur with models that are highly optimized to achieve small size (such as Mobilenet).

## 2.3 - Choosing a Quantization approach

The choice of which scheme to use depends on multiple factors:

* Model/Target requirements: Some models might be sensitive to quantization, requiring quantization aware training.

* Operator/Backend support: Some backends require fully quantized operators.

Currently, operator coverage is limited and may restrict the choices listed in the table below: The table below provides a guideline.

| Model Type | Preferred Scheme | Why |
|---|---|---|
| LSTM/RNN | Dynamic Quantization | Throughput dominated by compute/memory bandwidth for weights |
| BERT/Transformer | Dynamic Quantization | Throughput dominated by compute/memory bandwidth for weights |
| CNN | Static Quantization | Throughput limited by memory bandwidth for activations |
| CNN | Quantization Aware Training | In the case where accuracy can't be achieved with static quantization |

### 2.3.1 - Performance Results

Quantization provides a 4x reduction in the model size and a speedup of 2x to 3x compared to floating point implementations depending on the hardware platform and the model being benchmarked. Some sample results are:

| Model | Float Latency (ms) | Quantized Latency (ms) | Inference Performance Gain | Device | Notes |
|---|---|---|---|---|---|
| BERT | 581 | 313 | 1.8x | Xeon-D2191 (1.6GHz) | Batch size = 1, Maximum sequence length= 128, Single thread, x86-64, Dynamic quantization |
| Resnet-50 | 214 | 103 | 2x | Xeon-D2191 (1.6GHz) | Single thread, x86-64, Static quantization |
| Mobilenet-v2 | 97 | 17 | 5.7x | Samsung S9 | Static quantization, Floating point numbers are based on Caffe2 run-time and are not optimized |

### 2.3.2 - Accuracy results

We also compared the accuracy of static quantized models with the floating point models on Imagenet. For dynamic quantization, we compared the F1 score of BERT on the GLUE benchmark for MRPC.

**Computer Vision Model accuracy**

| Model | Top-1 Accuracy (Float) | Top-1 Accuracy (Quantized) | Quantization Scheme |
|---|---|---|---|
| Googlenet | 69.8 | 69.7 | Static post-training quantization |
| Inception-v3 | 77.5 | 77.1 | Static post-training quantization |
| ResNet-18 | 69.8 | 69.4 | Static post-training quantization |
| ResNet-50 | 76.1 | 75.9 | Static post-training quantization |
| ResNeXt-101 32x8d | 79.3 | 79 | Static post-training quantization |
| Mobilenet-v2 | 71.9 | 71.6 | Quantization Aware Training |
| Shufflenet-v2 | 69.4 | 68.4 | Static post-training quantization |

**Speech and NLP Model accuracy**

| Model | F1 (GLUEMRPC) Float | F1 (GLUEMRPC) Quantized | Quantization scheme |
|---|---|---|---|
| BERT | 0.902 | 0.895 | Dynamic quantization |

# 3 - Quantization with Lightning Fabric

## 3.1 - Quantization via Bitsandbytes

## 3.2 - 8-bit Quantization

8-bit quantization is discussed in the popular paper [8-bit Optimizers via Block-wise Quantization (Dettmers et al., 2022)](https://arxiv.org/abs/2110.02861) and was introduced in [FP8 Formats for Deep Learning (Micikevicius et al., 2022)](https://arxiv.org/pdf/2209.05433.pdf).

As stated in the original paper, 8-bit quantization was the natural progression after 16-bit precision. Although it was the natural progression, the implementation was not as simple as moving from FP32 to FP16 – as **those two floating point types share the same representation scheme and 8-bit does not**.

8-bit quantization requires a new representation scheme, and this new scheme allows for fewer numbers to be represented than FP16 or FP32. This means model performance may be affected when using quantization, so it is good to be aware of this trade-off. Additionally, model performance should be evaluated in its quantized form if the weights will be used on an edge device that requires quantization.

Lightning Fabric can use 8-bit quantization by setting the `mode` flag to `int8` for inference.

In [None]:
from lightning.fabric import Fabric
from lightning.fabric.plugins import BitsandbytesPrecision

# available 8-bit quantization modes
# ("int8")

mode = "int8"
plugin = BitsandbytesPrecision(mode=mode)
fabric = Fabric(plugins=plugin)

model = CustomModule() # your PyTorch model
model = fabric.setup_module(model) # quantizes the layers

## 3.3 - 4-bit Quantization

4-bit quantization is discussed in the popular paper [QLoRA: Efficient Finetuning of Quantized LLMs. (Dettmers el al., 2023)](https://arxiv.org/abs/2305.14314). QLoRA is a finetuning method that uses 4-bit quantization. The paper introduces this finetuning technique and demonstrates how it can be used to "finetune a 65B parameter model on a single 48GB GPU while preserving full 16-bit finetuning task performance" by using the NF4 (normal float) format.

Lightning Fabric can use 4-bit quantization by setting the `mode` flag to either `nf4` or `fp4`.

In [None]:
from lightning.fabric import Fabric
from lightning.fabric.plugins import BitsandbytesPrecision

# available 4-bit quantization modes
# ("nf4", "fp4")

mode = "nf4"
plugin = BitsandbytesPrecision(mode=mode)
fabric = Fabric(plugins=plugin)

model = CustomModule() # your PyTorch model
model = fabric.setup_module(model) # quantizes the layers

## 3.4 - Double Quantization

Double qunatization exists as an extra 4-bit quantization setting introduced alongside NF4 in the QLoRA paper. Double qunatization works by quantizing the quantization constants that are internal to bitsandbytes’ procedures.

Lightning Fabric can use 4-bit double quantization by setting the `mode` flag to either `nf4-dq` or `fp4-dq`.

In [None]:
from lightning.fabric import Fabric
from lightning.fabric.plugins import BitsandbytesPrecision

# available 4-bit double quantization modes
# ("nf4-dq", "fp4-dq")

mode = "nf4-dq"
plugin = BitsandbytesPrecision(mode=mode)
fabric = Fabric(plugins=plugin)

model = CustomModule() # your PyTorch model
model = fabric.setup_module(model) # quantizes the layers