In [9]:
import torch
from torch.utils.data import DataLoader, random_split, Subset
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from torch.optim.lr_scheduler import CosineAnnealingLR
import random
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
import copy

In [10]:
def fix_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Fixed random seed: {seed}")

fix_random_seed(42)

# For deterministic DataLoader behavior
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

Fixed random seed: 42


In [11]:

# ==========================================================
# Configuration
# ==========================================================
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
BATCH_SIZE = 128
NUM_WORKERS = 4

g = torch.Generator()
g.manual_seed(42)

# ==========================================================
# Transforms
# ==========================================================
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std),
])

test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std),
])

# ==========================================================
# Datasets and Splits
# ==========================================================
# Load once (without transform)
full_train = torchvision.datasets.CIFAR100(root='./data', train=True, download=True)
full_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True)

# Split indices using a fixed seed generator
train_indices, val_indices = random_split(range(len(full_train)), [45000, 5000], generator=g)
test_indices, small_test_indices = random_split(range(len(full_test)), [7500, 2500], generator=g)

# Wrap each subset with its own dataset and transform
train_dataset = Subset(
    torchvision.datasets.CIFAR100(root='./data', train=True, transform=train_transform),
    train_indices.indices
)
val_dataset = Subset(
    torchvision.datasets.CIFAR100(root='./data', train=True, transform=test_transform),
    val_indices.indices
)
test_dataset = Subset(
    torchvision.datasets.CIFAR100(root='./data', train=False, transform=test_transform),
    test_indices.indices
)
small_test_dataset = Subset(
    torchvision.datasets.CIFAR100(root='./data', train=False, transform=test_transform),
    small_test_indices.indices
)

print("Training set size:", len(train_dataset))
print("Validation set size:", len(val_dataset))
print("Test set size:", len(test_dataset))
print("Small test size:", len(small_test_dataset))

# ==========================================================
# Dataloaders
# ==========================================================
def get_loader(dataset, shuffle):
    return DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=g
    )

train_loader = get_loader(train_dataset, shuffle=True)
val_loader = get_loader(val_dataset, shuffle=False)
test_loader = get_loader(test_dataset, shuffle=False)
small_test_loader = get_loader(small_test_dataset, shuffle=False)


Training set size: 45000
Validation set size: 5000
Test set size: 7500
Small test size: 2500


In [12]:
def plot_weight_histograms(model, title):
    # Collect (layer_name, weights) pairs
    layer_weights = []
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
            layer_weights.append((name, module.weight.detach().cpu().numpy().flatten()))

    # Plot histograms
    plt.figure(figsize=(15, 8))
    for i, (name, w) in enumerate(layer_weights[:11]):  # limit to first 11 for readability
        plt.subplot(5, 6, i + 1)
        plt.hist(w, bins=40, color='blue', alpha=0.7)
        plt.title(name, fontsize=9)  # <-- actual layer name
        plt.tight_layout()
    plt.suptitle(title, y=1.02)
    plt.show()

In [13]:
state_dict = torch.load('models/vgg11_cifar100_pruned_structured.pt', map_location='cpu')
for k, v in state_dict.items():
    if 'weight' in k and len(v.shape) == 4:
        print(f"{k}: {v.shape}")

features.0.weight: torch.Size([52, 3, 3, 3])
features.3.weight: torch.Size([101, 52, 3, 3])
features.6.weight: torch.Size([196, 101, 3, 3])
features.8.weight: torch.Size([186, 196, 3, 3])
features.11.weight: torch.Size([361, 186, 3, 3])
features.13.weight: torch.Size([345, 361, 3, 3])
features.16.weight: torch.Size([330, 345, 3, 3])
features.18.weight: torch.Size([512, 330, 3, 3])


In [14]:
def make_vgg11_pruned():
    cfg = [52, 'M', 101, 'M', 196, 186, 'M', 361, 345, 'M', 330, 512, 'M']

    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            layers += [nn.Conv2d(in_channels, v, kernel_size=3, padding=1),
                       nn.ReLU(inplace=True)]
            in_channels = v
    features = nn.Sequential(*layers)

    # Classifier for CIFAR-100
    classifier = nn.Sequential(
        nn.Linear(512 * 1 * 1, 4096),  # adjust if you used adaptive pooling
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, 100)
    )

    model = models.vgg.VGG(features, num_classes=100)
    return model

# Instantiate pruned model
vgg11_pruned = make_vgg11_pruned()

# Load the pruned weights
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
state_dict = torch.load('models/vgg11_cifar100_pruned_structured.pt', map_location=device)
vgg11_pruned.load_state_dict(state_dict)
vgg11_pruned.to(device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 52, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(52, 101, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(101, 196, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(196, 186, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(186, 361, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(361, 345, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 

In [16]:
def evaluate(model, test_loader, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            preds = model(images)
            preds = torch.argmax(preds, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total * 100
    return accuracy

In [17]:
acc = evaluate(vgg11_pruned, test_loader, device)
print(f"Structured Pruning Accuracy Before Retraining: {acc:.2f}%")

acc = evaluate(vgg11_pruned, small_test_loader, device)
print(f"Structured Pruning Accuracy Before Retraining (Small): {acc:.2f}%")

Structured Pruning Accuracy Before Retraining: 54.93%
Structured Pruning Accuracy Before Retraining (Small): 54.60%


In [18]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.to(device)
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    avg_loss = running_loss / total
    accuracy = 100. * correct / total
    return avg_loss, accuracy

In [19]:
base_lr = 0.04
weight_decay = 5e-4
num_epochs = 5
optim = torch.optim.SGD(
    vgg11_pruned.parameters(),
    lr=base_lr,
    momentum=0.9,
    weight_decay=weight_decay
)
scheduler = CosineAnnealingLR(optim, T_max=num_epochs, eta_min=1e-5)
criterion = torch.nn.CrossEntropyLoss()

In [20]:
best_val = 0
best_retrained = copy.deepcopy(vgg11_pruned)
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(vgg11_pruned, train_loader, criterion, optim, device)
    scheduler.step()
    val_acc = evaluate(vgg11_pruned, val_loader, device)
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")
    if val_acc > best_val:
      best_retrained = copy.deepcopy(vgg11_pruned)
      best_val = val_acc
      torch.save(vgg11_pruned.state_dict(), "vgg11_cifar100_pruned_structured_finetuned.pt")

Epoch 1/5 | Loss: 1.3185 | Train Acc: 62.44% | Val Acc: 56.26%
Epoch 2/5 | Loss: 1.1066 | Train Acc: 68.00% | Val Acc: 60.74%
Epoch 3/5 | Loss: 0.7368 | Train Acc: 78.24% | Val Acc: 64.42%
Epoch 4/5 | Loss: 0.3972 | Train Acc: 87.92% | Val Acc: 68.96%
Epoch 5/5 | Loss: 0.1966 | Train Acc: 93.85% | Val Acc: 71.36%


In [21]:
acc = evaluate(best_retrained, test_loader, device)
print(f"Structured Pruning Accuracy After Retraining: {acc:.2f}%")

acc = evaluate(best_retrained, small_test_loader, device)
print(f"Structured Pruning Accuracy After Retraining (Small): {acc:.2f}%")

Structured Pruning Accuracy After Retraining: 71.64%
Structured Pruning Accuracy After Retraining (Small): 70.40%
