In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from torchvision import models
import os

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_dir = '../data/'
test_dir = 'normal_test'
classes = sorted(os.listdir(data_dir + test_dir))
batch_size = 64

In [3]:
mean = [0.44947562, 0.46524084, 0.40037745]
std = [0.18456618, 0.16353698, 0.20014246]

data_transforms = {
        'test': transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)])}

test_images = datasets.ImageFolder(os.path.join(data_dir, test_dir),
                    data_transforms['test'])

test_dataloader = DataLoader(test_images, batch_size=batch_size, shuffle=False, num_workers=4)

In [4]:
model = models.alexnet()
model.classifier[6] = nn.Linear(4096, 10)
model.load_state_dict(torch.load('alexnet_pretrained.model', map_location=str(device)))
model.eval()

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace)
    (3): Dropout(p=0.5)
    (4): Linear(in_features=4096, out_feature

In [5]:
def image_show(img):
    img = np.transpose(img.numpy(),(1,2,0))
    img = np.asarray(std).mean() * img + np.asarray(mean).mean()
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    plt.pause(0.001)

In [8]:
import math
import copy
import seaborn as sns
from PIL import Image

def occlusion(image, occluding_size, occluding_stride, model, classes, groundTruth):
    img = np.copy(image)
    height, width, _ = img.shape
    output_height = int(math.ceil((height-occluding_size) / occluding_stride + 1))
    output_width = int(math.ceil((width-occluding_size) / occluding_stride + 1))
    occluded_images = []
    for h in range(output_height):
        for w in range(output_width):
            #occluder region
            h_start = h * occluding_stride
            w_start = w * occluding_stride
            h_end = min(height, h_start + occluding_size)
            w_end = min(width, w_start + occluding_size)
            
            input_image = copy.copy(img)
            input_image[h_start:h_end,w_start:w_end,:] =  0
            occluded_images.append(transforms.ToTensor()(Image.fromarray(input_image)))
            
    L = np.empty(output_height * output_width)
    L.fill(groundTruth)
    L = torch.from_numpy(L)
    tensor_images = torch.stack([img for img in occluded_images])
    dataset = torch.utils.data.TensorDataset(tensor_images, L) 
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers=8) 

    heatmap=np.empty(0)
    model.eval()
    for data in dataloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        heatmap = np.concatenate((heatmap,outputs[0:outputs.size()[0],groundTruth].data.numpy()))
        
    return heatmap.reshape((output_height, output_width))

In [9]:
patch_size = 8
patch_stride = 8

In [None]:
with torch.no_grad():
    for data in test_dataloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        for i in range(len(inputs)):
            ind=labels[i]
            print(classes[ind])
            image_show(inputs[i])
            
            img = transforms.ToPILImage()(inputs[i])
            heatmap = occlusion(img, patch_size, patch_stride, model, classes, ind)
            plot_name='Heatmap (' + str(patch_size) + ' ' + str(patch_stride) + ').png'
            ax = sns.heatmap(heatmap, cmap='jet', square=True)
            #plt.savefig(plot_name)
            plt.show()
        break
