In [1]:
import argparse
import torch
import torchvision
import torch.nn.functional as F
import numpy as np

from nn.enums import ExplainingMethod
from nn.networks import ExplainableNet
from nn.utils import get_expl, plot_overview, clamp, load_image, make_dir

no display found. Using non-interactive Agg backend


In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="4"

In [3]:
def get_beta(i, num_iter):
    """
    Helper method for beta growth
    """
    start_beta, end_beta = 10.0, 100.0
    return start_beta * (end_beta / start_beta) ** (i / num_iter)

In [4]:
# args
args_cuda = True
args_method = 'lrp'
args_beta_growth = None
args_img = '../data/collie4.jpeg'
args_target_img = '../data/tiger_cat.jpeg'

In [6]:
# options
device = torch.device("cuda" if args_cuda else "cpu")
method = getattr(ExplainingMethod, args_method)
print('Explanation method {} will be used'.format(method))

# load model
data_mean = np.array([0.485, 0.456, 0.406])
data_std = np.array([0.229, 0.224, 0.225])
vgg_model = torchvision.models.vgg16(pretrained=True)
model = ExplainableNet(vgg_model, data_mean=data_mean, data_std=data_std, beta=1000 if args_beta_growth else None)
if method == ExplainingMethod.pattern_attribution:
    model.load_state_dict(torch.load('../models/model_vgg16_pattern_small.pth'), strict=False)
model = model.eval().to(device)


Explanation method ExplainingMethod.lrp will be used
in_channels:  3
in_channels:  64
in_channels:  64
in_channels:  128
in_channels:  128
in_channels:  256
in_channels:  256
in_channels:  256
in_channels:  512
in_channels:  512
in_channels:  512
in_channels:  512
in_channels:  512


In [7]:
# load images
x = load_image(data_mean, data_std, device, args_img)
x_target = load_image(data_mean, data_std, device, args_target_img)
x_adv = x.clone().detach().requires_grad_()

In [9]:
x.shape

torch.Size([1, 3, 224, 224])

In [8]:
org_expl, org_acc, org_idx = get_expl(model, x, method)

ZA:  tensor([[[[ 7.4757, 10.7397,  9.6943,  ..., 11.2909, 11.0325,  7.1471],
          [ 8.9987, 12.8986, 12.7169,  ..., 17.3652, 15.5353,  9.4430],
          [ 6.4053,  9.2738, 10.1147,  ..., 21.2259, 17.4284,  8.9212],
          ...,
          [ 3.4782,  6.1923,  7.6292,  ..., 12.0124, 10.3241,  7.8390],
          [ 3.5625,  6.4150,  7.6562,  ..., 12.3634, 13.1586, 10.2126],
          [ 2.7336,  4.5150,  5.1407,  ...,  8.1313,  9.3188,  7.7767]],

         [[ 9.1312, 11.8667,  8.9114,  ..., 10.0830,  8.7897,  4.5493],
          [10.9227, 14.8181, 11.6095,  ..., 16.4554, 14.7167,  8.2421],
          [ 6.3305,  8.9964,  8.0854,  ..., 17.4102, 15.3199,  8.8214],
          ...,
          [ 3.3898,  6.4435,  8.3286,  ...,  7.9644,  7.0637,  4.6214],
          [ 2.6356,  4.9827,  6.5096,  ...,  8.0329,  7.3549,  4.7616],
          [ 1.9262,  3.3951,  4.2173,  ...,  5.4237,  5.4596,  3.5717]],

         [[ 9.4925, 10.4853,  8.8753,  ..., 11.0573,  9.7402,  6.6518],
          [10.9765, 13.05

ZA:  tensor([[[[33.0891, 44.8010, 52.3599,  ..., 40.7703, 32.9391, 22.3261],
          [42.4926, 63.8724, 65.2409,  ..., 57.9117, 50.1491, 33.9188],
          [42.7086, 62.3400, 53.3712,  ..., 57.0426, 45.8448, 31.0287],
          ...,
          [10.6743, 12.6223, 10.4058,  ..., 34.4880, 38.1798, 27.6517],
          [14.9996, 18.4802, 16.8317,  ..., 36.5787, 37.3451, 25.9655],
          [11.2717, 14.2941, 11.9228,  ..., 26.8149, 23.3515, 15.5482]],

         [[37.0695, 45.7606, 52.1274,  ..., 29.7729, 24.1075, 14.2268],
          [46.5857, 52.3499, 51.3737,  ..., 42.3455, 32.7214, 17.7614],
          [35.1536, 43.2523, 37.0679,  ..., 36.0679, 26.9837, 14.0455],
          ...,
          [ 7.2974,  9.2189,  8.6425,  ..., 20.8646, 24.9196, 19.0575],
          [12.0441, 15.4120, 13.7432,  ..., 26.2184, 27.3730, 18.5678],
          [12.3971, 14.0013, 10.2948,  ..., 26.3363, 21.4279, 11.8353]],

         [[26.0487, 37.4274, 38.9446,  ..., 28.7056, 20.9918, 13.5915],
          [38.5637, 54.22

ZA:  tensor([[[[ 7.5175, 12.1490, 11.1058,  ..., 12.3007, 13.1552, 19.4738],
          [ 8.7341, 12.6583, 12.1637,  ..., 12.2441, 14.8479, 19.6192],
          [ 7.3738,  7.0305,  9.3248,  ...,  7.4109, 10.1384, 15.5093],
          ...,
          [ 5.7200,  4.9888,  3.2224,  ...,  3.6404,  4.3586,  6.8736],
          [ 6.3530,  5.8708,  4.7956,  ...,  4.5157,  5.4849,  7.9280],
          [ 4.2783,  4.7594,  3.8945,  ...,  3.6312,  3.8957,  5.0558]],

         [[ 8.6729, 11.1429, 10.8726,  ...,  7.1816, 10.0628,  8.7542],
          [ 5.6576, 11.5236, 10.0490,  ...,  5.7255,  8.5449,  7.8605],
          [ 4.1367,  9.1640, 10.3603,  ...,  4.7541,  7.3535,  7.2823],
          ...,
          [ 3.8415,  4.9637,  3.3384,  ...,  3.7300,  6.0557,  4.8927],
          [ 4.9168,  7.0524,  5.6708,  ...,  4.9902,  6.9914,  6.1939],
          [ 1.4594,  2.8749,  2.5341,  ...,  2.0528,  2.3966,  2.1632]],

         [[10.5996, 10.7604, 11.8378,  ..., 10.4450, 13.1489, 16.4083],
          [10.9752, 15.64