In [1]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vgg16, resnet152, alexnet
import torch

import os
import json
from PIL import Image
import numpy as np

In [2]:
args = {"batch_size": 16}
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


In [3]:
class ImageNetDataset(Dataset):
    def __init__(self, image_dir, annotations_file, transformations = None, device = 'cpu'):
        self.image_dir = image_dir
        self.images = os.listdir(self.image_dir)
        self.annotations = json.load(open(annotations_file))
        self.transformations = transformations
        self.device = device

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        img = Image.open(img_path).convert('RGB')
        label = int(self.annotations[self.images[idx].split('.')[0]])
        if self.transformations:
            img = self.transformations(img)
        img = img.to(self.device)
        return img, label


preprocessing = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
])
imagenetValDataset = ImageNetDataset(image_dir = './data/imagenet/ILSVRC/Data/CLS-LOC/val/', 
                                    annotations_file='./data/imagenet/ILSVRC/imagnet_classes_val.json', 
                                    transformations=preprocessing, 
                                    device = device)

imagenetValDataloader = torch.utils.data.DataLoader(imagenetValDataset, batch_size=args['batch_size'], shuffle=True)


In [4]:
model = resnet152(pretrained = True)

In [6]:
from tqdm import tqdm
accuracy = 0
count = 0
for batch_idx, (img, label) in enumerate(tqdm(imagenetValDataloader)):
    out = model(img).detach()
    count += out.shape[0]
    accuracy += np.sum((label == torch.argmax(out, dim = 1)).numpy()*1)
    if(batch_idx % 50 == 0):
        print("Evaluated {}/{}: Accuracy: {:.3f}".format(str(count), 
                                                         str(len(imagenetValDataset)), float(accuracy)/count))

  0%|▍                                                                                                                    | 11/3125 [00:46<3:38:02,  4.20s/it]

Evaluated 176/50000: Accuracy: 0.739


  1%|█▍                                                                                                                   | 39/3125 [02:41<3:33:36,  4.15s/it]


KeyboardInterrupt: 