# Module 4.3: Pruning & Distillation

**Goal**: Compress models through pruning and knowledge distillation

**Time**: 90 minutes

**Concepts Covered**:
- Magnitude pruning implementation
- Structured pruning (heads, neurons)
- Knowledge distillation with temperature
- Distill Phi-3-Mini → SmolLM-1.7B
- Compare pruned vs distilled models

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def magnitude_pruning(model, sparsity=0.5):
    """Magnitude-based unstructured pruning"""
    pruned_model = model
    for name, param in pruned_model.named_parameters():
        if 'weight' in name and len(param.shape) >= 2:
            # Calculate threshold
            flat_param = param.data.abs().flatten()
            threshold = torch.quantile(flat_param, sparsity)
            # Create mask
            mask = param.data.abs() > threshold
            param.data *= mask.float()
    return pruned_model

def knowledge_distillation_loss(student_logits, teacher_logits, temperature=3.0, alpha=0.7):
    """Knowledge distillation loss with temperature scaling"""
    # Soft targets from teacher
    teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
    student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
    
    # KL divergence
    kd_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)
    
    # Hard targets (ground truth)
    # hard_loss = F.cross_entropy(student_logits, labels)
    
    # Combined
    # total_loss = alpha * kd_loss + (1 - alpha) * hard_loss
    
    return kd_loss

print("Pruning removes less important weights")
print("Distillation transfers knowledge from teacher to student model")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.