# üî¨ Lecture 3: Pruning & Sparsity - Complete Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/efficientml_course/blob/main/03_pruning_sparsity_1/demo.ipynb)

## What You'll Learn
- Why neural networks are over-parameterized
- How magnitude pruning works step-by-step
- Implementing pruning from scratch
- Visualizing sparse weight matrices
- Measuring accuracy before/after pruning
- Iterative pruning for better results


In [None]:
# Setup and Imports
!pip install torch torchvision matplotlib -q

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)


## Part 1: Create and Train a Model

We'll use a simple MLP on MNIST to demonstrate pruning concepts.


In [None]:
# Define a simple MLP for MNIST
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

model = SimpleMLP().to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"üìä Model Architecture:")
print(f"   Total parameters: {total_params:,}")
print(f"   Model size: {total_params * 4 / 1024 / 1024:.2f} MB (FP32)")
print(f"\n   Layer sizes:")
for name, param in model.named_parameters():
    if 'weight' in name:
        print(f"   {name}: {param.shape} = {param.numel():,} params")


In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)

print(f"üì¶ Dataset loaded:")
print(f"   Training samples: {len(train_dataset):,}")
print(f"   Test samples: {len(test_dataset):,}")


In [None]:
# Training and evaluation functions
def train_epoch(model, train_loader, optimizer):
    model.train()
    total_loss = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    return 100. * correct / total

# Train the model
print("üèãÔ∏è Training the model...")
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1, 6):
    loss = train_epoch(model, train_loader, optimizer)
    acc = evaluate(model, test_loader)
    print(f"   Epoch {epoch}: Loss={loss:.4f}, Accuracy={acc:.2f}%")

original_accuracy = evaluate(model, test_loader)
print(f"\n‚úÖ Original Model Accuracy: {original_accuracy:.2f}%")


## Part 2: Analyzing Weight Distribution

**Key Insight**: Most weights in trained neural networks are very small!

This is why pruning works - we can remove small weights without hurting accuracy much.


In [None]:
# Visualize weight distribution
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

layers = [('fc1', model.fc1), ('fc2', model.fc2), 
          ('fc3', model.fc3), ('fc4', model.fc4)]

for ax, (name, layer) in zip(axes.flat, layers):
    weights = layer.weight.data.cpu().numpy().flatten()
    
    ax.hist(weights, bins=100, color='steelblue', alpha=0.7, edgecolor='black')
    ax.axvline(x=0, color='red', linestyle='--', linewidth=2)
    ax.set_title(f'{name}: {len(weights):,} weights', fontsize=12)
    ax.set_xlabel('Weight Value')
    ax.set_ylabel('Count')
    
    # Calculate % of small weights
    small_weights = np.abs(weights) < 0.1
    ax.text(0.95, 0.95, f'{small_weights.sum()/len(weights)*100:.1f}% are tiny\n(|w| < 0.1)', 
            transform=ax.transAxes, ha='right', va='top',
            bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.8))

plt.suptitle('üìä Weight Distribution by Layer\n(Notice: Most weights are clustered near zero!)', fontsize=14)
plt.tight_layout()
plt.show()


## Part 3: Implementing Magnitude Pruning

**Algorithm**: Remove weights with the smallest absolute values

```
1. Collect all weights from the model
2. Sort by absolute value
3. Find threshold at desired sparsity (e.g., 90th percentile)
4. Set weights below threshold to zero
5. Fine-tune to recover accuracy
```


In [None]:
# Implement magnitude pruning from scratch
def magnitude_prune(model, sparsity):
    """
    Prune weights by magnitude (global pruning).
    
    Args:
        model: Neural network
        sparsity: Fraction to remove (0.9 = remove 90%)
    
    Returns:
        masks: Dictionary of binary masks
    """
    # Step 1: Collect all weights
    all_weights = []
    for name, param in model.named_parameters():
        if 'weight' in name:
            all_weights.append(param.data.abs().flatten())
    all_weights = torch.cat(all_weights)
    
    # Step 2: Find threshold
    threshold = torch.quantile(all_weights, sparsity)
    print(f"   Threshold for {sparsity*100:.0f}% sparsity: {threshold:.6f}")
    
    # Step 3: Create masks and apply
    masks = {}
    for name, param in model.named_parameters():
        if 'weight' in name:
            mask = (param.data.abs() > threshold).float()
            masks[name] = mask
            param.data *= mask  # Zero out pruned weights
    
    return masks

def count_sparsity(model):
    """Calculate model sparsity percentage."""
    total = zeros = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            total += param.numel()
            zeros += (param == 0).sum().item()
    return zeros / total * 100

# Create a copy for pruning
pruned_model = copy.deepcopy(model)

print("üî™ Applying 90% magnitude pruning...")
masks = magnitude_prune(pruned_model, sparsity=0.9)

print(f"\nüìä Results:")
print(f"   Sparsity: {count_sparsity(pruned_model):.1f}%")
print(f"   Original accuracy: {original_accuracy:.2f}%")
pruned_acc = evaluate(pruned_model, test_loader)
print(f"   Pruned accuracy (before fine-tuning): {pruned_acc:.2f}%")
print(f"   Accuracy drop: {original_accuracy - pruned_acc:.2f}%")


## Part 4: Fine-tuning to Recover Accuracy

After pruning, we fine-tune while **keeping pruned weights at zero**.


In [None]:
# Fine-tune while keeping pruned weights at zero
def fine_tune_pruned(model, train_loader, test_loader, masks, epochs=3):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    
    for epoch in range(epochs):
        model.train()
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            
            # Zero out gradients for pruned weights
            for name, param in model.named_parameters():
                if name in masks:
                    param.grad.data *= masks[name]
            
            optimizer.step()
            
            # Re-apply masks (keep zeros at zero)
            for name, param in model.named_parameters():
                if name in masks:
                    param.data *= masks[name]
        
        acc = evaluate(model, test_loader)
        print(f"   Fine-tune epoch {epoch+1}: Accuracy = {acc:.2f}%")
    
    return model

print("üîß Fine-tuning pruned model...")
pruned_model = fine_tune_pruned(pruned_model, train_loader, test_loader, masks, epochs=3)

final_accuracy = evaluate(pruned_model, test_loader)
print(f"\n" + "="*50)
print(f"üìä FINAL RESULTS")
print(f"="*50)
print(f"   Original accuracy:     {original_accuracy:.2f}%")
print(f"   Pruned + fine-tuned:   {final_accuracy:.2f}%")
print(f"   Accuracy drop:         {original_accuracy - final_accuracy:.2f}%")
print(f"   Weights removed:       90%")
print(f"   Compression ratio:     10x")


## Part 5: Visualizing Sparse Weight Matrices


In [None]:
# Visualize sparse weight matrices
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

orig_layers = [model.fc1, model.fc2, model.fc3, model.fc4]
prune_layers = [pruned_model.fc1, pruned_model.fc2, pruned_model.fc3, pruned_model.fc4]
names = ['fc1', 'fc2', 'fc3', 'fc4']

for i, (orig, pruned, name) in enumerate(zip(orig_layers, prune_layers, names)):
    # Original (top row)
    w = orig.weight.data.cpu().numpy()[:64, :64]
    axes[0, i].imshow(w != 0, cmap='Blues', aspect='auto')
    axes[0, i].set_title(f'Original {name}')
    axes[0, i].axis('off')
    
    # Pruned (bottom row)
    w = pruned.weight.data.cpu().numpy()[:64, :64]
    sparsity = (w == 0).sum() / w.size * 100
    axes[1, i].imshow(w != 0, cmap='Reds', aspect='auto')
    axes[1, i].set_title(f'Pruned {name}\n({sparsity:.0f}% sparse)')
    axes[1, i].axis('off')

plt.suptitle('üîç Weight Matrices: Colored = Non-zero, White = Pruned', fontsize=14)
plt.tight_layout()
plt.show()

print("\n" + "="*50)
print("üéØ KEY TAKEAWAYS")
print("="*50)
print("‚Ä¢ Neural networks are over-parameterized")
print("‚Ä¢ 90% of weights can be removed with minimal accuracy loss")
print("‚Ä¢ Magnitude pruning: remove smallest |weights|")
print("‚Ä¢ Fine-tuning is essential to recover accuracy")
print("‚Ä¢ Sparse models can be 10x smaller!")
