# Lesson 4 - Quantization

In this lesson, we'll discuss the concept of "quantization". This technique helps reduce the memory overhead of a model and enables running inference with larger LLMs.

### Import required packages and load the LLM

In [1]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

from utils import generate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)



In [3]:
model.config

GPT2Config {
  "_name_or_path": "gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.35.2",
  "use_cache": true,
  "vocab_size": 50257
}

In [4]:
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# pad on the left so we can append new tokens on the right
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

### Define a Float 32 type

In [5]:
# fix dtype post quantization to "pretend" to be fp32
def get_float32_dtype(self):
    return torch.float32

In [6]:
GPT2Model.type = property(get_float32_dtype)

Check memory footprint of non-quantized model

In [7]:
model.get_memory_footprint()

510342192

### Define a quantization function

In [8]:
def quantize(t):
    # obtain range of values in the tensor to map between 0 and 255
    min_val, max_val = t.min(), t.max()

    # determine the "zero-point", or value in the tensor to map to 0
    scale = (max_val - min_val)/255
    zero_point = min_val

    # quantize and clamp to ensure we're in [0, 255]
    t_quant = (t - zero_point)/scale
    t_quant = torch.clamp(t_quant, min=0, max=255)

    # keep track of scale and zero_point for reversing quantization
    state = (scale, zero_point)

    # cast to uint8 and return
    t_quant = t_quant.type(torch.uint8)

    return t_quant, state

In [9]:
model.transformer

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [10]:
t = model.transformer.h[0].attn.c_attn.weight.data
print(t, t.shape)

tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.0513, -0.0584,  0.0250],
        [ 0.0874,  0.1473,  0.2387,  ..., -0.0525, -0.0113, -0.0156],
        [ 0.0039,  0.0695,  0.3668,  ...,  0.1143,  0.0363, -0.0318],
        ...,
        [-0.2592, -0.0164,  0.1991,  ...,  0.0095, -0.0516,  0.0319],
        [ 0.1517,  0.2170,  0.1043,  ...,  0.0293, -0.0429, -0.0475],
        [-0.4100, -0.1924, -0.2400,  ..., -0.0046,  0.0070,  0.0198]]) torch.Size([768, 2304])


Let's quantize and check the values

In [11]:
t_q, state = quantize(t)
print(t_q, t_q.min(), t_q.max())

tensor([[107, 116, 124,  ..., 130, 125, 129],
        [132, 135, 139,  ..., 126, 128, 127],
        [128, 131, 145,  ..., 133, 130, 127],
        ...,
        [116, 127, 137,  ..., 129, 126, 130],
        [135, 138, 133,  ..., 129, 126, 126],
        [110, 119, 117,  ..., 128, 128, 129]], dtype=torch.uint8) tensor(0, dtype=torch.uint8) tensor(255, dtype=torch.uint8)


### Define a dequantization function

In [12]:
def dequantize(t, state):
    scale, zero_point = state
    return t.to(torch.float32)*scale + zero_point

In [13]:
t_rev = dequantize(t_q, state=state)
print(t_rev)

tensor([[-0.4774, -0.2783, -0.1014,  ...,  0.0313, -0.0793,  0.0092],
        [ 0.0755,  0.1419,  0.2303,  ..., -0.0572, -0.0129, -0.0351],
        [-0.0129,  0.0534,  0.3630,  ...,  0.0976,  0.0313, -0.0351],
        ...,
        [-0.2783, -0.0351,  0.1861,  ...,  0.0092, -0.0572,  0.0313],
        [ 0.1419,  0.2082,  0.0976,  ...,  0.0092, -0.0572, -0.0572],
        [-0.4110, -0.2120, -0.2562,  ..., -0.0129, -0.0129,  0.0092]])


Let's check the quantization error

In [14]:
err = torch.abs(t_rev - t)
print(err, err.min(), err.max(), err.median())

tensor([[0.0035, 0.0170, 0.0036,  ..., 0.0200, 0.0209, 0.0158],
        [0.0119, 0.0055, 0.0084,  ..., 0.0046, 0.0017, 0.0195],
        [0.0168, 0.0161, 0.0038,  ..., 0.0167, 0.0050, 0.0032],
        ...,
        [0.0191, 0.0187, 0.0131,  ..., 0.0004, 0.0056, 0.0006],
        [0.0098, 0.0088, 0.0067,  ..., 0.0202, 0.0143, 0.0097],
        [0.0010, 0.0196, 0.0162,  ..., 0.0084, 0.0199, 0.0107]]) tensor(0.) tensor(0.0221) tensor(0.0111)


Response generated by original model

In [15]:
response_generated = generate(
    model=model,
    tokenizer=tokenizer,
    requests= [
        ("The quick brown fox jumped over the", 10),
        ("The rain in Spain falls", 10),
        ("What comes up must",10)]
)

response_generated

['The quick brown fox jumped over the fence and ran to the other side of the fence',
 'The rain in Spain falls on the first day of the month, and the',
 'What comes up must be a good idea.\n\n"I think']

### Let's apply the quantization technique to the entire model

In [16]:
def quantize_model(model):
    states = {}
    for name, param in model.named_parameters():
        param.requires_grad = False
        param.data, state = quantize(param)
        states[name] = state
    
    return model, states

In [17]:
quant_model, states = quantize_model(model=model)

Memory Footprint of Quantized model

In [18]:
quant_model.get_memory_footprint()

137022768

Memory Footprint of `states`

In [20]:
def size_in_bytes(t):
    return t.numel() * t.element_size()

In [21]:
sum([size_in_bytes(v[0]) + size_in_bytes(v[1]) for v in states.values()])

1184

Observation: `states` consume low memory compared to the memory consumed by the model parameters

In [22]:
def dequantize_model(model, states):
    for name, param in model.named_parameters():
        param.data = dequantize(param.data, state=states[name])
    
    return model

In [23]:
dequant_model = dequantize_model(model=quant_model, states=states)

In [24]:
dequant_model.get_memory_footprint()

510342192

Observation: Memory footprint of dequantized model is same as the original model

Response generated by the quantized model

In [25]:
response_generated = generate(
    model=dequant_model,
    tokenizer=tokenizer,
    requests= [
        ("The quick brown fox jumped over the", 10),
        ("The rain in Spain falls", 10),
        ("What comes up must",10)]
)

response_generated

['The quick brown fox jumped over the fence.\n\nThe fox jumped over the fence',
 'The rain in Spain falls on Saturday night.\n\nSpainSpainSpainSpain',
 'What comes up must be what is what is what is what is what']