- **Date:** 2019-5-29
- **Author:** Zhanyuan Zhang
- **Purpose:** Reproduce the accuracy in the paper
- **References:** 
    - Create ImageFolder instance for ImageNet validation data set.: https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet
    - Create data loader: https://github.com/pytorch/examples/blob/97304e232807082c2e7b54c597615dc0ad8f6173/imagenet/main.py#L218



In [0]:
%load_ext autoreload
%autoreload 2

In [4]:
from bagnets.utils import plot_heatmap, generate_heatmap_pytorch
from bagnets.utils import pad_image, convert2channel_last, imagenet_preprocess, extract_patches, bagnet_predict, compare_heatmap
from bagnets.utils import bagnet33_debug, plot_saliency, compute_saliency_map
from foolbox.utils import samples
import bagnets.pytorch
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import time
import os
import cv2
img_path = "./ILSVRC2012_img_val"
root = "./"
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
if use_cuda:
    print(torch.cuda.get_device_name(0))

Tesla T4


In [0]:
# load pretrained model
bagnet33 = bagnets.pytorch.bagnet33(pretrained=True).to(device)
bagnet17 = bagnets.pytorch.bagnet17(pretrained=True).to(device)

In [0]:
def get_topk_acc(y_hat, y):
    """ Compute top-k accuracy
    Input:
    - y_hat: numpy array with shape (batchsize, K). top-k prediction classes
    - y: numpy array with shape(batchsize, ). target classes
    Return: top-k accuracy
    """
    is_correct = [y[i] in y_hat[i] for i in range(y.size)]
    is_correct = np.array(is_correct)
    return is_correct.sum()/y.size

In [0]:
bs = 20
original, label = samples(dataset='imagenet', index=1, batchsize=bs, shape=(224, 224), data_format='channels_first')

# preprocess sample image
images = imagenet_preprocess(original)
images = torch.from_numpy(images).to(device)

In [8]:
label

array([559, 438, 990, 949, 853, 609, 609, 915, 455, 541, 630, 741, 471,
       129,  99, 251,  22, 317, 305, 243])

In [9]:
y_hat = bagnet_predict(bagnet33, images, k=5)
acc = get_topk_acc(y_hat, label)
print('top-5 accuracy on 20 ImageNet samples from foolbox: {}'.format(acc))

top-5 accuracy on 20 ImageNet samples from foolbox: 1.0


In [10]:
y_hat

array([[559, 694, 532, 554, 540],
       [438, 441, 680, 647, 720],
       [990, 988, 924, 937, 997],
       [949, 927, 928, 644, 957],
       [853, 912, 425, 716, 825],
       [609, 654, 656, 436, 757],
       [751, 609, 817, 511, 479],
       [915, 672, 410, 879, 660],
       [455, 539, 911, 658, 735],
       [541, 423, 576, 424, 653],
       [636, 630, 748, 893, 709],
       [741, 406, 497, 668, 884],
       [471, 703, 506, 535, 835],
       [129, 132, 733,  23, 127],
       [ 99, 100,   9,  23,   8],
       [251, 246, 290, 275, 210],
       [ 22,  80, 128,  21, 146],
       [317, 316, 113, 311, 307],
       [302, 305, 304, 300, 306],
       [243, 242, 163, 159, 246]])

In [11]:
label

array([559, 438, 990, 949, 853, 609, 609, 915, 455, 541, 630, 741, 471,
       129,  99, 251,  22, 317, 305, 243])

In [0]:
def cvload(path):
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
imagenet_transform = transforms.Compose([transforms.Resize(256), 
                                          transforms.CenterCrop(224), 
                                          transforms.ToTensor(), 
                                          normalize])
imagenet_val = datasets.ImageNet('./', split='val', download=False, 
                                 transform=imagenet_transform)
val_loader = torch.utils.data.DataLoader(
    imagenet_val,
    batch_size=128)

In [0]:
def validate(val_loader, model, acc_fn, device):
    # switch to evaluate mode
    model.eval()
    total_iter = len(val_loader)
    cum_acc = 0
    with torch.no_grad():
        start = time.time()
        for i, (images, target) in enumerate(val_loader):
            images, target = images.to(device), target.to(device)
            tic = time.time()
            logits = model(images)
            tac = time.time()
            # measure accuracy
            p = torch.nn.Softmax(dim=1)(logits)
            _, y_hat = torch.topk(p, k=5, dim=1)
            acc = acc_fn(y_hat.cpu().numpy(), target.cpu().numpy())
            cum_acc += acc

            print('Iteration {}, validation accuracy: {:.3f}, time: {}s'.format(i, acc, tac-tic))
    end = time.time()
    val_acc = cum_acc / total_iter
    print('Validation accuracy: {:.3f}, time: {}s'.format(val_acc, end-start))
    return val_acc

In [14]:
val_acc_33 = validate(val_loader, bagnet33, get_topk_acc, device)

Iteration 0, validation accuracy: 0.930, time: 0.016033411026000977s
Iteration 1, validation accuracy: 0.930, time: 0.007769346237182617s
Iteration 2, validation accuracy: 0.945, time: 0.0077972412109375s
Iteration 3, validation accuracy: 0.953, time: 0.007807016372680664s
Iteration 4, validation accuracy: 0.906, time: 0.015614509582519531s
Iteration 5, validation accuracy: 0.922, time: 0.00794363021850586s
Iteration 6, validation accuracy: 0.938, time: 0.007865667343139648s
Iteration 7, validation accuracy: 0.898, time: 0.008091926574707031s
Iteration 8, validation accuracy: 0.867, time: 0.007904052734375s
Iteration 9, validation accuracy: 0.812, time: 0.008068561553955078s
Iteration 10, validation accuracy: 0.828, time: 0.007848978042602539s
Iteration 11, validation accuracy: 0.828, time: 0.008934497833251953s
Iteration 12, validation accuracy: 0.852, time: 0.009186744689941406s
Iteration 13, validation accuracy: 0.805, time: 0.009407520294189453s
Iteration 14, validation accuracy: 0

In [15]:
val_acc_17 = validate(val_loader, bagnet17, get_topk_acc, device)

Iteration 0, validation accuracy: 0.906, time: 0.008053064346313477s
Iteration 1, validation accuracy: 0.898, time: 0.007922172546386719s
Iteration 2, validation accuracy: 0.914, time: 0.007990598678588867s
Iteration 3, validation accuracy: 0.859, time: 0.007681608200073242s
Iteration 4, validation accuracy: 0.852, time: 0.008299112319946289s
Iteration 5, validation accuracy: 0.922, time: 0.008082151412963867s
Iteration 6, validation accuracy: 0.883, time: 0.007938146591186523s
Iteration 7, validation accuracy: 0.836, time: 0.007857561111450195s
Iteration 8, validation accuracy: 0.758, time: 0.007893800735473633s
Iteration 9, validation accuracy: 0.789, time: 0.007750511169433594s
Iteration 10, validation accuracy: 0.742, time: 0.008121013641357422s
Iteration 11, validation accuracy: 0.766, time: 0.008245229721069336s
Iteration 12, validation accuracy: 0.750, time: 0.008126974105834961s
Iteration 13, validation accuracy: 0.727, time: 0.00885152816772461s
Iteration 14, validation accura