In [None]:
# Init workspace
!rm -r dataset
!mkdir dataset

# Download dataset and extract it
!gdown 111HiEoEvZDdg1Y2EefI6n5dA_p4sMV4V
!mv imagenet-a.tar ./dataset
!tar -xf ./dataset/imagenet-a.tar
!mv imagenet-a ./dataset

# Cleanup
!rm ./dataset/imagenet-a.tar

# (optional) Upgrading pytorch for the latest augmentation functions
#!pip install --upgrade torch torchvision torchaudio

In [None]:
# Init workspace
!rm -r dataset
!mkdir dataset

# Download dataset and extract it
!gdown 1WKQGHjHUkIwZT0P2TpU9h-lY-6CnrsDd
!mv imagenetv2-matched-frequency.tar.gz ./dataset
!tar -xf ./dataset/imagenetv2-matched-frequency.tar.gz
!mv imagenetv2-matched-frequency-format-val ./dataset

# Cleanup
!rm ./dataset/imagenetv2-matched-frequency.tar.gz

# (optional) Upgrading pytorch for the latest augmentation functions
#!pip install --upgrade torch torchvision torchaudio

In [150]:
import json
import os
from os import listdir
from os.path import basename, isfile, join
from pathlib import Path
import requests
from contextlib import nullcontext
from copy import deepcopy
from typing import Union
from PIL import Image

import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
from torchvision.transforms import v2
from transformers import ViTForImageClassification, ViTImageProcessor

import warnings
warnings.filterwarnings("ignore")

In [138]:
# Use cuda if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SIZE = (224, 224)

In [140]:
def load_imagenet_v2_labels() -> list[int]:

    imagenet_v2 = "./dataset/imagenetv2-matched-frequency-format-val"

    labels = [int(f) for f in listdir(imagenet_v2) if not isfile(join(imagenet_v2, f))]
    labels.sort()

    return labels

In [141]:
def load_model_labels() -> list[str]:

    url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
    path = Path(basename(url))

    # Check if labels file already exists
    if not path.exists():
        response = requests.get(url)
        path.write_text(response.text)

    # Load labels
    with open(path, "r") as f:
        labels = json.load(f)

    return labels

In [142]:
def load_model(model_name: str = "google/vit-base-patch16-224") -> ViTForImageClassification:

    # Load the pre-trained model
    return ViTForImageClassification.from_pretrained(model_name).to(DEVICE)

In [151]:
class ImageNetV2(torch.utils.data.Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = load_imagenet_v2_labels() * 10
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):

        label = idx // 10

        img_folder = os.path.join(self.img_dir, str(label))
        img_path = [join(img_folder, f) for f in listdir(img_folder) if isfile(join(img_folder, f))][idx % (label if label != 0 else 1) - 1]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return image, label

In [144]:
def load_dataset(resize: bool = True) -> torch.utils.data.dataloader.DataLoader:

    imagenet_v2 = "./dataset/imagenetv2-matched-frequency-format-val"

    # Prepare data transformations for the train loader
    transforms = [] if not resize else [T.Resize(SIZE)]
    transforms.append(T.ToTensor())
    transform = T.Compose(transforms)

    # Load data
    imagenet_v2_dataset = ImageNetV2(annotations_file=[], img_dir=imagenet_v2, transform=transform)
    return torch.utils.data.DataLoader(imagenet_v2_dataset, 1, shuffle=True, num_workers=8)

In [145]:
def classify(model: ViTForImageClassification, img: torch.Tensor, no_grad: bool = True) -> dict:

    # Use GPU if available
    img = img.to(DEVICE)

    # Perform inference
    model.eval()

    with torch.no_grad() if no_grad else nullcontext():
        outputs = model(img)

    # Extract probabilities from model's output logits
    results = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()

    return results

In [146]:
def elaborate_results(results: torch.Tensor) -> Union[dict, list]:

    # Load model's labels
    model_labels = load_model_labels()

    if len(results.shape) == 1:
        results = [results]

    # Process results
    final_results = []

    for result in results:

        item_results = {
            "predicted": {},
            "results": {}
        }

        predicted = None

        for index, probability in enumerate(result):

            item_results["results"][index] = {
                "index": index,
                "label": model_labels[index],
                "probability": probability.item()
            }

            if predicted is None or predicted["probability"] < probability.item():
                predicted = item_results["results"][index]

        item_results["predicted"] = predicted

        final_results.append(item_results)

    return final_results if len(final_results) > 1 else final_results[0]


In [152]:
# # Load model (only once)
# model = load_model()

# # Load data (only once)
# data_loader = load_dataset()

# # Evaluate the model (Accuracy: 18.37 %)
# accuracy = 0

# for index, img in enumerate(data_loader):

#     # Get model prediction
#     results = classify(model=model, img=img[0])
#     results = elaborate_results(results=results)
#     predicted, results = results["predicted"], results["results"]

#     if img[1].item() == predicted["index"]:
#         accuracy = accuracy + 1

#     print(f"Image {index+1} / {len(data_loader)} | Accuracy: {round((accuracy / (index + 1)) * 100, 2)}% ({accuracy} / {index + 1})")

# accuracy = accuracy / len(data_loader)

In [None]:
# Supponendo che load_dataset, classify_image, merged_labels e il modello siano definiti altrove
accuracy_before = 0
accuracy_after = 0

# Load model (only once)
model = load_model()
original_model = deepcopy(model)

data_loader = load_dataset(resize=False)

# transformation = T.Compose([
#         T.Resize((500, 500)),
#         T.CenterCrop((384, 384)) ])

transforms = v2.Compose([
    v2.RandomResizedCrop(size=SIZE, antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

resize_transformation = T.Compose([ T.Resize(SIZE) ])

# Salva lo stato iniziale del modello
initial_state = model.state_dict().copy()

for index, img in enumerate(data_loader):
    # Ripristina lo stato iniziale del modello
    model = deepcopy(original_model)
    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    # Azzera i gradienti prima di calcolare i nuovi
    optimizer.zero_grad()

    # Immagine ridimensionata (384x384)
    img1 = resize_transformation(img[0])
    # Immagini con augmentation
    img2 = transforms(img[0])
    img3 = transforms(img[0])
    img4 = transforms(img[0])

    # Concatena immagini
    imgs = [img1, img2, img3, img4]
    input = torch.cat(imgs, dim=0)

    # Classificazione dell'immagine 1 prima delle augmentation
    results = classify(model=model, img=img1)
    results = elaborate_results(results=results)
    predicted, results = results["predicted"], results["results"]

    # Aggiorna accuracy della classificazione senza augmentation
    if img[1].item() == predicted["index"]:
        accuracy_before = accuracy_before + 1

    predicted_before = predicted["label"]

    # Calcola gli output delle immagini
    output = model(input.to(DEVICE))

    # Combina le probabilità delle immagini
    probabilities = torch.nn.functional.softmax(output.logits, dim=-1).squeeze().to(DEVICE)

    # Calcolo entropia
    marginal = torch.mean(probabilities, dim=0).to(DEVICE)
    entropy = -torch.sum(marginal * torch.log(marginal)).to(DEVICE)
    entropy.backward()

    # Gradient step
    optimizer.step()

    # Classificazione dell'immagine 1 dopo le augmentation
    results = classify(model=model, img=img1)
    results = elaborate_results(results=results)
    predicted, results = results["predicted"], results["results"]

    #print(probabilities1 == probabilities2)
    # Aggiorna accuracy della classificazione con augmentation
    if img[1].item() == predicted["index"]:
        accuracy_after = accuracy_after + 1

    label1 = f"Image {index + 1} / {len(data_loader)}"
    label2 = f"Accuracy before: {round((accuracy_before / (index + 1)) * 100, 1)}% ({accuracy_before} / {index + 1})"
    label3 = f"Accuracy after: {round((accuracy_after / (index + 1)) * 100, 1)}% ({accuracy_after} / {index + 1})"
    label4 = f"Diff: {round((accuracy_after / (index + 1)) * 100 - (accuracy_before / (index + 1)) * 100, 2)}%"

    print(f"{label1} | {label2} | {label3} | {label4}")