### Example

Simple notebook example

In [12]:
import os
import requests
import torch
from json import load, dump 
from PIL import Image
from torch import inference_mode, softmax, squeeze
from torchvision.transforms import Compose, Resize, ToTensor
from model.vit import ViT

MODELS = [
    "ViT-B-16p-Imagenet1k",
    "ViT-B-32p-Imagenet1k",
    "ViT-L-16p-Imagenet1k",
    "ViT-L-32p-Imagenet1k"
]
 
model = ViT.build(MODELS[0])
model.eval()  
 
transform = Compose([
    Resize(model.image_size),
    ToTensor()
])
 
filename = "labels.txt" 
if os.path.exists(filename):
    with open(filename, "r") as file:
        labels = load(file)
else:
    url = "https://huggingface.co/eric-hermosis/ViT-Imagenet1k/resolve/main/labels.txt"
    response = requests.get(url)
    response.raise_for_status()
    labels = response.json()
    with open(filename, "w") as file:
        dump(labels, file) 

labels = [labels[str(index)] for index in range(len(labels))]


def predict_image(image_path, topk=3):
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image)

    with inference_mode():
        output = squeeze(model(input_tensor), 0)
        topk = torch.topk(softmax(output, dim=-1), k=topk)

    results = []
    for index, probability in zip(topk.indices.tolist(), topk.values.tolist()):
        results.append({"index": index, "label": labels[index], "probability": round(probability*100, 2)})
    return results

In [13]:
predictions = predict_image( "image.png")
for prediction in predictions:
    print(prediction['label'], ", ", prediction['probability'], "%")

giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca ,  99.74 %
lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens ,  0.13 %
sloth bear, Melursus ursinus, Ursus ursinus ,  0.02 %
