# Lecture 3: Pruning & Sparsity (Part I)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/efficientml-course/efficientml_course/03_pruning_sparsity_1/demo.ipynb)

Implementing magnitude pruning to remove 90% of weights!


In [None]:
!pip install torch torchvision -q
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

# Create a simple model
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# Count parameters before pruning
def count_nonzero(model):
    total = 0
    nonzero = 0
    for p in model.parameters():
        total += p.numel()
        nonzero += (p != 0).sum().item()
    return nonzero, total

before_nz, before_total = count_nonzero(model)
print(f"Before pruning: {before_nz:,} / {before_total:,} non-zero ({100*before_nz/before_total:.1f}%)")


In [None]:
# Apply magnitude pruning - remove smallest 90% of weights
for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.9)
        prune.remove(module, 'weight')  # Make pruning permanent

after_nz, after_total = count_nonzero(model)
print(f"After 90% pruning: {after_nz:,} / {after_total:,} non-zero ({100*after_nz/after_total:.1f}%)")
print(f"\nðŸŽ¯ Removed {100*(1-after_nz/before_nz):.1f}% of weights!")

# Visualize sparsity
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for i, (name, module) in enumerate([(n, m) for n, m in model.named_modules() if isinstance(m, nn.Linear)]):
    w = module.weight.data.cpu().numpy()
    axes[i].imshow(w != 0, cmap='binary', aspect='auto')
    axes[i].set_title(f'Layer {i+1}: {(w != 0).sum()}/{w.size} non-zero')
plt.suptitle('Sparse Weight Matrices (white = zero)')
plt.tight_layout()
plt.show()
