In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import h5py
import numpy as np

# setup net
device = device = torch.device("cuda")
model = torchvision.models.resnet50(pretrained=True).to(device)
print("device is", device)
#print("model is", model)

device is cuda


In [2]:
import PIL
import io
import yaml

def pil_bgr2rgb(im):
    b, g, r = im.split()
    im = PIL.Image.merge("RGB", (r, g, b))
    return im


class ImageNetDataset:
    def __init__(self, hdf5_filename, train, transform=None):
        self.hdf5_filename = hdf5_filename
        self.train = train
        self.dataset_name = 'train' if train else 'validation'
        self.transform = transform
        self.open = False
        self.h5 = None
        self.h5_images = None
        self.h5_targets = None

        with h5py.File(hdf5_filename, 'r') as tmp_h5:
            h5_targets = tmp_h5[self.dataset_name + '/targets']
            self.length = len(h5_targets)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if not self.open:
            self.h5 = h5py.File(self.hdf5_filename, 'r', swmr=True)
            self.h5_images = self.h5[self.dataset_name + '/images']
            self.h5_targets = self.h5[self.dataset_name + '/targets']
            self.open = True
        target = self.h5_targets[idx]
        jpg_bytes = self.h5_images[idx].tobytes()
        pil_image = PIL.Image.open(io.BytesIO(jpg_bytes))
        if self.transform is not None:
            img = self.transform(pil_image)
        else:
            img = pil_image
        return img, int(target)

# setup data
filename = '../data/imagenet_full/imagenet.hdf5'
show_loader = ImageNetDataset(filename, train=False)
dct = None
with open('../data/imagenet_full/dict.txt', encoding='utf-8') as data_file:
    dct = yaml.load(data_file.read())


OSError: Unable to open file (unable to open file: name = '../data/imagenet_full/imagenet.hdf5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [None]:
from torchvision.transforms import ToTensor, RandomCrop, Resize

import attribution
import attribution.methods as am
import attribution.utils as au

import importlib
importlib.reload(attribution.methods)
importlib.reload(attribution.utils)

methods = []
methods.append(am.Occlusion(size=10))
methods.append(am.Occlusion(size=10, patch_type="avg"))
methods.append(am.Occlusion(size=10, patch_type="inv"))
methods.append(am.SmoothGrad(steps=1,std=0)) # = gradients
methods.append(am.SmoothGrad(steps=30,std=0.2))
methods.append(am.IntegratedGradients(steps=30))
#methods.append(am.IntegratedGradients(steps=30, only_positive=True))
#methods.append(am.SmoothGrad(std=0, times_input=True))
#methods.append(am.SmoothGrad(times_input=True))
#methods.append(am.IntegratedGradients(steps=30, baseline=1))


def transform(img):
    """ switch axes for showing original img """
    return np.swapaxes(np.swapaxes(img, 0, 2), 0, 1)

for i, (img, real_label) in enumerate(show_loader):
    if i == 7:
        print(dct[real_label])
        img = Resize(256)(img)
        img = RandomCrop((224,224))(img)
        img = ToTensor()(img)
        img = img.unsqueeze(0)
        img = img.to(device)
        au.compare_methods(methods, 
                           model=model, 
                           img=img, 
                           img_trafo=transform, 
                           blur=0, 
                           mode="heatmap")
        break