### **Import libraries**

In [None]:
import numpy as np
import random
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
import torch
# from https://pytorch.org/docs/stable/notes/randomness.html#cuda-convolution-benchmarking
torch.backends.cudnn.deterministic = True            # force cuda to use deterministic algorithm
torch.backends.cudnn.benchmark = False               # disable cuda feature of selecting the fastest algorith
                                                     # (it can be non-deterministic and affect repeatability)
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torch.nn.utils.prune as prune

from torch.utils.data import random_split, DataLoader

from torchvision import datasets, transforms, models

from tqdm.notebook import tqdm

import copy
import types

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [None]:
# Choosing the device to work with
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
# control random number generation for repeatability
torch.manual_seed(42)
random.seed(42)

## **Configuration**

In [None]:
# Set variable to activate or deactivate part of the code
do_mnist_data_analysis = 0; #(0)1: (De)activate code to show analysis done on MNIST data
do_cifar_data_analysis = 0; #(0)1: (De)activate code to show analysis done on CIFAR data
run_LeNet_mnist = 1;        #(0)1: (De)activate code to execute LeNet on MNIST data
run_ResNet_cifar = 0;       #(0)1: (De)activate code to execute ResNet on CIFAR data

# Hyperparameters for training on MNIST
epochs_mnist = 30  # Number of training epochs
lr_mnist = 0.01    # Learning rate for the optimizer

# Hyperparameters for training on CIFAR
epochs_cifar = 150         # Number of training epochs
lr_cifar = 0.1              # Learning rate for the optimizer

# Hyperparameter for SNIP
sparsities_mnist = [0, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 0.995, 0.999] # Sparsities values for LeNet5 architecture employed for mnist
sparsities_cifar = [0, 0.25, 0.5, 0.75] # sparsities values for ResNet18 architecture employed for cifar-10

### Values indicated in the assignment:


1. Hyperparameters for training on MNIST
   * epochs_mnist = 30  
   * lr_mnist = 0.01    
   * sparsities_mnist = $[0, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 0.995, 0.999]$


2.   Hyperparameters for training on CIFAR
    * epochs_cifar = 150   
    * lr_cifar = 0.1          
    * sparsities_cifar = $[0, 0.25, 0.5, 0.75]$



# **Datasets**

## MNIST

In [None]:
if run_LeNet_mnist == 1:
  # Load MNIST dataset
  transform = transforms.ToTensor()
  trainval_mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  test_mnist = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

### Dataset Analysis

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # Dataset size
  print(f"Training + Validation Set Size: {len(trainval_mnist)} images")
  print(f"Test Set Size: {len(test_mnist)} images")

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # check if there are nan in the dataset
  print(trainval_mnist.data.isnan().any())
  print(test_mnist.data.isnan().any())

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # create a visualization of a MNIST plot with pixels values
  image, label = trainval_mnist[1] # extract both image and label

  # Plot the image with pixel values
  plt.figure(figsize=(10, 10))
  plt.imshow(image.squeeze(), cmap='gray') # use image.squeeze() to remove unnecessary dimensions

  # Loop through each pixel and annotate its value
  for i in range(image.shape[1]): # image.shape[1] represents the height (28 pixels)
      for j in range(image.shape[2]): # image.shape[2] represents the width (28 pixels)
          # Convert the tensor element to a Python float before formatting
          pixel_value = image[0, i, j].item()*255  # Access the pixel value at (0, i, j) for a single-channel image
          plt.text(j, i, f'{pixel_value:.0f}', ha='center', va='center', color='red', fontsize=8)

  # Remove the axis
  plt.axis('off')
  plt.show()

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # Studying resolutions
  trainval_resolutions = [img.shape[1:] for img, _ in trainval_mnist]
  test_resolutions = [img.shape[1:] for img, _ in test_mnist]
  resolutions = trainval_resolutions + test_resolutions
  unique_resolutions = set(resolutions)
  print("Unique resolutions:", unique_resolutions)

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # Define figure size and grid layout
  fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(8, 8))

  # Create a dictionary to store 10 images per class
  class_images = {cls: [] for cls in range(10)}  # 10 labels in MNIST

  # Loop through the dataset to collect 10 images per class
  for img, label in test_mnist:
      if len(class_images[label]) < 10:  # Check if we have collected 10 images for this class
          class_images[label].append(img)

      # Stop once we've collected 10 images for each class
      if all(len(images) == 10 for images in class_images.values()):
          break

  # Loop through the axes and plot images for each class
  for i, (cls, images) in enumerate(class_images.items()):
      # Display the class name once at the top of each column
      ax = axes[0, i]  # Top row, each column corresponds to a class
      ax.set_title("class: "+ str(cls), fontsize=12)  # Set the class name as title
      ax.axis('off')  # Hide axis for the title cell

      # Loop through the images for this class and plot them in the column
      for j, img in enumerate(images):
          ax = axes[j, i]  # Set the row for the images (starting from row 1)
          npimg = img.numpy()  # Convert to numpy
          ax.imshow(npimg.squeeze(), cmap="gray")  # Display image with the correct channel order
          ax.axis('off')  # Hide axis for better visualization

  # Adjust layout for better spacing
  plt.tight_layout()
  plt.show()

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # plot pixel vs pixel frequency for MNIST trainval dataset
  all_pixels = []

  for image,label in trainval_mnist:
    image = image.numpy().flatten()
    all_pixels.extend(image)

  all_pixels = np.array(all_pixels)
  all_pixels = (all_pixels*255).astype(int) #convert pixels value to the range [0 255]
  pixel_values, pixel_counts = np.unique(all_pixels, return_counts=True)

  #create figure
  from matplotlib.ticker import ScalarFormatter
  fig, axs = plt.subplots(1, 2, figsize=(20, 6))
  axs[0].bar(pixel_values, pixel_counts, color='royalblue', width = 3)
  axs[0].set_xlabel('Pixel Values', fontsize=14, fontweight='bold')
  axs[0].set_ylabel('Frequency', fontsize=14, fontweight='bold')
  axs[0].set_title('Frequency of Pixel Values in Train+Val MNIST Dataset', fontsize=16, fontweight='bold')
  axs[0].grid(True)
  axs[0].yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
  axs[0].ticklabel_format(style='plain', axis='y')

  axs[1].bar(pixel_values, pixel_counts, color='royalblue', width = 3)
  axs[1].set_xlabel('Pixel Values', fontsize=14, fontweight='bold')
  axs[1].set_ylabel('Frequency', fontsize=14, fontweight='bold')
  axs[1].set_title('Frequency of Pixel Values in Train+Val MNIST Dataset (Zoomed)', fontsize=16, fontweight='bold')
  axs[1].grid(True)
  axs[1].set_ylim(0,2000000)
  axs[1].yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
  axs[1].ticklabel_format(style='plain', axis='y')


In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # create graph with average pixel distribution for each sample in each class
  from collections import defaultdict

  # Create a dictionary to store the average pixel values for each class
  class_avg_pixel_values = defaultdict(list)

  # Iterate over the dataset
  for image, label in trainval_mnist:
      # Calculate the average pixel value for the current image
      avg_pixel_value = image.numpy().mean()

      # Append the average pixel value to the corresponding class list
      class_avg_pixel_values[label].append(avg_pixel_value)

  # Convert the dictionary values to numpy arrays for easier manipulation
  for class_label in class_avg_pixel_values:
      class_avg_pixel_values[class_label] = np.array(class_avg_pixel_values[class_label])


  # Plot the average pixel values for each class as histograms
  fig, axs = plt.subplots(2, 5, figsize=(20, 10))

  # Iterate through each class and plot the average pixel values as a histogram
  for class_label, avg_values in class_avg_pixel_values.items():
      row = class_label // 5
      col = class_label % 5
      axs[row, col].hist(avg_values*255, bins=30, color='royalblue', alpha=0.7, width=1)
      axs[row, col].set_title(f'Class {class_label}')
      axs[row, col].set_xlabel('Average Pixel Value')
      axs[row, col].set_ylabel('Frequency')
      axs[row, col].grid(True)

  plt.tight_layout()
  plt.show()



In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  def select_images_from_class(dataset, desired_class, num_images):
      # Filter images by the desired class
      class_images = [(idx, image, label) for idx, (image, label) in enumerate(dataset) if label == desired_class]

      # Select the specified number of images from the desired class
      new_var = random.sample
      selected_images = new_var(class_images, num_images)
      return selected_images


In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  desired_class = 4
  selected_images = select_images_from_class(trainval_mnist, desired_class, num_images=10)

  # Plot the selected images
  fig, axs = plt.subplots(2, 5, figsize=(10, 8))

  for i, (image_id, image, label) in enumerate(selected_images):
      row = i // 5
      col = i % 5
      axs[row, col].imshow(image.squeeze(), cmap='gray')
      axs[row, col].set_title(f'ID: {image_id}')
      axs[row, col].axis('off')

  #plt.suptitle(f'10 Images from Class {desired_class}', fontsize=16, fontweight='bold')
  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.show()

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # Before computing class frequency in each dataset, split the in trainig | validation | test set
  val_ratio_mnist = 0.1
  val_size_mnist = int(val_ratio_mnist*len(trainval_mnist))
  train_mnist, val_mnist = data.random_split(trainval_mnist, [len(trainval_mnist) - val_size_mnist,
                                                      val_size_mnist])
  print(f"Training Set Size: {len(train_mnist)} images")
  print(f"Validation Set Size: {len(val_mnist)} images")
  print(f"Test Set Size: {len(test_mnist)} images")

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # compute the class frequency for the 3 dataset
  # Function to compute class frequency
  def compute_class_frequencies(dataset):
      labels = np.array([label for _, label in dataset])  # Extract labels
      label_counts = np.zeros(10, dtype=int)  # Array to store class counts
      for label in labels:
          label_counts[label] += 1  # Count occurrences
      return label_counts

  #class frequency
  train_counts = compute_class_frequencies(train_mnist)
  val_counts =  compute_class_frequencies(val_mnist)
  test_counts = compute_class_frequencies(test_mnist)

  # Create a 1-row, 3-column subplot
  fig, axes = plt.subplots(1, 3, figsize=(14, 6))

  # Plot histogram for Train  Dataset
  bars_train = axes[0].bar(range(10), train_counts, tick_label=range(10), color='royalblue')
  axes[0].set_xlabel('Class')
  axes[0].set_ylabel('Frequency')
  axes[0].set_title('Class Frequency in MNIST Train Dataset')
  axes[0].set_xticklabels(range(10), rotation=0)

  # Add values on top of each bar for  Validation Dataset
  for bar in bars_train:
      height = bar.get_height()
      axes[0].text(bar.get_x() + bar.get_width() / 2, height, f'{height}', ha='center', va='bottom')

  # Plot histogram for Validation Dataset
  bars_val = axes[1].bar(range(10), val_counts, tick_label=range(10), color='orange')
  axes[1].set_xlabel('Class')
  axes[1].set_ylabel('Frequency')
  axes[1].set_title('Class Frequency in MNIST Validation Dataset')
  axes[1].set_xticklabels(range(10), rotation=0)

  # Add values on top of each bar for Test Dataset
  for bar in bars_val:
      height = bar.get_height()
      axes[1].text(bar.get_x() + bar.get_width() / 2, height, f'{height}', ha='center', va='bottom')

  # Plot histogram for Test Dataset
  bars_test = axes[2].bar(range(10), test_counts, tick_label=range(10), color='seagreen')
  axes[2].set_xlabel('Class')
  axes[2].set_ylabel('Frequency')
  axes[2].set_title('Class Frequency in MNIST Test Dataset')
  axes[2].set_xticklabels(range(10), rotation=0)

  # Add values on top of each bar for Test Dataset
  for bar in bars_test:
      height = bar.get_height()
      axes[2].text(bar.get_x() + bar.get_width() / 2, height, f'{height}', ha='center', va='bottom')

  plt.tight_layout()
  plt.show()

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  #compute the ratio between each class in the Train, Val and Test dataset vs the highest frequency class
  #find the highest frequency
  highest_frequency_train = max(train_counts)
  #coumpute the ratio for each class
  ratios_train = train_counts / highest_frequency_train

  #find the highest frequency
  highest_frequency_val = max(val_counts)
  #coumpute the ratio for each class
  ratios_val = val_counts / highest_frequency_val

  #find the highest frequency
  highest_frequency_test = max(test_counts)
  #coumpute the ratio for each class
  ratios_test = test_counts / highest_frequency_test

  ratios_tot = {
      'Class': range(10),
      'train': ratios_train,
      'val': ratios_val,
      'test': ratios_test
  }

  df = pd.DataFrame(ratios_tot)
  print(df)


In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # make a plot to compare the three distributions
  plt.figure(figsize=(10,6))
  plt.plot(ratios_train, label='Train', linestyle='--', marker='o')
  plt.plot(ratios_val, label='Val', linestyle='--', marker='o')
  plt.plot(ratios_test, label='Test', linestyle='--', marker='o')
  plt.xlabel('Class')
  plt.ylabel('Ratio vs higher frequency class in each dataset')
  plt.title('Class ratio in Train | Val | Test Datasets')
  plt.xticks(range(len(ratios_train)))
  plt.legend()
  plt.grid(True)

  # Add values on top of each point for Train + Val dataset
  for i, value in enumerate(ratios_train):
      plt.text(i, value, f'{value:.2f}', ha='center', va='bottom')

  # Add values on top of each point for Train + Val dataset
  for i, value in enumerate(ratios_val):
      plt.text(i, value, f'{value:.2f}', ha='center', va='bottom')

  # Add values on top of each point for Test dataset
  for i, value in enumerate(ratios_test):
      plt.text(i, value, f'{value:.2f}', ha='center', va='bottom')
  plt.show()

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # do the ratio vs the total number of samples for each class in Train,Val and test Dataset
  total_samples_train = np.sum(train_counts)
  ratios_train_total = train_counts / total_samples_train

  total_samples_val = np.sum(val_counts)
  ratios_val_total = val_counts / total_samples_val

  total_samples_test = np.sum(test_counts)
  ratios_test_total = test_counts / total_samples_test

  ratios_total = {
      'Class': range(10),
      'train': ratios_train_total,
      'val': ratios_val_total,
      'test': ratios_test_total
  }
  df = pd.DataFrame(ratios_total)
  print(df)
  #check print total_samples_train
  print(f'Total samples in Train Dataset: {total_samples_train}')
  #check print total_samples_train
  print(f'Total samples in Validation Dataset: {total_samples_val}')
  #check print total_samples_test
  print(f'Total samples in Test Dataset: {total_samples_test}')

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # make a plot to compare ratios_test_total vs ratios_train_total ans ratios_val_total
  plt.figure(figsize=(10,6))
  plt.plot(ratios_train_total, label='Train set', linestyle='--', marker='o')
  plt.plot(ratios_val_total, label='Validation set', linestyle='--', marker='o')
  plt.plot(ratios_test_total, label='Test set', linestyle='--', marker='o')
  plt.xlabel('Class')
  plt.ylabel('Ratio vs total #samples in each dataset')
  plt.title('Class ratio in Train | Val | Test Datasets')
  plt.xticks(range(len(ratios_train_total)))
  plt.legend()
  plt.grid(True)

  # Add values on top of each point for Train + Val dataset
  for i, value in enumerate(ratios_train_total):
    plt.text(i, value, f'{value:.3f}', ha='center', va='bottom')

  # Add values on top of each point for Train + Val dataset
  for i, value in enumerate(ratios_val_total):
    plt.text(i, value, f'{value:.3f}', ha='center', va='bottom')

  # Add values on top of each point for Test dataset
  for i, value in enumerate(ratios_test_total):
    plt.text(i, value, f'{value:.3f}', ha='center', va='bottom')

  plt.show()

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # compute mean and std deviation for the solely training set
  imgs_mnist = torch.stack([img for img, _ in train_mnist])
  mean_mnist = imgs_mnist.mean(dim=(0, 2, 3))
  std_mnist = imgs_mnist.std(dim=(0, 2, 3))
  print(f"Mean Train: {mean_mnist}")
  print(f"Std dev Train: {std_mnist}")

  imgs_mnist = torch.stack([img for img, _ in trainval_mnist])
  mean_mnist = imgs_mnist.mean(dim=(0, 2, 3))
  std_mnist = imgs_mnist.std(dim=(0, 2, 3))
  print(f"Mean Train+Val: {mean_mnist}")
  print(f"Std dev Train+Val: {std_mnist}")

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # Function to compute class frequency
  def compute_class_frequencies(dataset):
      labels = np.array([label for _, label in dataset])  # Extract labels
      label_counts = np.zeros(10, dtype=int)  # Array to store class counts
      for label in labels:
          label_counts[label] += 1  # Count occurrences
      return label_counts

  # Compute class frequencies for both datasets
  trainval_counts = compute_class_frequencies(trainval_mnist)
  test_counts = compute_class_frequencies(test_mnist)

  # Create a 1-row, 2-column subplot
  fig, axes = plt.subplots(1, 2, figsize=(14, 6))

  # Plot histogram for Train + Validation Dataset
  bars_trainval = axes[0].bar(range(10), trainval_counts, tick_label=range(10), color='royalblue')
  axes[0].set_xlabel('Class')
  axes[0].set_ylabel('Frequency')
  axes[0].set_title('Class Frequency in MNIST Train+Val Dataset')
  axes[0].set_xticklabels(range(10), rotation=0)

  # Add values on top of each bar for Train + Validation Dataset
  for bar in bars_trainval:
      height = bar.get_height()
      axes[0].text(bar.get_x() + bar.get_width() / 2, height, f'{height}', ha='center', va='bottom')

  # Plot histogram for Test Dataset
  bars_test = axes[1].bar(range(10), test_counts, tick_label=range(10), color='seagreen')
  axes[1].set_xlabel('Class')
  axes[1].set_ylabel('Frequency')
  axes[1].set_title('Class Frequency in MNIST Test Dataset')
  axes[1].set_xticklabels(range(10), rotation=0)

  # Add values on top of each bar for Test Dataset
  for bar in bars_test:
      height = bar.get_height()
      axes[1].text(bar.get_x() + bar.get_width() / 2, height, f'{height}', ha='center', va='bottom')

  plt.tight_layout()
  plt.show()

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  #compute the ratio between each class in the Train + Val dataset vs the highest frequency class
  #find the highest frequency
  highest_frequency_trainval = max(trainval_counts)
  #coumpute the ratio for each class
  ratios_trainval = trainval_counts / highest_frequency_trainval
  # Display the ratios
  for i, ratio in enumerate(ratios_trainval):
      print(f'Class {i}: {ratio:.2f}')

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  #compute the ratio between each class in the Test dataset vs the highest frequency class
  #find the highest frequency
  highest_frequency_test = max(test_counts)
  #coumpute the ratio for each class
  ratios_test = test_counts / highest_frequency_test
  # Display the ratios
  for i, ratio in enumerate(ratios_test):
      print(f'Class {i}: {ratio:.2f}')

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # make a plot to compare the two distributions
  plt.figure(figsize=(10,6))
  plt.plot(ratios_trainval, label='Train + Val', linestyle='--', marker='o')
  plt.plot(ratios_test, label='Test', linestyle='--', marker='o', color = 'seagreen')
  plt.xlabel('Class')
  plt.ylabel('Ratio')
  plt.title('Class ratio in Train+Val vs Test Datasets')
  plt.xticks(range(len(ratios_trainval)))
  plt.legend()
  plt.grid(True)
  # Add values on top of each point for Train + Val dataset
  for i, value in enumerate(ratios_trainval):
      plt.text(i, value, f'{value:.2f}', ha='center', va='bottom')

  # Add values on top of each point for Test dataset
  for i, value in enumerate(ratios_test):
      plt.text(i, value, f'{value:.2f}', ha='center', va='bottom')
  plt.show()

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # do the ratio vs the total number of samples for each class in Train+Val Dataset
  total_samples_trainval = np.sum(trainval_counts)
  ratios_trainval_total = trainval_counts / total_samples_trainval
  #display the ratios
  for i, ratio in enumerate(ratios_trainval_total):
      print(f'Class {i}: {ratio:.3f}')

  #check print total_samples_trainval
  print(f'Total samples in Train+Val Dataset: {total_samples_trainval}')

In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # do the ratio vs the total number of samples for each class in Test Dataset
  total_samples_test = np.sum(test_counts)
  ratios_test_total = test_counts / total_samples_test
  #display the ratios
  for i, ratio in enumerate(ratios_test_total):
      print(f'Class {i}: {ratio:.3f}')

  #check print total_samples_test
  print(f'Total samples in Test Dataset: {total_samples_test}')


In [None]:
if run_LeNet_mnist == 1 and do_mnist_data_analysis == 1:
  # make a plot to compare ratios_test_total vs ratios_trainval_total
  plt.figure(figsize=(10,6))
  plt.plot(ratios_trainval_total, label='Train + Val', linestyle='--', marker='o')
  plt.plot(ratios_test_total, label='Test', linestyle='--', marker='o', color = 'seagreen')
  plt.xlabel('Class')
  plt.ylabel('Ratio')
  plt.title('Class ratio in Train+Val vs Test Datasets')
  plt.xticks(range(len(ratios_trainval_total)))
  plt.legend()
  plt.grid(True)

  # Add values on top of each point for Train + Val dataset
  for i, value in enumerate(ratios_trainval_total):
    plt.text(i, value, f'{value:.3f}', ha='center', va='bottom')

  # Add values on top of each point for Test dataset
  for i, value in enumerate(ratios_test_total):
    plt.text(i, value, f'{value:.3f}', ha='center', va='bottom')

  plt.show()

### Data Loader

In [None]:
if run_LeNet_mnist == 1:
  imgs_mnist = torch.stack([img for img, _ in trainval_mnist])
  mean_mnist = imgs_mnist.mean(dim=(0, 2, 3))
  std_mnist = imgs_mnist.std(dim=(0, 2, 3))
  print(f"Mean: {mean_mnist}")
  print(f"Std: {std_mnist}")

In [None]:
if run_LeNet_mnist == 1:
  transform_mnist = transforms.Compose([
          transforms.ToTensor(),                                 #converts PIL images into PyTorch tensor and normalizes from [0,255] to [0.0,1.0]
          transforms.Normalize(
          mean = mean_mnist,                                     #mean values for MNIST dataset
          std = std_mnist                                        #standard deviation values for MNIST dataset
          )                                                      #normalization; mean = 0.1307 | std.dev = 0.3081
          ])

  trainval_mnist = datasets.MNIST('../data', train=True, download=True,  # trainval = train + validation set
                        transform=transform_mnist)
  test_mnist = datasets.MNIST('../data', train=False, download=True,  #according to documentaton download = True does not download test set if it is already download.
                                                                      #[https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html]
                        transform=transform_mnist)




  val_ratio_mnist = 0.1                                                                   #portion of training set used as validation set [10%]
  val_size_mnist = int(val_ratio_mnist*len(trainval_mnist))                               #calculation of validation set dimension
  train_mnist, val_mnist = random_split(trainval_mnist, [len(trainval_mnist) - val_size_mnist,
                                                      val_size_mnist])                   #divide the training set (split_minst) in training set and validation set

  #https://pytorch.org/docs/stable/notes/randomness.html#cuda-convolution-benchmarking
  #DataLoader will reseed workers following Randomness in multi-process data loading algorithm.
  #Use worker_init_fn() and generator to preserve reproducibility:

  num_workers = 0 #default [https://pytorch.org/docs/stable/data.html#data-loading-randomness]
  def seed_worker(worker_id):
      worker_seed = torch.initial_seed() % 2**32
      numpy.random.seed(worker_seed)
      random.seed(worker_seed)

  g = torch.Generator()
  g.manual_seed(0)

  train_mnist_loader = data.DataLoader(train_mnist, batch_size = 100, num_workers=num_workers,worker_init_fn=seed_worker, generator=g) #minibatch for training
  val_mnist_loader = data.DataLoader(val_mnist, batch_size = 1000,  num_workers=num_workers,worker_init_fn=seed_worker, generator=g)   #minibatch for validation
  test_mnist_loader = data.DataLoader(test_mnist, batch_size = 1000,  num_workers=num_workers,worker_init_fn=seed_worker, generator=g) #minibatch for testing

  #train_mnist_loader = data.DataLoader(train_mnist, batch_size = 100, shuffle = True)    #minibatch for training
  #val_mnist_loader = data.DataLoader(val_mnist, batch_size = 1000, shuffle = True)       #minibatch for validation
  #test_mnist_loader = data.DataLoader(test_mnist, batch_size = 1000, shuffle = False)    #minibatch for testing


## CIFAR10

In [None]:
if run_ResNet_cifar == 1:
  # Load CIFAR-10 dataset
  transform = transforms.ToTensor()
  trainval_cifar = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
  test_cifar = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

  # Define the classes in the CIFAR-10 dataset
  classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

### Dataset Analysis

In [None]:
if run_ResNet_cifar == 1 and do_cifar_data_analysis == 1:
  # Dataset size
  print(f"Training + Validation Set Size: {len(trainval_cifar)} images")
  print(f"Test Set Size: {len(test_cifar)} images")

In [None]:
if run_ResNet_cifar == 1 and do_cifar_data_analysis == 1:
  # check if there are NaNs in the dataset
  print(np.isnan(trainval_cifar.data).any())
  print(np.isnan(test_cifar.data).any())

In [None]:
if run_ResNet_cifar == 1 and do_cifar_data_analysis == 1:
  # Studying resolutions
  trainval_resolutions = [img.shape[1:] for img, _ in trainval_cifar]+[img.shape[1:] for img, _ in test_cifar]
  test_resolutions = [img.shape[1:] for img, _ in test_cifar]
  resolutions = trainval_resolutions + test_resolutions
  unique_resolutions = set(resolutions)
  print("Unique resolutions:", unique_resolutions)

In [None]:
if run_ResNet_cifar == 1 and do_cifar_data_analysis == 1:
  # Define figure size and grid layout
  fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(12, 12))

  # Create a dictionary to store 10 images per class
  class_images = {cls: [] for cls in range(10)}  # 10 classes in CIFAR-10

  # Loop through the dataset to collect 10 images per class
  for img, label in test_cifar:
      if len(class_images[label]) < 10:  # Check if we have collected 10 images for this class
          class_images[label].append(img)

      # Stop once we've collected 10 images for each class
      if all(len(images) == 10 for images in class_images.values()):
          break

  # Loop through the axes and plot images for each class
  for i, (cls, images) in enumerate(class_images.items()):
      # Display the class name once at the top of each column
      ax = axes[0, i]  # Top row, each column corresponds to a class
      ax.set_title(classes[cls], fontsize=12)  # Set the class name as title
      ax.axis('off')  # Hide axis for the title cell

      # Loop through the images for this class and plot them in the column
      for j, img in enumerate(images):
          ax = axes[j, i]  # Set the row for the images (starting from row 1)
          npimg = img.numpy()  # Convert to numpy
          ax.imshow(np.transpose(npimg, (1, 2, 0)))  # Display image with the correct channel order
          ax.axis('off')  # Hide axis for better visualization

  # Adjust layout for better spacing
  plt.tight_layout()
  plt.show()


In [None]:
if run_ResNet_cifar == 1 and do_cifar_data_analysis == 1:
  all_pixels_trainval = {"R": [], "G": [], "B": []}

  for image, label in trainval_cifar:
      image = image.numpy() * 255
      image = image.astype(int)
      all_pixels_trainval["R"].extend(image[0].flatten())
      all_pixels_trainval["G"].extend(image[1].flatten())
      all_pixels_trainval["B"].extend(image[2].flatten())

  fig, axs = plt.subplots(1, 3, figsize=(20,6))
  colors = {"R": "red", "G": "green", "B": "blue"}

  for i, (channel, color) in enumerate(colors.items()):
      _, pixel_counts_trainval = np.unique(np.array(all_pixels_trainval[channel]),
                                                              return_counts=True)
      axs[i].bar(np.arange(256), pixel_counts_trainval, color=color, width=3)
      axs[i].set_xlabel("Pixel Values", fontsize=14, fontweight="bold")
      axs[i].set_ylabel("Frequency", fontsize=14, fontweight="bold")
      axs[i].set_title(f"{channel}-channel Pixel Values in Train+Val Set", fontsize=16, fontweight="bold")
      axs[i].grid(True)
      axs[i].ticklabel_format(style="plain", axis="y")

  plt.tight_layout()
  plt.show()

In [None]:
if run_ResNet_cifar == 1 and do_cifar_data_analysis == 1:
  from collections import defaultdict

  # Dizionario per memorizzare i valori medi dei pixel per ogni classe e canale
  class_avg_pixel_values = {"R": defaultdict(list), "G": defaultdict(list), "B": defaultdict(list)}

  # Itera sul dataset di addestramento
  for image, label in trainval_cifar:
      image = image.numpy() * 255 # Converti il tensore in array numpy
      # Calcola la media dei pixel per ogni canale
      class_avg_pixel_values["R"][label].append(image[0].mean())
      class_avg_pixel_values["G"][label].append(image[1].mean())
      class_avg_pixel_values["B"][label].append(image[2].mean())

  # Converti le liste in array numpy
  for channel in class_avg_pixel_values:
      for class_label in class_avg_pixel_values[channel]:
          class_avg_pixel_values[channel][class_label] = np.array(class_avg_pixel_values[channel][class_label])

  # Creazione della figura in formato verticale
  fig, axs = plt.subplots(10, 3, figsize=(12, 25), sharex=True, sharey=True)
  colors = {"R": "red", "G": "green", "B": "blue"}

  # Itera su ogni classe e canale per creare gli istogrammi
  for class_label in range(10):
      for i, (channel, color) in enumerate(colors.items()):
          avg_values = class_avg_pixel_values[channel][class_label]
          axs[class_label, i].hist(avg_values, bins=50, color=color)
          axs[class_label, i].set_title(f'Class {classes[class_label]} - {channel} channel', fontsize=10)
          axs[class_label, i].set_xlabel('Avg Pixel Value')
          axs[class_label, i].set_ylabel('Frequency')
          axs[class_label, i].grid(True)

  plt.tight_layout()
  plt.show()


In [None]:
if run_ResNet_cifar == 1 and do_cifar_data_analysis == 1:
  # Function to compute class frequency
  def compute_class_frequencies(dataset):
      labels = np.array([label for _, label in dataset])  # Extract labels
      label_counts = np.zeros(10, dtype=int)  # Array to store class counts
      for label in labels:
          label_counts[label] += 1  # Count occurrences
      return label_counts

  # Compute class frequencies for both datasets
  trainval_counts = compute_class_frequencies(trainval_cifar)
  test_counts = compute_class_frequencies(test_cifar)

  # Create a 1-row, 2-column subplot
  fig, axes = plt.subplots(1, 2, figsize=(14, 6))

  # Plot histogram for Train + Validation Dataset
  bars_trainval = axes[0].bar(range(10), trainval_counts, tick_label=classes, color='royalblue')
  axes[0].set_xlabel('Class')
  axes[0].set_ylabel('Frequency')
  axes[0].set_title('Class Frequency in CIFAR-10 Training Data')
  axes[0].set_xticklabels(classes, rotation=45)

  # Add values on top of each bar for Test Dataset
  for bar in bars_trainval:
      height = bar.get_height()
      axes[0].text(bar.get_x() + bar.get_width() / 2, height, f'{height}', ha='center', va='bottom')

  # Plot histogram for Test Dataset
  bars_test = axes[1].bar(range(10), test_counts, tick_label=classes, color='seagreen')
  axes[1].set_xlabel('Class')
  axes[1].set_ylabel('Frequency')
  axes[1].set_title('Class Frequency in CIFAR-10 Test Data')
  axes[1].set_xticklabels(classes, rotation=45)

  # Add values on top of each bar for Test Dataset
  for bar in bars_test:
      height = bar.get_height()
      axes[1].text(bar.get_x() + bar.get_width() / 2, height, f'{height}', ha='center', va='bottom')

  # Adjust layout for better spacing
  plt.tight_layout()
  plt.show()

### Data Loader

In [None]:
if run_ResNet_cifar == 1:
  # Compute mean and std of dataset
  imgs_cifar = torch.stack([img for img, _ in trainval_cifar])
  mean_cifar = imgs_cifar.mean(dim=(0, 2, 3))
  std_cifar = imgs_cifar.std(dim=(0, 2, 3))
  print(f"Mean: {mean_cifar}")
  print(f"Std: {std_cifar}")

In [None]:
if run_ResNet_cifar == 1:
  # Define data transformations
  transform_cifar = transforms.Compose(
      [
        # Convert the image to a PyTorch tensor
        transforms.ToTensor(),

        # Normalize the tensor using the CIFAR-10 mean and std deviation
        transforms.Normalize(
          mean = mean_cifar,  # mean values for CIFAR-10 dataset
          std = std_cifar     # standard deviation values for CIFAR-10 dataset
        )
      ]
  )

  # Load the CIFAR-10 training dataset with the applied transformations
  trainval_cifar = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_cifar)

  # Load the CIFAR-10 test dataset with the applied transformations
  test_cifar = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cifar)

  # Define a validation set ratio (e.g., 10% of the training data)
  val_ratio_cifar = 0.1

  # Calculate the number of samples in the validation set
  val_size_cifar = int(val_ratio_cifar * len(trainval_cifar))

  # Split the training data into training and validation sets
  train_cifar, val_cifar = data.random_split(trainval_cifar,
                              [len(trainval_cifar) - val_size_cifar, val_size_cifar])

In [None]:
if run_ResNet_cifar == 1:
  # Define data transformations for the training and validation sets
  augment_cifar = transforms.Compose(
      [
        # Randomly crop the image to 32x32 and add 4 pixels of padding
        transforms.RandomCrop(32, padding = 4),

        # Randomly flip the image horizontally for data augmentation
        # Data augmentation helps improve generalization by introducing variations in the data
        transforms.RandomHorizontalFlip(),

        # Convert the image to a PyTorch tensor
        transforms.ToTensor(),

        # Normalize the tensor using the CIFAR-10 mean and std deviation
        transforms.Normalize(
          mean = mean_cifar,  # mean values for CIFAR-10 dataset
          std = std_cifar     # standard deviation values for CIFAR-10 dataset
        )
      ]
  )

  train_cifar.dataset = copy.copy(trainval_cifar)
  train_cifar.dataset.transform = augment_cifar

In [None]:
if run_ResNet_cifar == 1:
  # Dataset size
  print(f"Training Set Size: {len(train_cifar)} images")
  print(f"Validation Set Size: {len(val_cifar)} images")
  print(f"Test Set Size: {len(test_cifar)} images")

In [None]:
if run_ResNet_cifar == 1:
  # Create a DataLoader for the training data with batch size of 128 and shuffling enabled
  train_cifar_loader = data.DataLoader(train_cifar, batch_size=128, shuffle=True)

  # Create a DataLoader for the validation data with batch size of 1024 and shuffling enabled
  val_cifar_loader = data.DataLoader(val_cifar, batch_size=1024, shuffle=True)

  # Create a DataLoader for the test data with batch size of 1024 and shuffling disabled
  test_cifar_loader = data.DataLoader(test_cifar, batch_size=1024, shuffle=False)

In [None]:
if run_ResNet_cifar == 1:
  # Define figure size and layout (1 row, 10 columns)
  fig, axes = plt.subplots(nrows=1, ncols=10, figsize=(12, 4))

  # Iterate over the first 10 images from train_cifar
  for i, (img, label) in enumerate(train_cifar):
      if i == 10:  # Stop after 10 images
          break
      img = img = img * std_cifar[:, None, None] + mean_cifar[:, None, None]    # Unnormalize the image
      axes[i].imshow(img.permute(1, 2, 0))  # Convert (C, H, W) to (H, W, C) for plotting
      axes[i].axis("off")  # Hide axis for better visualization

  # Adjust layout
  plt.tight_layout()
  plt.show()


# **Architectures**

### LeNet-5

In [None]:
if run_LeNet_mnist == 1:
  # See https://github.com/pytorch/examples/blob/main/mnist/main.py

  class LeNet5(nn.Module):

      def __init__(self):                       #initialization for the architecture LeNet5
          super().__init__()
          self.conv1 = nn.Conv2d(1, 32, 3, 1)   #First convolutional layer with 1 input channel, 32 output channels, 3x3 kernel size, and stride of 1.
          self.conv2 = nn.Conv2d(32, 64, 3, 1)  #Second convolutional layer with 32 input channels, 64 output channels, 3x3 kernel size, and stride of 1.
          self.dropout1 = nn.Dropout(0.25)      #First dropout layer with a dropout probability of 25% to reduce overfitting.
          self.dropout2 = nn.Dropout(0.5)       #Second dropout layer with a dropout probability of 50%.
          self.fc1 = nn.Linear(9216, 128)       #First fully connected layer with 9216 input neurons and 128 output neurons.
          self.fc2 = nn.Linear(128, 10)         #Second fully connected layer with 128 input neurons and 10 output neurons (one for each class).

      def forward(self, x):                     #Defines the forward pass of the data through the network.
          x = self.conv1(x)                     #Applies the first convolutional layer.
          x = F.relu(x)                         #Applies ReLU.
          x = self.conv2(x)
          x = F.relu(x)
          x = F.max_pool2d(x, 2)                #Applies max pooling with a 2x2 window, reducing the spatial dimensions.
          x = self.dropout1(x)                  #Applies the first dropout layer.
          x = torch.flatten(x, 1)               #Flattens the tensor while retaining the batch dimension. from multidimension to 2D tensor
          x = self.fc1(x)
          x = F.relu(x)
          x = self.dropout2(x)
          x = self.fc2(x)
          return F.log_softmax(x, dim=1)



### ResNet-18

In [None]:
if run_ResNet_cifar == 1:
  class BasicBlock(nn.Module):

      def __init__(self, in_channels, out_channels, stride=1):
          super().__init__()

          self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)          #1st convolutional layer
          self.bn1 = nn.BatchNorm2d(out_channels)                                                             #1st batch normalization (bn)
          self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)  #2nd convolutional layer
          self.bn2 = nn.BatchNorm2d(out_channels)                                                             #2nd batch normalization (bn)

          #define shortcut connection
          self.shortcut = nn.Sequential()                                                                    #defines an empty sequential container for the shortcut connection
          if stride != 1 or in_channels != out_channels:
              self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),             #from paper https://doi.org/10.48550/arXiv.1512.03385 pag4+5
                nn.BatchNorm2d(out_channels)
                    )


      def forward(self, x):
          out = self.conv1(x)
          out = self.bn1(out)
          out = F.relu(out)
          out = self.conv2(out)
          out = self.bn2(out)
          out += self.shortcut(x)
          out = F.relu(out)
          return out



  #ref. doc [(https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html)]


In [None]:
if run_ResNet_cifar == 1:
  class ResNet18(nn.Module):

      def __init__(self):
          super().__init__()
          self.in_channels = 64
          self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False)
          self.bn1 = nn.BatchNorm2d(64)

          self.layer1 = self._make_layer(BasicBlock, 64,  2, 1)
          self.layer2 = self._make_layer(BasicBlock, 128, 2, 2)
          self.layer3 = self._make_layer(BasicBlock, 256, 2, 2)
          self.layer4 = self._make_layer(BasicBlock, 512, 2, 2)

          self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
          self.fc = nn.Linear(512, 10)


      def _make_layer(self, block, out_channels, num_blocks, stride):
          strides = [stride] + [1] * (num_blocks - 1)
          layers = []
          for stride in strides:
              layers.append(block(self.in_channels, out_channels, stride))
              self.in_channels = out_channels
          return nn.Sequential(*layers)


      def forward(self, x):
          out = self.conv1(x)
          out = self.bn1(out)
          out = F.relu(out)

          out = self.layer1(out)
          out = self.layer2(out)
          out = self.layer3(out)
          out = self.layer4(out)

          out = self.avgpool(out)
          out = torch.flatten(out,1)
          out = self.fc(out)
          return F.log_softmax(out, dim=1)


# **Functions**

### Training and Test Functions

In [None]:
if run_LeNet_mnist == 1 or run_ResNet_cifar == 1:
  class AverageMeter:
      """
      Keeps track of the average, sum, and count of a given metric.
      Useful for monitoring loss and accuracy during training and evaluation.
      """

      def __init__(self, name: str, fmt: str = ":f") -> None:
          """
          Initializes the meter.

          Parameters:
              name (str): Name of the metric being tracked.
              fmt (str, optional): Format string for displaying values (default: ':f').
          """
          self.name = name
          self.fmt = fmt
          self.reset()

      def reset(self) -> None:
          """Resets all tracked values to zero."""
          self.val = 0.0
          self.avg = 0.0
          self.sum = 0.0
          self.count = 0

      def update(self, val: float, n: int = 1) -> None:
          """
          Updates the meter with a new value.

          Parameters:
              val (float): New value to add.
              n (int, optional): Number of occurrences (default: 1).
          """
          self.val = val
          self.sum += val * n
          self.count += n
          self.avg = self.sum / self.count

In [None]:
if run_LeNet_mnist == 1 or run_ResNet_cifar == 1:
  #This is the train function
  def train(model: nn.Module, train_loader: data.DataLoader, optimizer: torch.optim.Optimizer,
            loss: nn.Module, spars: float = 0, epoch: int = 1, device: str = "cpu") -> tuple[float, float]:
      """
      Trains the model for one epoch.

      Parameters:
          model (nn.Module): The PyTorch model to train.
          train_loader (DataLoader): DataLoader providing training batches.
          optimizer (torch.optim.Optimizer): Optimizer used for updating model parameters.
          loss (nn.Module): Loss function used for training.
          spars (float, optional): Sparsity level of the model (default: 0).
          epoch (int, optional): Current epoch number (default: 1).
          device (str, optional): Device to run training on ("cpu" or "cuda", default: "cpu").

      Returns:
          tuple[float, float]: Final loss and classification error of the epoch.
      """

      model.train()  # Set the model to training mode

      # Initialize tqdm progress bar for visualization
      progress_bar = tqdm(train_loader, total=len(train_loader),
                          desc=f"TRAIN | Epoch {epoch} | Sparsity: {spars}")

      mean_train_loss = AverageMeter("train_loss")
      mean_train_error = AverageMeter("train_error")

      for batch_idx, (data, target) in enumerate(progress_bar):
          # Move data and target to the specified device
          data, target = data.to(device), target.to(device)

          # Forward pass
          output = model(data)
          this_loss = loss(output, target)  # Compute loss

          # Compute classification error percentage
          this_error = torch.mean((output.argmax(dim=1) != target) * 100.0)

          # Update metrics
          mean_train_loss.update(this_loss.item())
          mean_train_error.update(this_error.item())

          # Backpropagation and optimization step
          optimizer.zero_grad()
          this_loss.backward()
          optimizer.step()

          # Update progress bar with loss and error
          progress_bar.set_postfix(loss=mean_train_loss.avg, error=mean_train_error.avg)

      return mean_train_loss.avg, mean_train_error.avg


In [None]:
if run_LeNet_mnist == 1 or run_ResNet_cifar == 1:

  def valid(model: nn.Module, val_loader: data.DataLoader, loss: nn.Module,
            spars: float = 0, epoch: int = 1, device: str = "cpu") -> tuple[float, float]:
      """
      Runs validation on the given model.

      Parameters:
          model (nn.Module): The PyTorch model to validate.
          val_loader (DataLoader): DataLoader for validation data.
          loss (nn.Module): Loss function used for evaluation.
          spars (float, optional): Sparsity level of the model (default: 0).
          epoch (int, optional): Current epoch number (default: 1).
          device (str, optional): Device to run validation on ("cpu" or "cuda", default: "cpu").

      Returns:
          tuple[float, float]: Final average loss and classification error.
      """

      progress_bar = tqdm(val_loader, total=len(val_loader),
                          desc=f"VALIDATION | Epoch {epoch} | Sparsity: {spars}")

      mean_val_loss = AverageMeter("val_loss")
      mean_val_error = AverageMeter("val_error")

      model.eval()  # Set model to evaluation mode
      with torch.no_grad():  # Disable gradient calculation for efficiency
          for batch_idx, (data, target) in enumerate(progress_bar):
              data, target = data.to(device), target.to(device)

              # Forward pass
              output = model(data)
              this_loss = loss(output, target)

              # Compute classification error
              this_error = torch.mean((output.argmax(dim=1) != target) * 100.0)

              # Update metrics
              mean_val_loss.update(this_loss.item())
              mean_val_error.update(this_error.item())

              # Update progress bar
              progress_bar.set_postfix(loss=mean_val_loss.avg, error=mean_val_error.avg)

      return mean_val_loss.avg, mean_val_error.avg



  def test(model: nn.Module, test_loader: data.DataLoader, loss: nn.Module,
          spars: float = 0, device: str = "cpu") -> tuple[float, float]:
      """
      Runs testing on the given model.

      Parameters:
          model (nn.Module): The PyTorch model to test.
          test_loader (DataLoader): DataLoader for test data.
          loss (nn.Module): Loss function used for evaluation.
          spars (float, optional): Sparsity level of the model (default: 0).
          device (str, optional): Device to run testing on ("cpu" or "cuda", default: "cpu").

      Returns:
          tuple[float, float]: Final average loss and classification error.
      """

      progress_bar = tqdm(test_loader, total=len(test_loader),
                          desc=f"TEST | Sparsity: {spars}")

      mean_test_loss = AverageMeter("test_loss")
      mean_test_error = AverageMeter("test_error")

      model.eval()  # Set model to evaluation mode
      with torch.no_grad():  # Disable gradient calculation for efficiency
          for batch_idx, (data, target) in enumerate(progress_bar):
              data, target = data.to(device), target.to(device)

              # Forward pass
              output = model(data)
              this_loss = loss(output, target)

              # Compute classification error
              this_error = torch.mean((output.argmax(dim=1) != target) * 100.0)

              # Update metrics
              mean_test_loss.update(this_loss.item())
              mean_test_error.update(this_error.item())

              # Update progress bar
              progress_bar.set_postfix(loss=mean_test_loss.avg, error=mean_test_error.avg)

      return mean_test_loss.avg, mean_test_error.avg


### Pruning

In [None]:
if run_LeNet_mnist == 1 or run_ResNet_cifar == 1:
  from typing import Iterable, List  # Import Iterable and List type hints


  def snip_forward_conv2d(self: nn.Conv2d, x: torch.Tensor) -> torch.Tensor:
      """Custom forward for Conv2d layers using a weight mask."""
      # Apply padding manually before convolution to match original behavior
      padding = self.padding  # Get original padding values

      # Pad the input tensor
      x = F.pad(x, (padding[1], padding[1], padding[0], padding[0]))

      # Apply convolution with mask This
      return F.conv2d(x, self.weight*self.weight_mask, self.bias, self.stride)


  def snip_forward_linear(self: nn.Linear, x: torch.Tensor) -> torch.Tensor:
      """Custom forward for Linear layers using a weight mask."""
      return F.linear(x, self.weight * self.weight_mask, self.bias)




  def SNIP(model: nn.Module, spars: float, train_loader: data.DataLoader,
          loss: nn.Module, device: str = "cpu") -> List[torch.Tensor]:
      """
      Implements the SNIP pruning algorithm to determine which parameters
      to keep for a given sparsity level.

      This function creates a copy of the model, attaches a mask to each
      prunable layer (Conv2d and Linear), overrides their forward passes,
      and then computes the gradient-based importance scores.

      Parameters:
          model (nn.Module): The model to be pruned.
          sparsity (float): Fraction of parameters to prune (0 <= sparsity < 1).
          train_loader: A DataLoader from which the first mini-batch is taken.
          loss_fn: Loss function to compute the loss.
          device (str): Device to run the model on ("cpu" or "cuda").

      Returns:
          List[torch.Tensor]: A list of binary masks for each prunable layer.
      """
      # Fetch one mini-batch and send to device.
      data, target = next(iter(train_loader))
      data, target = data.to(device), target.to(device)

      # Create a copy of the model and move it to the specified device.
      snip_model = copy.deepcopy(model).to(device)

      # Attach a weight mask to each prunable layer and override its forward method.
      for layer in snip_model.modules():
          if isinstance(layer, (nn.Conv2d, nn.Linear)):
              # Initialize a mask of ones (same shape as weights).
              layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))
              # Freeze the original weights so they are not updated.
              layer.weight.requires_grad = False

              # Override the forward pass with our custom SNIP forward.
              if isinstance(layer, nn.Conv2d):
                  layer.forward = types.MethodType(snip_forward_conv2d, layer)
              elif isinstance(layer, nn.Linear):
                  layer.forward = types.MethodType(snip_forward_linear, layer)

      # Forward pass and compute gradients.
      snip_model.zero_grad()
      output = snip_model(data)
      loss_value = loss(output, target)
      loss_value.backward()

      # Collect the absolute gradients of the weight masks.
      grads_abs = [
          torch.abs(layer.weight_mask.grad)
          for layer in snip_model.modules()
          if isinstance(layer, (nn.Conv2d, nn.Linear))
          ]

      # Concatenate all gradients into a single vector and normalize.
      all_scores = torch.cat([g.flatten() for g in grads_abs])
      total_score = torch.sum(all_scores)
      all_scores_normalized = all_scores / total_score

      # Determine the threshold score to keep (1 - sparsity) fraction of parameters.
      num_params_to_keep = int(len(all_scores_normalized) * (1 - spars))
      threshold_scores, _ = torch.topk(all_scores_normalized, num_params_to_keep, sorted=True)
      acceptable_score = threshold_scores[-1]

      # Create binary masks.
      keep_masks = [
          ((g / total_score) >= acceptable_score).float()
          for g in grads_abs
          ]

      # Optionally print the total number of parameters kept.
      total_kept = sum(mask.sum().item() for mask in keep_masks)
      print(f"Sparsity: {spars}, Total weights kept: {int(total_kept)}")

      return keep_masks

In [None]:
if run_LeNet_mnist == 1 or run_ResNet_cifar == 1:
  def apply_prune_mask(model: nn.Module, keep_masks: Iterable[torch.Tensor]) -> None:
      """
      Applies pruning masks to the model's prunable layers (Conv2d and Linear).
      For each such layer, the weights corresponding to zeros in the keep mask are
      set to zero, and a hook is registered to ensure gradients for pruned weights
      remain zero during backpropagation.

      Parameters:
          model (nn.Module): The PyTorch model.
          keep_masks (iterable of torch.Tensor): One mask per prunable layer.
      """
      # Filter out the layers we want to prune (Conv2d and Linear)
      prunable_layers = [
          layer for layer in model.modules()
          if isinstance(layer, (nn.Conv2d, nn.Linear))
      ]

      for layer, keep_mask in zip(prunable_layers, keep_masks):
          if layer.weight.shape != keep_mask.shape:
              raise ValueError(
                  f"Shape mismatch: {layer.weight.shape} vs {keep_mask.shape}"
              )

          # Zero-out the pruned weights using a no_grad context
          with torch.no_grad():
              layer.weight[keep_mask == 0] = 0.

          # Register a hook to maintain zero gradients for pruned weights.
          # Using a lambda with a default argument ensures early binding.
          layer.weight.register_hook(lambda grad, mask=keep_mask: grad * mask)

In [None]:
if run_LeNet_mnist == 1 or run_ResNet_cifar == 1:
  def plot_conv_filters(tensor, n_rows, RGB=False, title="Filters"):
      tensor = tensor.cpu().detach().numpy()
      num_filters = tensor.shape[0]  # Number of output channels
      if RGB is False:
        # Calculate number of columns needed
        n_cols = num_filters // n_rows
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows))
        axes = np.array(axes).reshape(n_rows, n_cols)  # Ensure correct shape
        for i in range(n_rows * n_cols):
          row, col = divmod(i, n_cols)
          if i < num_filters:
            binary_mask = (tensor[i, 0] != 0).astype(float)  # Convert to binary
            axes[row, col].imshow(binary_mask, vmin=0, vmax=1)  # Fix scaling
            axes[row, col].axis("off")
      if RGB is True:
        cmaps = ["Reds", "Greens", "Blues"]
        n_cols = num_filters // n_rows
        fig, axes = plt.subplots(n_rows*3, n_cols, figsize=(n_cols, n_rows*3))
        axes = np.array(axes).reshape(n_rows*3, n_cols)  # Ensure correct shape
        for channel in range(tensor.shape[1]):
          for i in range(n_rows * n_cols):
            row, col = divmod(i, n_cols)
            if i < num_filters:
              binary_mask = (tensor[i, channel] != 0).astype(float)  # Convert to binary
              axes[row + n_rows*channel, col].imshow(binary_mask, cmap=cmaps[channel], vmin=0, vmax=1)  # Fix scaling
              axes[row + n_rows*channel, col].axis("off")

      plt.suptitle(title, fontsize=12)  # Increase title font size
      plt.show()

In [None]:
if run_LeNet_mnist == 1 or run_ResNet_cifar == 1:
  def plot_linear_weights(tensor, title="Linear Weights", manual_figsize=(14, 2)):
      tensor = tensor.cpu().detach().numpy()  # Convert tensor to numpy
      out_features, in_features = tensor.shape  # (output neurons, input neurons)

      fig, ax = plt.subplots(figsize=manual_figsize)
      binary_mask = (tensor != 0).astype(float)  # Convert to binary mask

      # Use a colormap for a smoother look
      ax.imshow(binary_mask, vmin=0, vmax=1, aspect='auto')  # "coolwarm" colormap
      ax.axis("off")  # Hide the axis

      # Customize the title
      plt.suptitle(title, fontsize=8, ha='center')  # Bold, centered, and colored title

      # Adding a subtle grid for structure (optional)
      ax.grid(False)  # You can remove or adjust grid if you prefer

      # Tighten the layout and show the plot
      plt.tight_layout()
      plt.show()

# **MAIN with MNIST**

### Pruning

In [None]:
if run_LeNet_mnist == 1:
  # Initialize the LeNet-5 model and move it to the selected device (CPU/GPU)
  lenet = LeNet5().to(device)

  # Define different sparsity levels for MNIST dataset experiments
  # Sparsities_mnist = [0, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 0.995, 0.999]
  # Use negative log-likelihood loss
  loss = F.nll_loss

In [None]:
if run_LeNet_mnist == 1:
  # Dictionary to store pruned versions of LeNet-5 for different sparsity levels
  snipped_lenets = {}

  total_weights_lenet = sum(p.numel() for name, p in lenet.named_parameters() if "weight" in name)
  print(f"Total weights: {total_weights_lenet}")

  # Iterate through the defined sparsity levels
  for spar in sparsities_mnist:

      # Create a deep copy of the original LeNet-5 model to apply pruning without modifying the original
      lenet_copy = copy.deepcopy(lenet)

      if spar == 0:
        # Store the pruned model in the dictionary, indexed by its corresponding sparsity level
        snipped_lenets[spar] = lenet_copy
        continue

      # Compute the pruning mask using SNIP, which determines which connections to keep
      keep_masks = SNIP(lenet, spar, train_mnist_loader, loss, device)

      # Apply the computed mask to prune the copied model
      apply_prune_mask(lenet_copy, keep_masks)

      # Store the pruned model in the dictionary, indexed by its corresponding sparsity level
      snipped_lenets[spar] = lenet_copy

In [None]:
if run_LeNet_mnist == 1:
 for spar, model in snipped_lenets.items():
      plot_conv_filters(model.conv1.weight, n_rows=1, title=f"Pruned Filters (conv1) with sparsity={spar}")

In [None]:
if run_LeNet_mnist == 1:
  for spar, model in snipped_lenets.items():
      plot_linear_weights(model.fc1.weight, title=f"Pruned Parameters (fc1) with sparsity={spar}", manual_figsize=(14, 2))

In [None]:
if run_LeNet_mnist == 1:
  for spar, model in snipped_lenets.items():
      plot_linear_weights(model.fc2.weight, title=f"Pruned Parameters (fc2) with sparsity={spar}", manual_figsize=(8, 1))

### Training

In [None]:
if run_LeNet_mnist == 1:
  # Dictionaries to store training and validation metrics for different sparsity levels
  train_losses_mnist = {}  # Training loss per sparsity level
  train_errors_mnist = {}  # Training error per sparsity level
  val_losses_mnist = {}    # Validation loss per sparsity level
  val_errors_mnist = {}    # Validation error per sparsity level

  # Iterate over pruned models at different sparsity levels
  for spar, snipped_lenet in snipped_lenets.items():
      # Initialize lists to store loss and error per epoch
      train_losses_mnist[spar] = []
      train_errors_mnist[spar] = []
      val_losses_mnist[spar] = []
      val_errors_mnist[spar] = []

      # Initialize the optimizer for the current pruned model
      optimizer = optim.SGD(snipped_lenet.parameters(), lr=lr_mnist)

      # Train and validate the model for the defined number of epochs
      for epoch in range(1, epochs_mnist + 1):
          # Perform training for the current epoch and store the loss and error in lists
          train_loss, train_error = train(
              snipped_lenet, train_mnist_loader, optimizer, loss, spar, epoch, device
          )
          train_losses_mnist[spar].append(train_loss)
          train_errors_mnist[spar].append(train_error)

          # Perform validation for the current epoch and store the loss and error in lists
          val_loss, val_error = valid(
              snipped_lenet, val_mnist_loader, loss, spar, epoch, device
          )
          val_losses_mnist[spar].append(val_loss)
          val_errors_mnist[spar].append(val_error)


In [None]:
if run_LeNet_mnist == 1:
  # Loop through sparsities to plot the graphs for each sparsity value
  for sparsity in sparsities_mnist:
      epochs = range(1, epochs_mnist + 1)

      plt.figure(figsize=(12, 5))

      # Plot training and validation loss
      plt.subplot(1, 2, 1)
      plt.plot(epochs, train_losses_mnist[sparsity], label="Train Loss", color='tab:blue',
              linestyle='dashed', linewidth=2, marker='o', markersize=6)
      plt.plot(epochs, val_losses_mnist[sparsity], label="Validation Loss", color='tab:orange',
              linestyle='dashed', linewidth=2, marker='s', markersize=6)
      plt.xlabel("Epochs", fontsize=12)
      plt.ylabel("Loss", fontsize=12)
      #plt.xticks(range(1, epochs_mnist + 1, 7), fontsize=10)
      plt.title(f"Loss vs Epochs (Sparsity: {sparsity})", fontsize=14)
      plt.legend(loc='upper right', fontsize=10)
      plt.grid(True, which='both', linestyle='--', linewidth=0.5)

      # Plot training and validation error
      plt.subplot(1, 2, 2)
      plt.plot(epochs, train_errors_mnist[sparsity], label="Train Error", color='tab:blue',
              linestyle='dashed', linewidth=2, marker='o', markersize=6)
      plt.plot(epochs, val_errors_mnist[sparsity], label="Validation Error", color='tab:orange',
              linestyle='dashed', linewidth=2, marker='s', markersize=6)
      plt.xlabel("Epochs", fontsize=12)
      plt.ylabel("Error", fontsize=12)
      #plt.xticks(range(1, epochs_mnist + 1, 7), fontsize=10)
      plt.title(f"Error vs Epochs (Sparsity: {sparsity})", fontsize=14)
      plt.legend(loc='upper right', fontsize=10)
      plt.grid(True, which='both', linestyle='--', linewidth=0.5)

      plt.tight_layout()
      plt.show()

      print(train_errors_mnist)
      print()
      print(val_errors_mnist)

errors_train_mnist
0: 1.35,
0.25: 1.3962962962962964,
0.5: 1.5777777777777777,
0.75: 2.011111111111111,
 0.9: 3.4185185185185185,
  0.95: 5.174074074074074,
   0.99: 10.924074074074074,
    0.995: 16.44814814814815,
    0.999: 81.20925925925926

In [None]:
if run_LeNet_mnist == 1:
  for spar, snipped_lenet in snipped_lenets.items():
      torch.save(snipped_lenet, f'snipped_lenet_{spar}.pth')

### Evaluation

In [None]:
if run_LeNet_mnist == 1:
  # Dictionaries to store test losses and errors for different sparsity levels
  test_losses_mnist = {}  # Test loss per sparsity level
  test_errors_mnist = {}  # Test error per sparsity level

  # Iterate over the pruned models and evaluate them on the test set
  for spar, snipped_lenet in snipped_lenets.items():
      # Evaluate the model on the test set and store the results
      test_losses_mnist[spar], test_errors_mnist[spar] = test(
          snipped_lenet, test_mnist_loader, loss, spar, device
      )

In [None]:
if run_LeNet_mnist == 1:
  # Define figure size and grid layout
  #fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(8, 8))
  fig, axes = plt.subplots(nrows=1, ncols=len(sparsities_mnist), figsize=(25, 25))
  # Loop through the grid and plot images with predictions
  # In the case of only 1 row/col, axes is not an ndarray but a single Axes object.
  # Wrap it in a list to make it iterable, or change subplots(nrows=1...) to subplots(nrows=x...) where x > 1
  for i, ax in enumerate(axes if len(sparsities_mnist) > 1 else [axes]):
      sample_idx = torch.randint(len(test_mnist), size=(1,)).item()  # Randomly select an image index

      img, label = test_mnist[sample_idx]  # Retrieve image and label

      # Get model prediction
      pred_label = snipped_lenets[sparsities_mnist[i]](img.unsqueeze(0).to(device)).argmax(dim=1).item()

      # Display the image
      ax.imshow(img.squeeze(), cmap="gray")
      ax.set_title(f"Spar:{sparsities_mnist[i]} Class:{label} Pred:{pred_label}")
      ax.axis("off")  # Hide axis for a cleaner look

  # Adjust layout for better spacing
  plt.tight_layout()
  plt.show()


In [None]:
# plot specific ID image in the test_mnist
if run_LeNet_mnist == 1:
    # Create a subset of test_mnist with 1000 images
    subset_indices = torch.randint(len(test_mnist), size=(1000,))
    subset_test_mnist = torch.utils.data.Subset(test_mnist, subset_indices)

    # Create a figure and axes for plotting
    fig, axes = plt.subplots(nrows=1, ncols=len(sparsities_mnist), figsize=(25, 5))

    # Loop through the sparsities and plot predictions for each model
    for i, spar in enumerate(sparsities_mnist):
        image_index_in_original = 710

        # Get the image and label from the original test_mnist dataset
        img, label = test_mnist[image_index_in_original]

        # Get the actual image ID
        image_id = test_mnist.targets[image_index_in_original].item()

        # Get model prediction
        pred_label = snipped_lenets[spar](img.unsqueeze(0).to(device)).argmax(dim=1).item()

        # Display the image on the current subplot
        axes[i].imshow(img.squeeze(), cmap="gray")
        axes[i].set_title(f"Spar: {spar}, Class: {label}, Pred: {pred_label}")
        axes[i].axis("off")

    # Adjust layout and display the plot
    plt.tight_layout()
    plt.show()

In [None]:
if run_LeNet_mnist == 1:
  # Create a figure with 3 subplots in a single row
  fig, axs = plt.subplots(1, 3, figsize=(18, 5))

  last_train_errors_mnist = {}
  last_val_errors_mnist = {}

  for spar in sparsities_mnist:
      last_train_errors_mnist[spar] = train_errors_mnist[spar][-1]
      last_val_errors_mnist[spar] = val_errors_mnist[spar][-1]

  # Define lists for errors, titles, and labels to iterate efficiently
  errors = [last_train_errors_mnist, last_val_errors_mnist, test_errors_mnist]
  titles = ["Training Error", "Validation Error", "Test Error"]

  # Loop through the three error datasets to plot them
  for i, (error, title) in enumerate(zip(errors, titles)):
      # Set color palette for lines and scatter points
      line_color = 'tab:blue' if i == 0 else 'tab:orange' if i == 1 else 'tab:green'
      scatter_color = line_color

      # Plot sparsity vs. error (line)
      #axs[i].plot(range(len(error.keys())), error.values(), label=title, color=line_color, linewidth=2, linestyle='-', marker='o')
      axs[i].plot(sparsities_mnist, error.values(), label=title, color=line_color, linewidth=2, linestyle='-', marker='o')

      # Scatter plot for better visibility
      #axs[i].scatter(range(len(error.keys())), error.values(), color=scatter_color, s=80, zorder=5)
      axs[i].scatter(sparsities_mnist, error.values(), color=scatter_color, s=80, zorder=5)

      # Set y-axis limits for consistency
      axs[i].set_ylim(0, 100)


      # Add labels and title
      axs[i].set_xlabel("Sparsity", fontsize=12)
      axs[i].set_ylabel("Error", fontsize=12)
      axs[i].set_title(title, fontsize=14)


      # Add grid for better readability
      axs[i].grid(True, which='both', linestyle='--', linewidth=0.5)


      # Add legend with better positioning
      axs[i].legend(loc='upper left', fontsize=10)



  # Adjust layout for better spacing and avoid overlap
  plt.tight_layout()

  # Show the plots
  plt.show()

In [None]:
# combine previous plots in one
if run_LeNet_mnist == 1:
  # Create a figure with 3 subplots in a single row
  fig, axs = plt.subplots(1, 1, figsize=(8, 5))

  last_train_errors_mnist = {}
  last_val_errors_mnist = {}

  for spar in sparsities_mnist:
      last_train_errors_mnist[spar] = train_errors_mnist[spar][-1]
      last_val_errors_mnist[spar] = val_errors_mnist[spar][-1]

  # Define lists for errors, titles, and labels to iterate efficiently
  errors = [last_train_errors_mnist, last_val_errors_mnist, test_errors_mnist]
  titles = ["Training Error", "Validation Error", "Test Error"]

  # Loop through the three error datasets to plot them
  for i, (error, title) in enumerate(zip(errors, titles)):
      # Set color palette for lines and scatter points
      line_color = 'tab:blue' if i == 0 else 'tab:orange' if i == 1 else 'tab:green'
      scatter_color = line_color

      # Plot sparsity vs. error (line)
      #axs[i].plot(range(len(error.keys())), error.values(), label=title, color=line_color, linewidth=2, linestyle='-', marker='o')
      axs.plot(sparsities_mnist, error.values(), label=title, color=line_color, linewidth=2, linestyle='-', marker='o')

      # Scatter plot for better visibility
      #axs[i].scatter(range(len(error.keys())), error.values(), color=scatter_color, s=80, zorder=5)
      axs.scatter(sparsities_mnist, error.values(), color=scatter_color, s=80, zorder=5)

      # Set y-axis limits for consistency
      axs.set_ylim(0, 100)


      # Add labels and title
      axs.set_xlabel("Sparsity", fontsize=12)
      axs.set_ylabel("Error", fontsize=12)
      axs.set_title(title, fontsize=14)

      # Set x-ticks with corresponding sparsities
      #axs[i].set_xticks(range(len(error.keys())))
      #axs.set_xticks(sparsities_mnist)
      #axs.set_xticklabels(list(error.keys()), rotation=-90, ha="right")

      # Add grid for better readability
      axs.grid(True, which='both', linestyle='--', linewidth=0.5)


      # Add legend with better positioning
      axs.legend(loc='upper left', fontsize=10)



  # Adjust layout for better spacing and avoid overlap
  plt.tight_layout()

  # Show the plots
  plt.show()

In [None]:
if run_LeNet_mnist == 1:
  # Define the number of rows and columns for the subplot grid
  ncols = len(sparsities_mnist)
  nrows = 1

  # Initialize accuracy_list before the loop
  accuracy_list = []

  # Manually apply transforms to all images
  images = torch.stack([test_mnist[i][0] for i in range(len(test_mnist))])
  labels = torch.tensor([test_mnist[i][1] for i in range(len(test_mnist))])

  images = images.to(device)


  # Create a figure with subplots
  fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(50, 50))

  # Check if axes is a single Axes object and wrap it in a list if necessary
  if ncols == 1 and nrows == 1:
      axes = [axes]
  else:
      # Flatten the axes array for easy indexing if it's not a single Axes
      axes = axes.flatten() # This will only run if axes is an ndarray (multiple subplots)

  # Loop through sparsity levels and corresponding subplots
  for idx, spar in enumerate(sparsities_mnist):
      # Get the model for the current sparsity level
      model = snipped_lenets[spar]
      model.eval()

      # Get predictions using the model
      with torch.no_grad():  # Disable gradient calculations for inference
          pred = model(images).argmax(dim=1).cpu()  # Get predictions and move to CPU

      # Calculate the confusion matrix
      cm = confusion_matrix(labels, pred)
      disp = ConfusionMatrixDisplay(cm)

      # Plot the confusion matrix in the correct subplot
      disp.plot(cmap='Blues', values_format='d', ax=axes[idx], colorbar=False)

      # Add colorbar with adjusted height, using axes[idx].figure instead of disp.ax.figure
      cbar = axes[idx].figure.colorbar(disp.im_, ax=axes[idx], fraction=0.046, pad=0.04)

      axes[idx].set_title(f"Confusion Matrix for Sparsity: {spar}", fontsize=16)

      #increase fontsize for x and y labels and tick labels
      axes[idx].set_xlabel("Predicted Labels", fontsize=30)
      axes[idx].set_ylabel("True Labels", fontsize=30)
      axes[idx].tick_params(axis='both', which='major', labelsize=20)

      # increase fontsize of the values in the confusion matrix
      for text in axes[idx].texts:
        text.set_fontsize(12)

  # Adjust layout to prevent overlap
  plt.tight_layout()
  plt.show()

# **MAIN with CIFAR10**

### Pruning

In [None]:
if run_ResNet_cifar == 1:
  # Initialize the ResNet18 model and move it to the selected device (CPU/GPU)
  resnet = ResNet18().to(device)

  # Define different sparsity levels for cifar10 dataset experiments
  # sparsities_cifar = [0, 0.25, 0.5, 0.75]

  # Use negative log-likelihood loss
  loss = F.nll_loss

In [None]:
if run_ResNet_cifar == 1:
  snipped_resnets = {}

  total_weights_resnet = sum(p.numel() for name, p in resnet.named_parameters() if "weight" in name)
  print(f"Total weights: {total_weights_resnet}")

  # Iterate through the defined sparsity levels
  for spar in sparsities_cifar:

      # Create a deep copy of the original model to apply pruning without modifying the original
      resnet_copy = copy.deepcopy(resnet)

      if spar == 0:
        snipped_resnets[spar] = resnet_copy
        continue

      # Compute the pruning mask using SNIP, which determines which connections to keep
      keep_masks = SNIP(resnet, spar, train_cifar_loader, loss, device)

      # Apply the computed mask to prune the copied model
      apply_prune_mask(resnet_copy, keep_masks)

      # Store the pruned model in the dictionary, indexed by its corresponding sparsity level
      snipped_resnets[spar] = resnet_copy

In [None]:
if run_ResNet_cifar == 1:
  for spar, model in snipped_resnets.items():
      plot_conv_filters(model.conv1.weight, n_rows=2, RGB=True, title=f"Pruned Filters (conv1) with sparsity={spar}")

In [None]:
if run_ResNet_cifar == 1:
  for spar, model in snipped_resnets.items():
      plot_linear_weights(model.fc.weight, title=f"Pruned Parameters (fc) with sparsity={spar}", manual_figsize=(14, 2))

### Training

In [None]:
if run_ResNet_cifar == 1:
  # Dictionaries to store training and validation metrics for different sparsity levels
  train_losses_cifar = {}  # Training loss per sparsity level (list for each sparsity)
  train_errors_cifar = {}  # Training error per sparsity level (list for each sparsity)
  val_losses_cifar = {}    # Validation loss per sparsity level (list for each sparsity)
  val_errors_cifar = {}    # Validation error per sparsity level (list for each sparsity)

  # Iterate over pruned models at different sparsity levels
  for spar, snipped_resnet in snipped_resnets.items():
      # Initialize lists to store loss and error per epoch
      train_losses_cifar[spar] = []
      train_errors_cifar[spar] = []
      val_losses_cifar[spar] = []
      val_errors_cifar[spar] = []

      # Initialize the optimizer with momentum
      optimizer = optim.SGD(snipped_resnet.parameters(), lr=lr_cifar, momentum=0.9)

      # Learning rate scheduler that reduces LR at specific milestones
      scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120], gamma=0.1)

      # Train and validate the model for the defined number of epochs
      for epoch in range(1, epochs_cifar + 1):
          # Perform training for the current epoch and store the loss and error in lists
          train_loss, train_error = train(
              snipped_resnet, train_cifar_loader, optimizer, loss, spar, epoch, device
          )
          train_losses_cifar[spar].append(train_loss)
          train_errors_cifar[spar].append(train_error)

          # Perform validation for the current epoch and store the loss and error in lists
          val_loss, val_error = valid(
              snipped_resnet, val_cifar_loader, loss, spar, epoch, device
          )
          val_losses_cifar[spar].append(val_loss)
          val_errors_cifar[spar].append(val_error)

          # Step the learning rate scheduler
          scheduler.step()

In [None]:
if run_ResNet_cifar == 1:
  # Loop through sparsities to plot the graphs for each sparsity value
  for sparsity in sparsities_cifar:
      epochs = range(1, epochs_cifar + 1)

      plt.figure(figsize=(12, 5))

      # Plot training and validation loss
      plt.subplot(1, 2, 1)
      plt.plot(epochs, train_losses_cifar[sparsity], label="Train Loss", color='tab:blue',
              linestyle='dashed', linewidth=2, marker='o', markersize=6)
      plt.plot(epochs, val_losses_cifar[sparsity], label="Validation Loss", color='tab:orange',
              linestyle='dashed', linewidth=2, marker='s', markersize=6)
      plt.xlabel("Epochs", fontsize=12)
      plt.ylabel("Loss", fontsize=12)
      #plt.xticks(range(1, epochs_cifar + 1, 5), fontsize=10)
      plt.title(f"Loss vs Epochs (Sparsity: {sparsity})", fontsize=14)
      plt.legend(loc='upper right', fontsize=10)
      plt.grid(True, which='both', linestyle='--', linewidth=0.5)

      # Plot training and validation error
      plt.subplot(1, 2, 2)
      plt.plot(epochs, train_errors_cifar[sparsity], label="Train Error", color='tab:blue',
              linestyle='dashed', linewidth=2, marker='o', markersize=6)
      plt.plot(epochs, val_errors_cifar[sparsity], label="Validation Error", color='tab:orange',
              linestyle='dashed', linewidth=2, marker='s', markersize=6)
      plt.xlabel("Epochs", fontsize=12)
      plt.ylabel("Error", fontsize=12)
      #plt.xticks(range(1, epochs_cifar + 1), fontsize=10)
      plt.title(f"Error vs Epochs (Sparsity: {sparsity})", fontsize=14)
      plt.legend(loc='upper right', fontsize=10)
      plt.grid(True, which='both', linestyle='--', linewidth=0.5)

      plt.tight_layout()
      plt.show()

In [None]:
if run_ResNet_cifar == 1:
  for spar, snipped_resnet in snipped_resnets.items():
      torch.save(snipped_resnet, f'snipped_resnet_{spar}.pth')

### Evaluation

In [None]:
if run_ResNet_cifar == 1:
  # Dictionaries to store test losses and errors for different sparsity levels
  test_losses_cifar = {}
  test_errors_cifar = {}

  # Iterate over each pruned model in snipped_resnets
  for spar, snipped_resnet in snipped_resnets.items():
      # Evaluate the pruned model on the test set and store the results
      test_losses_cifar[spar], test_errors_cifar[spar] = test(
          snipped_resnet, test_cifar_loader, loss, spar, device
          )

In [None]:
if run_ResNet_cifar == 1:
  # Define figure size and 4x6 grid layout
  # fig, axes = plt.subplots(nrows=4, ncols=6, figsize=(18, 12))
  fig, axes = plt.subplots(nrows=len(sparsities_cifar), ncols=6, figsize=(18, 12))
  # Reshape the axes array to have the desired 2D shape
  axes = axes.reshape(len(sparsities_cifar), 6)
  # Loop through the grid and plot images with predictions
  for row in range(len(sparsities_cifar)):  # Each row represents a different sparsity level
      sparsity = sparsities_cifar[row]  # Get the current sparsity level
      for col in range(6):  # Six images per sparsity level
          ax = axes[row, col]
          sample_idx = torch.randint(len(test_cifar), size=(1,)).item()  # Randomly select an image index
          img, label = test_cifar[sample_idx]  # Retrieve image and label

          # Get model prediction
          pred_label = snipped_resnets[sparsity](img.unsqueeze(0).to(device)).argmax(dim=1).item()

          # Display the image
          img = img * std_cifar[:, None, None] + mean_cifar[:, None, None]  # Unnormalize the image
          ax.imshow(img.permute(1, 2, 0))  # Convert (C, H, W) to (H, W, C) for visualization
          ax.set_title(f"Label: {classes[label]}\nPred: {classes[pred_label]}", fontsize=10)
          ax.axis("off")  # Hide axis for a cleaner look

      # Add sparsity level label on the left of the row
      axes[row, 0].annotate(f"Spar: {sparsity}", xy=(-0.2, 0.5), xycoords='axes fraction', fontsize=12,
                            ha='right', va='center', rotation=90, fontweight='bold')

  # Adjust layout for better spacing
  plt.tight_layout()
  plt.show()

In [None]:
if run_ResNet_cifar == 1:
  # Create a figure with 3 subplots in a single row
  fig, axs = plt.subplots(1, 3, figsize=(18, 5))

  last_train_errors_cifar = {}
  last_val_errors_cifar = {}

  for spar in sparsities_cifar:
      last_train_errors_cifar[spar] = train_errors_cifar[spar][-1]
      last_val_errors_cifar[spar] = val_errors_cifar[spar][-1]

  # Define lists for errors, titles, and labels to iterate efficiently
  errors = [last_train_errors_cifar, last_val_errors_cifar, test_errors_cifar]
  titles = ["Training Error", "Validation Error", "Test Error"]

  # Loop through the three error datasets to plot them
  for i, (error, title) in enumerate(zip(errors, titles)):
      # Set color palette for lines and scatter points
      line_color = 'tab:blue' if i == 0 else 'tab:orange' if i == 1 else 'tab:green'
      scatter_color = line_color

      # Plot sparsity vs. error (line)
      # axs[i].plot(range(len(error.keys())), error.values(), label=title, color=line_color, linewidth=2, linestyle='-', marker='o')
      axs[i].plot(sparsities_cifar, error.values(), label=title, color=line_color, linewidth=2, linestyle='-', marker='o')

      # Scatter plot for better visibility
      #axs[i].scatter(range(len(error.keys())), error.values(), color=scatter_color, s=80, zorder=5)
      axs[i].scatter(sparsities_cifar, error.values(), color=scatter_color, s=80, zorder=5)

      # Set y-axis limits for consistency
      axs[i].set_ylim(0, 90)

      # Add labels and title
      axs[i].set_xlabel("Sparsity", fontsize=12)
      axs[i].set_ylabel("Error", fontsize=12)
      axs[i].set_title(title, fontsize=14)

      # Add grid for better readability
      axs[i].grid(True, which='both', linestyle='--', linewidth=0.5)

      # Add legend with better positioning
      axs[i].legend(loc='upper right', fontsize=10)

  # Adjust layout for better spacing and avoid overlap
  plt.tight_layout()

  # Show the plots
  plt.show()

In [None]:
if run_ResNet_cifar == 1:
  # Define the number of rows and columns for the subplot grid
  ncols = len(sparsities_cifar)
  nrows = 1

  # Manually apply transforms to all images
  images = torch.stack([test_cifar[i][0] for i in range(len(test_cifar))])
  labels = torch.tensor([test_cifar[i][1] for i in range(len(test_cifar))])

  # Move to GPU if needed
  images = images.to(device)


  # Create a figure with subplots
  fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(30, 30))

  # Check if axes is a single Axes object and wrap it in a list if necessary
  if ncols == 1 and nrows == 1:
      axes = [axes]
  else:
     axes = axes.flatten() # This will only run if axes is an ndarray (multiple subplots)

  # Loop through sparsity levels and corresponding subplots
  for idx, spar in enumerate(sparsities_cifar):
      model = snipped_resnets[spar].to(device)
      model.eval()

      # Get predictions using the model
      with torch.no_grad():  # Disable gradient calculations for inference
          pred = model(images).argmax(dim=1).cpu()  # Get predictions

      # Calculate the confusion matrix
      cm = confusion_matrix(labels, pred)
      disp = ConfusionMatrixDisplay(cm, display_labels=classes)

      # Plot the confusion matrix in the correct subplot
      disp.plot(cmap='Blues', values_format='d', ax=axes[idx], colorbar=False)

      # Add colorbar with adjusted height, using axes[idx].figure instead of disp.ax.figure
      cbar = axes[idx].figure.colorbar(disp.im_, ax=axes[idx], fraction=0.046, pad=0.04)

      axes[idx].set_title(f"Confusion Matrix for Sparsity: {spar}", fontsize=30) # Increased title fontsize

      # Increase fontsize for x and y labels and tick labels
      axes[idx].set_xlabel("Predicted Label", fontsize=20)
      axes[idx].set_ylabel("True Label", fontsize=20)
      axes[idx].tick_params(axis='both', which='major', labelsize=16)


      #Increase fontsize of the values in the confusion matrix
      for text in axes[idx].texts:
        text.set_fontsize(16)

  # Adjust layout to prevent overlap
  plt.tight_layout()
  plt.show()