# Module 11.1: BitNet - 1.58-bit Quantization

**Goal**: Implement BitNet quantization with ternary weights (-1, 0, +1)

**Time**: 120 minutes

**Concepts Covered**:
- Ternary weight training (-1, 0, +1)
- BitLinear layer implementation
- Straight-through estimator
- Training from scratch with BitNet
- Performance comparison (memory, speed, quality)
- Energy efficiency analysis

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
# BitNet: 1.58-bit Quantization
import torch
import torch.nn as nn
import torch.nn.functional as F

class BitLinear(nn.Module):
    """BitLinear layer with ternary weights (-1, 0, +1)"""
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Full precision weights for training
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        
        # Scale factor
        self.scale = nn.Parameter(torch.ones(1))
    
    def quantize_weights(self):
        """Quantize weights to ternary (-1, 0, +1)"""
        # Calculate threshold (median of absolute values)
        abs_weights = self.weight.abs()
        threshold = abs_weights.median()
        
        # Quantize: sign(weight) if |weight| > threshold, else 0
        quantized = torch.sign(self.weight) * (abs_weights > threshold).float()
        
        # Scale to match original weight magnitude
        scale = self.weight.abs().mean() / quantized.abs().clamp(min=1e-8).mean()
        
        return quantized * scale
    
    def forward(self, x):
        """Forward pass with straight-through estimator"""
        # During training: use full precision with STE
        # During inference: use quantized weights
        
        if self.training:
            # Straight-through estimator: quantize in forward, but use full precision gradient
            quantized = self.quantize_weights()
            return F.linear(x, quantized)
        else:
            # Inference: use quantized weights
            quantized = self.quantize_weights()
            return F.linear(x, quantized)

# Test BitLinear
bit_linear = BitLinear(128, 64)
x = torch.randn(2, 10, 128)

output = bit_linear(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nBitNet benefits:")
print("- 1.58 bits per weight (ternary: -1, 0, +1)")
print("- ~16x memory reduction vs FP16")
print("- Faster inference (integer operations)")
print("- Energy efficient")

In [None]:
# Straight-Through Estimator (STE)
def ste_quantize(weights, bits=1):
    """Quantize with straight-through estimator"""
    # Forward: quantize
    quantized = torch.sign(weights) * (weights.abs() > weights.abs().median()).float()
    
    # Backward: pass through gradient (no quantization in backward)
    return weights + (quantized - weights).detach()

# Example STE usage
x = torch.randn(4, 4, requires_grad=True)
x_quantized = ste_quantize(x)

print("Straight-Through Estimator:")
print("- Forward: quantized values")
print("- Backward: full precision gradients")
print("- Allows training with quantized weights")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.