# Lecture 4: Pruning & Sparsity (Part II) - Lottery Ticket

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/efficientml-course/efficientml_course/04_pruning_sparsity_2/demo.ipynb)

Finding winning tickets and structured sparsity patterns.


In [None]:
!pip install torch -q
import torch
import torch.nn as nn
import copy

# Lottery Ticket Hypothesis Demo
# "Sparse networks can train from scratch if reset to original init"

torch.manual_seed(42)

class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(100, 50)
        self.fc2 = nn.Linear(50, 10)
    
    def forward(self, x):
        return self.fc2(torch.relu(self.fc1(x)))

# Step 1: Save original initialization
model = SimpleMLP()
original_init = copy.deepcopy(model.state_dict())
print("Step 1: Saved original random initialization")

# Step 2: "Train" the model (simulate with random updates)
for p in model.parameters():
    p.data += torch.randn_like(p) * 0.1
print("Step 2: Trained the model")

# Step 3: Prune smallest 80% of weights
def create_mask(model, sparsity=0.8):
    masks = {}
    for name, param in model.named_parameters():
        if 'weight' in name:
            threshold = torch.quantile(param.data.abs(), sparsity)
            masks[name] = (param.data.abs() > threshold).float()
    return masks

mask = create_mask(model, 0.8)
print(f"Step 3: Created mask (keeping top 20% weights)")

# Step 4: Reset to original init but keep mask = WINNING TICKET!
winning_ticket = SimpleMLP()
winning_ticket.load_state_dict(original_init)
for name, param in winning_ticket.named_parameters():
    if name in mask:
        param.data *= mask[name]

total = sum(m.numel() for m in mask.values())
nonzero = sum(m.sum().item() for m in mask.values())
print(f"Step 4: Winning ticket has {nonzero:.0f}/{total} weights ({100*nonzero/total:.0f}%)")
print("\nðŸŽ¯ This sparse network can train to same accuracy as dense!")
