In [None]:
!pip install torch torchvision matplotlib


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils.prune as prune
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import os, time


In [None]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.reshape(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        return self.fc2(x)


In [None]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

def evaluate(model, dataloader, device):
    model.eval()
    total, correct = 0, 0
    start = time.time()
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            _, pred = torch.max(out, 1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    end = time.time()
    return 100 * correct / total, end - start


In [None]:
transform = transforms.ToTensor()
train = datasets.MNIST('./data', train=True, download=True, transform=transform)
test = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train, batch_size=64, shuffle=True)
test_loader_clean = DataLoader(test, batch_size=1000)

def add_noise(dataset, std=0.3):
    noisy = torch.clip(dataset.data.float() / 255. + torch.randn_like(dataset.data.float()) * std, 0., 1.)
    noisy = (noisy - 0.1307) / 0.3081
    return DataLoader(TensorDataset(noisy.unsqueeze(1), dataset.targets), batch_size=1000)

test_loader_noisy = add_noise(test)


In [None]:
pruning_methods = {
    'L1Unstructured': prune.L1Unstructured,
    'RandomUnstructured': prune.RandomUnstructured,
    'L1Structured': lambda module, name: prune.ln_structured(module, name=name, amount=0.3, n=1, dim=0),
    'RandomStructured': lambda module, name: prune.random_structured(module, name=name, amount=0.3, dim=0),
    'CustomAmount50': lambda module, name: prune.l1_unstructured(module, name=name, amount=0.5),
    'ConvOnly': lambda module, name: prune.l1_unstructured(module, name=name, amount=0.3) if isinstance(module, nn.Conv2d) else None,
    'FCOnly': lambda module, name: prune.l1_unstructured(module, name=name, amount=0.3) if isinstance(module, nn.Linear) else None,
}


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = []

for name, method in pruning_methods.items():
    model = CNN().to(device)
    opt = optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()

    for e in range(3):
        train(model, train_loader, opt, loss_fn, device)

    # Apply pruning
    modules = [(model.conv1, 'weight'), (model.conv2, 'weight'), (model.fc1, 'weight'), (model.fc2, 'weight')]
    for module, param in modules:
        if callable(method):
            result = method(module, param)
        else:
            prune.global_unstructured(modules, pruning_method=method, amount=0.3)
            break  # Only once for global

    # Remove pruning hooks
    for m, p in modules:
        if hasattr(m, p):
            try:
                prune.remove(m, p)
            except:
                pass

    acc_clean, t_clean = evaluate(model, test_loader_clean, device)
    acc_noisy, t_noisy = evaluate(model, test_loader_noisy, device)

    torch.save(model.state_dict(), f"{name}_model.pth")
    size_mb = os.path.getsize(f"{name}_model.pth") / (1024 ** 2)

    results.append((name, acc_clean, t_clean, acc_noisy, t_noisy, size_mb))

for r in results:
    print(f"ðŸ”§ {r[0]} | âœ… Clean: {r[1]:.2f}% in {r[2]:.2f}s | ðŸ§ª Noisy: {r[3]:.2f}% in {r[4]:.2f}s | ðŸ“¦ {r[5]:.2f} MB")
