In [1]:
import os

if not os.path.exists("weights"):
  os.mkdir("weights")
  !wget https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-B_16.npz -O weights/ViT-B_16.npz
  !wget https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz -O weights/ViT-L_16.npz

In [2]:
import timm
import torch as th
from torchvision import datasets
from torchvision import transforms as T
from tqdm.auto import tqdm

device = th.device("cuda" if th.cuda.is_available() else "cpu")

In [3]:
def get_accuracy(model, dataset, batch_size):
  dataloader = th.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

  model.eval()
  good = 0
  total = 0
  tqdm_loader = tqdm(dataloader, desc="Evaluating", unit="batches")
  for images, labels in tqdm_loader:
    images, labels = images.to(device), labels.to(device)
    logits = model(images)

    good  += th.argmax(logits, dim=1).eq(labels).sum().item()
    total += labels.size(0)

    tqdm_loader.set_description_str(f"Accuracy: {good / total:.2%}")

  return good / total

In [4]:
transforms = T.Compose([
    T.ToTensor(),
    T.Resize((384, 384)),
])

dataset = datasets.ImageNet(root="ImageNet", split="val", transform=transforms)

In [5]:
ViT = timm.create_model("vit_base_patch16_384")
timm.models.load_checkpoint(ViT, "weights/ViT-B_16.npz")
ViT.to(device)

accuracy_b_16 = get_accuracy(ViT, dataset, batch_size=4)
print(f"ViT-B_16 w/ Imagenet21k accuracy: {accuracy_b_16:.2%}")

Evaluating:   0%|          | 0/12500 [00:00<?, ?batches/s]

ViT-B_16 w/ Imagenet21k accuracy: 81.58%


In [6]:
ViT = timm.create_model("vit_large_patch16_384")
timm.models.load_checkpoint(ViT, "weights/ViT-L_16.npz")
ViT.to(device)

accuracy_l_16 = get_accuracy(ViT, dataset, batch_size=1)
print(f"ViT-L_16 w/ Imagenet21k accuracy: {accuracy_l_16:.2%}")

Evaluating:   0%|          | 0/50000 [00:00<?, ?batches/s]

ViT-L_16 w/ Imagenet21k accuracy: 82.73%
