# ⚡ Deployment & Optimization

**Topics:** Quantization, Pruning, ONNX Export

In [None]:
# Setup
!pip install torch torchvision -q
import torch
import torch.nn as nn
import numpy as np
print('✅ Setup complete!')

In [None]:
# Quantization Demo
def quantize_tensor(x, bits=8):
    """Simple uniform quantization"""
    qmin, qmax = 0, 2**bits - 1
    x_min, x_max = x.min(), x.max()
    scale = (x_max - x_min) / (qmax - qmin)
    zero_point = qmin - x_min / scale
    
    # Quantize
    q = torch.round(x / scale + zero_point).clamp(qmin, qmax).to(torch.int8)
    
    # Dequantize
    x_dequant = (q.float() - zero_point) * scale
    
    return q, x_dequant, scale, zero_point

# Demo
x = torch.randn(1000) * 2
for bits in [8, 4, 2]:
    q, dq, s, z = quantize_tensor(x, bits)
    error = torch.mean((x - dq)**2).item()
    print(f'INT{bits}: MSE={error:.6f}, Scale={s:.4f}')

In [None]:
# Pruning Demo
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)
    
    def forward(self, x):
        return self.fc2(torch.relu(self.fc1(x)))

model = SimpleNet()

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    nonzero = sum((p != 0).sum().item() for p in model.parameters())
    return total, nonzero

total, nonzero = count_params(model)
print(f'Before pruning: {nonzero}/{total} params ({100*nonzero/total:.1f}%)')

# Magnitude pruning
def prune_by_magnitude(model, ratio=0.5):
    for name, param in model.named_parameters():
        if 'weight' in name:
            threshold = torch.quantile(param.abs().flatten(), ratio)
            mask = param.abs() > threshold
            param.data *= mask

prune_by_magnitude(model, 0.5)
total, nonzero = count_params(model)
print(f'After 50% pruning: {nonzero}/{total} params ({100*nonzero/total:.1f}%)')

In [None]:
# Knowledge Distillation Concept
def distillation_loss(student_logits, teacher_logits, labels, temperature=4.0, alpha=0.5):
    """Combine soft and hard targets"""
    soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=-1)
    soft_loss = nn.functional.kl_div(
        nn.functional.log_softmax(student_logits / temperature, dim=-1),
        soft_targets,
        reduction='batchmean'
    ) * (temperature ** 2)
    
    hard_loss = nn.functional.cross_entropy(student_logits, labels)
    
    return alpha * soft_loss + (1 - alpha) * hard_loss

# Demo
batch_size, num_classes = 32, 10
teacher_logits = torch.randn(batch_size, num_classes)
student_logits = torch.randn(batch_size, num_classes)
labels = torch.randint(0, num_classes, (batch_size,))

loss = distillation_loss(student_logits, teacher_logits, labels)
print(f'Distillation Loss: {loss.item():.4f}')