# Prune the pre-trained model.

In [1]:
import os
from datetime import datetime

import pandas as pd
import torch
import torch.nn.utils.prune as prune
from accelerate import Accelerator
from evaluate import load
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import models, datasets
from torchvision.transforms import transforms

from src.utils import calculate_sparsity


In [2]:
# Set device
accelerator = Accelerator(device_placement=True)
device = accelerator.device

In [3]:
# load pretrained model.
model = models.resnet50()
model.fc = nn.Linear(model.fc.in_features, 100)  # resnet
model.load_state_dict(torch.load('resnet_cifar100.pth'))

# Fuse the model layers
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [4]:
# Increase pruning ratio
pruning_ratio = .5

# Apply unstructured pruning
for module in model.modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=pruning_ratio)
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=pruning_ratio)
    elif isinstance(module, torch.nn.BatchNorm2d):
        prune.l1_unstructured(module, name='weight', amount=pruning_ratio)
        prune.l1_unstructured(module, name='bias', amount=pruning_ratio)

# Apply structured pruning (channel pruning)
for module in model.modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.ln_structured(module, name='weight', amount=pruning_ratio, n=1, dim=0)

# Remove the pruning reparameterization
for module in model.modules():
    if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) or \
            isinstance(module, torch.nn.BatchNorm2d):
        prune.remove(module, 'weight')
        if isinstance(module, torch.nn.BatchNorm2d):
            prune.remove(module, 'bias')

f'Sparsity of the model is {calculate_sparsity(model) * 100}%.'

'Sparsity of the model is 73.4324038028717%.'

In [5]:
norm = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    norm,
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    norm,
])

In [6]:
# Load CIFAR-100 dataset
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=512, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Load evaluation metrics
accuracy = load("accuracy")
f1 = load("f1")

model, optimizer, scheduler, train_dataloader, test_dataloader = accelerator.prepare(
    model, optimizer, scheduler, train_dataloader, test_dataloader
)

In [8]:
training_result = []

In [9]:
num_epochs = 20

# Fine-tune the pruned model and evaluate for sparsed model.
for epoch in range(num_epochs):
    tic = datetime.now()

    model.train()

    train_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        accelerator.backward(loss)
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss /= len(train_dataloader)

    model.eval()
    test_loss = 0.0
    test_preds = []
    test_labels = []

    with torch.no_grad():
        for images, labels in test_dataloader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            test_preds.extend(accelerator.gather(preds).cpu().numpy())
            test_labels.extend(accelerator.gather(labels).cpu().numpy())

    test_loss /= len(test_dataloader)
    test_acc = accuracy.compute(references=test_labels, predictions=test_preds)["accuracy"]
    test_f1 = f1.compute(references=test_labels, predictions=test_preds, average="macro")["f1"]

    # Update the learning rate based on validation loss
    scheduler.step()

    # Time calculation
    toc = datetime.now()
    elapsed_time = toc - tic
    elapsed_time_in_hh_mm_ss = str(elapsed_time).split('.')[0]

    print(
        f"Epoch [{epoch + 1}/{num_epochs}]: Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, "
        f'Train Accuracy: {correct / total:.3f}, '
        f"Test Accuracy: {test_acc:.4f}, Test F1: {test_f1:.4f}, "
        f'lr: {optimizer.param_groups[0]["lr"]}, '
        f'Elapsed Time: {elapsed_time_in_hh_mm_ss}\n'
    )

    training_result.append({
        'train_loss': train_loss,
        'test_loss': test_loss,
        'train_acc': correct / total,
        'test_acc': test_acc,
        'lr': optimizer.param_groups[0]["lr"]
    })

Epoch [1/20]: Train Loss: 2.4450, Test Loss: 2.1234, Train Accuracy: 0.381, Test Accuracy: 0.4480, Test F1: 0.4425, lr: 0.009999383162408303, Elapsed Time: 0:00:16
Epoch [2/20]: Train Loss: 1.7028, Test Loss: 2.0002, Train Accuracy: 0.538, Test Accuracy: 0.4789, Test F1: 0.4771, lr: 0.009997532801828657, Elapsed Time: 0:00:16
Epoch [3/20]: Train Loss: 1.4485, Test Loss: 1.9606, Train Accuracy: 0.601, Test Accuracy: 0.4913, Test F1: 0.4867, lr: 0.00999444937480985, Elapsed Time: 0:00:16
Epoch [4/20]: Train Loss: 1.2841, Test Loss: 1.9638, Train Accuracy: 0.641, Test Accuracy: 0.4949, Test F1: 0.4927, lr: 0.009990133642141357, Elapsed Time: 0:00:16
Epoch [5/20]: Train Loss: 1.1541, Test Loss: 1.9660, Train Accuracy: 0.674, Test Accuracy: 0.4987, Test F1: 0.4924, lr: 0.009984586668665639, Elapsed Time: 0:00:16
Epoch [6/20]: Train Loss: 1.0611, Test Loss: 1.9747, Train Accuracy: 0.697, Test Accuracy: 0.4990, Test F1: 0.4978, lr: 0.009977809823015398, Elapsed Time: 0:00:16
Epoch [7/20]: Tra

In [10]:
f'Sparsity of the model is {calculate_sparsity(model) * 100}%.'

'Sparsity of the model is 65.28058648109436%.'

In [11]:
# index is epoch number.
tr = pd.DataFrame(training_result, columns=['train_loss', 'test_loss', 'train_acc', 'test_acc', 'lr'])
tr.to_csv('pruned_model_result.csv')
tr.head()

Unnamed: 0,train_loss,test_loss,train_acc,test_acc,lr
0,2.445029,2.123417,0.3812,0.448,0.009999
1,1.702789,2.000239,0.53774,0.4789,0.009998
2,1.448468,1.960616,0.60126,0.4913,0.009994
3,1.284137,1.963774,0.6408,0.4949,0.00999
4,1.154144,1.966032,0.67362,0.4987,0.009985


In [12]:
# Save the pruned model to disk
torch.save(model.state_dict(), "pruned_model.pth")

# Get the size of the saved model file
model_size = os.path.getsize("pruned_model.pth") / (1024 * 1024)  # Size in MB
print(f"Pruned model size: {model_size:.2f} MB")

Pruned model size: 90.76 MB
