# Lecture 6: Quantization (Part II) - QAT & LLM Quantization

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/efficientml-course/efficientml_course/06_quantization_2/demo.ipynb)

Quantization-Aware Training and 4-bit LLM inference with GPTQ/AWQ.


In [None]:
!pip install torch -q
import torch
import torch.nn as nn
import torch.nn.functional as F

# Quantization-Aware Training (QAT) - Fake Quantization
class FakeQuantize(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, num_bits=8):
        qmin, qmax = 0, 2**num_bits - 1
        scale = x.max() - x.min()
        scale = scale / (qmax - qmin) + 1e-8
        
        # Quantize and dequantize
        x_q = torch.clamp(torch.round(x / scale), qmin, qmax)
        x_dq = x_q * scale
        return x_dq
    
    @staticmethod
    def backward(ctx, grad_output):
        # Straight-Through Estimator - pass gradients unchanged
        return grad_output, None

fake_quant = FakeQuantize.apply

# Demo: QAT simulates quantization during training
x = torch.randn(10, requires_grad=True)
y = fake_quant(x, 4)  # 4-bit quantization
loss = y.sum()
loss.backward()

print("Fake Quantization Demo (4-bit)")
print(f"Input:  {x[:5].detach()}")
print(f"Output: {y[:5].detach()}")
print(f"Grad:   {x.grad[:5]} (STE passes gradients through)")
print("\nðŸŽ¯ Model learns to be robust to quantization noise!")
