# Workshop 3.1: Fine-Grained Pruning for Efficient Inference - Hands-On Practice

## Learning Objectives
By the end of this notebook, you will:
- Understand what fine-grained (unstructured) pruning is and why it's important for edge AI
- Learn about magnitude-based pruning (remove smallest weights)
- **PRACTICE** implementing post-training pruning with PyTorch
- Compare model sparsity, speed, and accuracy trade-offs
- Complete hands-on coding exercises

## What is Fine-Grained Pruning?
Fine-grained pruning (also called unstructured pruning) removes individual weights from neural networks based on their magnitude. Unlike structured pruning that removes entire channels or filters, fine-grained pruning can remove any weight regardless of its position, creating sparse networks with irregular patterns.

This workshop focuses specifically on magnitude-based pruning, which provides:
- Significant model compression (up to 90% sparsity)
- Recoverable accuracy with fine-tuning
- Simple implementation using PyTorch
- Clear understanding of pruning trade-offs

Paper: https://arxiv.org/pdf/1506.02626


---
**🔥 HANDS-ON PRACTICE**: This notebook contains code completion exercises marked with `# TODO:` comments. Fill in the missing code to complete the pruning workflow!

In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune
from torchvision import models
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import time
import copy
import os

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Step 1: Load and Prepare CIFAR-10 Dataset

### Understanding CIFAR-10: Why This Dataset?

CIFAR-10 is perfect for learning pruning because:
- **Small images (32×32)** - Fast training and testing
- **Realistic challenge** - 10 different object classes
- **Edge AI relevant** - Similar to mobile camera applications
- **Resource-friendly** - Doesn't require powerful hardware

The small image size makes it ideal for:
- Mobile device deployment
- Edge computing scenarios
- Real-time inference applications
- Learning optimization techniques like pruning

In [None]:
# Define transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=256, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=200, shuffle=False, num_workers=2)

# CIFAR-10 classes
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Training samples: {len(trainset)}")
print(f"Test samples: {len(testset)}")
print(f"Classes: {classes}")

## Step 2: Load Pre-trained MobileNetV2 Model

In [None]:
# TODO: Complete the MobileNetV2 adaptation function
def create_mobilenetv2_cifar10(num_classes=10, pretrained=True):
    """
    Create MobileNetV2 adapted for CIFAR-10
    CIFAR-10 images are 32x32, smaller than ImageNet's 224x224
    """
    # TODO: Load pre-trained MobileNetV2 using models.mobilenet_v2()
    model = # Your code here
    
    # Modify the first convolution layer for smaller input size
    # Original: stride=2, now stride=1 to preserve spatial dimensions
    model.features[0][0] = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
    
    # TODO: Modify classifier for CIFAR-10 (10 classes instead of 1000)
    # HINT: model.classifier[1] should be a Linear layer with model.last_channel input features
    model.classifier[1] = # Your code here
    
    return model

# Create model instance
model = create_mobilenetv2_cifar10(num_classes=10, pretrained=True).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"MobileNetV2 loaded with {total_params:,} total parameters")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model adapted for CIFAR-10 (32x32 images, 10 classes)")

## Step 3: Train the Original Model

In [None]:
# TODO: Complete the training function
def train_model(model, trainloader, testloader, epochs=10, learning_rate=0.001):
    # TODO: Define criterion (loss function) - use CrossEntropyLoss
    criterion = # Your code here
    
    # TODO: Define optimizer - use Adam with the given learning rate
    optimizer = # Your code here
    
    train_losses = []
    train_accuracies = []
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Training loop with progress bar
        pbar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{epochs}')
        for batch_idx, (inputs, labels) in enumerate(pbar):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # TODO: Zero gradients
            # Your code here
            
            # TODO: Forward pass
            outputs = # Your code here
            
            # TODO: Calculate loss
            loss = # Your code here
            
            # TODO: Backward pass
            # Your code here
            
            # TODO: Update weights
            # Your code here
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f'{running_loss/(batch_idx+1):.3f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
        
        epoch_loss = running_loss / len(trainloader)
        epoch_acc = 100. * correct / total
        
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        
        print(f'Epoch {epoch+1}: Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')
    
    return train_losses, train_accuracies

# Train the model
print("Training original model...")
train_losses, train_accuracies = train_model(model, trainloader, testloader, epochs=10)

## Step 4: Evaluate Original Model Performance

In [None]:
# TODO: Complete the evaluation function
def evaluate_model(model, testloader, model_name="Model"):
    """
    Evaluate model accuracy and measure inference time
    """
    model.eval()
    correct = 0
    total = 0
    
    # Measure inference time
    start_time = time.time()
    
    with torch.no_grad():
        for inputs, labels in tqdm(testloader, desc=f'Evaluating {model_name}'):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # TODO: Forward pass
            outputs = # Your code here
            
            # TODO: Get predictions
            _, predicted = # Your code here
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    end_time = time.time()
    inference_time = end_time - start_time
    
    accuracy = 100. * correct / total
    print(f'{model_name} - Accuracy: {accuracy:.2f}%, Inference Time: {inference_time:.2f}s')
    
    return accuracy, inference_time

# TODO: Complete the sparsity calculation function
def calculate_sparsity(model):
    """
    Calculate the percentage of zero weights in the model
    """
    total_params = 0
    zero_params = 0
    
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            # Check if the module has been pruned (has weight_mask)
            if hasattr(module, 'weight_mask'):
                # TODO: Calculate effective weight (weight * mask)
                effective_weight = # Your code here
                total_params += effective_weight.numel()
                zero_params += (effective_weight == 0).sum().item()
            else:
                # If not pruned, count normally
                total_params += module.weight.numel()
                zero_params += (module.weight == 0).sum().item()
    
    sparsity = 100.0 * zero_params / total_params if total_params > 0 else 0.0
    return sparsity

# Evaluate original model performance (baseline)
print("Evaluating original model performance...")
original_accuracy, original_time = evaluate_model(model, testloader, "Original")
original_sparsity = calculate_sparsity(model)

print(f"\nOriginal Model Summary:")
print(f"  Accuracy: {original_accuracy:.2f}%")
print(f"  Inference time: {original_time:.2f}s")
print(f"  Sparsity: {original_sparsity:.2f}%")
print(f"  This will be our baseline for comparison")

## Step 5: Analyze Weight Distributions

Before pruning, let's analyze the weight distributions to understand which weights are candidates for removal.

In [None]:
# TODO: Complete the weight collection function
def collect_weights(model):
    """
    Collect all weights from the model for analysis
    """
    all_weights = []
    
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            # TODO: Get weights and flatten them
            # HINT: use module.weight.data.cpu().flatten().numpy()
            weights = # Your code here
            all_weights.extend(weights)
    
    return np.array(all_weights)

# Collect weights from the trained model
all_weights = collect_weights(model)

print(f"Weight Distribution Analysis:")
print(f"  Total weights: {len(all_weights):,}")
print(f"  Mean: {np.mean(all_weights):.6f}")
print(f"  Std: {np.std(all_weights):.6f}")
print(f"  Min: {np.min(all_weights):.6f}")
print(f"  Max: {np.max(all_weights):.6f}")

# Analyze small weights (candidates for pruning)
small_weights = np.abs(all_weights) < 0.01
print(f"  Weights with |w| < 0.01: {small_weights.sum():,} ({small_weights.mean()*100:.1f}%)")
print(f"  These small weights are good candidates for pruning!")

In [None]:
# Visualize weight distribution
plt.figure(figsize=(15, 5))

# Plot 1: Full weight distribution
plt.subplot(1, 2, 1)
plt.hist(all_weights, bins=100, alpha=0.7, color='blue', edgecolor='black')
plt.title('Weight Distribution - Full Range')
plt.xlabel('Weight Value')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)

# Plot 2: Zoomed-in distribution around zero
plt.subplot(1, 2, 2)
close_to_zero = all_weights[np.abs(all_weights) < 0.1]
plt.hist(close_to_zero, bins=50, alpha=0.7, color='red', edgecolor='black')
plt.title('Weight Distribution - Near Zero (|w| < 0.1)')
plt.xlabel('Weight Value')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nKey Insight:")
print(f"  Many weights are close to zero, suggesting high pruning potential!")
print(f"  Magnitude-based pruning will remove the smallest weights first.")

## Step 6: Apply Magnitude-Based Pruning

Magnitude-based pruning removes weights with the smallest absolute values. This is based on the assumption that small weights contribute less to the model's output.

In [None]:
# TODO: Complete the magnitude-based pruning function
def apply_magnitude_pruning(model, pruning_ratio):
    """
    Apply magnitude-based pruning to all Conv2d and Linear layers
    """
    print(f"Applying {pruning_ratio:.0%} magnitude-based pruning...")
    
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            # TODO: Apply L1 (magnitude-based) unstructured pruning
            # HINT: use prune.l1_unstructured(module, name='weight', amount=pruning_ratio)
            # Your code here
            
            print(f"  Pruned layer: {name}")
    
    print(f"Magnitude-based pruning complete!")
    return model

# Set pruning ratio (start with 70% pruning)
PRUNING_RATIO = 0.7

# Create a copy of the model for pruning
pruned_model = copy.deepcopy(model)

# Apply pruning
pruned_model = apply_magnitude_pruning(pruned_model, PRUNING_RATIO)

print(f"\nPruning Applied:")
print(f"  Pruning ratio: {PRUNING_RATIO:.0%}")
print(f"  This means {PRUNING_RATIO:.0%} of weights are set to zero")

## Step 7: Evaluate Pruned Model

In [None]:
# Evaluate the pruned model
print("Evaluating pruned model...")
pruned_accuracy, pruned_time = evaluate_model(pruned_model, testloader, f"{PRUNING_RATIO:.0%} Pruned")
pruned_sparsity = calculate_sparsity(pruned_model)

# Compare with original model
accuracy_drop = original_accuracy - pruned_accuracy

print(f"\nPruned Model Results:")
print(f"  Accuracy: {pruned_accuracy:.2f}% (drop: {accuracy_drop:.2f}%)")
print(f"  Inference time: {pruned_time:.2f}s")
print(f"  Sparsity: {pruned_sparsity:.1f}%")

print(f"\nComparison:")
print(f"  Original accuracy: {original_accuracy:.2f}%")
print(f"  Pruned accuracy: {pruned_accuracy:.2f}%")
print(f"  Accuracy drop: {accuracy_drop:.2f}%")
print(f"  Sparsity achieved: {pruned_sparsity:.1f}%")
print(f"  --> Fine-tuning can help recover some lost accuracy!")

## Step 8: Fine-tune the Pruned Model

Fine-tuning helps the remaining weights adapt to compensate for the pruned weights, often recovering much of the lost accuracy.

In [None]:
# TODO: Complete the fine-tuning function
def fine_tune_pruned_model(model, trainloader, testloader, epochs=5, learning_rate=0.0001):
    """
    Fine-tune a pruned model to recover accuracy
    """
    print(f"Fine-tuning pruned model for {epochs} epochs...")
    
    # TODO: Set up optimizer and loss function
    # HINT: Use a smaller learning rate for fine-tuning
    optimizer = # Your code here
    criterion = # Your code here
    
    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        progress_bar = tqdm(trainloader, desc=f'Fine-tune Epoch {epoch+1}/{epochs}')
        
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # TODO: Forward pass
            optimizer.zero_grad()
            outputs = # Your code here
            loss = # Your code here
            
            # TODO: Backward pass
            # Your code here
            # Your code here
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update progress bar
            progress_bar.set_postfix({
                'Loss': f'{running_loss/len(trainloader):.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
        
        epoch_accuracy = 100. * correct / total
        print(f'Epoch {epoch+1}: Training accuracy: {epoch_accuracy:.2f}%')
    
    print(f"Fine-tuning complete!")
    return model

# Fine-tune the pruned model
finetuned_model = copy.deepcopy(pruned_model)
finetuned_model = fine_tune_pruned_model(finetuned_model, trainloader, testloader, epochs=5)

## Step 9: Evaluate Fine-tuned Model

In [None]:
# Evaluate the fine-tuned model
print("Evaluating fine-tuned model...")
finetuned_accuracy, finetuned_time = evaluate_model(finetuned_model, testloader, f"{PRUNING_RATIO:.0%} Pruned + Fine-tuned")
finetuned_sparsity = calculate_sparsity(finetuned_model)

# Calculate recovery
recovery = finetuned_accuracy - pruned_accuracy
final_accuracy_drop = original_accuracy - finetuned_accuracy

print(f"\nFine-tuned Model Results:")
print(f"  Accuracy: {finetuned_accuracy:.2f}%")
print(f"  Inference time: {finetuned_time:.2f}s")
print(f"  Sparsity: {finetuned_sparsity:.1f}%")

print(f"\nFinal Comparison:")
print(f"  Original accuracy: {original_accuracy:.2f}%")
print(f"  Pruned accuracy: {pruned_accuracy:.2f}%")
print(f"  Fine-tuned accuracy: {finetuned_accuracy:.2f}%")
print(f"  Recovery from fine-tuning: +{recovery:.2f}%")
print(f"  Final accuracy drop: {final_accuracy_drop:.2f}%")
print(f"  Sparsity achieved: {finetuned_sparsity:.1f}%")

# Check if we achieved good compression with acceptable accuracy
target_accuracy = 85.0
if finetuned_accuracy >= target_accuracy:
    print(f"\nSUCCESS: Achieved {finetuned_accuracy:.2f}% accuracy (target: ≥{target_accuracy}%)")
    print(f"   Model compressed to {100-finetuned_sparsity:.1f}% of original size!")
else:
    print(f"\nTarget not met: {finetuned_accuracy:.2f}% vs {target_accuracy}% target")
    print(f"   Try a lower pruning ratio or more fine-tuning epochs")

## Step 10: Visualize Results

In [None]:
# Create comprehensive comparison visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Plot 1: Accuracy comparison
models = ['Original', f'{PRUNING_RATIO:.0%} Pruned', f'{PRUNING_RATIO:.0%} Pruned\n+ Fine-tuned']
accuracies = [original_accuracy, pruned_accuracy, finetuned_accuracy]
colors = ['blue', 'red', 'green']

bars1 = ax1.bar(models, accuracies, color=colors, alpha=0.7)
ax1.axhline(y=85, color='red', linestyle='--', alpha=0.8, label='Target (85%)')
ax1.set_ylabel('Accuracy (%)')
ax1.set_title('Model Accuracy Comparison')
ax1.set_ylim(0, 100)
ax1.legend()
ax1.grid(True, alpha=0.3)

# Add value labels on bars
for bar, acc in zip(bars1, accuracies):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height + 1,
            f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold')

# Plot 2: Sparsity comparison
sparsities = [original_sparsity, pruned_sparsity, finetuned_sparsity]

bars2 = ax2.bar(models, sparsities, color=colors, alpha=0.7)
ax2.set_ylabel('Sparsity (%)')
ax2.set_title('Model Sparsity Comparison')
ax2.set_ylim(0, 100)
ax2.grid(True, alpha=0.3)

# Add value labels on bars
for bar, sparsity in zip(bars2, sparsities):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height + 1,
            f'{sparsity:.1f}%', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

print(f"\nVisualization Summary:")
print(f"  Left chart: Shows accuracy impact and recovery from fine-tuning")
print(f"  Right chart: Shows sparsity achieved (higher = more compressed)")
print(f"  Goal: High sparsity with acceptable accuracy!")

## Step 11: Experiment with Different Pruning Ratios

Try different pruning ratios to find the best trade-off between compression and accuracy!

In [None]:
# TODO: Experiment with different pruning ratios
# Change the value below to try different compression levels
EXPERIMENT_RATIO = 0.5  # Try: 0.5, 0.6, 0.8, 0.9

print(f"\nExperimenting with {EXPERIMENT_RATIO:.0%} pruning...")

# Create experimental model
experimental_model = copy.deepcopy(model)
experimental_model = apply_magnitude_pruning(experimental_model, EXPERIMENT_RATIO)

# Evaluate experimental model
exp_accuracy, exp_time = evaluate_model(experimental_model, testloader, f"{EXPERIMENT_RATIO:.0%} Experimental")
exp_sparsity = calculate_sparsity(experimental_model)

# Compare with the original 70% pruning
print(f"\nExperimental Results ({EXPERIMENT_RATIO:.0%} pruning):")
print(f"  Accuracy: {exp_accuracy:.2f}%")
print(f"  Sparsity: {exp_sparsity:.1f}%")
print(f"  Accuracy drop: {original_accuracy - exp_accuracy:.2f}%")

print(f"\nComparison with {PRUNING_RATIO:.0%} pruning:")
print(f"  {EXPERIMENT_RATIO:.0%} pruning: {exp_accuracy:.2f}% accuracy, {exp_sparsity:.1f}% sparsity")
print(f"  {PRUNING_RATIO:.0%} pruning: {pruned_accuracy:.2f}% accuracy, {pruned_sparsity:.1f}% sparsity")

if exp_accuracy > pruned_accuracy:
    print(f"  → {EXPERIMENT_RATIO:.0%} pruning gives better accuracy!")
elif exp_sparsity > pruned_sparsity:
    print(f"  → {EXPERIMENT_RATIO:.0%} pruning gives better compression!")
else:
    print(f"  → {PRUNING_RATIO:.0%} pruning seems to be a good balance")

print(f"\nTry different ratios to find your optimal trade-off!")

## Step 12: Summary and Key Insights

Congratulations! You've successfully implemented magnitude-based pruning with PyTorch.

In [None]:
print("\n" + "="*80)
print("MAGNITUDE-BASED PRUNING WORKSHOP COMPLETE!")
print("="*80)

print(f"\nFINAL RESULTS SUMMARY:")
print(f"  Original Model: {original_accuracy:.2f}% accuracy, {original_sparsity:.1f}% sparsity")
print(f"  {PRUNING_RATIO:.0%} Pruned: {pruned_accuracy:.2f}% accuracy, {pruned_sparsity:.1f}% sparsity")
print(f"  Fine-tuned: {finetuned_accuracy:.2f}% accuracy, {finetuned_sparsity:.1f}% sparsity")
print(f"  Recovery: +{recovery:.2f}% accuracy from fine-tuning")
print(f"  Final compression: {finetuned_sparsity:.1f}% sparsity achieved")

print(f"\nKEY INSIGHTS:")
print(f"  • Magnitude-based pruning removes the smallest weights first")
print(f"  • {PRUNING_RATIO:.0%} pruning achieved {finetuned_sparsity:.1f}% sparsity with {final_accuracy_drop:.2f}% accuracy drop")
print(f"  • Fine-tuning is crucial for recovering accuracy after pruning")
print(f"  • There's a trade-off between compression and accuracy")
print(f"  • Different pruning ratios offer different compression-accuracy trade-offs")

print(f"\nPRACTICAL APPLICATIONS:")
print(f"  • Mobile and edge device deployment")
print(f"  • Reducing memory bandwidth requirements")
print(f"  • Enabling larger models on resource-constrained devices")
print(f"  • Battery life improvement in mobile applications")

print(f"\nNEXT STEPS:")
print(f"  • Try structured pruning for actual inference speedup")
print(f"  • Combine pruning with quantization for maximum compression")
print(f"  • Explore gradual pruning during training")
print(f"  • Test on different model architectures")

print(f"\n" + "="*80)
print("Thank you for completing the Magnitude-Based Pruning workshop!")
print("Continue exploring model compression techniques for efficient AI deployment.")
print("="*80)