# B-LRP: Imagenet experiment
This notebook visualises the B-LRP results using the LRP-Epsilon rule and perform a pixelflipping experiment on a dolwnoaded subset of Imagenet dataset


In [None]:
#imports

import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib.ticker as ticker
import copy

from torch import nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

from tqdm import tqdm
import cv2

import seaborn as sns

First, we import a VGG16 network, pretrained on Imagenet. We adjust the standard dropout layers in a way, that allows us to record the dropped neurons, which we need to compute LRP.


In [None]:
class MyDropout(nn.Module):
    def __init__(self, p=0.5):
        super(MyDropout, self).__init__()
        self.p = p
        self.seed = 0
    
    
    def forward(self, input, freeze = False):
        # if model.eval(), don't apply dropout
        if not self.training:
            return input
        
        if not freeze:
            q=np.random.randint(10000000, size = 1)[0]
            self.seed = q
        
        torch.manual_seed(self.seed)   
        return torch.nn.functional.dropout(input, p=self.p)
        
        
class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            MyDropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            MyDropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


cfgs = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def vgg16(pretrained=False, progress=True, **kwargs):
    r"""VGG 16-layer model (configuration "D")
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
    

model = vgg16()
model.load_state_dict(torchvision.models.vgg16(pretrained=True).state_dict())

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model.to(device)

In a following cell we define functions for LRP-CMP rule

In [None]:
# --------------------------------------------------------------
# LRP Composite rule
# --------------------------------------------------------------

# More information at http://www.heatmapping.org/tutorial/

def LRP_CMP(image, class_id, model, dropout = True, verbose = False, device = 'cpu'):


    mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(1,-1,1,1).to(device)  # We define mean and std for data normalisation
    std  = torch.Tensor([0.229, 0.224, 0.225]).reshape(1,-1,1,1).to(device)

    image = image.to(device)
    
    if not dropout:
        model.eval()
    else:
        model.train()

    X = (image.transpose(2,1).transpose(0,1).view([1,3,224,224]) - mean) / std
    
    layers = list(model._modules['features']) + toconv(list(model._modules['classifier']))
    L = len(layers)
    
    
    A = [X]+[None]*L
    with torch.no_grad():
      for l in range(L): A[l+1] = layers[l].forward(A[l]).to(device)
    
    scores = np.array(A[-1].data.view(-1).cpu())
    ind = np.argsort(-scores)
    
    if verbose:
        for i in ind[:5]:
            print('%20s (%3d): %6.3f'%(imgclasses[i][:20],i,scores[i]))

    T = torch.FloatTensor((1.0*(np.arange(1000)==class_id).reshape([1,1000,1,1]))).to(device)
    R = [None] * L + [(A[-1]*T)]

    for l in range(1,L)[::-1]:

        A[l] = A[l].requires_grad_(True)

        if isinstance(layers[l],torch.nn.MaxPool2d): layers[l] = torch.nn.AvgPool2d(2)
        
        if isinstance(layers[l],torch.nn.Conv2d) or isinstance(layers[l],torch.nn.AvgPool2d):

            
            if l <= 16:       rho = lambda p: p + 0.25*p.clamp(min=0); incr = lambda z: z+1e-9
            if 17 <= l <= 30: rho = lambda p: p;                       incr = lambda z: z+1e-9+0.25*((z**2).mean()**.5).data
            if l >= 31:       rho = lambda p: p;                       incr = lambda z: z+1e-9

            z = incr(newlayer(layers[l],rho).forward(A[l]))                     # step 1
            s = (R[l+1]/z).data                                                 # step 2
            (z*s).sum().backward(); c = A[l].grad                               # step 3
            R[l] = (A[l]*c).data                                                # step 4
            
        else:
            if not dropout:
              R[l] = R[l+1]
            else:
              if isinstance(layers[l],MyDropout):
                  incr = lambda z: z+1e-9
                  z = incr(layers[l].forward(A[l], freeze = True))
                  s = (R[l+1]/z).data                                           # step 2
                  (z*s).sum().backward(); c = A[l].grad                         # step 3
                  R[l] = (A[l]*c).data                                          # step 4
              else:
                  R[l] = R[l+1]

    A[0] = A[0].requires_grad_(True)

    with torch.no_grad():
      lb = (A[0]*0+(0-mean)/std)
      hb = (A[0]*0+(1-mean)/std)
    
    lb = lb.requires_grad_(True)
    hb = hb.requires_grad_(True)

    z = layers[0].forward(A[0]) + 1e-9                                          # step 1 (a)
    z -= newlayer(layers[0],lambda p: p.clamp(min=0)).forward(lb)               # step 1 (b)
    z -= newlayer(layers[0],lambda p: p.clamp(max=0)).forward(hb)               # step 1 (c)
    s = (R[1]/z).data                                                           # step 2
    (z*s).sum().backward(); c,cp,cm = A[0].grad,lb.grad,hb.grad                 # step 3
    #print(c,cp,cm)
    R[0] = (A[0]*c+lb*cp+hb*cm).data
    
    return R[0].data

# --------------------------------------------------------------
# Clone a layer and pass its parameters through the function g
# --------------------------------------------------------------

def newlayer(layer,g):

    layer = copy.deepcopy(layer)

    try: layer.weight = nn.Parameter(g(layer.weight))
    except AttributeError: pass

    try: layer.bias   = nn.Parameter(g(layer.bias))
    except AttributeError: pass

    return layer

# --------------------------------------------------------------
# convert VGG classifier's dense layers to convolutional layers
# --------------------------------------------------------------

def toconv(layers):

    newlayers = []

    for i,layer in enumerate(layers):

        if isinstance(layer,nn.Linear):

            newlayer = None

            if i == 0:
                m,n = 512,layer.weight.shape[0]
                newlayer = nn.Conv2d(m,n,7)
                newlayer.weight = nn.Parameter(layer.weight.reshape(n,m,7,7))

            else:
                m,n = layer.weight.shape[1],layer.weight.shape[0]
                newlayer = nn.Conv2d(m,n,1)
                newlayer.weight = nn.Parameter(layer.weight.reshape(n,m,1,1))

            newlayer.bias = nn.Parameter(layer.bias)

            newlayers += [newlayer]

        else:
            newlayers += [layer]

    return newlayers

# --------------------------------------------------------------
# Function for MinMax Normalisation of Relevances
# --------------------------------------------------------------

def normalise_relevance(relevance_matrix):
    a = relevance_matrix.min()
    b = relevance_matrix.max()
    
    if (a == 0.) & (b == 0.):
        return relevance_matrix
    if (a > 0.):
      return (relevance_matrix >0.)*relevance_matrix/b
    if (b < 0.):
      return - (relevance_matrix <=0.)*relevance_matrix/a
    
    return (relevance_matrix >0.)*relevance_matrix/b  - (relevance_matrix <=0.)*relevance_matrix/a

### Visualization of B-LRP on the 'castle' example
In a following cells we inspect the B-LRP visually on the example of 'castle' image

In [None]:
img_name = 'castle.jpg'
img = np.array(cv2.imread(img_name))[...,::-1]/255.0
plt.imshow(img)

img = torch.tensor(img).float().to(device)
class_ind = 483                                         # index corresponding to a 'castle' class in Imagenet

In [None]:
N_MC = 100                                              # Number of samples from the posterior

model.to(device)

LRPs = torch.zeros([N_MC, 224, 224])
Standard_LRP = torch.zeros([224, 224])

counter = 0

for i in tqdm(range(N_MC)):
  LRPs[i] = LRP_CMP(img, class_ind, model, dropout = True, device = device)[0].sum(axis = 0).data
  LRPs[i][LRPs[i] != LRPs[i]] = 0.

Standard_LRP = LRP_CMP(img, class_ind, model, dropout = False, device = device)[0].sum(axis = 0).data
Standard_LRP[Standard_LRP != Standard_LRP] = 0.

In [None]:
plt.rcParams.update(plt.rcParamsDefault)
alphas = [5,25,50,75,95]

LRPs = LRPs.to('cpu')
fig, ax = plt.subplots(1, 3, figsize=(15.6,3))

ax[0].imshow(img.cpu(), cmap = 'gray')
ax[0].title.set_text('Original Image')
ax[0].xaxis.set_major_locator(plt.NullLocator())
ax[0].yaxis.set_major_locator(plt.NullLocator())

ax[1].imshow(normalise_relevance(Standard_LRP.cpu()), cmap = 'seismic')
ax[1].title.set_text('Standard LRP')
ax[1].xaxis.set_major_locator(plt.NullLocator())
ax[1].yaxis.set_major_locator(plt.NullLocator())

ax[2].imshow(normalise_relevance(LRPs.mean(axis  = 0)), cmap = 'seismic')
ax[2].title.set_text('Expected LRP')
ax[2].xaxis.set_major_locator(plt.NullLocator())
ax[2].yaxis.set_major_locator(plt.NullLocator())

fig, ax = plt.subplots(1, len(alphas), figsize=(14,6))

cols = ['{}-th Percentile'.format(col) for col in alphas]

for axe, col in zip(ax, cols):
    axe.set_title(col)


for i in range(len(alphas)):
    ax[i].imshow(normalise_relevance(np.percentile(LRPs.reshape([N_MC, - 1]).numpy(), alphas[i], axis = 0).reshape([224,224])), cmap='seismic')
    ax[i].xaxis.set_major_locator(plt.NullLocator())
    ax[i].yaxis.set_major_locator(plt.NullLocator())
    
plt.show()

### B-LRP Pixelflipping with LRP-epsilon
In the following we perform a pixelflipping comparison between B-LRP and LRP-epsilon rule

For the first step we need to load the subset of Imagenet data

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
    ])

testset = torchvision.datasets.ImageFolder('imagenet_data', transform=transform)

target = {0: 483,                      # Indexes, coresponding to a true classes in a subset
         1: 562,
         2: 951,
         3: 355,
         4: 721,
         5: 932,
         6: 849,
         7: 282,
         8: 980,
         9: 907,
         }

In [None]:
# --------------------------------------------------------------
# LRP Epsilon rule
# --------------------------------------------------------------

# More information at http://www.heatmapping.org/tutorial/

def LRP_epsilon(image, class_id, model, dropout = True, verbose = False, device = 'cpu', epsilon = 1e-9):

    mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(1,-1,1,1).to(device)  # We define mean and std for data normalisation
    std  = torch.Tensor([0.229, 0.224, 0.225]).reshape(1,-1,1,1).to(device)
    image = image.to(device)
   
    if not dropout:
        model.eval()
    else:
        model.train()

    model.to(device)

    X = (image.view([1,3,224,224]).to(device) - mean) / std
    
    layers = list(model._modules['features']) + toconv(list(model._modules['classifier']))
    L = len(layers)

    A = [X]+[None]*L
    with torch.no_grad():
      for l in range(L): A[l+1] = layers[l].forward(A[l]).to(device)
    
    scores = np.array(A[-1].data.view(-1).cpu())
    ind = np.argsort(-scores)
    
    if verbose:
        for i in ind[:3]:
            print('New instance:')
            print('%20s (%3d): %6.3f'%(imgclasses[i][:20],i,scores[i]))

    T = torch.FloatTensor((1.0*(np.arange(1000)==class_id).reshape([1,1000,1,1]))).to(device)

    R = [None] * L + [(A[-1]*T)]

    for l in range(0,L)[::-1]:

        A[l] = A[l].requires_grad_(True)
        rho = lambda p: p;                       incr = lambda z: z + epsilon

        if isinstance(layers[l],torch.nn.MaxPool2d): layers[l] = torch.nn.AvgPool2d(2)
        
        if isinstance(layers[l],torch.nn.Conv2d) or isinstance(layers[l],torch.nn.AvgPool2d):

            z = incr(newlayer(layers[l],rho).forward(A[l]))              # step 1
            s = (R[l+1]/z).data                                          # step 2
            (z*s).sum().backward(); c = A[l].grad                        # step 3
            R[l] = (A[l]*c).data                                         # step 4
            
        else:
            if not dropout:
              R[l] = R[l+1]
            else:
              if isinstance(layers[l],MyDropout):
                  incr = lambda z: z + epsilon
                  z = incr(layers[l].forward(A[l], freeze = True))       # step 1
                  s = (R[l+1]/z).data                                    # step 2
                  (z*s).sum().backward(); c = A[l].grad                  # step 3
                  R[l] = (A[l]*c).data                                   # step 4
              else:
                  R[l] = R[l+1]
    return R[0].data

In the following cell we define a function for pixelflipping

To speed up the computation, we evaluate model's output scores not each time we flip 1 pixel, but when we flip *steps* pixels.

In [None]:
# --------------------------------------------------------------
# Function for performing a pixelflipping
# --------------------------------------------------------------

def pixelflipping(image, class_ind, R, model, inner_steps, step_size):

  # image:        Original Image,
  # class_ind:    Index of a true class,
  # R:            Relevances,
  # model:        Neural Network,
  # inner_steps:  Number of times we evaluate scores on the augmented image,
  # step_size:    Number of pixels flipped between the evaluations,

  mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(1,-1,1,1).to(device) 
  std  = torch.Tensor([0.229, 0.224, 0.225]).reshape(1,-1,1,1).to(device)
  
  img = image.clone()
  relevances = R.clone()
  scores = torch.zeros(inner_steps)
  model.eval()

  for i in range(inner_steps):
    scores[i] = model(img.to(device).view([1,3,224,224]))[0][class_ind].data
    
    for s in range(step_size):
        ind = np.unravel_index(torch.argmax(relevances).cpu(), relevances.shape)
        img[0][ind[0]][ind[1]] = (np.random.uniform(0,1,1)[0] - mean.view(3)[0])/std.view(3)[0]
        img[1][ind[0]][ind[1]] = (np.random.uniform(0,1,1)[0] - mean.view(3)[1])/std.view(3)[1]
        img[2][ind[0]][ind[1]] = (np.random.uniform(0,1,1)[0] - mean.view(3)[2])/std.view(3)[2]

        relevances[ind] = -np.Inf
  return scores

In [None]:
# Parameters:

N_pics = 1                        # Number of images used in pixelflipping, (in main experiment 300 was used)
N_inner = 128                     # Number of times we evaluate scores on the augmented image,
N_MC = 100                        # Number of times we sample the posterior to get an estimation for percentiles in B-LRP
steps = 66                        # Number of pixels flipped between the evaluations,
alphas = [5, 25, 50, 75, 95]      # Alphas used in B-LRP


score_lrp = torch.zeros([N_pics, N_inner]).data.to(device)
score_random= torch.zeros([N_pics, N_inner]).data.to(device)
score_alphas= torch.zeros([len(alphas), N_pics, N_inner]).data.to(device)

samples = np.random.choice(len(testset), N_pics)

model.to(device)

counter = 0
for q in tqdm(samples):
  
  img = testset[q][0].to(device)

  class_ind = target[testset[q][1]]

  LRP_MAP = LRP_epsilon(img, class_ind, model, dropout = False, verbose = False, device = device)[0].sum(axis = 0).data

  LRPs = torch.zeros([N_MC, 224, 224])
  for i in range(N_MC):
    LRPs[i] = LRP_epsilon(img, class_ind, model, dropout = True, verbose = False, device = device)[0].sum(axis = 0).data

  LRP_ALPHAs = torch.zeros([len(alphas), 224,224])
  
  for i in range(len(alphas)):
    LRP_ALPHAs[i] = torch.tensor(np.percentile(LRPs.reshape([N_MC, - 1]).numpy(), alphas[i], axis = 0).reshape([224,224]))

  score_lrp[counter]= pixelflipping(img, class_ind, LRP_MAP, model, N_inner, steps)

  for i in range(len(alphas)):
    score_alphas[i][counter] = pixelflipping(img, class_ind, LRP_ALPHAs[i], model, N_inner, steps)


  counter = counter + 1

In [None]:
plt.style.use('default')

matplotlib.rcParams.update({'font.size': 22})
fig, ax = plt.subplots(figsize=(8,6), num='seaborn')

ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=100))

x =  np.linspace(0, 100*(N_inner*steps/224/224), N_inner )
ax.grid(True, which="both", ls="-")
ax.set_ylabel('Mean output score')
ax.set_xlabel('Percentage of pixels flipped')
ax.set_title('Imagenet')

ax.plot(x, score_lrp[:300].cpu().mean(axis = 0).detach().numpy(), label='Standard LRP',linewidth=2,markersize=10, marker = '*')

for i in range(len(alphas)):
    ax.plot(x, score_alphas[i][:300].cpu().mean(axis = 0).detach().numpy()
    , label=r'B-LRP $\alpha = $' + str(alphas[i]),linewidth=1.5,markersize=4, marker = '.')

ax.legend()
plt.savefig("IMAGENET_PIXELFLIPPING_FINAL.pdf", bbox_inches='tight')
plt.show()