In [None]:
import sys
sys.path.append("../../")

import torch
from torch.utils.data import DataLoader
import torch.nn as nn

from torchvision import models, transforms
from torchvision.datasets import CIFAR10

import torch_pruning as tp
from utils.train import evaluate, train_one_epoch
import matplotlib.pyplot as plt

## Pruning

Model pruning removes unimportant weights or neurons from a trained network to reduce model size and computation while preserving most of the original accuracy.

<p align="center">
  <img src="../../assets/img/deployment/pruning.png" width="400">
</p>

In [None]:
imagenet_mean = (0.485, 0.456, 0.406)
imagenet_std  = (0.229, 0.224, 0.225)

t = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

train_dataset = CIFAR10(
    root="../../assets/cifar10", 
    train=True, 
    download=True, 
    transform=t
)
val_dataset = CIFAR10(
    root="../../assets/cifar10", 
    train=False, 
    download=True, 
    transform=t
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # 10 classes
model.load_state_dict(torch.load("../../assets/models/finetuned_resnet18.pth"))

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

model.to(device)

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # don't prune final classifier

example_inputs = torch.randn(1, 3, 224, 224).to(device)

In [None]:
model.eval()

base_macs, base_params = tp.utils.count_ops_and_params(
    model, example_inputs
)

_, baseline_acc = evaluate(model, val_loader, loss_fn, device)

sparsities = [.0]
accuracies = [baseline_acc*100]

print(f"Baseline MACs: {base_macs/1e6:.2f} M, Params: {base_params/1e6:.2f} M, Acc: {baseline_acc*100:.2f}%")

In [None]:
imp = tp.importance.TaylorImportance()  # importance cirteria

iterative_steps = 5

pruner = tp.pruner.MetaPruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5, # prune 50% channels per step
    ignored_layers=ignored_layers,
)

for step in range(iterative_steps):
    model.train()
    model.zero_grad()

    loss = model(example_inputs).sum()
    loss.backward()
    
    # prune
    pruner.step()

    # fine-tune
    for _ in range(3):  # even 1â€“3 epochs helps a lot
        train_one_epoch(
            model, train_loader, optimizer, loss_fn, device
        ) 

    # evaluate
    model.eval()
    _, acc = evaluate(model, val_loader, loss_fn, device)

    macs, params = tp.utils.count_ops_and_params(
        model, example_inputs
    )

    sparsity = 1.0 - (params / base_params)

    sparsities.append(sparsity)
    accuracies.append(acc*100)

    print(f"Step {step+1}:\n\tSparsity={sparsity:.2f}, Accuracy={acc*100:.2f}%")

In [None]:
plt.figure(figsize=(6, 4))
plt.plot(sparsities, accuracies, marker="o")
plt.xlabel("Channel Sparsity")
plt.ylabel("Validation Accuracy")
plt.title("Accuracy vs Sparsity (ResNet-18)")
plt.grid(True)
plt.show()