# Some import function and intialization. These include the download of the CIFAR100 dataset

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.utils.prune as prune

from torchinfo import summary
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define the transform
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

# Load CIFAR100 test dataset
test_dataset = datasets.CIFAR100('data_src', train=False, download=True, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, drop_last=True)
print('Number of samples:',len(test_dataset))

Files already downloaded and verified
Number of samples: 10000


# User input to determine the pruning amount. The model was based on EfficientNetB0 and has 0% ,15% or 50% pruning done on the conv2d layer

In [15]:
pruning_amount=int(input('Please type in the pruning amount:0, 15 or 50:').strip())

Please type in the pruning amount:0, 15 or 50: 50


# Since the code is based on efficientnet's pruning, we have to reload efficientnet model, pruning from scratch first before we can load in the pretrain model as the pruning introduce new masking parameter

In [16]:
# Load the pre-trained EfficientNet B0 model
modified_EfficientNetB0 = models.efficientnet_b0(pretrained=True)


# Pruning function to prune 50% of connections in the given module
def prune_module(module, amount=pruning_amount/100):
    for name, module in module.named_modules():
        # Prune only Conv2d layers
        if isinstance(module, nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=amount)
            
if pruning_amount!=0:
    # Apply pruning to the entire model
    prune_module(modified_EfficientNetB0)

# Modify the classifier for CIFAR-100 (which has 100 classes)
in_features = modified_EfficientNetB0.classifier[1].in_features

# Replace the classifier
modified_EfficientNetB0.classifier = nn.Sequential(
    nn.Linear(in_features, 100)  # CIFAR-100 has 100 classes
)
modified_model = modified_EfficientNetB0.to(device)



# Loading the pretrain model and create the summary to look at the number of trainable parameter, might need to modify the map_location to ulilized GPU acceleration since this code was designed for CPU based system

In [17]:
# Load the trained model
model_path = "prune"+str(pruning_amount)+"_train.pth"
# Define the device

modified_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
modified_model.eval()
print(f"Loaded model from {model_path}")


model_input_size = (32, 3, 224, 224)  # Example input size (batch_size, channels, height, width)

# Print the summary
model_summary = summary(modified_model, input_size=model_input_size)
print(model_summary)




Loaded model from prune50_train.pth
Layer (type:depth-idx)                                  Output Shape              Param #
EfficientNet                                            [32, 100]                 --
├─Sequential: 1-1                                       [32, 1280, 7, 7]          --
│    └─Conv2dNormActivation: 2-1                        [32, 32, 112, 112]        --
│    │    └─Conv2d: 3-1                                 [32, 32, 112, 112]        432
│    │    └─BatchNorm2d: 3-2                            [32, 32, 112, 112]        64
│    │    └─SiLU: 3-3                                   [32, 32, 112, 112]        --
│    └─Sequential: 2-2                                  [32, 16, 112, 112]        --
│    │    └─MBConv: 3-4                                 [32, 16, 112, 112]        792
│    └─Sequential: 2-3                                  [32, 24, 56, 56]          --
│    │    └─MBConv: 3-5                                 [32, 24, 56, 56]          3,268
│    │    └─MBConv:

# Training loop

In [18]:
# CIFAR100 class names
class_names = test_dataset.classes

# Initialize arrays to keep track of correct predictions and total samples for each class
num_classes = len(class_names)
class_correct = np.zeros(num_classes)
class_total = np.zeros(num_classes)

# Test the model
with torch.no_grad():
    for data, labels in tqdm(test_dataloader, desc="Testing Progress"):
        data, labels = data.to(device), labels.to(device)
        outputs = modified_model(data)
        _, predicted = torch.max(outputs.data, 1)
        
        correct = (predicted == labels)
        for label, is_correct in zip(labels, correct):
            class_correct[label] += is_correct.item()
            class_total[label] += 1

# Calculate and print accuracy for each class
class_accuracies = 100 * class_correct / class_total
# Create a list of tuples (class_name, accuracy) and sort it by accuracy
acc_list = [(class_names[i], class_accuracies[i]) for i in range(num_classes)]
acc_list.sort(key=lambda x: x[1])


Testing Progress: 100%|██████████| 312/312 [03:46<00:00,  1.38it/s]


# Result for the most and lease accurate classes

In [19]:

# Top 10 most accurate classes
print("Top 10 Most Accurate Classes:")
for class_name, accuracy in acc_list[-10:][::-1]:
    print(f'{class_name}: {accuracy:.2f}%')

print("\nTop 10 Least Accurate Classes:")
# Top 10 least accurate classes
for class_name, accuracy in acc_list[:10]:
    print(f'{class_name}: {accuracy:.2f}%')


Top 10 Most Accurate Classes:
apple: 98.00%
wardrobe: 97.00%
motorcycle: 97.00%
road: 96.00%
chimpanzee: 96.00%
skunk: 95.96%
tractor: 95.00%
sunflower: 95.00%
orange: 94.00%
tank: 93.00%

Top 10 Least Accurate Classes:
girl: 49.00%
boy: 53.00%
seal: 60.00%
pine_tree: 61.00%
shrew: 61.00%
otter: 64.00%
bowl: 67.00%
willow_tree: 67.68%
possum: 68.00%
mouse: 68.69%
