In [None]:
# Init workspace
!mkdir dataset

# Download dataset and extract it
!gdown 111HiEoEvZDdg1Y2EefI6n5dA_p4sMV4V
!mv imagenet-a.tar ./dataset
!tar -xf ./dataset/imagenet-a.tar
!mv imagenet-a ./dataset

# Cleanup
!rm ./dataset/imagenet-a.tar

# (optional) Upgrading pytorch for the latest augmentation functions
#!pip install --upgrade torch torchvision torchaudio

In [28]:
import torch
import torch.optim as optim
from transformers import ViTForImageClassification, ViTImageProcessor
import torchvision
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import v2

import json
from os.path import basename, join
from pathlib import Path
import requests

import re
from contextlib import nullcontext
from typing import Union
from copy import deepcopy

import warnings
warnings.filterwarnings("ignore")

In [4]:
# Use cuda if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SIZE = (384, 384)

In [5]:
def load_imagenet_a_labels() -> list[str]:

    imagenet_a = "./dataset/imagenet-a"

    with open(join(imagenet_a, "README.txt"), "r") as f:
        lines = f.readlines()

    pattern = re.compile(r"n\d{8}\s(.+)")

    labels = []

    for label in [line.strip() for line in lines if pattern.match(line)]:
        label = label.strip()
        label = label.split(" ")
        label = " ".join(label[1:])

        labels.append(label)

    return labels


In [6]:
def load_model_labels() -> list[str]:

    url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
    path = Path(basename(url))

    # Check if labels file already exists
    if not path.exists():
        response = requests.get(url)
        path.write_text(response.text)

    # Load labels
    with open(path, "r") as f:
        labels = json.load(f)

    return labels

In [7]:
def merge_labels() -> dict:

    # Map imagenet's labels with model's labels
    imagenet_a_labels = load_imagenet_a_labels()
    model_labels = load_model_labels()

    labels = {}

    for imagenet_a_index, item in enumerate(imagenet_a_labels):

        model_index = model_labels.index(item)

        labels[item] = {
            "imagenet-a": imagenet_a_index,
            "model": model_index
        }

    return labels

In [8]:
def load_model(model_name: str = "google/vit-base-patch16-384") -> ViTForImageClassification:

    # Load the pre-trained model
    return ViTForImageClassification.from_pretrained(model_name).to(DEVICE)

In [9]:
def load_dataset(resize: bool = True) -> torch.utils.data.dataloader.DataLoader:

    imagenet_a = "./dataset/imagenet-a/"

    # Prepare data transformations for the train loader
    transforms = [] if not resize else [T.Resize(SIZE)]
    transforms.append(T.ToTensor())
    transform = T.Compose(transforms)

    # Load data
    imagenet_a_dataset = torchvision.datasets.ImageFolder(root=imagenet_a, transform=transform)
    return torch.utils.data.DataLoader(imagenet_a_dataset, 1, shuffle=True, num_workers=8)

In [10]:
def classify(model: ViTForImageClassification, img: torch.Tensor, no_grad: bool = True) -> dict:

    # Use GPU if available
    img = img.to(DEVICE)

    # Perform inference
    model.eval()

    with torch.no_grad() if no_grad else nullcontext():
        outputs = model(img)

    # Extract probabilities from model's output logits
    results = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()

    return results

In [23]:
def elaborate_results(results: torch.Tensor) -> Union[dict, list]:

    # Load model's labels
    model_labels = load_model_labels()

    # ImageNet-A's labels
    imagenet_a_labels = load_imagenet_a_labels()

    if len(results.shape) == 1:
        results = [results]

    # Process results
    final_results = []

    for result in results:

        item_results = {
            "predicted": {},
            "results": {}
        }

        predicted = None

        for index, probability in enumerate(result):

            if model_labels[index] not in imagenet_a_labels:
                continue

            item_results["results"][index] = {
                "index": index,
                "label": model_labels[index],
                "probability": probability.item()
            }

            if predicted is None or predicted["probability"] < probability.item():
                predicted = item_results["results"][index]

        item_results["predicted"] = predicted

        final_results.append(item_results)

    return final_results if len(final_results) > 1 else final_results[0]


In [26]:
# Load model (only once)
model = load_model()

# Load data (only once)
data_loader = load_dataset()

# Extract merged labels
merged_labels = merge_labels()

In [None]:
# Evaluate the model (Accuracy: 18.37 %)
accuracy = 0

for index, img in enumerate(data_loader):

    # Get model prediction
    results = classify(model=model, img=img[0])
    results = elaborate_results(results=results)
    predicted, results = results["predicted"], results["results"]

    # Check if the predicted label exists inside the dataset labels
    if predicted["label"] in merged_labels:

        merged_label = merged_labels[predicted["label"]]

        # Check if the prediction was correct
        if merged_label["imagenet-a"] == img[1].item():
            accuracy = accuracy + 1

    print(f"Image {index+1} / {len(data_loader)} | Accuracy: {round((accuracy / (index + 1)) * 100, 2)}% ({accuracy} / {index + 1})")

accuracy = accuracy / len(data_loader)

In [None]:
# Supponendo che load_dataset, classify_image, merged_labels e il modello siano definiti altrove
accuracy_before = 0
accuracy_after = 0

# Load model (only once)
model = load_model()
original_model = deepcopy(model)

# Extract merged labels
merged_labels = merge_labels()

data_loader = load_dataset(resize=False)

# transformation = T.Compose([
#         T.Resize((500, 500)),
#         T.CenterCrop((384, 384)) ])

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(384, 384), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

resize_transformation = T.Compose([ T.Resize((384, 384)) ])

# Salva lo stato iniziale del modello
initial_state = model.state_dict().copy()

for index, img in enumerate(data_loader):
    # Ripristina lo stato iniziale del modello
    model = deepcopy(original_model)
    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    # Azzera i gradienti prima di calcolare i nuovi
    optimizer.zero_grad()

    # Immagine ridimensionata (384x384)
    img1 = resize_transformation(img[0])
    # Immagini con augmentation
    img2 = transforms(img[0])
    img3 = transforms(img[0])
    img4 = transforms(img[0])

    # Concatena immagini
    imgs = [img1, img2, img3, img4]
    input = torch.cat(imgs, dim=0)

    # Classificazione dell'immagine 1 prima delle augmentation
    results = classify(model=model, img=img1)
    results = elaborate_results(results=results)
    predicted, results = results["predicted"], results["results"]

    # Aggiorna accuracy della classificazione senza augmentation
    if predicted["label"] in merged_labels:
        merged_label = merged_labels[predicted["label"]]
        # Check if the prediction was correct
        if merged_label["imagenet-a"] == img[1].item():
            accuracy_before += 1
    predicted_before = predicted["label"]

    # Calcola gli output delle immagini
    output = model(input.to(DEVICE))

    # Combina le probabilità delle immagini
    probabilities = torch.nn.functional.softmax(output.logits, dim=-1).squeeze().to(DEVICE)

    # Calcolo entropia
    marginal = torch.mean(probabilities, dim=0).to(DEVICE)
    entropy = -torch.sum(marginal * torch.log(marginal)).to(DEVICE)
    entropy.backward()

    # Gradient step
    optimizer.step()

    # Classificazione dell'immagine 1 dopo le augmentation
    results = classify(model=model, img=img1)
    results = elaborate_results(results=results)
    predicted, results = results["predicted"], results["results"]

    #print(probabilities1 == probabilities2)
    # Aggiorna accuracy della classificazione con augmentation
    if predicted["label"] in merged_labels:
        merged_label = merged_labels[predicted["label"]]
        if merged_label["imagenet-a"] == img[1].item():
            accuracy_after += 1

    print(f"Image {index + 1} / {len(data_loader)} | Accuracy before: {round((accuracy_before / (index + 1)) * 100, 1)}% ({accuracy_before} / {index + 1}) | Accuracy after: {round((accuracy_after / (index + 1)) * 100, 1)}% ({accuracy_after} / {index + 1})", "\t --> Predicted:", predicted_before, "/", predicted["label"])