In [1]:
# Init workspace
!mkdir dataset

# Download dataset and extract it
!gdown 111HiEoEvZDdg1Y2EefI6n5dA_p4sMV4V
!mv imagenet-a.tar ./dataset
!tar -xf ./dataset/imagenet-a.tar
!mv imagenet-a ./dataset

# Cleanup
!rm ./dataset/imagenet-a.tar

mkdir: cannot create directory ‘dataset’: File exists
Downloading...
From (original): https://drive.google.com/uc?id=111HiEoEvZDdg1Y2EefI6n5dA_p4sMV4V
From (redirected): https://drive.google.com/uc?id=111HiEoEvZDdg1Y2EefI6n5dA_p4sMV4V&confirm=t&uuid=08d06f71-a2aa-4f84-88f3-68f4feaff201
To: /content/imagenet-a.tar
100% 688M/688M [00:04<00:00, 157MB/s]
mv: cannot move 'imagenet-a' to './dataset/imagenet-a': Directory not empty


In [None]:
# Upgrading pytorch for the latest augmentation functions
#!pip install --upgrade torch torchvision torchaudio

In [15]:
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
import torchvision
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import v2

import json
from os.path import basename, join
from pathlib import Path
import requests

import re

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SIZE = (384, 384)

In [5]:
def load_imagenet_a_labels() -> list:

    imagenet_a = "./dataset/imagenet-a"

    with open(join(imagenet_a, "README.txt"), "r") as f:
        lines = f.readlines()

    pattern = re.compile(r"n\d{8}\s(.+)")

    labels = []

    for index, label in enumerate([line.strip() for line in lines if pattern.match(line)]):
        label = label.strip()
        label = label.split(" ")
        label = " ".join(label[1:])

        labels.append(label)

    return labels


In [6]:
def load_model_labels() -> list[str]:

    url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
    path = Path(basename(url))

    # Check if labels file already exists
    if not path.exists():
        response = requests.get(url)
        path.write_text(response.text)

    # Load labels
    with open(path, "r") as f:
        labels = json.load(f)

    return labels

In [7]:
def merge_labels() -> dict:

    imagenet_a_labels = load_imagenet_a_labels()
    model_labels = load_model_labels()

    labels = {}

    for imagenet_a_index, item in enumerate(imagenet_a_labels):

        model_index = model_labels.index(item)

        labels[item] = {
            "imagenet-a": imagenet_a_index,
            "model": model_index
        }

    return labels

In [8]:
def load_model(model_name: str = "google/vit-base-patch16-384") -> ViTForImageClassification:

    # Load the pre-trained model
    model = ViTForImageClassification.from_pretrained(model_name).to(DEVICE)

    return model

In [9]:
def load_dataset(resize: bool = True):

    imagenet_a = "./dataset/imagenet-a/"

    # Prepare data transformations for the train loader
    transforms = [] if not resize else [T.Resize(SIZE)]
    transforms.append(T.ToTensor())
    transform = T.Compose(transforms)

    # Load data
    imagenet_a_dataset = torchvision.datasets.ImageFolder(root=imagenet_a, transform=transform)
    return torch.utils.data.DataLoader(imagenet_a_dataset, 1, shuffle=True, num_workers=8)

In [10]:
def classify_image(model: ViTForImageClassification, img: torch.Tensor) -> tuple[dict, dict]:

    # Use GPU if available
    img = img.to(DEVICE)

    # Perform inference
    model.eval()

    with torch.no_grad():
        outputs = model(img)

    # Extract probabilities from model's output logits
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()

    labels = load_model_labels()
    results = {}

    for index, probability in enumerate(probabilities):
        results[index] = {
            "index": index,
            "label": labels[index],
            "probability": probability.item()
        }

    predicted = probabilities.argmax(-1).item()
    predicted = {
        "index": predicted,
        "label": labels[predicted],
        "probability": probabilities[predicted].item()
    }

    return predicted, results

In [11]:
# Load model (only once)
model = load_model()

# Load data (only once)
data_loader = load_dataset()

# Extract merged labels
merged_labels = merge_labels()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [12]:
# Evaluate the model (Accuracy: 18.37 %)
# accuracy = 0

# for index, img in enumerate(data_loader):

#     # Get model prediction
#     predicted, results = classify_image(model=model, img=img[0])

#     # Check if the predicted label exists inside the dataset labels
#     if predicted["label"] in merged_labels:

#         merged_label = merged_labels[predicted["label"]]

#         # Check if the prediction was correct
#         if merged_label["imagenet-a"] == img[1].item():
#             accuracy = accuracy + 1

#     print(f"Image {index+1} / {len(data_loader)} | Accuracy: {round((accuracy / (index + 1)) * 100, 2)}% ({accuracy} / {index + 1})")

# accuracy = accuracy / len(data_loader)

In [None]:
# MEMO

data_loader = load_dataset(resize=False)

add_augmentation = lambda transformations: T.Compose(transformations)

augmentations = [
    add_augmentation([T.RandomCrop((100, 100)), T.Resize(SIZE)]),
    add_augmentation([v2.RandomChannelPermutation(), T.Resize(SIZE)])
]

augmentations = [
    T.CenterCrop((192, 192)),
    v2.RandomRotation((0, 360)),
    v2.RandomChannelPermutation(),
    v2.RandomGrayscale(),
    v2.RandomAutocontrast(),
    v2.RandomPerspective()
]

for index, img in enumerate(data_loader):

    augmentation = T.Compose([
        v2.RandomOrder(augmentations),
        T.Resize(SIZE)
    ])

    processed_img = img[0].detach().clone()
    processed_img = augmentation(processed_img)

    torchvision.utils.save_image(processed_img, "test.png")
    # Get model prediction
    predicted, results = classify_image(model=model, img=processed_img)

    input()



  self.pid = os.fork()





In [22]:
c = v2.RandomChoice([v2.RandomEqualize])
c

RandomChoice(transforms=[<class 'torchvision.transforms.v2._color.RandomEqualize'>], p=[1.0])

In [17]:
dir(v2)

['AugMix',
 'AutoAugment',
 'AutoAugmentPolicy',
 'CenterCrop',
 'ClampBoundingBoxes',
 'ColorJitter',
 'Compose',
 'ConvertBoundingBoxFormat',
 'ConvertImageDtype',
 'CutMix',
 'ElasticTransform',
 'FiveCrop',
 'GaussianBlur',
 'Grayscale',
 'Identity',
 'InterpolationMode',
 'JPEG',
 'Lambda',
 'LinearTransformation',
 'MixUp',
 'Normalize',
 'PILToTensor',
 'Pad',
 'RGB',
 'RandAugment',
 'RandomAdjustSharpness',
 'RandomAffine',
 'RandomApply',
 'RandomAutocontrast',
 'RandomChannelPermutation',
 'RandomChoice',
 'RandomCrop',
 'RandomEqualize',
 'RandomErasing',
 'RandomGrayscale',
 'RandomHorizontalFlip',
 'RandomInvert',
 'RandomIoUCrop',
 'RandomOrder',
 'RandomPerspective',
 'RandomPhotometricDistort',
 'RandomPosterize',
 'RandomResize',
 'RandomResizedCrop',
 'RandomRotation',
 'RandomShortestSize',
 'RandomSolarize',
 'RandomVerticalFlip',
 'RandomZoomOut',
 'Resize',
 'SanitizeBoundingBoxes',
 'ScaleJitter',
 'TenCrop',
 'ToDtype',
 'ToImage',
 'ToPILImage',
 'ToPureTensor