# Initialization and data loading

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.utils.prune as prune

from torchinfo import summary
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define the transform
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

# Load CIFAR100 test dataset
test_dataset = datasets.CIFAR100('data_src', train=False, download=True, transform=transform)

Files already downloaded and verified


# Make the dataset smaller so it easier for demo 10 for each class

In [2]:
# Create an empty list to store the selected samples
selected_samples = []

# Create a dictionary to keep track of how many samples you have collected per class
samples_per_class = {class_idx: 0 for class_idx in range(100)}

# Iterate through the dataset and select 10 samples from each class
for i in range(len(test_dataset)):
    image, target = test_dataset[i]
    class_idx = target
    if samples_per_class[class_idx] < 10:
        selected_samples.append((image, target))
        samples_per_class[class_idx] += 1

# Create a custom dataset from the selected samples
class CustomCIFAR100Dataset(torch.utils.data.Dataset):
    def __init__(self, samples):
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

# Create the final test dataset with 10 samples from each class
final_test_dataset = CustomCIFAR100Dataset(selected_samples)

# Create a DataLoader for the final test dataset
test_dataloader = DataLoader(final_test_dataset, batch_size=32, shuffle=False, drop_last=True)


# 50% prune

In [3]:

# Load the pre-trained EfficientNet B0 model
modified_EfficientNetB0 = models.efficientnet_b0(pretrained=True)


# Pruning function to prune 50% of connections in the given module
def prune_module(module, amount=0.50):
    for name, module in module.named_modules():
        # Prune only Conv2d layers
        if isinstance(module, nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=amount)

# Apply pruning to the entire model
prune_module(modified_EfficientNetB0)

# Modify the classifier for CIFAR-100 (which has 100 classes)
in_features = modified_EfficientNetB0.classifier[1].in_features

# Replace the classifier
modified_EfficientNetB0.classifier = nn.Sequential(
    nn.Linear(in_features, 100)  # CIFAR-100 has 100 classes
)
modified_model = modified_EfficientNetB0.to(device)

# Load the trained model
model_path = "prune50_train.pth"
# Define the device

modified_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
modified_model.eval()
print(f"Loaded model from {model_path}")

print('Number of samples:',len(test_dataset))

model_input_size = (32, 3, 224, 224)  # Example input size (batch_size, channels, height, width)

# Print the summary
model_summary = summary(modified_model, input_size=model_input_size)
print(model_summary)

# CIFAR100 class names
class_names = test_dataset.classes

# Initialize arrays to keep track of correct predictions and total samples for each class
num_classes = len(class_names)
class_correct = np.zeros(num_classes)
class_total = np.zeros(num_classes)

# Test the model
with torch.no_grad():
    for data, labels in tqdm(test_dataloader, desc="Testing Progress"):
        data, labels = data.to(device), labels.to(device)
        outputs = modified_model(data)
        _, predicted = torch.max(outputs.data, 1)
        
        correct = (predicted == labels)
        for label, is_correct in zip(labels, correct):
            class_correct[label] += is_correct.item()
            class_total[label] += 1

# Calculate and print accuracy for each class
class_accuracies = 100 * class_correct / class_total
# Create a list of tuples (class_name, accuracy) and sort it by accuracy
prune50_acc = [(class_names[i], class_accuracies[i]) for i in range(num_classes)]
prune50_acc.sort(key=lambda x: x[1])





Loaded model from prune50_train.pth
Number of samples: 10000
Layer (type:depth-idx)                                  Output Shape              Param #
EfficientNet                                            [32, 100]                 --
├─Sequential: 1-1                                       [32, 1280, 7, 7]          --
│    └─Conv2dNormActivation: 2-1                        [32, 32, 112, 112]        --
│    │    └─Conv2d: 3-1                                 [32, 32, 112, 112]        432
│    │    └─BatchNorm2d: 3-2                            [32, 32, 112, 112]        64
│    │    └─SiLU: 3-3                                   [32, 32, 112, 112]        --
│    └─Sequential: 2-2                                  [32, 16, 112, 112]        --
│    │    └─MBConv: 3-4                                 [32, 16, 112, 112]        792
│    └─Sequential: 2-3                                  [32, 24, 56, 56]          --
│    │    └─MBConv: 3-5                                 [32, 24, 56, 56]          

Testing Progress: 100%|██████████| 31/31 [00:20<00:00,  1.53it/s]


# 0% prune

In [4]:
# Load the pre-trained EfficientNet B0 model
modified_EfficientNetB0 = models.efficientnet_b0(pretrained=True)

# Modify the classifier for CIFAR-100 (which has 100 classes)
in_features = modified_EfficientNetB0.classifier[1].in_features

# Replace the classifier
modified_EfficientNetB0.classifier = nn.Sequential(
    nn.Linear(in_features, 100)  # CIFAR-100 has 100 classes
)
modified_model = modified_EfficientNetB0.to(device)

# Load the trained model
model_path = "prune0_train.pth"
# Define the device

modified_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
modified_model.eval()
print(f"Loaded model from {model_path}")

print('Number of samples:',len(test_dataset))

model_input_size = (32, 3, 224, 224)  # Example input size (batch_size, channels, height, width)

# Print the summary
model_summary = summary(modified_model, input_size=model_input_size)
print(model_summary)

# CIFAR100 class names
class_names = test_dataset.classes

# Initialize arrays to keep track of correct predictions and total samples for each class
num_classes = len(class_names)
class_correct = np.zeros(num_classes)
class_total = np.zeros(num_classes)

# Test the model
with torch.no_grad():
    for data, labels in tqdm(test_dataloader, desc="Testing Progress"):
        data, labels = data.to(device), labels.to(device)
        outputs = modified_model(data)
        _, predicted = torch.max(outputs.data, 1)
        
        correct = (predicted == labels)
        for label, is_correct in zip(labels, correct):
            class_correct[label] += is_correct.item()
            class_total[label] += 1

# Calculate and print accuracy for each class
class_accuracies = 100 * class_correct / class_total
# Create a list of tuples (class_name, accuracy) and sort it by accuracy
prune0_acc = [(class_names[i], class_accuracies[i]) for i in range(num_classes)]
prune0_acc.sort(key=lambda x: x[1])



Loaded model from prune0_train.pth
Number of samples: 10000
Layer (type:depth-idx)                                  Output Shape              Param #
EfficientNet                                            [32, 100]                 --
├─Sequential: 1-1                                       [32, 1280, 7, 7]          --
│    └─Conv2dNormActivation: 2-1                        [32, 32, 112, 112]        --
│    │    └─Conv2d: 3-1                                 [32, 32, 112, 112]        864
│    │    └─BatchNorm2d: 3-2                            [32, 32, 112, 112]        64
│    │    └─SiLU: 3-3                                   [32, 32, 112, 112]        --
│    └─Sequential: 2-2                                  [32, 16, 112, 112]        --
│    │    └─MBConv: 3-4                                 [32, 16, 112, 112]        1,448
│    └─Sequential: 2-3                                  [32, 24, 56, 56]          --
│    │    └─MBConv: 3-5                                 [32, 24, 56, 56]         

Testing Progress: 100%|██████████| 31/31 [00:20<00:00,  1.53it/s]


# 15% prune

In [5]:
# Load the pre-trained EfficientNet B0 model
modified_EfficientNetB0 = models.efficientnet_b0(pretrained=True)


# Pruning function to prune 50% of connections in the given module
def prune_module(module, amount=0.15):
    for name, module in module.named_modules():
        # Prune only Conv2d layers
        if isinstance(module, nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=amount)

# Apply pruning to the entire model
prune_module(modified_EfficientNetB0)

# Modify the classifier for CIFAR-100 (which has 100 classes)
in_features = modified_EfficientNetB0.classifier[1].in_features

# Replace the classifier
modified_EfficientNetB0.classifier = nn.Sequential(
    nn.Linear(in_features, 100)  # CIFAR-100 has 100 classes
)
modified_model = modified_EfficientNetB0.to(device)

# Load the trained model
model_path = "prune15_train.pth"
# Define the device

modified_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
modified_model.eval()
print(f"Loaded model from {model_path}")

print('Number of samples:',len(test_dataset))

model_input_size = (32, 3, 224, 224)  # Example input size (batch_size, channels, height, width)

# Print the summary
model_summary = summary(modified_model, input_size=model_input_size)
print(model_summary)

# CIFAR100 class names
class_names = test_dataset.classes

# Initialize arrays to keep track of correct predictions and total samples for each class
num_classes = len(class_names)
class_correct = np.zeros(num_classes)
class_total = np.zeros(num_classes)

# Test the model
with torch.no_grad():
    for data, labels in tqdm(test_dataloader, desc="Testing Progress"):
        data, labels = data.to(device), labels.to(device)
        outputs = modified_model(data)
        _, predicted = torch.max(outputs.data, 1)
        
        correct = (predicted == labels)
        for label, is_correct in zip(labels, correct):
            class_correct[label] += is_correct.item()
            class_total[label] += 1

# Calculate and print accuracy for each class
class_accuracies = 100 * class_correct / class_total
# Create a list of tuples (class_name, accuracy) and sort it by accuracy
prune15_acc = [(class_names[i], class_accuracies[i]) for i in range(num_classes)]
prune15_acc.sort(key=lambda x: x[1])



Loaded model from prune15_train.pth
Number of samples: 10000
Layer (type:depth-idx)                                  Output Shape              Param #
EfficientNet                                            [32, 100]                 --
├─Sequential: 1-1                                       [32, 1280, 7, 7]          --
│    └─Conv2dNormActivation: 2-1                        [32, 32, 112, 112]        --
│    │    └─Conv2d: 3-1                                 [32, 32, 112, 112]        734
│    │    └─BatchNorm2d: 3-2                            [32, 32, 112, 112]        64
│    │    └─SiLU: 3-3                                   [32, 32, 112, 112]        --
│    └─Sequential: 2-2                                  [32, 16, 112, 112]        --
│    │    └─MBConv: 3-4                                 [32, 16, 112, 112]        1,252
│    └─Sequential: 2-3                                  [32, 24, 56, 56]          --
│    │    └─MBConv: 3-5                                 [32, 24, 56, 56]        

Testing Progress: 100%|██████████| 31/31 [00:20<00:00,  1.50it/s]


# Final Result

In [6]:
# Function to print a table with class names and accuracies
def print_accuracy_table(title, data):
    print(title)
    print("+------------+------------+")
    print("| Class Name | Accuracy % |")
    print("+------------+------------+")
    for class_name, accuracy in data:
        print(f"| {class_name:10s} | {accuracy:10.2f}% |")
    print("+------------+------------+\n")

print_accuracy_table("Top 10 Most Accurate Classes Prune50:", prune50_acc[-10:][::-1])
print_accuracy_table("Top 10 Most Accurate Classes Prune15:", prune15_acc[-10:][::-1])
print_accuracy_table("Top 10 Most Accurate Classes Prune0:", prune0_acc[-10:][::-1])



print_accuracy_table("Top 10 Most Least Classes Prune50:", prune50_acc[:10][::-1])
print_accuracy_table("Top 10 Most Least Classes Prune15:", prune15_acc[:10][::-1])
print_accuracy_table("Top 10 Most Least Classes Prune0:", prune0_acc[:10][::-1])



Top 10 Most Accurate Classes Prune50:
+------------+------------+
| Class Name | Accuracy % |
+------------+------------+
| wardrobe   |     100.00% |
| tractor    |     100.00% |
| tank       |     100.00% |
| sunflower  |     100.00% |
| spider     |     100.00% |
| sea        |     100.00% |
| rose       |     100.00% |
| road       |     100.00% |
| poppy      |     100.00% |
| plate      |     100.00% |
+------------+------------+

Top 10 Most Accurate Classes Prune15:
+------------+------------+
| Class Name | Accuracy % |
+------------+------------+
| wolf       |     100.00% |
| wardrobe   |     100.00% |
| train      |     100.00% |
| tank       |     100.00% |
| sunflower  |     100.00% |
| sea        |     100.00% |
| rose       |     100.00% |
| road       |     100.00% |
| plate      |     100.00% |
| plain      |     100.00% |
+------------+------------+

Top 10 Most Accurate Classes Prune0:
+------------+------------+
| Class Name | Accuracy % |
+------------+-----------