### Model Compression Techniques - All Compressions

This notebook implements three compression techniques:
1. **Pruning**: Remove unimportant weights
2. **Quantization**: Reduce numerical precision  
3. **Knowledge Distillation**: Train smaller model from larger one

---

This block below imports PyTorch and its common submodules along with supporting libraries, sets fixed random seeds for reproducibility, and prints the PyTorch version from the baseline notebook.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import torch.nn.utils.prune as prune

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import copy
from datetime import datetime

exec(open('helper_functions.py').read())

device = get_device()
print(f"Using device: {device}")

Ã¢Å“â€œ Helper functions loaded successfully!
Ã¢Å¡Â  Using CPU (this will be slower)
Using device: cpu


This cell below sets up batch size and transforms, loads the MNIST/CIFAR-10 datasets with those transforms, wraps them in `DataLoader`s, then restores the saved baseline models and previous results for comparison.

In [None]:
BATCH_SIZE = 128

mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

cifar_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=mnist_transform)
cifar_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=cifar_transform)
cifar_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=cifar_transform)

mnist_train_loader = DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=True)
mnist_test_loader = DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)
cifar_train_loader = DataLoader(cifar_train, batch_size=BATCH_SIZE, shuffle=True)
cifar_test_loader = DataLoader(cifar_test, batch_size=BATCH_SIZE, shuffle=False)

print("Datasets loaded")

from helper_functions import SimpleMNIST, SimpleCIFAR

mnist_baseline = SimpleMNIST().to(device)
mnist_baseline.load_state_dict(torch.load('mnist_baseline.pth'))
print("MNIST baseline loaded")

cifar_baseline = SimpleCIFAR().to(device)
cifar_baseline.load_state_dict(torch.load('cifar_baseline.pth'))
print("CIFAR baseline loaded")

all_results = pd.read_csv('baseline_results.csv').to_dict('records')
print(f"Loaded {len(all_results)} baseline results")

  entry = pickle.load(f, encoding="latin1")


âœ“ Datasets loaded
âœ“ Helper functions loaded successfully!
âœ“ MNIST baseline loaded
âœ“ CIFAR baseline loaded
âœ“ Loaded 2 baseline results


**Pruning**

These helper routines find all conv/linear layers in a model, prune a specified fraction of the smallestâ€‘magnitude weights globally (and then make that pruning permanent), and compute the resulting percentage of zeroedâ€‘out parameters.

In [None]:
def apply_unstructured_pruning(model, amount=0.3):
    """
    Apply magnitude-based unstructured pruning to all Conv2d and Linear layers.
    
    Args:
        model: Neural network to prune
        amount: Fraction of weights to prune (0.3 = 30%)
    
    Returns:
        Pruned model
    """
    parameters_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            parameters_to_prune.append((module, 'weight'))
    

    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount,
    )
    
    for module, param_name in parameters_to_prune:
        prune.remove(module, param_name)
    
    return model


def get_sparsity(model):
    """
    Calculate the percentage of zero weights in the model.
    
    Returns:
        sparsity: Percentage of weights that are zero (0-100)
    """
    total_params = 0
    zero_params = 0
    
    for param in model.parameters():
        total_params += param.numel()
        zero_params += (param == 0).sum().item()
    
    sparsity = 100. * zero_params / total_params
    return sparsity

print("Pruning functions defined")

âœ“ Pruning functions defined


### Run Pruning 

Test different pruning levels: 30%, 50%, 70%, 90%

In [None]:
print("PRUNING")
print("-" * 60)

pruning_amounts = [0.3, 0.5, 0.7, 0.9]

for prune_amount in pruning_amounts:
    print(f"\n{'='*60}")
    print(f"Testing {int(prune_amount*100)}% Pruning")
    print(f"{'='*60}\n")
    
    print(f"[1/2] Pruning MNIST model ({int(prune_amount*100)}%)...")
    mnist_pruned = copy.deepcopy(mnist_baseline)
    mnist_pruned = apply_unstructured_pruning(mnist_pruned, amount=prune_amount)
    
    
    sparsity = get_sparsity(mnist_pruned)
    print(f"Sparsity: {sparsity:.2f}%")
    
    metrics = collect_metrics(mnist_pruned, mnist_test_loader, f"MNIST_Pruned_{int(prune_amount*100)}%")
    metrics['compression_type'] = 'pruning'
    metrics['pruning_amount'] = prune_amount
    metrics['sparsity'] = sparsity
    all_results.append(metrics)
    
    print(f"\n[2/2] Pruning CIFAR model ({int(prune_amount*100)}%)...")
    cifar_pruned = copy.deepcopy(cifar_baseline)
    cifar_pruned = apply_unstructured_pruning(cifar_pruned, amount=prune_amount)
    
    sparsity = get_sparsity(cifar_pruned)
    print(f"Sparsity: {sparsity:.2f}%")
    
    metrics = collect_metrics(cifar_pruned, cifar_test_loader, f"CIFAR_Pruned_{int(prune_amount*100)}%")
    metrics['compression_type'] = 'pruning'
    metrics['pruning_amount'] = prune_amount
    metrics['sparsity'] = sparsity
    all_results.append(metrics)

results_df = pd.DataFrame(all_results)
results_df.to_csv('pruning_results.csv', index=False)
print("\nPruning results saved to pruning_results.csv")

print("\nPRUNING COMPLETE!")
print("-"*60)

PRUNING EXPERIMENTS

Testing 30% Pruning

[1/2] Pruning MNIST model (30%)...
Sparsity: 29.98%
Evaluating MNIST_Pruned_30%...
  Accuracy: 99.33%
  Size: 1.61 MB
  Latency: 0.2099 ms/image


[2/2] Pruning CIFAR model (30%)...
Sparsity: 29.99%
Evaluating CIFAR_Pruned_30%...
  Accuracy: 78.41%
  Size: 9.44 MB
  Latency: 0.8395 ms/image


Testing 50% Pruning

[1/2] Pruning MNIST model (50%)...
Sparsity: 49.97%
Evaluating MNIST_Pruned_50%...
  Accuracy: 99.34%
  Size: 1.61 MB
  Latency: 0.2276 ms/image


[2/2] Pruning CIFAR model (50%)...
Sparsity: 49.98%
Evaluating CIFAR_Pruned_50%...
  Accuracy: 78.13%
  Size: 9.44 MB
  Latency: 0.8348 ms/image


Testing 70% Pruning

[1/2] Pruning MNIST model (70%)...
Sparsity: 69.96%
Evaluating MNIST_Pruned_70%...
  Accuracy: 99.34%
  Size: 1.61 MB
  Latency: 0.1621 ms/image


[2/2] Pruning CIFAR model (70%)...
Sparsity: 69.97%
Evaluating CIFAR_Pruned_70%...
  Accuracy: 78.03%
  Size: 9.44 MB
  Latency: 0.8364 ms/image


Testing 90% Pruning

[1/2] Pruning

## Quantization <a id='quantization'></a>

The `quantize_model` function deepâ€‘copies a model to CPU and applies PyTorchâ€™s dynamic postâ€‘training quantization to its linear/conv layers, while `get_quantized_model_size` measures a modelâ€™s serialized size in memory to report its storage footprint (cell below).

In [None]:
def quantize_model(model):
    """
    Apply post-training dynamic quantization.
    
    Args:
        model: Neural network to quantize
    
    Returns:
        Quantized model
    """
    model_cpu = copy.deepcopy(model).cpu()
    model_cpu.eval()
   
    quantized_model = torch.quantization.quantize_dynamic(
        model_cpu,
        {nn.Linear, nn.Conv2d},  
        dtype=torch.qint8        
    )
    
    return quantized_model


def get_quantized_model_size(model):
    """
    Calculate model size without saving to disk.
    Uses in-memory size calculation.
    """
    import io
    
    buffer = io.BytesIO()
    
    torch.save(model.state_dict(), buffer)
    
    size_mb = buffer.tell() / 1024**2
    
    return size_mb

### Run Quantization 

In [None]:
print("QUANTIZATION EXPERIMENTS")
print("-" * 60)

print("\n[1/2] Quantizing MNIST model...")
mnist_quantized = quantize_model(mnist_baseline)
print("MNIST model quantized")

mnist_test_cpu = DataLoader(
    torchvision.datasets.MNIST(root='./data', train=False, transform=mnist_transform),
    batch_size=BATCH_SIZE, shuffle=False
)

print("Evaluating quantized MNIST model...")
mnist_quantized.eval()
correct = 0
total = 0
times = []

with torch.no_grad():
    for i, (data, target) in enumerate(mnist_test_cpu):
        start = time.time()
        output = mnist_quantized(data)
        end = time.time()
        
        if i > 0:  
            times.append((end - start) * 1000 / data.size(0))
        
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

accuracy = 100. * correct / total
size_mb = get_quantized_model_size(mnist_quantized)
latency_ms = np.mean(times)

metrics = {
    'model_name': 'MNIST_Quantized_INT8',
    'accuracy': accuracy,
    'size_mb': size_mb,
    'latency_ms': latency_ms,
    'compression_type': 'quantization',
    'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
all_results.append(metrics)

print(f"  Accuracy: {accuracy:.2f}%")
print(f"  Size: {size_mb:.2f} MB")
print(f"  Latency: {latency_ms:.4f} ms/image")

print("\n[2/2] Quantizing CIFAR model...")
cifar_quantized = quantize_model(cifar_baseline)
print("CIFAR model quantized")

cifar_test_cpu = DataLoader(
    torchvision.datasets.CIFAR10(root='./data', train=False, transform=cifar_transform),
    batch_size=BATCH_SIZE, shuffle=False
)

print("Evaluating quantized CIFAR model...")
cifar_quantized.eval()
correct = 0
total = 0
times = []

with torch.no_grad():
    for i, (data, target) in enumerate(cifar_test_cpu):
        start = time.time()
        output = cifar_quantized(data)
        end = time.time()
        
        if i > 0:
            times.append((end - start) * 1000 / data.size(0))
        
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

accuracy = 100. * correct / total
size_mb = get_quantized_model_size(cifar_quantized)
latency_ms = np.mean(times)

metrics = {
    'model_name': 'CIFAR_Quantized_INT8',
    'accuracy': accuracy,
    'size_mb': size_mb,
    'latency_ms': latency_ms,
    'compression_type': 'quantization',
    'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
all_results.append(metrics)

print(f"  Accuracy: {accuracy:.2f}%")
print(f"  Size: {size_mb:.2f} MB")
print(f"  Latency: {latency_ms:.4f} ms/image")

results_df = pd.DataFrame(all_results)
results_df.to_csv('quantization_results.csv', index=False)
print("\nQuantization results saved")

print("\nQUANTIZATION COMPLETE!")
print("-"*60)

QUANTIZATION EXPERIMENTS

[1/2] Quantizing MNIST model...
âœ“ MNIST model quantized
Evaluating quantized MNIST model...


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantized_model = torch.quantization.quantize_dynamic(


  Accuracy: 99.34%
  Size: 0.46 MB
  Latency: 0.1352 ms/image

[2/2] Quantizing CIFAR model...
âœ“ CIFAR model quantized


  entry = pickle.load(f, encoding="latin1")


Evaluating quantized CIFAR model...
  Accuracy: 78.48%
  Size: 3.43 MB
  Latency: 0.4926 ms/image

âœ“ Quantization results saved

QUANTIZATION COMPLETE!


### Knowledge Distillation <a id='distillation'></a>

This `TinyMNIST` class defines a much smaller convolutional network used as a student model in distillation, with two conv+pool layers, reduced channel counts and a smaller fullyâ€‘connected head.

In [None]:
class TinyMNIST(nn.Module):
    """
    Smaller student model for MNIST.
    Much fewer parameters than SimpleMNIST.
    """
    def __init__(self):
        super(TinyMNIST, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) 
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 64) 
        self.fc2 = nn.Linear(64, 10)
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


class TinyCIFAR(nn.Module):
    """
    Smaller student model for CIFAR-10.
    """
    def __init__(self):
        super(TinyCIFAR, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)   
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 256) 
        self.fc2 = nn.Linear(256, 10)
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


def distillation_loss(student_logits, teacher_logits, labels, temperature=3.0, alpha=0.7):
    """
    Compute knowledge distillation loss.
    
    Args:
        student_logits: Raw outputs from student model
        teacher_logits: Raw outputs from teacher model
        labels: True class labels
        temperature: Softening parameter (higher = softer)
        alpha: Weight between distillation and classification loss
    
    Returns:
        Combined loss
    """
    soft_student = F.log_softmax(student_logits / temperature, dim=1)
    soft_teacher = F.softmax(teacher_logits / temperature, dim=1)
    
    distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (temperature ** 2)
    
    student_loss = F.cross_entropy(student_logits, labels)
    
    total_loss = alpha * distill_loss + (1 - alpha) * student_loss
    
    return total_loss


def train_student(student, teacher, train_loader, epochs=10, temperature=3.0, alpha=0.7):
    """
    Train student model using knowledge distillation.
    
    Args:
        student: Small model to train
        teacher: Large pre-trained model
        train_loader: Training data
        epochs: Number of training epochs
        temperature: Distillation temperature
        alpha: Distillation weight
    """
    teacher.eval() 
    student.train()
    
    optimizer = optim.Adam(student.parameters(), lr=0.001)
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            
            with torch.no_grad():
                teacher_logits = teacher(data)
            
            student_logits = student(data)
            
            loss = distillation_loss(student_logits, teacher_logits, target, temperature, alpha)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = student_logits.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, '
                      f'Loss: {loss.item():.4f}, Acc: {100.*correct/total:.2f}%')
        
        epoch_loss = running_loss / len(train_loader)
        print(f'\n>>> Epoch {epoch+1} complete. Avg Loss: {epoch_loss:.4f}, '
              f'Train Accuracy: {100.*correct/total:.2f}%\n')

print("Knowledge distillation functions defined")

âœ“ Knowledge distillation functions defined


### Run Knowledge Distillation

That block runs the distillation experiments by creating tiny student models for MNIST and CIFAR, printing their size/compression ratios, training them against the preâ€‘trained teacher networks, collecting metrics, and saving the resulting student weights and combined results.

In [None]:
print("KNOWLEDGE DISTILLATION EXPERIMENTS")
print("-" * 60)

print("\n[1/2] Training MNIST student model...\n")
mnist_student = TinyMNIST().to(device)
print(f"Student model size: {sum(p.numel() for p in mnist_student.parameters()):,} parameters")
print(f"Teacher model size: {sum(p.numel() for p in mnist_baseline.parameters()):,} parameters")
print(f"Compression ratio: {sum(p.numel() for p in mnist_baseline.parameters()) / sum(p.numel() for p in mnist_student.parameters()):.2f}x\n")

train_student(mnist_student, mnist_baseline, mnist_train_loader, epochs=10)

metrics = collect_metrics(mnist_student, mnist_test_loader, "MNIST_Distilled_Student")
metrics['compression_type'] = 'distillation'
all_results.append(metrics)

torch.save(mnist_student.state_dict(), 'mnist_student.pth')

print("\n[2/2] Training CIFAR student model...\n")
cifar_student = TinyCIFAR().to(device)
print(f"Student model size: {sum(p.numel() for p in cifar_student.parameters()):,} parameters")
print(f"Teacher model size: {sum(p.numel() for p in cifar_baseline.parameters()):,} parameters")
print(f"Compression ratio: {sum(p.numel() for p in cifar_baseline.parameters()) / sum(p.numel() for p in cifar_student.parameters()):.2f}x\n")

train_student(cifar_student, cifar_baseline, cifar_train_loader, epochs=10)

metrics = collect_metrics(cifar_student, cifar_test_loader, "CIFAR_Distilled_Student")
metrics['compression_type'] = 'distillation'
all_results.append(metrics)

torch.save(cifar_student.state_dict(), 'cifar_student.pth')

results_df = pd.DataFrame(all_results)
results_df.to_csv('all_compression_results.csv', index=False)
print("\nâœ“ All results saved to all_compression_results.csv")

print("\nKNOWLEDGE DISTILLATION COMPLETE!")
print("-"*60)

KNOWLEDGE DISTILLATION EXPERIMENTS

[1/2] Training MNIST student model...

Student model size: 105,866 parameters
Teacher model size: 421,642 parameters
Compression ratio: 3.98x

Epoch 1/10, Batch 0/469, Loss: 14.3606, Acc: 10.94%
Epoch 1/10, Batch 100/469, Loss: 3.3159, Acc: 70.41%
Epoch 1/10, Batch 200/469, Loss: 1.7718, Acc: 79.77%
Epoch 1/10, Batch 300/469, Loss: 1.2835, Acc: 84.15%
Epoch 1/10, Batch 400/469, Loss: 1.0060, Acc: 86.83%

>>> Epoch 1 complete. Avg Loss: 2.3494, Train Accuracy: 88.10%

Epoch 2/10, Batch 0/469, Loss: 0.6981, Acc: 95.31%
Epoch 2/10, Batch 100/469, Loss: 0.3798, Acc: 96.20%
Epoch 2/10, Batch 200/469, Loss: 0.5299, Acc: 96.53%
Epoch 2/10, Batch 300/469, Loss: 0.5114, Acc: 96.63%
Epoch 2/10, Batch 400/469, Loss: 0.4651, Acc: 96.76%

>>> Epoch 2 complete. Avg Loss: 0.5889, Train Accuracy: 96.78%

Epoch 3/10, Batch 0/469, Loss: 0.5059, Acc: 95.31%
Epoch 3/10, Batch 100/469, Loss: 0.4594, Acc: 97.61%
Epoch 3/10, Batch 200/469, Loss: 0.3564, Acc: 97.70%
Epoch 3