In [None]:
# Importing necessary libraries and modules

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import pandas as pd
import numpy as np
from scipy.spatial import distance

import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.models import (
    resnet34,
    ResNet34_Weights,
    resnet18,
    ResNet18_Weights,
    vgg11,
    VGG11_Weights,
)

from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
from tqdm.notebook import tqdm
import seaborn as sb

from copy import deepcopy

# Set seaborn theme
sb.set_theme()

In [None]:
# Check if CUDA is available and use it if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

### Load Data

In [None]:
# Define a series of transformations to apply to an image.

transform = transforms.Compose(
    [
        transforms.Resize(size=256),  # Resize the image to 256x256 pixels.
        transforms.CenterCrop(
            size=224
        ),  # Crop the center of the image to 224x224 pixels.
        transforms.ToTensor(),  # Convert the image to a tensor.
        transforms.Normalize(
            [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        ),  # Normalize the image with mean and standard deviation.
    ]
)

In [None]:
# Load the training data
train_data = datasets.CIFAR10("data", train=True, download=True, transform=transform)

# Load the test data
test_data = datasets.CIFAR10("data", train=False, download=True, transform=transform)

In [None]:
# Split the training data into training and validation subsets
train_subset, val_subset = torch.utils.data.random_split(
    train_data, [0.85, 0.15], generator=torch.Generator().manual_seed(1)
)

In [None]:
class ResNet34:
    def __init__(self):
        # Load the ResNet34 model with default weights
        self.resnet34 = resnet34(weights=ResNet34_Weights.DEFAULT)
        modules = list(self.resnet34.children())[:-1]

        # Create a new model with all the layers except the last one
        self.resnet34 = torch.nn.Sequential(*modules)

        self.resnet34.eval()  # Set the model to evaluation mode

        self.resnet34 = self.resnet34.to(device)

    def __repr__(self):
        return "ResNet34"

    def get_features(self, images):
        with torch.no_grad():
            # Get the features from the features extractor model
            features = self.resnet34(images)
        return features

    def get_size(self):
        with torch.no_grad():
            features = torch.flatten(
                self.resnet34(torch.zeros(1, 3, 224, 224).to(device)), start_dim=1
            )
        # Return the size of the features
        return features.shape

In [None]:
class ResNet18:
    def __init__(self):
        # Load the ResNet18 model with default weights
        self.resnet18 = resnet18(weights=ResNet18_Weights.DEFAULT)
        modules = list(self.resnet18.children())[:-1]

        # Create a new model with all the layers except the last one
        self.resnet18 = torch.nn.Sequential(*modules)

        self.resnet18.eval()  # Set the model to evaluation mode

        self.resnet18 = self.resnet18.to(device)

    def __repr__(self):
        return "ResNet18"

    def get_features(self, images):
        with torch.no_grad():
            # Get the features from the features extractor model
            features = self.resnet18(images)
        return features

    def get_size(self):
        with torch.no_grad():
            features = torch.flatten(
                self.resnet18(torch.zeros(1, 3, 224, 224).to(device)), start_dim=1
            )
        # Return the size of the features
        return features.shape

In [None]:
class VGG11:
    def __init__(self):
        # Load the VGG11 model with default weights
        self.vgg11 = vgg11(weights=VGG11_Weights.DEFAULT)
        modules = list(self.vgg11.children())[:-1]

        # Create a new model with all the layers except the last one
        self.vgg11 = torch.nn.Sequential(*modules)

        self.vgg11.eval()  # Set the model to evaluation mode

        self.vgg11 = self.vgg11.to(device)

    def __repr__(self):
        return "VGG11"

    def get_features(self, images):
        with torch.no_grad():
            # Get the features from the features extractor model
            features = self.vgg11(images)
        return features

    def get_size(self):
        with torch.no_grad():
            features = torch.flatten(
                self.vgg11(torch.zeros(1, 3, 224, 224).to(device)), start_dim=1
            )
        # Return the size of the features
        return features.shape

In [None]:
fe_resnet18 = ResNet18()
fe_resnet34 = ResNet34()
fe_vgg11 = VGG11()