- **Date:** 2019-5-31
- **Author:** Zhanyuan Zhang
- **Purpose:** 
    1. Make predictions by properly using patches.
    1. Generate low-resolution Bagnet-33 heatmap.

In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
import os
from google.colab import drive
drive.mount('/content/gdrive')
os.chdir('/content/gdrive/My Drive/dl-security/') #Change the path to the directory that contains all code and data

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
#!pip install https://github.com/bethgelab/foolbox/archive/master.zip

In [0]:
from bagnets.utils import plot_heatmap, generate_heatmap_pytorch
from bagnets.utils import pad_image, convert2channel_last, imagenet_preprocess, extract_patches, bagnet_predict, compare_heatmap
from bagnets.utils import bagnet33_debug, plot_saliency, compute_saliency_map
from bagnets.utils import get_topk_acc, validate
from bagnets.utils import bagnet33_debug
from foolbox.utils import samples
import bagnets.pytorch
from bagnets.pytorch import Bottleneck, BagNet
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import time
import os
import cv2
img_path = "./ILSVRC2012_img_val"
root = "./"
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
if use_cuda:
    print(torch.cuda.get_device_name(0))

Tesla T4


In [0]:
# Bagnet-33 outputing logits given by patches
bagnet33_patch = bagnets.pytorch.bagnet33(pretrained=True, avg_pool=False).to(device)
bagnet33_patch.eval()

# Bagnet-33 outputing class logits
bagnet33_avg = bagnets.pytorch.bagnet33(pretrained=True, avg_pool=True).to(device)
bagnet33_avg.eval()

bs = 20
original, labels = samples(dataset='imagenet', index=1, batchsize=bs, shape=(224, 224), data_format='channels_first')
images = imagenet_preprocess(original)
images, targets = torch.from_numpy(images).to(device), torch.from_numpy(labels).to(device)

## 1. Prediction by patches

In [0]:
def bagnet_patch_predict(bagnet, images, k=1, return_class=True):
    """bagnet makes top-k predictions based on patches
    Input:
    - bagnet (pytorch model): bagnet without average pooling
    - images (pytorch tensor): a batch of images
    - k (int): top-k prediction
    - return_class (bool): if true, return class index; otherwise return class evidence
    Output: (numpy array): prediction
    """
    with torch.no_grad():
        logits = bagnet(images)
    logits = logits.permute(0, 3, 1, 2)
    N = logits.shape[0]
    logits = logits.view(N, 1000, -1)
    avg_logits = torch.mean(logits, dim=2)
    values, indices = torch.topk(avg_logits, k, dim=1)
    if return_class:
        return indices.cpu().numpy()
    else:
        return values.cpu().numpy()

In [0]:
y_hat = bagnet_patch_predict(bagnet33_patch, images, k=5)
y_hat

array([[559, 765, 612, 532, 428],
       [438, 441, 907, 711, 737],
       [955,  47,  46,  39, 990],
       [949, 989, 927,   7, 644],
       [853, 425, 663, 915, 497],
       [609, 656, 654, 757, 864],
       [817, 751, 511, 479, 717],
       [915, 672, 879, 660, 853],
       [455, 560, 555, 856, 584],
       [541, 542, 423, 820, 822],
       [748, 630, 636, 893, 464],
       [668, 741, 406, 497, 663],
       [471, 506, 535, 821, 835],
       [602, 129, 733, 456, 134],
       [ 83,  85,  82,   8, 135],
       [251, 290, 288, 293, 246],
       [ 22,  21,  80, 128, 146],
       [317, 307, 304, 316, 308],
       [305, 302, 304, 306, 300],
       [243, 242, 180, 163, 159]])

In [0]:
get_topk_acc(y_hat, labels)

0.9

In [0]:
y_hat2 = bagnet_predict(bagnet33_avg, images, k=5)
y_hat2

array([[559, 765, 612, 532, 428],
       [438, 441, 907, 711, 737],
       [955,  47,  46,  39, 990],
       [949, 989, 927,   7, 644],
       [853, 425, 663, 915, 497],
       [609, 656, 654, 757, 864],
       [817, 751, 511, 479, 717],
       [915, 672, 879, 660, 853],
       [455, 560, 555, 856, 584],
       [541, 542, 423, 820, 822],
       [748, 630, 636, 893, 464],
       [668, 741, 406, 497, 663],
       [471, 506, 535, 821, 835],
       [602, 129, 733, 456, 134],
       [ 83,  85,  82,   8, 135],
       [251, 290, 288, 293, 246],
       [ 22,  21,  80, 128, 146],
       [317, 307, 304, 316, 308],
       [305, 302, 304, 306, 300],
       [243, 242, 180, 163, 159]])

In [0]:
get_topk_acc(y_hat2, labels)

0.9

## 2. Heatmap by Bagnet-33

In [0]:
with torch.no_grad():
    patch_logits = bagnet33_patch(images)

In [0]:
def get_heatmap(bagnet, images, targets):
    """Generates low-resolution heatmap for Bagnet-33
    Input:
    - bagnet (pytorch): Bagnet without average pooling
    - images (pytorch tensor): images
    - targets (int list): for which class the heatmap is computed for
    Output: (numpy array) heatmap
    """
    with torch.no_grad():
        patch_logits = bagnet(images)
    patch_logits = patch_logits.permute([0, 3, 1, 2]).cpu().numpy()
    N = images.shape[0]
    heatmaps = np.zeros((N, 224, 224))
    for z in range(N):
        patch_target_logits = patch_logits[z, targets[z], :, :]
        for p, i in enumerate(range(33, 224, 8)):
            for q, j in enumerate(range(33, 224, 8)):
                patch = np.full((33, 33), patch_target_logits[p, q])
                heatmaps[z, i-33:i, j-33:j] = patch
    return heatmaps

In [0]:
heatmaps = get_heatmap(bagnet33_patch, images, labels)
for i in range(20):
    original_image = convert2channel_last(original[i])/255.
    fig = plt.figure(figsize=(8, 4))
    ax = plt.subplot(121)
    ax.set_title('original')
    plt.imshow(original_image)
    plt.axis('off')

    ax = plt.subplot(122)
    ax.set_title('heatmap')
    plot_heatmap(heatmaps[i], original_image, ax, dilation=0.5, percentile=99, alpha=.25)
    plt.axis('off')

Output hidden; open in https://colab.research.google.com to view.