# Quantization in Depth
The goal of this notebook is to provide an in depth understanding of Quantization, its theory, use cases, and implementation. The concepts, and instruction within this notebook are used from the [DeepLearning.AI](https://www.deeplearning.ai) course: [Quantization in Depth](https://learn.deeplearning.ai/courses/quantization-in-depth/lesson/1/introduction). I highly recommend watching, and completing this course on your own time. However, I wanted to provide an all-in-one notebook, including my insights as I take the course for those who may be interested.

***
## Outline:
### Linear Quantization
This Notebook aims to help gain an understanding of Linear Quantization. By deep diving into the internals of linear quantization and implementing the variance from scratch (per channel, tensor, and group quantization) we should be able to study the advantages/drawbacks for each method and their impacts on some example tensors.
<br><br>
### BYO 8-Bit Quantizer
Building our own quantizer to quantize any model in 8-bit precision using one of the quantization schemes presented before. Quantization schemes are agnostic to modality, meaning: it can be applied to any model as long as it contains linear layers. technically your quantizer will be able to quantize a vision, text, audio, or even a multimodal model.

### Quantization Packages
Learn more about challenges that can be faced regarding extreme quantization such as weight packing or challenges regarding LLM quantization. 

# <font color=orange>Linear Quantization I-A: Quantize and De-quantize a Tensor</font>
In this lesson, you will learn the fundamentals of linear quantization.

In [None]:
import torch
from helper import plot_quantization_errors, plot_results

## Quantization with Random Scale and Zero Point
* Implement Linear Quantization for when the "scale" and the "zero point" are known/randomly selected. <br>

***Linear Quantization Formula:***
**<font color=orange>r</font>** = original value (input/high-bit), **<font color=purple>s</font>** = Scale (input/high-bit)
**<font color=red>q</font>** = quantized value (output/Low-bit), **<font color=olive>z</font>** = zero point
$ r = s(q - z) $ or $ q = int(round(r/s + z)) $

In [None]:
def linear_q_with_scale_and_zero_point(tensor, scale, zero_point, dtype = torch.int8):

    scaled_and_shifted_tensor = tensor / scale + zero_point

    rounded_tensor = torch.round(scaled_and_shifted_tensor)

    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max

    q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)
    
    return q_tensor

In [None]:
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

In [None]:
### these are random values for "scale" and "zero_point"
### to test the implementation
scale = 3.5
zero_point = -70

In [None]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor, scale, zero_point)
quantized_tensor

In [None]:
# What happens if we do not cast the quantized tensor to float?
dequantized_tensor = scale * (quantized_tensor - zero_point)
dequantized_tensor

In [None]:
# CORRECT implementation: casting quantized tensor to float
def linear_dequantization(quantized_tensor, scale, zero_point):
    return scale * (quantized_tensor.float() - zero_point)

dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point)
dequantized_tensor

In [None]:
plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)

In [None]:
# Quantization Error: Calculate an "overall" quantization error by using Mean Squared Error technique.
(dequantized_tensor - test_tensor).square().mean()

In [None]:
def quantization_mse(dequantized_tensor, tensor):
    print(f"Quantization Mean Squared Error: {(dequantized_tensor - tensor).square().mean()}")

In [None]:
q_error = quantization_mse(dequantized_tensor, test_tensor)
q_error

***Quantization Error: <font color=red>170.8753</font>***
This error is considered to be quite high however this is due to the randomly assigned zero-point and scale values. In the next sections we'll cover how to derive closer or exact values
<br>
#### Advantages of Quantization
* Smaller Model
* Speed Increase:
  * Memory Bandwidth
  * Faster Operations:
      * GEMM: General Matrix Multiply (matrix to matrix multiplication)
      * GEMV: General Matrix Multiply (matrix to vector multiplication)

# <font color=orange>Linear Quantization I-B: Get the Scale and Zero Point</font>
In this lesson, continue to learn about fundamentals of linear quantization, and implement your own Linear Quantizer.

#### ***Scale and Zero-Point***
If we look at extreme values we should get:
$ r_{min} = s(q_{min}- z) $ 
$ r_{max} = s(q_{max}- z) $ 

subtracting the first equation from the second we get the scale s:
$ s = (r_{max} - r_{min}) / (q_{max} - q_{min})$

for the zero point we need to round the value:
$ z = int(round(q_{min} - (r_{min}/s))) $
*The goal is to represent 0 in the original 'r' range with an integer in the quantized 'q' range*

Therefore, for our previous example:
$ s = (728.6 - (-184)) / (127 - (-128)) >> 912/255 >>s = 3.58$
$ z = int(round((-128) - (-184)/3.58) >> int(round((-128) - (-51.4))) >> int(round(-76.6)) >> z = -77$

What do you do if the zero point is out of range?
case 1: (z < q_min) >> set z = q_min
case 2: (z > q_max) >> set z = q_max
*This elminiates overflow and underflow*

In [None]:
q_min = torch.iinfo(torch.int8).min
q_max = torch.iinfo(torch.int8).max
print(f"q_min: {q_min}, q_max: {q_max}")

In [None]:
r_min = test_tensor.min().item()
r_max = test_tensor.max().item()
print(f"r_min: {r_min}, r_max: {r_max}")

In [None]:
scale = (r_max - r_min) / (q_max - q_min)
print(f"scale: {scale}")

In [None]:
zero_point = int(round(q_min - (r_min/scale)))
print(f"zero_point: {zero_point}")

In [None]:
def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
    q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
    r_min, r_max = tensor.min().item(), tensor.max().item()
    scale = (r_max - r_min) / (q_max - q_min)
    zero_point = (q_min - (r_min/scale))
    if zero_point < q_min:
        zero_point = q_min
    elif zero_point > q_max:
        zero_point = q_max
    else:
        zero_point = int(round(zero_point))
        
    return scale, zero_point

In [None]:
new_scale , new_zero_point = get_q_scale_and_zero_point(test_tensor)
print(f"new_scale: {new_scale}, new_zero_point: {new_zero_point}")

In [None]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor, new_scale, new_zero_point)
dequantized_tensor = linear_dequantization(quantized_tensor, new_scale, new_zero_point)
plot_quantization_errors(test_tensor, quantized_tensor,dequantized_tensor)

In [None]:
(dequantized_tensor - test_tensor).square().mean()

In [None]:
def linear_quantization(tensor, dtype=torch.int8):
    scale, zero_point = get_q_scale_and_zero_point(tensor, dtype=dtype)
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor, scale, zero_point, dtype=dtype)
    return quantized_tensor, scale, zero_point

In [None]:
r_tensor = torch.randn((4,4))
r_tensor

In [None]:
quantized_tensor, scale, zero_point = linear_quantization(r_tensor)
quantized_tensor

In [None]:
scale

In [None]:
zero_point

In [None]:
dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point)

In [None]:
plot_quantization_errors(r_tensor, quantized_tensor, dequantized_tensor)

In [None]:
quantization_mse(dequantized_tensor, r_tensor)

### Custom Attempt at Batch Normalization Quantization
Turns out specifically, this class uses per channel quantization (per row)

In [None]:
class BatchTensor:
    def __init__(self, tensor, dtype):
        # TENSORS
        self.tensor = tensor
        self.quantized_tensor = None
        self.dequantized_tensor = None
        
        # VARIABLES
        self.dtype = dtype
        self.rmin = self.tensor.min().item()
        self.rmax = self.tensor.max().item()
        self.qmin = torch.iinfo(self.dtype).min
        self.qmax = torch.iinfo(self.dtype).max
        self.scale = []
        self.zero_point = []
        
        # MISCELLANEOUS
        self.rows_size = tensor.size(0)
        self.cols_size = tensor.size(1)
        self.mse_tensor = None
        
    def get_batch_scales_and_zeropoints(self):
        if self.scale or self.zero_point:
            self.clear()
        for i in range(self.rows_size):
            a, b = get_q_scale_and_zero_point(self.tensor[i])
            self.scale.append(a), self.zero_point.append(b)
            
    def linear_batch_quantization(self):
        self.quantized_tensor = torch.zeros(self.rows_size, self.cols_size)
        for i in range(self.rows_size):
            self.quantized_tensor[i] = linear_q_with_scale_and_zero_point(self.tensor[i], self.scale[i], self.zero_point[i])
    
    def linear_batch_dequantization(self):
        self.dequantized_tensor = torch.zeros(self.rows_size, self.cols_size)
        for i in range(self.rows_size):
            self.dequantized_tensor[i] = linear_dequantization(self.quantized_tensor[i], self.scale[i], self.zero_point[i])
            
    def batch_quantization_mse(self):
        return (bt.dequantized_tensor - bt.tensor).square().mean()
            
    def clear(self):
        self.scale.clear()
        self.zero_point.clear()

In [None]:
t = torch.randn((4,4))
bt = BatchTensor(t, torch.int8)
bt.tensor

In [None]:
bt.get_batch_scales_and_zeropoints()
bt.linear_batch_quantization()
bt.linear_batch_dequantization()
plot_quantization_errors(bt.tensor, bt.quantized_tensor, bt.dequantized_tensor)
bt.batch_quantization_mse()

In [None]:
new_scale , new_zero_point = get_q_scale_and_zero_point(bt.tensor)
quantized_tensor = linear_q_with_scale_and_zero_point(bt.tensor, new_scale, new_zero_point)
dequantized_tensor = linear_dequantization(quantized_tensor, new_scale, new_zero_point)
plot_quantization_errors(bt.tensor, quantized_tensor,dequantized_tensor)
quantization_mse(dequantized_tensor, bt.tensor)

# <font color=orange>Linear Quantization II-A: Symmetric vs. Asymmetric Mode</font>
In this lesson, you will learn a different way of performing linear quantization, Symmetric Mode.

There are **two** modes in linear quantization
* **Asymmetric**: We map [$r_{min}, r_{max}$] to [$q_{min}, q_{max}$] (*What was implemented previously*)
* **Symmetric**: We map [$-r_{min}, r_{max}$] to [$-q_{min}, q_{max}$], where we can set $r_{max} = max(|r_{tensor}|)$
    * We don't need to use zero point ($z=0$), beacuse the floating point range and the quantized range are symmetric with respect to zero
    * Hence, We can simplify the equations to:
    * $q = int(round(r/s))$
    * $s = r_{max}/q_{max}$

In [None]:
import torch
from helper import plot_quantization_errors

In [None]:
def get_q_scale_symmetric(tensor, dtype=torch.int8):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(dtype).max
    return r_max / q_max

In [None]:
test_tensor = torch.rand((4,4))
test_tensor

In [None]:
get_q_scale_symmetric(test_tensor)

In [None]:
def linear_q_symmetric(tensor, dtype=torch.int8):
    scale = get_q_scale_symmetric(tensor)
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor, scale=scale, zero_point=0, dtype=dtype)
    return quantized_tensor, scale    

In [None]:
quantized_tensor, scale = linear_q_symmetric(test_tensor)

In [None]:
dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)

In [None]:
plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)
quantization_mse(dequantized_tensor, test_tensor)

Trade-Offs:
* **Utilization of Quantized Range**
    * Asymmetric quantization fully utilizes the quantized range
    * Symmetric mode will dedicate values of the quantized range to unnecessary values if the float range is biased towards one side. (e.g. RELU where the output is always positive)
*  **Simplicity**
    * Symmetric mode is much simpler and straightforward then Assymetric mode
*  **Memory**
    * Zero-points are not stored which saves memory    

# <font color=orange>Linear Quantization II-B: Finer Granularity for more Precision</font>
In this lesson, you will learn about different granularities of performing linear quantization.

### Per Tensor Quantization:

In [None]:
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

quantized_tensor, scale = linear_q_symmetric(test_tensor)
dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)
plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)
quantization_mse(dequantized_tensor, test_tensor)

### Per Channel Quantization:

In [None]:
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

In [None]:
dim=0 # dim=0 means along rows, dim=1 means along columns
output_dim = test_tensor.shape[dim]
output_dim

In [None]:
scale = torch.zeros(output_dim)
scale

In [None]:
for i in range(output_dim):
    sub_tensor = test_tensor.select(dim, i)
    scale[i] = get_q_scale_symmetric(sub_tensor)
scale

In [None]:
scale_shape = [1] * test_tensor.dim()
scale_shape

In [None]:
scale_shape[dim]=-1
scale_shape

In [None]:
scale = scale.view(scale_shape)
scale

In [None]:
scale.shape

In [None]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor, scale=scale, zero_point=0)
quantized_tensor

In [None]:
def linear_q_symmetric_per_channel(r_tensor, dim, dtype=torch.int8):
    
    output_dim = r_tensor.shape[dim]
    # store the scales
    scale = torch.zeros(output_dim)

    for index in range(output_dim):
        sub_tensor = r_tensor.select(dim, index)
        scale[index] = get_q_scale_symmetric(sub_tensor, dtype=dtype)

    # reshape the scale
    scale_shape = [1] * r_tensor.dim()
    scale_shape[dim] = -1
    scale = scale.view(scale_shape)
    quantized_tensor = linear_q_with_scale_and_zero_point(
        r_tensor, scale=scale, zero_point=0, dtype=dtype)
   
    return quantized_tensor, scale

In [None]:
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

In [None]:
### along the rows (dim = 0)
quantized_tensor_0, scale_0 = linear_q_symmetric_per_channel(test_tensor, dim=0)

### along the columns (dim = 1)
quantized_tensor_1, scale_1 = linear_q_symmetric_per_channel(test_tensor, dim=1)

In [None]:
dequantized_tensor_0 = linear_dequantization(quantized_tensor_0, scale_0, 0)
plot_quantization_errors(test_tensor, quantized_tensor_0, dequantized_tensor_0)
quantization_mse(dequantized_tensor_0, test_tensor)

In [None]:
dequantized_tensor_1 = linear_dequantization(quantized_tensor_1, scale_1, 0)
plot_quantization_errors(test_tensor, quantized_tensor_1, dequantized_tensor_1, n_bits=8)
quantization_mse(dequantized_tensor_1, test_tensor)

### Per Group Quantization:

In [None]:
def linear_q_symmetric_per_group(tensor, group_size, dtype=torch.int8):
    t_shape = tensor.shape
    assert t_shape[1] % group_size == 0
    assert tensor.dim()==2
    tensor = tensor.view(-1, group_size)
    quantized_tensor, scale = linear_q_symmetric_per_channel(tensor, dim=0, dtype=dtype)
    quantized_tensor = quantized_tensor.view(t_shape)
    return quantized_tensor, scale

In [None]:
def linear_dequantization_per_group(quantized_tensor, scale, group_size):
    q_shape = quantized_tensor.shape
    quantized_tensor = quantized_tensor.view(-1, group_size)
    dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)
    dequantized_tensor = dequantized_tensor.view(q_shape)
    return dequantized_tensor

In [None]:
test_tensor = torch.rand((6,6))

In [None]:
group_size = 6
quantized_tensor, scale = linear_q_symmetric_per_group(test_tensor, group_size=group_size)
dequantized_tensor = linear_dequantization_per_group(quantized_tensor, scale, group_size=group_size)
plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)
quantization_mse(dequantized_tensor, test_tensor)

# <font color=orange>Linear Quantization II-C: Quantizing Weights & Activations for Inference</font>

In a NN we can quantize the **weights** *and* **activations** >> depending on what is quantized, the **storage** and **computation** are not the same:
* **Storage** = Quantized Weight + Activation (e.g. W8A32) >> **Computation** = Floating Point arithmetic (FP32, FP16, BF16...)
* **Storage** = Quantized Weight + Quantized Activation (e.g. W8A8) >> **Computation** = Integer based arithmetic (Int8, Int4...)

In [None]:
def quantized_linear_W8A32_without_bias(input, q_w, s_w, z_w):
    assert input.dtype == torch.float32
    assert q_w.dtype == torch.int8

    dequantized_weight = q_w.to(torch.float32) * s_w + z_w
    output = torch.nn.functional.linear(input, dequantized_weight)
    
    return output

In [None]:
input = torch.tensor([1,2,3], dtype=torch.float32)
weight = torch.tensor([[-2,   -1.13, 0.42],
                       [-1.51, 0.25, 1.62],
                       [0.23,  1.35, 2.15]])

In [None]:
q_w, s_w  = linear_q_symmetric(weight)
q_w

In [None]:
s_w

In [None]:
output = quantized_linear_W8A32_without_bias(input, q_w, s_w, 0)
print(f"This is the W8A32 output: {output}")

In [None]:
fp32_output = torch.nn.functional.linear(input, weight)
print(f"This is the output if we don't quantize: {fp32_output}")

# <font color=orange>Building your own Quantizer: Custom Build an 8-Bit Quantizer</font>
In this lesson, you will learn how to compress any model in 8-bit precision leveraging past tools. This Quantizer is Model Agnostic.
* Create a 'W8A16LinearLayer' class to store 8-bit weights and scales
* Replace all 'torch.nn.linear' layers with 'W8A16LinearLayer'
* Build Quantizer and quantize the model end-to-end
* Testing the naive absmax quantization on many scenarios and study the impact

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### 1.1 - w8_a16_forward Function
-
W8A16LinearLayer
                    # 8-bit  # 16-bit         # optional
* w8_a16_forward -> weights, input,   scales, bias=None
* Cast the 8-bit weights to the same data type as the input, "casted weights",
* keeping the "casted weights" in the same range as before, [-128, 127]

Next,

((𝑖𝑛𝑝𝑢𝑡𝑠 * ''casted weights'') * 𝑠𝑐𝑎𝑙𝑒) + 𝑏𝑖𝑎𝑠

In [None]:
random_int8 = torch.randint(-128, 127, (32, 16)).to(torch.int8)
random_hs = torch.randn((1, 16), dtype=torch.bfloat16)
scales = torch.randn((1, 32), dtype=torch.bfloat16)
bias = torch.randn((1, 32), dtype=torch.bfloat16)

In [None]:
F.linear(random_hs, random_int8.to(random_hs.dtype))

In [None]:
F.linear(random_hs, random_int8.to(random_hs.dtype)) * scales

In [None]:
(F.linear(random_hs, random_int8.to(random_hs.dtype)) * scales) + bias

In [None]:
def w8_a16_forward(weight, input, scales, bias=None):
    
    casted_weights = weight.to(input.dtype)
    output = F.linear(input, casted_weights) * scales
    
    if bias is not None:
        output = output + bias
      
    return output

In [None]:
print("With bias:\n\n", 
      w8_a16_forward(random_int8, random_hs, scales, bias))

print("\nWithout bias:\n\n", 
      w8_a16_forward(random_int8, random_hs, scales))

### 1.2 - init Function of class W8A16LinearLayer¶

- This is how the `init` is of [PyTorch Linear layer](https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear):
```Python
def __init__(self, in_features, out_features, bias=True,
             device=None, dtype=None)

```

In [None]:
### running this will result in an error
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        self.int8_weights = nn.Parameter(torch.Tensor([0, 1]
                                     ).to(dtype=torch.int8))

try:
    
    W8A16LinearLayer(1, 1)
    
except Exception as error:
    print("\033[91m", type(error).__name__, ": ", error, "\033[0m")

In [None]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        
        self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )
        
        self.register_buffer("scales", 
                             torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", 
                                 torch.randn((1, out_features), 
                                             dtype=dtype))
        
        else:
            self.bias = None

In [None]:
dummy_instance = W8A16LinearLayer(16,32)
print(dummy_instance.int8_weights.shape)
print(dummy_instance.scales.shape)

### 1.3 - `forward` Function of class `W8A16LinearLayer`

- Use the `w8_a16_forward` defined earlier (Step 1.1) to define the `forward` function.

In [None]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        
        self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )
        
        self.register_buffer("scales", 
                             torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", 
                                 torch.randn((1, out_features), 
                                             dtype=dtype))
        
        else:
            self.bias = None

    def forward(self, input):
        return w8_a16_forward(self.int8_weights, 
                              input, self.scales, self.bias)

In [None]:
module = W8A16LinearLayer(16, 32)
dummy_hidden_states = torch.randn(1, 6, 16)

In [None]:
module(dummy_hidden_states).shape

In [None]:
module(dummy_hidden_states).dtype

### 1.4 - `quantize` Function of class `W8A16LinearLayer`

- `quantize` function will dynamically quantize half-precision weights into `torch.int8`

In [None]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        
        self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )
        
        self.register_buffer("scales", 
                             torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", 
                                 torch.randn((1, out_features), 
                                             dtype=dtype))
        
        else:
            self.bias = None

    def quantize(self, weights):
        w_fp32 = weights.clone().to(torch.float32)

        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)

        int8_weights = torch.round(weights
                        /scales.unsqueeze(1)).to(torch.int8)

        self.int8_weights = int8_weights
        self.scales = scales
    
    def forward(self, input):
        return w8_a16_forward(self.int8_weights, 
                              input, self.scales, self.bias)      

In [None]:
module = W8A16LinearLayer(4,8)
print('Weights before:\n' , module.int8_weights)

In [None]:
random_matrix = torch.randn((4, 8), dtype=torch.bfloat16)
module.quantize(random_matrix)
print("Weights After:\n" , module.int8_weights)

In [None]:
module.scales

In [None]:
module.scales.shape

In [None]:
module.int8_weights.shape

In [None]:
### dequantized weights
module.int8_weights * module.scales.unsqueeze(1)

In [None]:
### original weights
random_matrix

In [None]:
# Find the Quantization Error
(random_matrix - module.int8_weights 
 * module.scales.unsqueeze(1)).abs().mean()

# <font color=orange>Building your own Quantizer: Replace PyTorch layers with Quantized Layers</font>


In [None]:
import torch
import torch.nn as nn

### 2.1 - Model In-place Linear Layer Replacement
- Implement `replace_linear_with_target`

In [None]:
def replace_linear_with_target(module, target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias
            new_module = target_class(child.in_features,
                                      child.out_features,
                                      old_bias is not None,
                                      child.weight.dtype)
            setattr(module, name, new_module) # Replace the parent module with the correct name to the new_module

            if old_bias is not None:
                getattr(module, name).bias = old_bias
            else: 
                # Recursively call the function for nested modules
                replace_linear_with_target(child, target_class, module_name_to_exclude)

In [None]:
class DummyModel(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = torch.nn.Embedding(1, 1)
    # Try with bias
    self.linear_1 = nn.Linear(1, 1)
    # Try without bias
    self.linear_2 = nn.Linear(1, 1, bias=False)
    # Lm prediction head
    self.lm_head = nn.Linear(1, 1, bias=False)

In [None]:
model_1 = DummyModel()
model_2 = DummyModel()

In [None]:
replace_linear_with_target(model_1, W8A16LinearLayer, ["lm_head"])
print(model_1)

In [None]:
replace_linear_with_target(model_2, W8A16LinearLayer, [])
print(model_2)

### 2.2 - Linear Layer Replacement + Quantization
- Modify the `replace_linear_with_target` function to also perform quantization.
- Implement `replace_linear_with_target_and_quantize`.

In [None]:
def replace_linear_with_target_and_quantize(module, 
                               target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not \
        any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias
            old_weight = child.weight

            new_module = target_class(child.in_features, 
                                      child.out_features, 
                                      old_bias is not None, 
                                      child.weight.dtype)
            setattr(module, name, new_module)

            getattr(module, name).quantize(old_weight)
            
            if old_bias is not None:
              getattr(module, name).bias = old_bias
        else:
            # Recursively call the function for nested modules
            replace_linear_with_target_and_quantize(child, 
                     target_class, module_name_to_exclude)

In [None]:
model_3 = DummyModel()
replace_linear_with_target_and_quantize(model_3, W8A16LinearLayer, ["lm_head"])
print(model_3)

# <font color=orange>Building your own Quantizer: Quantize any Open Source PyTorch Model</font>
In this lesson, you will look at the results of open source models compressed using the custom quantizer you built.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Step 3: Test the Implementation on Various LLMs
#### 3.1 - [Salesforce/codegen-350M-mono](https://huggingface.co/Salesforce/codegen-350M-mono)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_id = 'Salesforce/codegen-350M-mono'

model = AutoModelForCausalLM.from_pretrained(model_id, 
                                    torch_dtype=torch.bfloat16, 
                                             low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
print(pipe("def hello_world():", max_new_tokens=20, do_sample=False))

In [None]:
print("Model before:\n\n", model)

In [None]:
replace_linear_with_target_and_quantize(model, W8A16LinearLayer, ["lm_head"])
pipe.model

In [None]:
print(pipe("def hello_world():", max_new_tokens=20, do_sample=False)[0]["generated_text"])

### 3.2 - [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50)

In [None]:
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import requests

# you can specify the revision tag if you don't want the timm dependency
processor = DetrImageProcessor.from_pretrained(
    "facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained(
    "facebook/detr-resnet-50", revision="no_timm")

In [None]:
previous_memory_footprint = model.get_memory_footprint()
previous_memory_footprint

In [None]:
img_path = "dinner_with_friends.png"
image = Image.open(img_path).convert("RGB")
image

In [None]:
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
  outputs = model(**inputs)

# convert outputs (bounding boxes and class logits) to COCO API
# let's only keep detections with score > 0.9
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
plot_results(model, image, results)

In [None]:
model

In [None]:
replace_linear_with_target_and_quantize(model, W8A16LinearLayer, ["0", "1", "2", "class_labels_classifier"])
model

- Visualize results after quantization.

In [None]:
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
  outputs = model(**inputs)

# convert outputs (bounding boxes and class logits) to COCO API
# let's only keep detections with score > 0.9
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
plot_results(model, image, results)

In [None]:
new_footprint = model.get_memory_footprint()
print("Footprint of the model in MBs: ", new_footprint/1e+6)
### Memory saved
print("Memory saved in MBs: ", (previous_memory_footprint - new_footprint)/1e+6)