In [1]:
import torch
from torch.autograd import Variable
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import transforms
import ipdb
import sys, glob
import numpy as np
from PIL import Image
import time
import os
from robustness_lib.robustness import model_utils, datasets
from user_constants import DATA_PATH_DICT

import warnings
warnings.filterwarnings("ignore")

print('Imports done')

Imports done


In [2]:
# Constants
DATA = 'ImageNet'  # Choices: ['CIFAR', 'ImageNet', 'RestrictedImageNet']
BATCH_SIZE = 1000
# BATCH_SIZE = 2
NUM_WORKERS = 8
NOISE_SCALE = 20

DATA_SHAPE = 32 if DATA == 'CIFAR' else 224  # Image size (fixed for dataset)
REPRESENTATION_SIZE = 2048  # Size of representation vector (fixed for model)

In [3]:
# Load dataset
dataset_function = getattr(datasets, DATA)
dataset = dataset_function(DATA_PATH_DICT[DATA])
_, test_loader = dataset.make_loaders(workers=NUM_WORKERS,
                                      batch_size=BATCH_SIZE // 2,
                                      data_aug=False, only_val=True)
data_iterator = enumerate(test_loader)
print('Data iterator created')

==> Preparing dataset imagenet..
Data iterator created


In [4]:
# Load model
model_kwargs = {
    'arch': 'resnet50',
    'dataset': dataset,
    'resume_path': f'./models/{DATA}.pt'
}
model_kwargs['state_dict_path'] = 'model'
model, _ = model_utils.make_and_restore_model(**model_kwargs)
model.eval()
for p in model.parameters():
    p.requires_grad = False
print('Model created')

=> loading checkpoint './models/ImageNet.pt'
=> loaded checkpoint './models/ImageNet.pt' (epoch 105)
Model created


In [5]:
correct = 0
im_count = 0
for idx, (im, targ, path) in data_iterator:
    im_count += im.shape[0]
    targ = targ.cpu().numpy()
    log, _ = model(im) 
    probs = F.softmax(log, dim=1).cpu()

    labels = torch.argmax(probs, dim=-1).numpy()
    correct += np.size(np.where((targ - labels) == 0))
    
    if idx%20 == 0:
        print(f'Batch is: {idx}')
        print(f'Correct images are {correct}/{im_count}')

print(f'Final results - Correct images are {correct}/{im_count}')


Batch is: 0
Correct images are 14/500
Batch is: 20
Correct images are 335/10500
Batch is: 40
Correct images are 650/20500
Batch is: 60
Correct images are 981/30500
Batch is: 80
Correct images are 1240/40500
Final results - Correct images are 1555/50000


In [6]:
print(f'Percentage of correct images are {correct*100/im_count}%')

Percentage of correct images are 3.11%
