In [3]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import copy
import os

# --- UTILS: Memory Footprint Calculation ---
def get_model_size_mb(model, bits=None):
    """Calculates the approximate memory footprint of the model parameters."""
    if bits is None:
        total_bits = 0
        for name, param in model.named_parameters():
            # Handle cases where weights might be stored as different dtypes
            bits_per_param = torch.finfo(param.dtype).bits if param.is_floating_point() else torch.iinfo(param.dtype).bits
            total_bits += param.numel() * bits_per_param
    else:
        total_bits = sum(param.numel() for name, param in model.named_parameters()) * bits
    return total_bits / (8 * 1024 * 1024)

# --- CORE MATH: GPTQ (Hessian-based Correction) ---
def gptq_quantize_weights(W, X, bits=4):
    """
    W: [Out_features, In_features] (Original Weight)
    X: [Batch, Seq, In_features] (Calibration Activations)
    """
    # 1. Flatten and prepare X
    X = X.reshape(-1, X.shape[-1]).t().float() 
    in_features = W.shape[1]
    
    # 2. Compute Inverse Hessian (H = XX^T)
    H = torch.matmul(X, X.t())
    damp = 0.01 * torch.mean(torch.diag(H))
    H_inv = torch.inverse(H + damp * torch.eye(in_features))
    
    W_quant = W.clone().float()
    
    # 3. Step through columns and compensate error
    for i in range(in_features):
        w_col = W_quant[:, i]
        
        # Symmetric Quantization (Signed)
        scale = w_col.abs().max() / (2**(bits-1) - 1)
        w_q = torch.round(w_col / scale).clamp(-2**(bits-1), 2**(bits-1) - 1) * scale
        
        # Update neighbors using Hessian to 'absorb' quantization error
        error = w_col - w_q
        update_factor = H_inv[i, i+1:] / H_inv[i, i]
        W_quant[:, i+1:] -= error.unsqueeze(1) * update_factor.unsqueeze(0)
        W_quant[:, i] = w_q
        
    return W_quant

# --- CORE MATH: AWQ (Activation-aware Scaling) ---
def awq_quantize_weights(W, act_means, bits=4):
    """
    W: [Out_features, In_features]
    act_means: [In_features] (Average magnitude of activations)
    """
    # Salience heuristic: prioritize weights connected to high-magnitude inputs
    # Scaling protects these 'salient' weights from rounding error
    scales = act_means.pow(0.5) / (W.abs().mean(dim=0).pow(0.5) + 1e-8)
    scales = scales / scales.max() 
    
    # Shield -> Quantize -> Unshield
    W_scaled = W * scales.view(1, -1)
    q_scale = W_scaled.abs().max() / (2**(bits-1) - 1)
    W_q = torch.round(W_scaled / q_scale).clamp(-2**(bits-1), 2**(bits-1) - 1) * q_scale
    
    return W_q / scales.view(1, -1)

# --- Quantize Whole Model ---
def quantize_model_gptq(model, tokenizer, bits=4):
    model_q = copy.deepcopy(model)
    activations = {}
    hooks = []
    
    def hook_fn(name):
        def hook(m, i, o): activations[name] = i[0].detach()
        return hook
    
    # Register hooks for all Conv1D layers
    for name, module in model_q.named_modules():
        if 'Conv1D' in str(type(module)):
            hooks.append(module.register_forward_hook(hook_fn(name)))
    
    # Run calibration
    inputs = tokenizer("Quantization is the process of reducing precision to make models smaller and faster.", return_tensors="pt")
    model_q(inputs.input_ids)
    for h in hooks: h.remove()
    
    # Quantize each layer
    for name, module in model_q.named_modules():
        if 'Conv1D' in str(type(module)):
            X = activations[name]
            W_orig = module.weight.data.t()  # [out, in]
            W_q = gptq_quantize_weights(W_orig, X, bits)
            module.weight.data = W_q.t()  # back to [in, out]
    
    return model_q

def quantize_model_awq(model, tokenizer, bits=4):
    model_q = copy.deepcopy(model)
    activations = {}
    hooks = []
    
    def hook_fn(name):
        def hook(m, i, o): activations[name] = i[0].detach()
        return hook
    
    # Register hooks for all Conv1D layers
    for name, module in model_q.named_modules():
        if 'Conv1D' in str(type(module)):
            hooks.append(module.register_forward_hook(hook_fn(name)))
    
    # Run calibration
    inputs = tokenizer("Quantization is the process of reducing precision to make models smaller and faster.", return_tensors="pt")
    model_q(inputs.input_ids)
    for h in hooks: h.remove()
    
    # Quantize each layer
    for name, module in model_q.named_modules():
        if 'Conv1D' in str(type(module)):
            X = activations[name]
            act_means = X.abs().mean(dim=(0, 1))
            W_orig = module.weight.data.t()  # [out, in]
            W_q = awq_quantize_weights(W_orig, act_means, bits)
            module.weight.data = W_q.t()  # back to [in, out]
    
    return model_q

# --- EXECUTION PIPELINE ---
def main():
    print("ðŸš€ Initializing GPT-2...")
    model_id = "gpt2"
    tokenizer = GPT2Tokenizer.from_pretrained(model_id)
    model = GPT2LMHeadModel.from_pretrained(model_id)
    
    orig_size = get_model_size_mb(model)
    print(f"ðŸ“¦ Original Model Footprint: {orig_size:.2f} MB")
    
    # Quantize with GPTQ
    print("\n--- Quantizing with GPTQ ---")
    model_gptq = quantize_model_gptq(model, tokenizer, bits=4)
    gptq_size = get_model_size_mb(model_gptq, bits=4)
    
    # Quantize with AWQ
    print("--- Quantizing with AWQ ---")
    model_awq = quantize_model_awq(model, tokenizer, bits=4)
    awq_size = get_model_size_mb(model_awq, bits=4)
    
    print("\n" + "="*50)
    print("MODEL FOOTPRINTS:")
    print(f"Original (FP32):         {orig_size:7.2f} MB")
    print(f"GPTQ Quantized (INT4):   {gptq_size:7.2f} MB")
    print(f"AWQ Quantized (INT4):    {awq_size:7.2f} MB")
    print("="*50)

if __name__ == "__main__":
    main()

ðŸš€ Initializing GPT-2...


Loading weights: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 148/148 [00:00<00:00, 1548.70it/s, Materializing param=transformer.wte.weight]             


ðŸ“¦ Original Model Footprint: 474.70 MB

--- Quantizing with GPTQ ---
--- Quantizing with AWQ ---

MODEL FOOTPRINTS:
Original (FP32):          474.70 MB
GPTQ Quantized (INT4):     59.34 MB
AWQ Quantized (INT4):      59.34 MB


In [6]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import math
import copy

# --- CORE MATH UTILS ---

def quantize_group(x, bits, group_size=128):
    if bits >= 16: return x.half().float()
    orig_shape = x.shape
    x = x.reshape(-1, group_size)
    q_max = 2**(bits - 1) - 1
    scales = x.abs().max(dim=1, keepdim=True)[0] / q_max
    x_q = torch.round(x / (scales + 1e-8)).clamp(-q_max, q_max) * scales
    return x_q.reshape(orig_shape)

def apply_gptq(W, X, bits, group_size=128):
    """Hessian-based Error Compensation"""
    if bits >= 16: return W.half().float()
    X = X.reshape(-1, X.shape[-1]).t().float()
    in_features = W.shape[1]
    H = torch.matmul(X, X.t())
    damp = 0.1 * torch.mean(torch.diag(H))
    H_inv = torch.inverse(H + damp * torch.eye(in_features))
    W_q = W.clone().float()
    for i in range(in_features):
        w_col = W_q[:, i]
        w_rounded = quantize_group(w_col.unsqueeze(0), bits, len(w_col)).squeeze(0)
        error = w_col - w_rounded
        W_q[:, i+1:] -= error.unsqueeze(1) * (H_inv[i, i+1:] / H_inv[i, i]).unsqueeze(0)
        W_q[:, i] = w_rounded
    return W_q

def apply_awq(W, X, bits, group_size=128):
    """Activation-aware Scaling"""
    if bits >= 16: return W.half().float()
    act_means = X.abs().mean(dim=(0, 1))
    # Alpha search for best scaling
    best_error, best_W = float('inf'), W
    for alpha in [0.5]: # Standard AWQ alpha
        scales = act_means.pow(alpha) / (W.abs().mean(dim=0).pow(1-alpha) + 1e-8)
        scales = scales / scales.max()
        W_scaled = W * scales.view(1, -1)
        W_q = quantize_group(W_scaled, bits, group_size) / scales.view(1, -1)
        error = torch.nn.functional.mse_loss(W_q @ X[0].t(), W @ X[0].t())
        if error < best_error:
            best_error, best_W = error, W_q
    return best_W

# --- MODEL PROCESSING ENGINE ---

def get_ppl(model, tokenizer, text):
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        loss = model(inputs.input_ids, labels=inputs.input_ids).loss
    return math.exp(loss.item())

def run_experiment(model_id, method, bits):
    tokenizer = GPT2Tokenizer.from_pretrained(model_id)
    model = GPT2LMHeadModel.from_pretrained(model_id)
    test_text = "The artificial intelligence revolution is driven by efficient algorithms."
    
    # Calibration
    activations = {}
    def hook_fn(name):
        def hook(m, i, o): activations[name] = i[0].detach()
        return hook
    
    hooks = [m.register_forward_hook(hook_fn(n)) for n, m in model.named_modules() if "Conv1D" in str(type(m))]
    model(tokenizer("Quantization improves inference speed.", return_tensors="pt").input_ids)
    for h in hooks: h.remove()

    # Quantize
    with torch.no_grad():
        for name, m in model.named_modules():
            if "Conv1D" in str(type(m)) and name in activations:
                W = m.weight.data.t()
                X = activations[name]
                W_q = apply_gptq(W, X, bits) if method == "GPTQ" else apply_awq(W, X, bits)
                m.weight.copy_(W_q.t())
                
    ppl = get_ppl(model, tokenizer, test_text)
    size = (sum(p.numel() for p in model.parameters()) * bits) / (8 * 1024 * 1024)
    return ppl, size

# --- MAIN RUN ---

if __name__ == "__main__":
    m_id = "gpt2"
    tok = GPT2Tokenizer.from_pretrained(m_id)
    base_model = GPT2LMHeadModel.from_pretrained(m_id)
    test_p = "The artificial intelligence revolution is driven by efficient algorithms."
    
    print(f"{'Method':<12} | {'Bits':<5} | {'Size (MB)':<10} | {'Perplexity':<10}")
    print("-" * 50)
    print(f"{'Baseline':<12} | {'32':<5} | {500.0:<10.1f} | {get_ppl(base_model, tok, test_p):.2f}")

    for b in [16, 8, 4]:
        for meth in ["GPTQ", "AWQ"]:
            p, s = run_experiment(m_id, meth, b)
            print(f"{meth:<12} | {b:<5} | {s:<10.1f} | {p:.2f}")

Loading weights: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 148/148 [00:00<00:00, 1550.87it/s, Materializing param=transformer.wte.weight]             


Method       | Bits  | Size (MB)  | Perplexity
--------------------------------------------------
Baseline     | 32    | 500.0      | 84.74


Loading weights: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 148/148 [00:00<00:00, 1583.63it/s, Materializing param=transformer.wte.weight]             


GPTQ         | 16    | 237.4      | 84.72


Loading weights: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 148/148 [00:00<00:00, 1591.15it/s, Materializing param=transformer.wte.weight]             


AWQ          | 16    | 237.4      | 84.72


Loading weights: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 148/148 [00:00<00:00, 1452.04it/s, Materializing param=transformer.wte.weight]             


GPTQ         | 8     | 118.7      | 86.27


Loading weights: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 148/148 [00:00<00:00, 1307.33it/s, Materializing param=transformer.wte.weight]             


AWQ          | 8     | 118.7      | 83.43


Loading weights: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 148/148 [00:00<00:00, 1310.68it/s, Materializing param=transformer.wte.weight]             


GPTQ         | 4     | 59.3       | 652.34


Loading weights: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 148/148 [00:00<00:00, 1312.23it/s, Materializing param=transformer.wte.weight]             


AWQ          | 4     | 59.3       | 102.95
