- **Date:** 2019-5-30
- **Author:** Zhanyuan Zhang
- **Purpose:** Reproduce the accuracy in the paper
- **References:** 
    - Create ImageFolder instance for ImageNet validation data set.: https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet
    - Create data loader: https://github.com/pytorch/examples/blob/97304e232807082c2e7b54c597615dc0ad8f6173/imagenet/main.py#L218



In [0]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from google.colab import drive
drive.mount('/content/gdrive')
os.chdir('/content/gdrive/My Drive/dl-security/') #Change the path to the directory that contains all code and data

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
#!pip install https://github.com/bethgelab/foolbox/archive/master.zip

In [4]:
from bagnets.utils import plot_heatmap, generate_heatmap_pytorch
from bagnets.utils import pad_image, convert2channel_last, imagenet_preprocess, extract_patches, bagnet_predict, compare_heatmap
from bagnets.utils import bagnet33_debug, plot_saliency, compute_saliency_map
from bagnets.utils import get_topk_acc, validate
from foolbox.utils import samples
import bagnets.pytorch
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import time
import os
img_path = "./ILSVRC2012_img_val"
root = "./"
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
if use_cuda:
    print(torch.cuda.get_device_name(0))

Tesla T4


## 1. Prepare data loader

In [0]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
imagenet_transform = transforms.Compose([transforms.Resize(256), 
                                          transforms.CenterCrop(224), 
                                          transforms.ToTensor(), 
                                          normalize])
imagenet_val = datasets.ImageNet('./', split='val', download=False, 
                                 transform=imagenet_transform)

In [6]:
val_loader_test = torch.utils.data.DataLoader(imagenet_val, batch_size=1)
print('number of validation images: {}'.format(len(val_loader_test)))

number of validation images: 50000


In [0]:
val_loader = torch.utils.data.DataLoader(
    imagenet_val,
    batch_size=128)

## 2. Check BagNet-33's validation accuracy on ImageNet

In [8]:
bagnet33 = bagnets.pytorch.bagnet33(pretrained=True).to(device)
bagnet33.eval()
print('Start validating Bagnet-33...')

Start validating Bagnet-33...


In [9]:
val_acc_33 = validate(val_loader, bagnet33, device)

Iteration 0, validation accuracy: 0.945, time: 1.6943097114562988s
Iteration 1, validation accuracy: 0.930, time: 0.009383440017700195s
Iteration 2, validation accuracy: 0.906, time: 0.012142658233642578s
Iteration 3, validation accuracy: 0.953, time: 0.008998394012451172s
Iteration 4, validation accuracy: 0.969, time: 0.009142875671386719s
Iteration 5, validation accuracy: 0.984, time: 0.009967803955078125s
Iteration 6, validation accuracy: 0.953, time: 0.009306669235229492s
Iteration 7, validation accuracy: 0.953, time: 0.009193658828735352s
Iteration 8, validation accuracy: 0.930, time: 0.009288787841796875s
Iteration 9, validation accuracy: 0.969, time: 0.011897087097167969s
Iteration 10, validation accuracy: 0.930, time: 0.009054422378540039s
Iteration 11, validation accuracy: 0.922, time: 0.00908660888671875s
Iteration 12, validation accuracy: 0.914, time: 0.008973360061645508s
Iteration 13, validation accuracy: 0.844, time: 0.0094451904296875s
Iteration 14, validation accuracy: 

## 3. Check BagNet-17's validation accuracy on ImageNet

In [10]:
bagnet17 = bagnets.pytorch.bagnet17(pretrained=True).to(device)
bagnet17.eval()
print('Start validating Bagnet-17...')

Start validating Bagnet-17...


In [11]:
val_acc_17 = validate(val_loader, bagnet17, device)

Iteration 0, validation accuracy: 0.914, time: 0.011829137802124023s
Iteration 1, validation accuracy: 0.867, time: 0.00889134407043457s
Iteration 2, validation accuracy: 0.859, time: 0.008825302124023438s
Iteration 3, validation accuracy: 0.938, time: 0.009114503860473633s
Iteration 4, validation accuracy: 0.953, time: 0.008967161178588867s
Iteration 5, validation accuracy: 0.961, time: 0.009353399276733398s
Iteration 6, validation accuracy: 0.922, time: 0.00931096076965332s
Iteration 7, validation accuracy: 0.906, time: 0.009261608123779297s
Iteration 8, validation accuracy: 0.867, time: 0.013370990753173828s
Iteration 9, validation accuracy: 0.938, time: 0.009995698928833008s
Iteration 10, validation accuracy: 0.875, time: 0.008860588073730469s
Iteration 11, validation accuracy: 0.898, time: 0.009206056594848633s
Iteration 12, validation accuracy: 0.828, time: 0.008755683898925781s
Iteration 13, validation accuracy: 0.750, time: 0.008812904357910156s
Iteration 14, validation accurac