This code performs a cross-evaluation between two pretrained models, ResNet18 and Vision Transformer (ViT), on a dataset of images generated using activation maximization. The images are evaluated both with and without augmentation, and the accuracy of each model on different datasets is calculated. The dataset is processed using a custom Dataset class, which handles loading images and applying transformations. The evaluation function measures model performance by calculating accuracy over a series of augmented and non-augmented images.

In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import glob
import matplotlib.pyplot as plt
import numpy as np

In [2]:
class ActivationMaximizationDataset(Dataset):
    def __init__(self, path, model_type, aug_type, batch_size):
        self.path = path
        self.model_type = model_type
        self.aug_type = aug_type
        self.batch_size = batch_size

        # The images in the dataset should follow the following name type:
        # {class_id}_{model_type}_{augmentation_type}.png
        # Example: 294_Vit_withAug
        
        name = f"*_{model_type}_{aug_type}.png"
        self.image_paths = glob.glob(os.path.join(path, name))
        
        self.labels = [] # Stores the labels of the images
        self.valid_image_paths = [] # The paths to the images
        
        for path in self.image_paths: # Gets the class id of each image
            class_id = int(os.path.basename(path).split('_')[0])
            self.labels.append(class_id)
            self.valid_image_paths.append(path)
        
        self.image_paths = self.valid_image_paths
        
        if aug_type == 'withAug': # If the image is created using random augmentations
            self.transforms = torchvision.transforms.Compose([
                torchvision.transforms.RandomResizedCrop(size=[224, 224], scale=(0.1, 1.0)),
                torchvision.transforms.RandomRotation(degrees=20),
                torchvision.transforms.RandomHorizontalFlip(p=0.5),
                torchvision.transforms.RandomPerspective(distortion_scale=0.4, p=0.5),
                torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                torchvision.transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else: # If the image is not created using random augmentations
            self.transforms = torchvision.transforms.Compose([
                torchvision.transforms.RandomResizedCrop(size=[224, 224], scale=(0.1, 1.0)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, x): # Loads and returns a batch of randomly augmented versions of an image for the given index x
        with Image.open(self.image_paths[x]) as image:
            image = image.convert('RGB')
            label = self.labels[x]
            augmented_images = [self.transforms(image) for _ in range(self.batch_size)] # Range determines the number of randomly augmented versions of the image
            return torch.stack(augmented_images), label # Returns the batch of randomly augmented versions of the image and a vector shaped len(batch) with labels

In [3]:
def create_dataloader(path, model_type, aug_type, batch_size=16):
    dataset = ActivationMaximizationDataset(path, model_type, aug_type, batch_size=batch_size)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

In [4]:
def cross_evaluation(path, resnet_model, vit_model):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    resnet_model = resnet_model.to(device)
    vit_model = vit_model.to(device)
    
    def evaluate_model(model, dataloader, model_name, aug_type):
        model.eval()
        true_prediction = 0
        total_prediction = 0
        with torch.no_grad():
            for images, labels in dataloader:
                for i in range(images.shape[1]):  # Gets the number of augmented versions
                    augmented_image = images[:, i, :, :, :]
                    
                    for j in range(len(augmented_image)):
                        img = augmented_image[j].cpu().numpy().transpose((1, 2, 0))
                        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
                        img = img.clip(0, 1)
                        
                    # To visualize each image:

                        # label = labels[j].item()
                        # outputs = model(augmented_image[j].unsqueeze(0).to(device))
                        # _, predicted = torch.max(outputs.data, 1)
                        # predicted_label = predicted.item()

                        # if label == predicted_label:
                        #     prediction_status = "Correct"
                        # else:
                        #     prediction_status = "Incorrect"
                        
                        # model_type = "ResNet" if "ResNet" in dataloader.dataset.model_type else "ViT"

                        # plt.figure(figsize=(4, 4))
                        # plt.imshow(img)
                        # plt.title(f"Inference Model: {model_name} - Image From: {model_type}({aug_type})\nTrue Label: {label}, Prediction: {predicted_label} ({prediction_status})")
                        # plt.axis("off")
                        # plt.show()
                    
                    augmented_image = augmented_image.to(device)
                    labels = labels.to(device)

                    outputs = model(augmented_image)
                    _, predicted = torch.max(outputs.data, 1)
                    total_prediction += labels.size(0)
                    true_prediction += (predicted == labels).sum().item()

        accuracy = true_prediction / total_prediction
        return accuracy
    
    results = {}

    resnet_no_aug_loader = create_dataloader(path, 'ResNet', 'withoutAug')
    resnet_aug_loader = create_dataloader(path, 'ResNet', 'withAug')
    vit_no_aug_loader = create_dataloader(path, 'ViT', 'withoutAug')
    vit_aug_loader = create_dataloader(path, 'ViT', 'withAug')

    # Self-Model Accuracy

    results['ViT accuracy of ViT images (not augmented)'] = evaluate_model(vit_model, vit_no_aug_loader, 'ViT', 'withoutAug')
    results['ResNet accuracy of ResNet images (not augmented)'] = evaluate_model(resnet_model, resnet_no_aug_loader, 'ResNet', 'withoutAug')
    results['ViT accuracy of ViT images (augmented)'] = evaluate_model(vit_model, vit_aug_loader, 'ViT', 'withAug')
    results['ResNet accuracy of ResNet images (augmented)'] = evaluate_model(resnet_model, resnet_aug_loader, 'ResNet', 'withAug')

    # Cross Evaluation Accuracy
        
    results['ViT accuracy of ResNet images (not augmented)'] = evaluate_model(vit_model, resnet_no_aug_loader, 'ViT', 'withoutAug')
    results['ResNet accuracy of ViT images (not augmented)'] = evaluate_model(resnet_model, vit_no_aug_loader, 'ResNet', 'withoutAug')
    results['ViT accuracy of ResNet images (augmented)'] = evaluate_model(vit_model, resnet_aug_loader, 'ViT', 'withAug')
    results['ResNet accuracy of ViT images (augmented)'] = evaluate_model(resnet_model, vit_aug_loader, 'ResNet', 'withAug')

    return results

In [None]:
path = 'testdataset' # Path to the images created by activation maximization
resnet_model = torchvision.models.resnet18(pretrained=True)
vit_model = torchvision.models.vit_b_16(weights='IMAGENET1K_V1')
results = cross_evaluation(path, resnet_model, vit_model)

print("Evaluation Results:")
for test, accuracy in results.items():
    print(f"{test}: {accuracy * 100:.2f}%")