# Quantization

---

In this lesson, we'll discuss the concept of "quantization". This technique helps reduce the memory overhead of a model and enables running inference with larger LLMs.

## Import required packages & load LLM

In [1]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

In [2]:
model_name = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

## Reuse KV-cache text generation function

In [3]:
from scripts.helper import generate

response = generate(model, tokenizer, "The quick brown fox jumped over the", 10)
print(response)

 fence and ran to the other side of the fence


## Add padding tokens to the model to prepare batches of prompts

In [4]:
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

In [5]:
# pad on the left so we can append new tokens on the right
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

## Define a Float 32 type

In [6]:
# fix dtype post quantization to "pretend" to be fp32
def get_float32_dtype(self):
    return torch.float32
GPT2Model.dtype = property(get_float32_dtype)

In [7]:
model.get_memory_footprint()

510342192

## Floating point representation

<img src="../../../images/floating-point-repr.png" alt="Floating point precision" style="width: 70%; height: auto;"/>

## Define a quantization function

In [8]:
def quantize(t):
    """
    Take the input tensor of floating point values and quantize it to uint8 that lies between 0 and 255.
    """
    # obtain range of values in the tensor to map between 0 and 255
    min_val, max_val = t.min(), t.max()

    # determine the "zero-point", or value in the tensor to map to 0
    scale = (max_val - min_val) / 255
    zero_point = min_val

    # quantize and clamp to ensure we're in [0, 255]
    t_quant = (t - zero_point) / scale
    t_quant = torch.clamp(t_quant, min=0, max=255)

    # keep track of scale and zero_point for reversing quantization
    state = (scale, zero_point)

    # cast to uint8 and return
    t_quant = t_quant.type(torch.uint8)
    return t_quant, state

In [9]:
t = model.transformer.h[0].attn.c_attn.weight.data
print(t, t.shape)

tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.0513, -0.0584,  0.0250],
        [ 0.0874,  0.1473,  0.2387,  ..., -0.0525, -0.0113, -0.0156],
        [ 0.0039,  0.0695,  0.3668,  ...,  0.1143,  0.0363, -0.0318],
        ...,
        [-0.2592, -0.0164,  0.1991,  ...,  0.0095, -0.0516,  0.0319],
        [ 0.1517,  0.2170,  0.1043,  ...,  0.0293, -0.0429, -0.0475],
        [-0.4100, -0.1924, -0.2400,  ..., -0.0046,  0.0070,  0.0198]]) torch.Size([768, 2304])


In [10]:
t_q, state = quantize(t)
print(t_q, t_q.min(), t_q.max())

tensor([[107, 116, 124,  ..., 130, 125, 129],
        [132, 135, 139,  ..., 126, 128, 127],
        [128, 131, 145,  ..., 133, 130, 127],
        ...,
        [116, 127, 137,  ..., 129, 126, 130],
        [135, 138, 133,  ..., 129, 126, 126],
        [110, 119, 117,  ..., 128, 128, 129]], dtype=torch.uint8) tensor(0, dtype=torch.uint8) tensor(255, dtype=torch.uint8)


## Define a dequantization function

In [11]:
def dequantize(t, state):
    """
    Take the quantized tensor and dequantize it back to floating point values.
    """
    scale, zero_point = state
    return t.to(torch.float32) * scale + zero_point

In [12]:
t_rev = dequantize(t_q, state)
print(t_rev)

tensor([[-0.4774, -0.2783, -0.1014,  ...,  0.0313, -0.0793,  0.0092],
        [ 0.0755,  0.1419,  0.2303,  ..., -0.0572, -0.0129, -0.0351],
        [-0.0129,  0.0534,  0.3630,  ...,  0.0976,  0.0313, -0.0351],
        ...,
        [-0.2783, -0.0351,  0.1861,  ...,  0.0092, -0.0572,  0.0313],
        [ 0.1419,  0.2082,  0.0976,  ...,  0.0092, -0.0572, -0.0572],
        [-0.4110, -0.2120, -0.2562,  ..., -0.0129, -0.0129,  0.0092]])


Please note that there will be a loss of precision when quantizing and dequantizing the model as this is a lossy compression. This is a trade-off between memory and precision.

In [13]:
torch.abs(t - t_rev)

tensor([[0.0035, 0.0170, 0.0036,  ..., 0.0200, 0.0209, 0.0158],
        [0.0119, 0.0055, 0.0084,  ..., 0.0046, 0.0017, 0.0195],
        [0.0168, 0.0161, 0.0038,  ..., 0.0167, 0.0050, 0.0032],
        ...,
        [0.0191, 0.0187, 0.0131,  ..., 0.0004, 0.0056, 0.0006],
        [0.0098, 0.0088, 0.0067,  ..., 0.0202, 0.0143, 0.0097],
        [0.0010, 0.0196, 0.0162,  ..., 0.0084, 0.0199, 0.0107]])

In [16]:
from scripts.helper import generate

prompt = "The quick brown fox jumped over the"
max_tokens = 10
response = generate(model, tokenizer, prompt, max_tokens)
print(prompt)
print(response)

The quick brown fox jumped over the
 fence and ran to the other side of the fence


## Let's apply the quantization technique to the entire model

In [17]:
def quantize_model(model):
    states = {}
    for name, param in model.named_parameters():
        param.requires_grad = False
        param.data, state = quantize(param.data)
        states[name] = state
    return model, states

In [18]:
quant_model, states = quantize_model(model)

In [19]:
quant_model.get_memory_footprint()

137022768

Let's also get the memory footprint of state dictionary.

In [20]:
def size_in_bytes(t):
    return t.numel() * t.element_size()

In [21]:
sum([
    size_in_bytes(v[0]) + size_in_bytes(v[1])
    for v in states.values()
])

1184

In [22]:
def dequantize_model(model, states):
    for name, param in model.named_parameters():
        state = states[name]
        param.data = dequantize(param.data, state)
    return model

In [23]:
dequant_model = dequantize_model(quant_model, states)

In [24]:
dequant_model.get_memory_footprint()

510342192

In [25]:
from scripts.helper import generate

prompt = "The quick brown fox jumped over the"
max_tokens = 10
response = generate(dequant_model, tokenizer, prompt, max_tokens)
print(prompt)
print(response)

The quick brown fox jumped over the
 fence.

The fox jumped over the fence


We can see that the output from dequantized model is slightly different from what we originally obtained from the model. This is because of the loss of precision during quantization and dequantization.

Hence, quantizing and dequantizing does have an imapct on overall model output and quality, and different quantization techniques exist to maximize the gains from quantization while minimizing the impact on quality.