In [1]:
import numpy as np
import torch
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, classification_report, f1_score
from transformers import ViTForImageClassification, ViTImageProcessor
from datasets import load_dataset
from torch.utils.data import DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
dataset = load_dataset("zh-plus/tiny-imagenet", split="valid")

## ViT classification performance

In [4]:
model_id = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_id)
model.to(device)
model.eval()

processor = ViTImageProcessor.from_pretrained(model_id)

In [5]:
def preprocess_function(examples):
    images = [img.convert("RGB") for img in examples["image"]]
    inputs = processor(images, return_tensors="pt")
    inputs["labels"] = examples["label"]
    return inputs

tokenized_dataset = dataset.map(
    preprocess_function,
    batched=True,
    batch_size=32,
    remove_columns=dataset.column_names,
)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [6]:
tokenized_dataset.set_format("torch")
test_dataloader = DataLoader(tokenized_dataset, batch_size=64, shuffle=False)

In [7]:
all_preds = []
all_labels = []

for batch in tqdm(test_dataloader):
    pixel_values = batch['pixel_values'].to(device)
    labels = batch['labels']

    with torch.no_grad():
        outputs = model(pixel_values)
        logits = outputs.logits

    predictions = torch.argmax(logits, dim=-1).cpu().numpy()

    all_preds.extend(predictions)
    all_labels.extend(labels.numpy())

  0%|          | 0/157 [00:00<?, ?it/s]

In [73]:
# need to map Tiny Imagenet to Imagenet labels since there is 200 vs 1000 classes
# https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json

import json

with open('imagenet_class_index.json', 'r') as f:
    imagenet_index_data = json.load(f)

wnid_to_idx = {v[0]: int(k) for k, v in imagenet_index_data.items()}

tiny_imagenet_wnids = dataset.features['label'].names

tiny_to_vit_idx = []
for i, wnid in enumerate(tiny_imagenet_wnids):
    if wnid in wnid_to_idx:
        tiny_to_vit_idx.append(wnid_to_idx[wnid])
    else:
        print(wnid, 'not in map')
        tiny_to_vit_idx.append(-1)

mapping = torch.tensor(tiny_to_vit_idx)
mapping

n02666347 not in map
n03373237 not in map
n04465666 not in map
n04598010 not in map
n07056680 not in map
n07646821 not in map
n07647870 not in map
n07657664 not in map
n07975909 not in map
n08496334 not in map
n08620881 not in map
n08742578 not in map
n12520864 not in map
n13001041 not in map
n13652335 not in map
n13652994 not in map
n13719102 not in map
n14991210 not in map


tensor([  1,  25,  30,  32,  50,  61,  69,  71,  75,  76,  79, 105, 107, 109,
        113, 115, 122, 123, 128, 145, 146, 149, 187, 207, 208, 235, 267, 281,
        283, 285, 286, 291, 294, 301, 311, 313, 314, 315, 319, 323, 325, 329,
        338, 341, 345, 347, 349, 353, 354, 365, 367, 372, 386, 387,  -1, 400,
        406, 414, 421, 424, 425, 427, 430, 435, 436, 437, 438, 440, 445, 447,
        448, 457, 458, 463, 466, 467, 470, 471, 474, 480, 485, 492, 496, 500,
        508, 509, 511, 517, 525, 526, 532, 543, 557,  -1, 562, 565, 567, 568,
        570, 573, 576, 604, 605, 612, 614, 619, 621, 625, 627, 635, 645, 652,
        655, 675, 678, 682, 683, 687, 704, 707, 716, 720, 731, 734, 735, 737,
        739, 744, 747, 760, 761, 765, 768, 774, 779, 781, 786, 801, 806, 808,
        811, 815, 817, 821, 826, 837, 839, 842, 845, 849, 850, 853, 862,  -1,
        873, 874, 877, 879, 887, 888, 890, 899, 900, 909,  -1, 917,  -1, 924,
        928, 929,  -1,  -1,  -1, 932, 935, 938, 945, 951, 954, 9

In [74]:
y_true_tiny = torch.tensor(all_labels)
y_true_full = mapping[y_true_tiny].numpy()

y_true_valid = y_true_full[y_true_full != -1]
y_pred_valid = np.array(all_preds)[y_true_full != -1]

accuracy = accuracy_score(y_true_valid, y_pred_valid)
print(f'accuracy: {accuracy}')

f1 = f1_score(y_true_valid, y_pred_valid, average='weighted')
print(f"weighted f1: {f1}")

accuracy: 0.6291208791208791
weighted f1: 0.7241718171110046


These metrics seem low but it's because we are testing the 1000 set model against a 200 class test set.

## MobileNetv2 classification performance

In [45]:
from transformers import AutoModelForImageClassification, AutoImageProcessor

In [56]:
model_id = "google/mobilenet_v2_1.0_224"
model = AutoModelForImageClassification.from_pretrained(model_id)
model.to(device)
model.eval()

processor = AutoImageProcessor.from_pretrained(model_id)

preprocessor_config.json:   0%|          | 0.00/406 [00:00<?, ?B/s]

In [57]:
def preprocess_function(examples):
    images = [img.convert("RGB") for img in examples["image"]]
    inputs = processor(images, return_tensors="pt")
    inputs["labels"] = examples["label"]
    return inputs

tokenized_dataset = dataset.map(
    preprocess_function,
    batched=True,
    batch_size=32,
    remove_columns=dataset.column_names,
)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [58]:
tokenized_dataset.set_format("torch")
test_dataloader = DataLoader(tokenized_dataset, batch_size=64, shuffle=False)

In [59]:
all_preds2 = []
all_labels2 = []

for batch in tqdm(test_dataloader):
    pixel_values = batch['pixel_values'].to(device)
    labels = batch['labels']

    with torch.no_grad():
        outputs = model(pixel_values)
        logits = outputs.logits

    predictions = torch.argmax(logits, dim=-1).cpu().numpy()

    all_preds2.extend(predictions)
    all_labels2.extend(labels.numpy())

  0%|          | 0/157 [00:00<?, ?it/s]

In [75]:
# model.config.id2label has every class shifted up by 1
tiny_to_vit_idx = []
for i, wnid in enumerate(tiny_imagenet_wnids):
    if wnid in wnid_to_idx:
        tiny_to_vit_idx.append(wnid_to_idx[wnid] + 1)
    else:
        print(wnid, 'not in map')
        tiny_to_vit_idx.append(-1)

mapping = torch.tensor(tiny_to_vit_idx)
mapping

n02666347 not in map
n03373237 not in map
n04465666 not in map
n04598010 not in map
n07056680 not in map
n07646821 not in map
n07647870 not in map
n07657664 not in map
n07975909 not in map
n08496334 not in map
n08620881 not in map
n08742578 not in map
n12520864 not in map
n13001041 not in map
n13652335 not in map
n13652994 not in map
n13719102 not in map
n14991210 not in map


tensor([  2,  26,  31,  33,  51,  62,  70,  72,  76,  77,  80, 106, 108, 110,
        114, 116, 123, 124, 129, 146, 147, 150, 188, 208, 209, 236, 268, 282,
        284, 286, 287, 292, 295, 302, 312, 314, 315, 316, 320, 324, 326, 330,
        339, 342, 346, 348, 350, 354, 355, 366, 368, 373, 387, 388,  -1, 401,
        407, 415, 422, 425, 426, 428, 431, 436, 437, 438, 439, 441, 446, 448,
        449, 458, 459, 464, 467, 468, 471, 472, 475, 481, 486, 493, 497, 501,
        509, 510, 512, 518, 526, 527, 533, 544, 558,  -1, 563, 566, 568, 569,
        571, 574, 577, 605, 606, 613, 615, 620, 622, 626, 628, 636, 646, 653,
        656, 676, 679, 683, 684, 688, 705, 708, 717, 721, 732, 735, 736, 738,
        740, 745, 748, 761, 762, 766, 769, 775, 780, 782, 787, 802, 807, 809,
        812, 816, 818, 822, 827, 838, 840, 843, 846, 850, 851, 854, 863,  -1,
        874, 875, 878, 880, 888, 889, 891, 900, 901, 910,  -1, 918,  -1, 925,
        929, 930,  -1,  -1,  -1, 933, 936, 939, 946, 952, 955, 9

In [76]:
y_true_tiny = torch.tensor(all_labels2)
y_true_full = mapping[y_true_tiny].numpy()

y_true_valid = y_true_full[y_true_full != -1]
y_pred_valid = np.array(all_preds2)[y_true_full != -1]

accuracy = accuracy_score(y_true_valid, y_pred_valid)
print(f'accuracy: {accuracy}')

f1 = f1_score(y_true_valid, y_pred_valid, average='weighted')
print(f"weighted f1: {f1}")

accuracy: 0.19648351648351647
weighted f1: 0.2754208517960667
