In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models as torchvision_models
import timm
from sklearn.metrics import accuracy_score
import numpy as np
import tarfile

In [None]:
#Loading Data Set
t = torch.load(r"gdrive/My Drive/miniproject2/dataset/part_one_dataset/train_data/1_train_data.tar.pth")

In [None]:
data, targets = t['data'], t['targets']
data = torch.tensor(data, dtype=torch.float32).permute(0, 3, 1, 2)
targets = torch.tensor(targets, dtype=torch.long)

#transformations for resizing
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
#Applying transformation on the given data
data = torch.stack([transform(img) for img in data])

#Loading the data using below
dataset = TensorDataset(data, targets)
loader = DataLoader(dataset, batch_size=64, shuffle=False)

#Function to extract features
def extract_features(model, loader):
    model.eval()
    features = []
    with torch.no_grad():
        for inputs, _ in loader:
            feature_vectors = model(inputs)
            features.append(feature_vectors)
    return torch.cat(features)

In [None]:
#
models = {
    # "EfficientNet-B3": timm.create_model('tf_efficientnet_b3_ns', pretrained=True), #69.76%
    # "VGG16": torchvision_models.vgg16(pretrained=True), # 60.72%
    #"DPN92": timm.create_model('dpn92', pretrained=True), #61.68%
    # "AlexNet": torchvision_models.alexnet(pretrained=True), # 57.36%
    # "ResNet50": torchvision_models.resnet50(pretrained=True), # 63.52%
    # "ResNet101": torchvision_models.resnet101(pretrained=True), # 67.32%
    # "DenseNet121": torchvision_models.densenet121(pretrained=True), #63.96%
    # "RegNetY-400MF": torchvision_models.regnet_y_400mf(pretrained=True), #64.00%
    #"SENet154": timm.create_model('senet154', pretrained=True), #69.24
    #"MobileNetV2": torchvision_models.mobilenet_v2(pretrained=True),  #66.48%
    #"InceptionV3": torchvision_models.inception_v3(pretrained=True), #62.28%
    # "EfficientNet-B0": torchvision_models.efficientnet_b0(pretrained=True), #79.88%
    "ViT-Base": timm.create_model('vit_base_patch16_224', pretrained=True) #84.68%
    # "Swin-Transformer": timm.create_model('swin_base_patch4_window7_224', pretrained=True), #0.24%
    #"ConvNeXt-Base": torchvision_models.convnext_base(pretrained=True)  #75.80%
}


for model_name, model in models.items():
    if 'ResNet' in model_name:
        model.fc = torch.nn.Identity()
    elif 'MobileNet' in model_name:
        model.classifier[1] = torch.nn.Identity()
    elif 'VGG' in model_name:
        model.classifier[6] = torch.nn.Identity()
    elif 'Inception' in model_name:
        model.fc = torch.nn.Identity()
    elif 'AlexNet' in model_name:
        model.classifier[6] = torch.nn.Identity()
    else:
        if model_name == "EfficientNet-B3":
            model.classifier = torch.nn.Identity()
        elif model_name == "DPN92":
            model.classifier = torch.nn.Identity()
        elif model_name == "EfficientNet-B0":
            model.classifier = torch.nn.Identity()
        elif model_name == "ViT-Base":
            model.head = torch.nn.Identity()
        elif model_name == "Swin-Transformer":
            model.head = torch.nn.Identity()
        elif model_name == "ConvNeXt-Base":
            model.classifier[2] = torch.nn.Identity()
        elif model_name == "DenseNet121":
            model.classifier = torch.nn.Identity()
        elif model_name == "RegNetY-400MF":
            model.fc = torch.nn.Identity()
        elif model_name == "SENet154":
            model.fc = torch.nn.Identity()

for model_name, model in models.items():
    print(f"\nExtracting features and applying LwP using {model_name}...")

    # Feature extraction
    features = extract_features(model, loader)
    features = features.numpy()

    num_classes = len(torch.unique(targets))
    prototypes = []
    for i in range(num_classes):
        class_features = features[targets == i]
        prototype = class_features.mean(axis=0)
        prototypes.append(prototype)
    prototypes = np.array(prototypes)
  #LWP calculation and accuracy calculation in further step
    predictions = []
    for feature in features:
        distances = np.linalg.norm(prototypes - feature, axis=1)
        predicted_class = np.argmin(distances)
        predictions.append(predicted_class)

    accuracy = accuracy_score(targets.numpy(), predictions)
    print(f"Accuracy with {model_name}: {accuracy * 100:.2f}%")