In [42]:
from torchvision.models import vit_b_16
from pathlib import Path
import torch
from transformers import AutoConfig, AutoModelForImageClassification
from torchvision import transforms
import PIL

root = Path().resolve().parent.parent
path = root / Path("data/output/classification/version_0/checkpoints/variant_b16_224.ckpt")

# Initialize with pretrained weights
model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=100,
    ignore_mismatched_sizes=True,
    image_size=224,
)

# Load checkpoint weights
ckpt = torch.load(path, map_location="cpu")["state_dict"]

# Remove prefix from key names
new_state_dict = {}
for k, v in ckpt.items():
    if k.startswith("net"):
        k = k.replace("net" + ".", "")
        new_state_dict[k] = v

model.load_state_dict(new_state_dict, strict=True)

img_path = root / Path("data/input/fgvc-aircraft-2013b/data/images/1617066.jpg")
image = PIL.Image.open(img_path).convert("RGB")

transforms_test = transforms.Compose(
    [
        transforms.Resize(
            (224, 224),
        ),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ]
)

image = transforms_test(image)

text_path = root / Path("data/input/fgvc-aircraft-2013b/data/variants.txt")
file = open(text_path, "r")
variants = file.readlines()
variants = [variant[:-1] for variant in variants]
file.close()

model.eval()
with torch.no_grad():
  outputs = model(image.unsqueeze(0))
logits = outputs.logits

prediction = logits.argmax(-1)
print(variants[prediction.item()])

top5_prob, top5_classes = torch.topk(logits, 5)
top5_prob = top5_prob[-1].numpy().tolist()
for i, x in enumerate(top5_classes[-1].numpy().tolist()):
    print(f"{variants[x]} ({top5_prob[i]})")

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


A380
A380 (12.680869102478027)
747-300 (11.449588775634766)
747-200 (9.584214210510254)
747-400 (7.507030487060547)
A310 (6.2847466468811035)


In [44]:
from transformers import ViTImageProcessor

processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
image_mean = processor.image_mean
image_std = processor.image_std
size = processor.size["height"]

print(image_mean, image_std, size)

[0.5, 0.5, 0.5] [0.5, 0.5, 0.5] 224
