In [39]:
import torch
from torch.utils.data import DataLoader, random_split, Subset
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.optim.lr_scheduler import CosineAnnealingLR
import random
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
import copy

In [40]:
def fix_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Fixed random seed: {seed}")

fix_random_seed(42)

# For deterministic DataLoader behavior
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

Fixed random seed: 42


In [41]:

# ==========================================================
# Configuration
# ==========================================================
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
BATCH_SIZE = 128
NUM_WORKERS = 4

g = torch.Generator()
g.manual_seed(42)

# ==========================================================
# Transforms
# ==========================================================
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std),
])

test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std),
])

# ==========================================================
# Datasets and Splits
# ==========================================================
# Load once (without transform)
full_train = torchvision.datasets.CIFAR100(root='./data', train=True, download=True)
full_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True)

# Split indices using a fixed seed generator
train_indices, val_indices = random_split(range(len(full_train)), [45000, 5000], generator=g)
test_indices, small_test_indices = random_split(range(len(full_test)), [7500, 2500], generator=g)

# Wrap each subset with its own dataset and transform
train_dataset = Subset(
    torchvision.datasets.CIFAR100(root='./data', train=True, transform=train_transform),
    train_indices.indices
)
val_dataset = Subset(
    torchvision.datasets.CIFAR100(root='./data', train=True, transform=test_transform),
    val_indices.indices
)
test_dataset = Subset(
    torchvision.datasets.CIFAR100(root='./data', train=False, transform=test_transform),
    test_indices.indices
)
small_test_dataset = Subset(
    torchvision.datasets.CIFAR100(root='./data', train=False, transform=test_transform),
    small_test_indices.indices
)

print("Training set size:", len(train_dataset))
print("Validation set size:", len(val_dataset))
print("Test set size:", len(test_dataset))
print("Small test size:", len(small_test_dataset))

# ==========================================================
# Dataloaders
# ==========================================================
def get_loader(dataset, shuffle):
    return DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=g
    )

train_loader = get_loader(train_dataset, shuffle=True)
val_loader = get_loader(val_dataset, shuffle=False)
test_loader = get_loader(test_dataset, shuffle=False)
small_test_loader = get_loader(small_test_dataset, shuffle=False)


Training set size: 45000
Validation set size: 5000
Test set size: 7500
Small test size: 2500




In [42]:
def plot_weight_histograms(model, title):
    # Collect (layer_name, weights) pairs
    layer_weights = []
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
            layer_weights.append((name, module.weight.detach().cpu().numpy().flatten()))

    # Plot histograms
    plt.figure(figsize=(15, 8))
    for i, (name, w) in enumerate(layer_weights[:11]):  # limit to first 11 for readability
        plt.subplot(5, 6, i + 1)
        plt.hist(w, bins=40, color='blue', alpha=0.7)
        plt.title(name, fontsize=9)  # <-- actual layer name
        plt.tight_layout()
    plt.suptitle(title, y=1.02)
    plt.show()

In [43]:
def print_model_sparsity(model):
    total_params = 0
    zero_params = 0

    for name, param in model.named_parameters():
        if param is not None:
            numel = param.numel()
            total_params += numel
            zero_params += torch.sum(param == 0).item()

    sparsity = 100.0 * zero_params / total_params if total_params > 0 else 0.0

    print(f"Total Parameters: {total_params:,}")
    print(f"Zero (Pruned) Parameters: {zero_params:,}")
    print(f"Sparsity: {sparsity:.2f}%")

In [44]:
vgg11_pruned = models.vgg11(weights=models.VGG11_Weights.IMAGENET1K_V1)
vgg11_pruned.classifier[6] = torch.nn.Linear(4096, 100)
vgg11_pruned.classifier[5] = torch.nn.Dropout(p=0.5) # Dropout
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
vgg11_pruned.load_state_dict(torch.load('models/vgg11_cifar100_pruned_unstructured.pt', weights_only=True, map_location=device))
vgg11_pruned.to(device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 

In [45]:
def evaluate(model, test_loader, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            preds = model(images)
            preds = torch.argmax(preds, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total * 100
    return accuracy

In [46]:
# Sparsity of pruned model
print("Unstructured Pruning before Retraining")
print_model_sparsity(vgg11_pruned)

Unstructured Pruning before Retraining
Total Parameters: 129,176,036
Zero (Pruned) Parameters: 90,415,675
Sparsity: 69.99%


In [47]:
acc = evaluate(vgg11_pruned, test_loader, device)
print(f"Unstructured Pruning Accuracy Before Retraining: {acc:.2f}%")

acc = evaluate(vgg11_pruned, small_test_loader, device)
print(f"Unstructured Pruning Accuracy Before Retraining (Small): {acc:.2f}%")

Unstructured Pruning Accuracy Before Retraining: 72.73%
Unstructured Pruning Accuracy Before Retraining (Small): 71.44%


In [48]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.to(device)
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    avg_loss = running_loss / total
    accuracy = 100. * correct / total
    return avg_loss, accuracy

In [49]:
base_lr = 0.04
weight_decay = 5e-4
num_epochs = 5
optim = torch.optim.SGD(
    vgg11_pruned.parameters(),
    lr=base_lr,
    momentum=0.9,
    weight_decay=weight_decay
)
scheduler = CosineAnnealingLR(optim, T_max=num_epochs, eta_min=1e-5)
criterion = torch.nn.CrossEntropyLoss()

In [50]:
best_val = 0
best_retrained = copy.deepcopy(vgg11_pruned)
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(vgg11_pruned, train_loader, criterion, optim, device)
    scheduler.step()
    val_acc = evaluate(vgg11_pruned, val_loader, device)
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")
    if val_acc > best_val:
      best_retrained = copy.deepcopy(vgg11_pruned)
      best_val = val_acc
      torch.save(vgg11_pruned.state_dict(), "vgg11_cifar100_pruned_unstructured_finetuned.pt")

Epoch 1/5 | Loss: 1.3131 | Train Acc: 62.96% | Val Acc: 57.08%
Epoch 2/5 | Loss: 1.1208 | Train Acc: 68.10% | Val Acc: 62.36%
Epoch 3/5 | Loss: 0.7556 | Train Acc: 77.70% | Val Acc: 67.04%
Epoch 4/5 | Loss: 0.4037 | Train Acc: 87.75% | Val Acc: 69.44%
Epoch 5/5 | Loss: 0.2013 | Train Acc: 93.74% | Val Acc: 72.14%


In [51]:
print("Unstructured Pruning after Retraining")
print_model_sparsity(best_retrained)

Unstructured Pruning after Retraining
Total Parameters: 129,176,036
Zero (Pruned) Parameters: 110,576
Sparsity: 0.09%


In [52]:
acc = evaluate(best_retrained, test_loader, device)
print(f"Unstructured Pruning Accuracy After Retraining: {acc:.2f}%")

acc = evaluate(best_retrained, small_test_loader, device)
print(f"Unstructured Pruning Accuracy After Retraining (Small): {acc:.2f}%")

Unstructured Pruning Accuracy After Retraining: 72.64%
Unstructured Pruning Accuracy After Retraining (Small): 72.32%


In [53]:
# For simplicity, use the same pruning settings as applied previously

base_acc = 71.44

with open('models/sensitivity_dict_unstructured.pkl', 'rb') as f:
    sensitivity_dict = pickle.load(f)

prune_amts = [10, 20, 30, 40, 50, 60, 70, 80, 90]
desired_prune = 0.7
max_layer_prune = 0.9
min_layer_prune = 0.1  # don't prune below 10% to avoid imbalance

# 1) gather layers and parameter counts
layer_param_count = {}
conv_order = []  # track order for depth weighting
for name, module in best_retrained.named_modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        nparams = sum(p.numel() for p in module.parameters())
        layer_param_count[name] = nparams
        conv_order.append(name)

total_params = sum(layer_param_count.values())
total_remove_target = int(desired_prune * total_params)

# 2) compute robustness R_l (lower R = more sensitive)
R = {}
for name in layer_param_count:
    accs = []
    for p in prune_amts:
        acc = sensitivity_dict.get((name, p))
        if acc is not None:
            accs.append(acc)
    if not accs or base_acc == 0:
        R[name] = 1.0
    else:
        R[name] = (sum(accs) / len(accs)) / base_acc

print("Robustness:")
print(R)

# 3) convert robustness to sensitivity (S = 1 / R) with clamping
S = {n: 1.0 / max(R[n], 1e-6) for n in R}

# 4) apply depth weighting (deeper layers get more pruning)
depth_weights = {}
for i, name in enumerate(conv_order):
    depth_weights[name] = 1.0 + (i / len(conv_order))  # linear increase by depth

# 5) combine into pruning importance score
# lower S -> less pruning; higher depth -> more pruning
score = {n: depth_weights[n] / S[n] for n in layer_param_count}

# 6) allocate pruning proportionally to score * param_count
denom = sum(score[n] * layer_param_count[n] for n in score)
remove_alloc = {}
for name in score:
    remove_alloc[name] = total_remove_target * (score[name] * layer_param_count[name]) / denom

# 7) compute final per-layer pruning %
prune_pct = {}
for name in remove_alloc:
    pct = remove_alloc[name] / layer_param_count[name]
    pct = min(max(pct, min_layer_prune), max_layer_prune)
    prune_pct[name] = pct

print("\nAdjusted Prune Percentages:")
for k, v in prune_pct.items():
    print(f"{k}: {v*100:.2f}%")

Robustness:
{'features.0': 0.8597113350752769, 'features.3': 1.016548463356974, 'features.6': 1.0189747418190866, 'features.8': 1.0093940525071545, 'features.11': 1.009269627970636, 'features.13': 1.0070921985815602, 'features.16': 0.9934677118327734, 'features.18': 0.995458504417071, 'classifier.0': 1.0212143834764218, 'classifier.3': 1.0217742938907552, 'classifier.6': 1.022769690182904}

Adjusted Prune Percentages:
features.0: 34.25%
features.3: 44.18%
features.6: 47.97%
features.8: 51.18%
features.11: 54.83%
features.13: 58.35%
features.16: 61.16%
features.18: 64.89%
classifier.0: 70.27%
classifier.3: 74.01%
classifier.6: 77.78%


In [54]:
def prune_layer_unstructured_L2(module, amount=0.3):
    # Only prune Conv2d or Linear layers
    if not isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        return 0

    with torch.no_grad():
        # Flatten weights
        weights = module.weight.data.view(-1)
        num_params = weights.numel()
        num_prune = int(amount * num_params)

        if num_prune == 0:
            return 0

        # Compute L2 magnitudes (since weights are scalars, this is just abs)
        magnitudes = weights.abs()

        # Get threshold
        threshold = torch.topk(magnitudes, num_prune, largest=False).values.max()

        # Zero out the smallest weights
        mask = magnitudes > threshold
        module.weight.data.view(-1)[~mask] = 0

        return num_prune

In [55]:
# 5) apply pruning (use a copy of the model)
vgg_pruned_finetuned = copy.deepcopy(best_retrained)

for name, module in vgg_pruned_finetuned.named_modules():
    if name in prune_pct:
        p = prune_pct[name]
        if p <= 0:
            continue
        if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
            try:
                prune_layer_unstructured_L2(module, amount=p)
            except Exception as e:
                print(f"Cannot prune layer {name}: {e}")


In [56]:
print("Unstructured Pruning after Finetuning and Pruning Again")
print_model_sparsity(vgg_pruned_finetuned)

Unstructured Pruning after Finetuning and Pruning Again
Total Parameters: 129,176,036
Zero (Pruned) Parameters: 90,415,678
Sparsity: 69.99%


In [57]:
acc = evaluate(vgg_pruned_finetuned, test_loader, device)
print(f"Unstructured Pruning after Finetuning and Pruning Again: {acc:.2f}%")

acc = evaluate(vgg_pruned_finetuned, small_test_loader, device)
print(f"Unstructured Pruning after Finetuning and Pruning Again (Small): {acc:.2f}%")

Unstructured Pruning after Finetuning and Pruning Again: 72.28%
Unstructured Pruning after Finetuning and Pruning Again (Small): 71.72%


In [58]:
torch.save(vgg_pruned_finetuned.state_dict(), "vgg11_cifar100_pruned_unstructured_finetuned_pruned.pt")