<a href="https://colab.research.google.com/github/ketanp23/sit-neuralnetworks-class/blob/main/Network_Pruning_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.nn.utils import prune
import ssl
# Temporarily disable SSL verification for dataset download if needed
ssl._create_default_https_context = ssl._create_unverified_context

# --- 1. Define the Simple CNN Model ---
class PrunableNet(nn.Module):
    def __init__(self):
        super(PrunableNet, self).__init__()
        # Define two convolutional layers and two fully connected layers
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)

        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)

        # Output size after two pools: (28 - 4) / 2 = 12 -> (12 - 4) / 2 = 4
        # 20 channels * 4 * 4
        self.fc = nn.Linear(320, 10)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 320) # Flatten
        x = self.fc(x)
        return x

# --- Helper function to calculate the sparsity of the model ---
def calculate_sparsity(model):
    """Calculates the percentage of zero weights in the entire model."""
    total_params = 0
    zero_params = 0
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            # Check for both weight and bias, though pruning typically focuses on weight
            for param_name in ['weight', 'bias']:
                if hasattr(module, param_name):
                    param = getattr(module, param_name)
                    total_params += param.numel()
                    zero_params += torch.sum(param == 0).item()

    if total_params == 0:
        return 0.0

    sparsity = 100. * zero_params / total_params
    return sparsity

# --- 2. Setup and Initial Training ---

# Load Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

trainset = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Initialize Model, Loss, and Optimizer
model = PrunableNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

print("Starting initial training (1 Epoch)...")
# Train for a minimal period to give weights some value
for epoch in range(1):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if i % 500 == 499:
            print(f'  [Epoch {epoch + 1}, Batch {i + 1:5d}] Loss: {loss.item():.4f}')

print("Initial training complete.")
initial_sparsity = calculate_sparsity(model)
print(f"Sparsity before pruning: {initial_sparsity:.2f}%")
print("-" * 50)


# --- 3. Apply Network Pruning ---

# Define the layers we want to prune
parameters_to_prune = [
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc, 'weight'),
]

# Apply L1 Unstructured Pruning: remove the lowest 40% magnitude weights globally
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.40, # Prune 40% of the weights
)

print("Pruning applied: Removed 40% of the lowest magnitude weights across Conv and FC layers.")

# --- 4. Evaluate Pruning Result ---

# Check sparsity after pruning
pruned_sparsity = calculate_sparsity(model)
print(f"Sparsity after pruning: {pruned_sparsity:.2f}%")
print("-" * 50)


# --- 5. Verify Model Status (Optional Fine-tuning) ---

# To confirm the pruning operation introduced masks, check the weight attribute:
print("Verifying pruning status for model.conv1.weight:")
print(f"  Has a 'weight_orig' attribute: {hasattr(model.conv1, 'weight_orig')}")
print(f"  Has a 'weight_mask' attribute: {hasattr(model.conv1, 'weight_mask')}")

# To make the pruning permanent (remove the mask and the original weight), use remove:
# prune.remove(model.conv1, 'weight')
# print("\nMask removed from conv1.weight.")

# --- Demonstration of Fine-tuning (Optional, but shows the next step) ---
# Fine-tuning is typically done after pruning to recover lost accuracy.
print("\nStarting Fine-tuning (1 Epoch) on the pruned model...")
for epoch in range(1):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        # The forward pass uses the masked/pruned weights automatically
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if i % 500 == 499:
            print(f'  [Epoch {epoch + 1}, Batch {i + 1:5d}] Loss: {loss.item():.4f}')

print("Fine-tuning complete. The pruned model is now trained using only 60% of its original weights.")

100%|██████████| 9.91M/9.91M [00:00<00:00, 24.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 594kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 5.47MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.79MB/s]


Starting initial training (1 Epoch)...
  [Epoch 1, Batch   500] Loss: 0.1497
Initial training complete.
Sparsity before pruning: 0.00%
--------------------------------------------------
Pruning applied: Removed 40% of the lowest magnitude weights across Conv and FC layers.
Sparsity after pruning: 39.81%
--------------------------------------------------
Verifying pruning status for model.conv1.weight:
  Has a 'weight_orig' attribute: True
  Has a 'weight_mask' attribute: True

Starting Fine-tuning (1 Epoch) on the pruned model...
  [Epoch 1, Batch   500] Loss: 0.0799
Fine-tuning complete. The pruned model is now trained using only 60% of its original weights.
