In [1]:
import os
import sys

# Add the project root directory to the Python path
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

In [2]:
import torch
from transformers import ViTForImageClassification

from src.make_dataset import vit_transform
from src.process import open

  from .autonotebook import tqdm as notebook_tqdm


# Load Icon Classifier

In [3]:
model = ViTForImageClassification.from_pretrained("../models/results/final_model")

# Inference

In [4]:
def tensors_from_imgs(img_paths):
    img_tensors = [vit_transform(open(img_path)) for img_path in img_paths]
    return torch.stack(img_tensors)

In [5]:
# Example images
img_paths = [
    "../data/raw/wikimedia/png/480px-ISO_7000_-_Ref-No_0246.svg.png", 
    "../data/raw/wikimedia/png/480px-ISO_7000_-_Ref-No_0082.svg.png",
    "../data/raw/wikimedia/png/480px-ISO_7000_-_Ref-No_0098.svg.png"  # Example not in 'selected'
]
img_tensors = tensors_from_imgs(img_paths)

In [6]:
# Perform inference
with torch.no_grad():  # Disable gradient computation during inference
    outputs = model(img_tensors)
    logits = outputs.logits

In [7]:
# Convert logits to probabilities
probs = torch.softmax(logits, dim=-1)

# Get predicted classes for the batch
predicted_class_idx = logits.argmax(dim=-1)
predicted_class_labels = [model.config.id2label[i] for i in predicted_class_idx.tolist()]
predicted_class_probs = [probs[i, j].item() for i,j in enumerate(predicted_class_idx.tolist())]

In [8]:
predicted_class_labels

['0246', '0082', '0087']

In [9]:
predicted_class_probs

[0.9996232986450195, 0.9991819262504578, 0.21028754115104675]