In [3]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm
import torch.nn as nn
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load CIFAR-10
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# Data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Load pre-trained EfficientNet
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)
device = torch.device("mps")
model = model.to(device)

model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [16]:
# Training loop
# for epoch in range(10):  # number of epochs
#     model.train()
#     for images, labels in train_loader:
#         images, labels = images.to(device), labels.to(device)
#         optimizer.zero_grad()
#         outputs = model(images)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
#     print(f'Epoch {epoch+1}, Loss: {loss.item()}')

state_dict = torch.load("efficientnet_state")
model.load_state_dict(state_dict)

<All keys matched successfully>

In [17]:
def evaluate_model(testing_model, dataloader, suppress_output=False):
    testing_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = testing_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    if not suppress_output:
        print(f'Accuracy of the model on the test images: {accuracy}%')
    return accuracy

In [14]:
%load_ext autoreload
%autoreload now
import pruning_funcs
import copy
import numpy as np
tests = np.arange(0.001, 0.05, 0.005)
num_trials = 5

print()
print('Avg. accuracy at scale values for Normalized Laplace Distribution')

print()
print('Unpruned Model')


accuracy = evaluate_model(model, test_loader, suppress_output=True)
percent_zero = pruning_funcs.percent_zero_weights(model)
print(f'Accuracy: {accuracy:.2f}%\tPercent Zero: {percent_zero:.2f}%')

print()
print("Pruned Models (Normalized Laplace)")
for prune_scale in tests:
    accuracy = 0.
    percent_zeros = 0.
    for i in range(num_trials):
        pruned_model = copy.deepcopy(model)
        pruning_funcs.normalized_laplace_prune(pruned_model, device, scale=prune_scale)
        accuracy += evaluate_model(pruned_model, test_loader, suppress_output=True)
        percent_zeros += pruning_funcs.percent_zero_weights(pruned_model)
    accuracy /= num_trials
    percent_zeros /= num_trials
    print(f'Scale: {prune_scale:.3f}\tAvg. Accuracy: {accuracy:.2f}%\tAvg. Percent Zero: {percent_zeros:.2f}%')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Avg. accuracy at scale values for Normalized Laplace Distribution

Unpruned Model
Accuracy: 93.01%	Percent Zero: 0.00%

Pruned Models (Normalized Laplace)
Scale: 0.001	Avg. Accuracy: 92.99%	Avg. Percent Zero: 0.68%
Scale: 0.006	Avg. Accuracy: 92.95%	Avg. Percent Zero: 4.07%
Scale: 0.011	Avg. Accuracy: 92.84%	Avg. Percent Zero: 7.41%
Scale: 0.016	Avg. Accuracy: 92.91%	Avg. Percent Zero: 10.66%
Scale: 0.021	Avg. Accuracy: 92.68%	Avg. Percent Zero: 13.80%
Scale: 0.026	Avg. Accuracy: 92.26%	Avg. Percent Zero: 16.82%
Scale: 0.031	Avg. Accuracy: 91.79%	Avg. Percent Zero: 19.69%
Scale: 0.036	Avg. Accuracy: 90.49%	Avg. Percent Zero: 22.43%
Scale: 0.041	Avg. Accuracy: 89.08%	Avg. Percent Zero: 25.03%
Scale: 0.046	Avg. Accuracy: 87.60%	Avg. Percent Zero: 27.49%


In [18]:
percents = np.arange(1,6,1)
print("Pruned Models (Standard Percent Pruning)")
for percent in percents:
    pruned_model = copy.deepcopy(model)
    pruning_funcs.percent_prune(pruned_model, device, percent=percent)
    accuracy = evaluate_model(pruned_model, test_loader, suppress_output=True)
    percent_zero = pruning_funcs.percent_zero_weights(pruned_model)
    print(f'Theoretic percent pruned: {percent}%\tActual percent pruned: {percent_zero:.2f}%\tAccuracy: {accuracy:.2f}%')

Pruned Models (Standard Percent Pruning)
Theoretic percent pruned: 1%	Actual percent pruned: 1.00%	Accuracy: 93.12%
Theoretic percent pruned: 2%	Actual percent pruned: 2.00%	Accuracy: 92.73%


KeyboardInterrupt: 