# EfficientNet Experiments

### Load packages (and magic formulas), define model, and download datasets

In [None]:
%load_ext autoreload
import timm
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models.resnet import resnet34
from torch import nn
import torch.optim as optim
from torch import nn
import copy
import pruning_funcs
import numpy as np
import matplotlib.pyplot as plt

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load CIFAR-10
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# Data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

### Load Pre-trained EfficientNet weights

In [2]:
# Load pre-trained EfficientNet
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)
device = torch.device("mps")
model = model.to(device)

### Define Hyperparameters

In [3]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

### Train/ Load Weights
This cell is where the model is trained (lines 1-2), or loaded from the state dict (lines 4-7). 
The intention is to either do one or the other, not both.

In [None]:
# Training loop
# for epoch in range(10):  # number of epochs
#     model.train()
#     for images, labels in train_loader:
#         images, labels = images.to(device), labels.to(device)
#         optimizer.zero_grad()
#         outputs = model(images)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
#     print(f'Epoch {epoch+1}, Loss: {loss.item()}')
state_dict = torch.load("efficientnet_state")
model.load_state_dict(state_dict)

### Evaluate Accuracy

In [5]:
def evaluate_model(testing_model, dataloader, suppress_output=False):
    testing_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = testing_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    if not suppress_output:
        print(f'Accuracy of the model on the test images: {accuracy}%')
    return accuracy

# Pruning Methods

### Normalized Laplace Pruning (Algorithm 3 in Paper)

In [None]:
# Normalized Laplace

%autoreload now

tests = np.arange(0.008, 0.081, 0.008)
num_trials = 30

print()
print('Avg. accuracy at scale values for Normalized Laplace Distribution')

print()
print('Unpruned Model')

accuracy = evaluate_model(model, test_loader, suppress_output=True)
percent_zero = pruning_funcs.percent_zero_weights(model)
print(f'Accuracy: {accuracy:.2f}%\tPercent Zero: {percent_zero:.2f}%')

print()
print("Pruned Models (Normalized Laplace)")
for prune_scale in tests:
    accuracy = 0.
    percent_zeros = 0.
    for i in range(num_trials):
        pruned_model = copy.deepcopy(model)
        pruning_funcs.normalized_laplace_prune(pruned_model, device, scale=prune_scale)
        accuracy += evaluate_model(pruned_model, test_loader, suppress_output=True)
        percent_zeros += pruning_funcs.percent_zero_weights(pruned_model)
    accuracy /= num_trials
    percent_zeros /= num_trials
    print(f'Scale: {prune_scale:.3f}\tAvg. Accuracy: {accuracy:.2f}%\tAvg. Percent Zero: {percent_zeros:.2f}%')

### Test Percent by Layer pruning method (Algorithm 1 in Paper)

In [7]:
# Percent by Layer

%autoreload now

percents = np.arange(5,51,5)
print()
print("Pruned Models (Standard Percent Pruning By Layer)")
for percent in percents:
    pruned_model = copy.deepcopy(model)
    pruning_funcs.percent_prune_by_layer(pruned_model, device, percent=percent)
    accuracy = evaluate_model(pruned_model, test_loader, suppress_output=True)
    percent_zero = pruning_funcs.percent_zero_weights(pruned_model)
    print(f'Theoretic percent pruned: {percent}%\tActual percent pruned: {percent_zero:.2f}%\tAccuracy: {accuracy:.2f}%')


Pruned Models (Standard Percent Pruning By Layer)
Theoretic percent pruned: 5%	Actual percent pruned: 5.00%	Accuracy: 82.65%
Theoretic percent pruned: 10%	Actual percent pruned: 10.00%	Accuracy: 26.52%
Theoretic percent pruned: 15%	Actual percent pruned: 15.00%	Accuracy: 10.16%
Theoretic percent pruned: 20%	Actual percent pruned: 20.00%	Accuracy: 10.00%
Theoretic percent pruned: 25%	Actual percent pruned: 25.00%	Accuracy: 10.00%
Theoretic percent pruned: 30%	Actual percent pruned: 30.00%	Accuracy: 10.00%
Theoretic percent pruned: 35%	Actual percent pruned: 35.00%	Accuracy: 10.00%
Theoretic percent pruned: 40%	Actual percent pruned: 40.00%	Accuracy: 10.00%
Theoretic percent pruned: 45%	Actual percent pruned: 45.00%	Accuracy: 10.00%
Theoretic percent pruned: 50%	Actual percent pruned: 50.00%	Accuracy: 10.00%


### Baseline Prune Algorithm 4 in Paper

In [8]:
# Bottom Percent
%autoreload now
percents = np.arange(5,51,5)
print()
print("Pruned Models (Standard Bottom Percent Pruning)")
for percent in percents:
    pruned_model = copy.deepcopy(model)
    pruning_funcs.bottom_percent_prune(pruned_model, device, percent=percent)
    accuracy = evaluate_model(pruned_model, test_loader, suppress_output=True)
    percent_zero = pruning_funcs.percent_zero_weights(pruned_model)
    print(f'Theoretic percent pruned: {percent}%\tActual percent pruned: {percent_zero:.2f}%\tAccuracy: {accuracy:.2f}%')



Pruned Models (Standard Bottom Percent Pruning)
Theoretic percent pruned: 5%	Actual percent pruned: 5.00%	Accuracy: 92.98%
Theoretic percent pruned: 10%	Actual percent pruned: 10.00%	Accuracy: 92.89%
Theoretic percent pruned: 15%	Actual percent pruned: 15.00%	Accuracy: 92.89%
Theoretic percent pruned: 20%	Actual percent pruned: 20.00%	Accuracy: 92.80%
Theoretic percent pruned: 25%	Actual percent pruned: 25.00%	Accuracy: 92.80%
Theoretic percent pruned: 30%	Actual percent pruned: 30.00%	Accuracy: 92.22%
Theoretic percent pruned: 35%	Actual percent pruned: 35.00%	Accuracy: 92.04%
Theoretic percent pruned: 40%	Actual percent pruned: 40.00%	Accuracy: 90.31%
Theoretic percent pruned: 45%	Actual percent pruned: 45.00%	Accuracy: 88.48%
Theoretic percent pruned: 50%	Actual percent pruned: 50.00%	Accuracy: 79.47%


In [None]:
### Stochastic Percent Pruning (Algorithm 2 in Paper)

# Bernoulli Pruning
%autoreload now

percent_prune = []
percent_prune_with_bernoulli = []
percents = np.arange(1,10,1)

for percent in percents:
    pruned_model = copy.deepcopy(model)
    pruning_funcs.bottom_percent_prune(pruned_model, device, percent=percent)
    percent_prune.append(evaluate_model(pruned_model, testloader, suppress_output=True))
    
    ppwb_runs = []
    for i in range(30):
        pruned_model_2 = copy.deepcopy(model)
        pruning_funcs.percent_prune_with_bernoulli(pruned_model_2, device, percent=(2*percent), p_success=0.5)
        ppwb_runs.append(evaluate_model(pruned_model_2, testloader, suppress_output=True))

    percent_prune_with_bernoulli.append(np.mean(ppwb_runs))

fig, ax = plt.subplots()
ax.set_title("Percent Pruning")
line1, = ax.plot(percents, percent_prune, color='blue')
line2, = ax.plot(percents, percent_prune_with_bernoulli, color='orange')
line1.set_label("Percent Pruning")
line2.set_label("Percent Pruning With Bernoulli (avg)")
ax.legend()
ax.set_xlabel("Percent pruned (or expected pruned)")
ax.set_ylabel("Model accuracy")
ax.set_xticks(percents)