In [None]:
pip install -q git+https://github.com/huggingface/transformers.git

In [1]:
from transformers import ViTForImageClassification
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTFeatureExtractor, ViTImageProcessor

from pathlib import Path as path
from PIL import Image
import requests
import csv

In [2]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0')

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-384')
model.eval()
model.to(device)
print()




In [3]:
class ImageNetDataset(Dataset):
    
    def __init__(self, folder_path, sample_tuple):
        self.folder_path = folder_path
        self.sample_tuple = sample_tuple
        # self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-384')
        self.feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-384')

    def __len__(self):
        return len(self.sample_tuple)

    def __getitem__(self, idx):
        img_path, label = self.sample_tuple[idx]

        img = Image.open(path.joinpath(self.folder_path, img_path)).convert('RGB')
        encoding = self.feature_extractor(images=img, return_tensors="pt")

        return encoding['pixel_values'][0], label
        
    @classmethod
    def from_path(cls, folder_path, split):
        c2i, i2c = {}, {}
        
        with open(path.joinpath(folder_path, 'LOC_synset_mapping.txt')) as f:
            for i, line in enumerate(f.readlines()):
                segments = line.strip().split(' ')
                label_idx = segments[0]
                c2i[label_idx] = i
                i2c[i] = label_idx

        def _read_path(row):
            sample_id = row[0]
            strings = sample_id.split('_')
            split = strings[1] if strings[0] == 'ILSVRC2012' else 'train'
        
            file_path = path('ILSVRC', 'Data', 'CLS-LOC')
            if split == 'train':
                file_path = path.joinpath(file_path, strings[0], strings[1], f'{sample_id}.JPEG')
            else:
                file_path = path.joinpath(file_path, split, f'{sample_id}.JPEG')

            # label_name, *bboxes = row[1].strip().split(' ')
            label_name = row[1].strip().split(' ')[0]
            label_id = c2i[label_name]

            return (file_path, label_id)

        with open(path.joinpath(folder_path, 'LOC_val_solution.csv')) as csv_file:
            dt = csv.reader(csv_file)
            next(dt)  # skip header row
            sample_tuple = list(map(_read_path, dt))
            # for a in dt:
            #     print(a)
            #     print(_read_path(a))
            #     break
        # print(sample_tuple[0])
        return cls(folder_path, sample_tuple)
        

val_ds = ImageNetDataset.from_path(
    folder_path=path.joinpath(path.home(), 'datasets', 'ImageNet'),
    split='val'
)
# for i, sample in enumerate(val_ds):
#     print(f'\r{i} {sample[0].shape}', end='')
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False)

In [4]:
count, total = 0, 0
for i, (images, labels) in enumerate(val_loader):
    print(f'\r{i}\{len(val_loader)}', end='')
    with torch.no_grad():
        outputs = model(images.to(device))
    pred = outputs.logits.argmax(-1)
    acc = labels.eq(pred.cpu())
    count += acc.sum().item()
    total += labels.shape[0]

print(count / total)


3124\31250.8391


In [5]:
print(count / total)

0.8391
