In [None]:
# Init workspace
!mkdir dataset

# Download dataset and extract it
!gdown 111HiEoEvZDdg1Y2EefI6n5dA_p4sMV4V
!mv imagenet-a.tar ./dataset
!tar -xf ./dataset/imagenet-a.tar
!mv imagenet-a ./dataset

# Cleanup
!rm ./dataset/imagenet-a.tar

In [12]:
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
import torchvision
import torch.nn.functional as F
import torchvision.transforms as T

import json
from os.path import basename, join
from pathlib import Path
import requests

import re

In [16]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
def load_imagenet_a_labels() -> list:

    imagenet_a = "./dataset/imagenet-a"

    with open(join(imagenet_a, "README.txt"), "r") as f:
        lines = f.readlines()

    pattern = re.compile(r"n\d{8}\s(.+)")

    labels = []

    for index, label in enumerate([line.strip() for line in lines if pattern.match(line)]):
        label = label.strip()
        label = label.split(" ")
        label = " ".join(label[1:])

        labels.append(label)

    return labels


In [5]:
def load_model_labels() -> list[str]:

    url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
    path = Path(basename(url))

    # Check if labels file already exists
    if not path.exists():
        response = requests.get(url)
        path.write_text(response.text)

    # Load labels
    with open(path, "r") as f:
        labels = json.load(f)

    return labels

In [6]:
def merge_labels() -> dict:

    imagenet_a_labels = load_imagenet_a_labels()
    model_labels = load_model_labels()

    labels = {}

    for imagenet_a_index, item in enumerate(imagenet_a_labels):

        model_index = model_labels.index(item)

        labels[item] = {
            "imagenet-a": imagenet_a_index,
            "model": model_index
        }

    return labels

In [13]:
def load_model(model_name: str = "google/vit-base-patch16-384") -> ViTForImageClassification:

    # Load the pre-trained model
    model = ViTForImageClassification.from_pretrained(model_name).to(DEVICE)

    return model

In [8]:
def load_dataset():

    imagenet_a = "./dataset/imagenet-a/"

    # Prepare data transformations for the train loader
    transform = T.Compose([
        T.Resize((384, 384)),                                                   # Resize each PIL image to 224 x 224
        T.ToTensor(),                                                           # Convert Numpy to Pytorch Tensor
    ])

    # Load data
    imagenet_a_dataset = torchvision.datasets.ImageFolder(root=imagenet_a, transform=transform)
    return torch.utils.data.DataLoader(imagenet_a_dataset, 1, shuffle=True, num_workers=8)

In [9]:
def classify_image(model: ViTForImageClassification, img: torch.Tensor) -> tuple[dict, dict]:

    # Perform inference
    model.eval()

    with torch.no_grad():
        outputs = model(img)

    # Extract probabilities from model's output logits
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()

    labels = load_model_labels()
    results = {}

    for index, probability in enumerate(probabilities):
        results[index] = {
            "index": index,
            "label": labels[index],
            "probability": probability.item()
        }

    predicted = probabilities.argmax(-1).item()
    predicted = {
        "index": predicted,
        "label": labels[predicted],
        "probability": probabilities[predicted].item()
    }

    return predicted, results

In [17]:
# Load model (only once)
model = load_model()

# Load data (only once)
data_loader = load_dataset()

# Extract merged labels
merged_labels = merge_labels()

In [None]:
accuracy = 0

for index, img in enumerate(data_loader):

    # Get model prediction
    predicted, results = classify_image(model=model, img=img[0].to(DEVICE))

    # Check if the predicted label exists inside the dataset labels
    if predicted["label"] in merged_labels:

        merged_label = merged_labels[predicted["label"]]

        # Check if the prediction was correct
        if merged_label["imagenet-a"] == img[1].item():
            accuracy = accuracy + 1

    print(f"Image {index+1} / {len(data_loader)} | Accuracy: {round((accuracy / (index + 1)) * 100, 2)}% ({accuracy} / {index + 1})")

accuracy = accuracy / len(data_loader)