# Part 4, Lab 3: INT8 and INT4 Weight Quantization

**Time:** ~45 minutes

Weight-only quantization (W8A16, W4A16) is the most practical approach for memory-bound LLM inference. This lab covers implementing and evaluating INT8 and INT4 weight quantization.

## Learning Objectives

1. Implement weight-only INT8 quantization
2. Implement INT4 with group scaling
3. Measure accuracy impact on a toy model
4. Understand GPTQ and AWQ concepts

In [None]:
import numpy as np
import torch
import torch.nn as nn

np.random.seed(42)
torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

---
## 1. INT8 Weight-Only Quantization

Weights stored as INT8, dequantized to FP16 during matrix multiplication.

In [None]:
class QuantizedLinearINT8(nn.Module):
    """Linear layer with INT8 weight-only quantization."""
    
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Quantized weights (INT8) and scales (FP16 per output channel)
        self.register_buffer('weight_int8', torch.zeros(out_features, in_features, dtype=torch.int8))
        self.register_buffer('scales', torch.ones(out_features, dtype=torch.float16))
        
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)
    
    @staticmethod
    def from_float(linear_layer):
        """Convert a regular linear layer to quantized."""
        quant_layer = QuantizedLinearINT8(
            linear_layer.in_features,
            linear_layer.out_features,
            bias=linear_layer.bias is not None
        )
        
        # Get weights
        weight = linear_layer.weight.data.float()
        
        # Per-channel quantization
        abs_max = weight.abs().max(dim=1)[0]
        scales = abs_max / 127.0
        scales = scales.clamp(min=1e-8)
        
        # Quantize
        weight_int8 = (weight / scales.unsqueeze(1)).round().clamp(-128, 127).to(torch.int8)
        
        quant_layer.weight_int8.copy_(weight_int8)
        quant_layer.scales.copy_(scales.half())
        
        if linear_layer.bias is not None:
            quant_layer.bias.data.copy_(linear_layer.bias.data)
        
        return quant_layer
    
    def forward(self, x):
        # Dequantize weights on-the-fly
        weight_fp = self.weight_int8.float() * self.scales.float().unsqueeze(1)
        return nn.functional.linear(x, weight_fp, self.bias)

# Test INT8 quantization
linear = nn.Linear(1024, 512).to(device)
linear_int8 = QuantizedLinearINT8.from_float(linear).to(device)

# Compare outputs
x = torch.randn(32, 1024, device=device)
y_fp32 = linear(x)
y_int8 = linear_int8(x)

error = (y_fp32 - y_int8).abs().mean().item()
print(f"INT8 vs FP32 output difference: {error:.6f}")
print(f"Memory: {linear.weight.numel() * 4 / 1024:.1f} KB → {linear_int8.weight_int8.numel() / 1024:.1f} KB")

---
## 2. INT4 with Group Quantization

INT4 requires group-wise scaling for acceptable accuracy. Typically groups of 128.

In [None]:
class QuantizedLinearINT4(nn.Module):
    """Linear layer with INT4 weight quantization (group-wise scaling)."""
    
    def __init__(self, in_features, out_features, group_size=128, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.group_size = group_size
        
        # Calculate number of groups
        assert in_features % group_size == 0, f"in_features must be divisible by group_size"
        num_groups = in_features // group_size
        
        # Store as INT8 (2 INT4 values packed per byte in practice)
        # For simplicity, we store as INT8 here
        self.register_buffer('weight_int4', torch.zeros(out_features, in_features, dtype=torch.int8))
        self.register_buffer('scales', torch.ones(out_features, num_groups, dtype=torch.float16))
        
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)
    
    @staticmethod
    def from_float(linear_layer, group_size=128):
        """Convert a regular linear layer to INT4 quantized."""
        quant_layer = QuantizedLinearINT4(
            linear_layer.in_features,
            linear_layer.out_features,
            group_size=group_size,
            bias=linear_layer.bias is not None
        )
        
        weight = linear_layer.weight.data.float()
        out_features, in_features = weight.shape
        
        # Reshape for group quantization
        weight_grouped = weight.view(out_features, -1, group_size)
        
        # Per-group scaling (INT4 range: -8 to 7)
        abs_max = weight_grouped.abs().max(dim=2)[0]
        scales = abs_max / 7.0
        scales = scales.clamp(min=1e-8)
        
        # Quantize
        weight_int4 = (weight_grouped / scales.unsqueeze(2)).round().clamp(-8, 7).to(torch.int8)
        weight_int4 = weight_int4.view(out_features, in_features)
        
        quant_layer.weight_int4.copy_(weight_int4)
        quant_layer.scales.copy_(scales.half())
        
        if linear_layer.bias is not None:
            quant_layer.bias.data.copy_(linear_layer.bias.data)
        
        return quant_layer
    
    def forward(self, x):
        out_features = self.out_features
        in_features = self.in_features
        
        # Dequantize: reshape, multiply by scales, reshape back
        weight_grouped = self.weight_int4.view(out_features, -1, self.group_size).float()
        weight_fp = weight_grouped * self.scales.float().unsqueeze(2)
        weight_fp = weight_fp.view(out_features, in_features)
        
        return nn.functional.linear(x, weight_fp, self.bias)

# Test INT4 quantization
linear = nn.Linear(1024, 512).to(device)
linear_int4 = QuantizedLinearINT4.from_float(linear, group_size=128).to(device)

# Compare outputs
x = torch.randn(32, 1024, device=device)
y_fp32 = linear(x)
y_int4 = linear_int4(x)

error = (y_fp32 - y_int4).abs().mean().item()
print(f"INT4 vs FP32 output difference: {error:.6f}")
print(f"Effective bits: 4 + {linear_int4.scales.numel() * 16 / linear_int4.weight_int4.numel():.2f} (scale overhead)")

---
## 3. Accuracy Comparison on MLP

Let's compare INT8 vs INT4 on a simple MLP.

In [None]:
class SimpleMLP(nn.Module):
    def __init__(self, hidden_size=1024):
        super().__init__()
        self.fc1 = nn.Linear(hidden_size, hidden_size * 4)
        self.fc2 = nn.Linear(hidden_size * 4, hidden_size)
        self.act = nn.GELU()
    
    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

def quantize_mlp(mlp, quant_class, **kwargs):
    """Quantize all linear layers in MLP."""
    quant_mlp = SimpleMLP.__new__(SimpleMLP)
    nn.Module.__init__(quant_mlp)
    quant_mlp.fc1 = quant_class.from_float(mlp.fc1, **kwargs) if kwargs else quant_class.from_float(mlp.fc1)
    quant_mlp.fc2 = quant_class.from_float(mlp.fc2, **kwargs) if kwargs else quant_class.from_float(mlp.fc2)
    quant_mlp.act = mlp.act
    return quant_mlp

# Create and quantize
mlp = SimpleMLP().to(device)
mlp_int8 = quantize_mlp(mlp, QuantizedLinearINT8).to(device)
mlp_int4 = quantize_mlp(mlp, QuantizedLinearINT4, group_size=128).to(device)

# Test
x = torch.randn(32, 1024, device=device)
y_fp32 = mlp(x)
y_int8 = mlp_int8(x)
y_int4 = mlp_int4(x)

print("MLP Output Comparison:")
print(f"  INT8 error: {(y_fp32 - y_int8).abs().mean().item():.6f}")
print(f"  INT4 error: {(y_fp32 - y_int4).abs().mean().item():.6f}")
print(f"  INT4/INT8 error ratio: {(y_fp32 - y_int4).abs().mean().item() / (y_fp32 - y_int8).abs().mean().item():.1f}x")

---
## 4. Understanding GPTQ and AWQ

Production INT4 methods like GPTQ and AWQ improve on naive quantization:

**GPTQ**: Uses calibration data to find optimal quantization order and adjust remaining weights to compensate for errors.

**AWQ**: Identifies "salient" weights (important for accuracy) and protects them with higher effective precision through scaling.

In [None]:
# Simplified AWQ-style importance detection
def compute_weight_importance(weight, activations):
    """
    AWQ insight: weight importance = weight magnitude × activation magnitude
    Weights that multiply large activations are more important.
    """
    # Activation statistics (mean abs value per input dimension)
    act_scale = activations.abs().mean(dim=0)
    
    # Weight importance = weight × activation scale
    importance = weight.abs() * act_scale.unsqueeze(0)
    
    return importance

# Demo: identify important weights
weight = torch.randn(512, 1024)
activations = torch.randn(1000, 1024)  # Calibration data

importance = compute_weight_importance(weight, activations)

# Top 1% most important weights
threshold = importance.quantile(0.99)
important_mask = importance > threshold

print(f"Important weights: {important_mask.sum().item()} / {importance.numel()} ({important_mask.float().mean() * 100:.1f}%)")
print(f"These weights could be kept at higher precision or protected via scaling.")

---
## Exercises

1. **Bit Packing**: Implement actual INT4 bit packing (2 values per byte)
2. **Calibration**: Add calibration to find optimal scales using real activations
3. **Mixed Precision**: Keep first/last layers in INT8, middle layers in INT4

## Key Takeaways

- INT8 weight-only quantization has minimal accuracy impact (~0.1% typically)
- INT4 requires group quantization (group_size=128 is common)
- Advanced methods (GPTQ, AWQ) use calibration to minimize quantization error
- Weight-only quantization helps memory-bound workloads (small batch, long context)