# Compute Relevancy Of Transformer Networks

Self-attention models, specifically Transformers have taken the computer vision field by storm be it OpenAI’s DALL-E or Google’s ViT models.  This creates a need for tools that can interpret and visualize the decision process behind transformer models. These visualizations can be used to debug models and verify that the models are fair and unbiased. A new approach for computing token relevance for Transformer models was proposed in the paper “Transformer Interpretability Beyond Attention Visualization” by Hila Chefer, Shir Gur, and Lior Wolf. The method assigns local relevance based on the Deep Taylor Decomposition and then propagates these relevancy scores through the layers. This propagation involves attention layers and skip-connections; both involve the mixing activation maps and have poised unique challenges to existing approaches. 

To learn more about its architecture, please refer [this](https://analyticsindiamag.com/compute-relevancy-of-transformer-networks-via-novel-interpretable-transformer/) article.

# Get the GitHub repo and install the requirements

You'll have to restart the runtime, make sure to navigate into the newly created directory once you restart runtime otherwise you'll encounter `ImportError`s

Clone the Transformer-Explainability GitHub repository, navigate into the newly created Transformer-Explainability directory and install the requirements.

In [None]:
!python -m pip install pip --upgrade --user -q --no-warn-script-location
!python -m pip install numpy pandas seaborn matplotlib scipy statsmodels sklearn tensorflow keras opencv-python pillow scikit-image torch torchvision \
     tqdm --user -q --no-warn-script-location

!python -m pip install timm==0.3.2 --user -q

import IPython
IPython.Application.instance().kernel.do_shutdown(True)


In [10]:
!git clone https://github.com/hila-chefer/Transformer-Explainability.git

import os
os.chdir(f'./Transformer-Explainability')

!pip install -r requirements.txt

fatal: destination path 'Transformer-Explainability' already exists and is not an empty directory.


You’ll have to restart the runtime after this. Make sure to navigate into the newly created directory once you restart runtime otherwise you’ll encounter ImportErrors.

# **Transformer Interpretability Beyond Attention Visualization**

Import necessary libraries and classes.

In [11]:
from PIL import Image
Image.LOAD_TRUNCATED_IMAGES = True
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import numpy as np
import cv2

from baselines.ViT.ViT_LRP import deit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_explanation_generator import LRP

#Imagenet class indices to names

Download the ImageNet class labels and create an index-to-class labels dictionary.

In [12]:
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
with open("imagenet_classes.txt", "r") as f:
    index_to_class = {i: s.strip() for i, s in enumerate(f.readlines())}
# index_to_class

--2021-06-17 05:45:00--  https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10472 (10K) [text/plain]
Saving to: ‘imagenet_classes.txt.1’


2021-06-17 05:45:00 (96.6 MB/s) - ‘imagenet_classes.txt.1’ saved [10472/10472]



# **DeiT examples**

Load a pre-trained DeiT model

In [13]:
# initialize ViT pretrained with DeiT
model = vit_LRP(pretrained=True).cpu()
model.eval()
attribution_generator = LRP(model)

normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,
])

Create two helper functions: one for visualizing the mask over images and second one for applying softmax to the final dense layer of the model to obtain the predicted classes.

In [14]:
# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam
#Create the function for interpreting the predictions process of DeiT model. generate_LRP is the only novel paper-specific method used in this function
def generate_visualization(original_image, class_index=None):
    transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cpu(), method="transformer_attribution", index=class_index).detach()
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).cpu).data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis

def print_top_classes(predictions, **kwargs):    
    # Print Top-5 predictions
    prob = torch.softmax(predictions, dim=1)
    class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
    max_str_len = 0
    class_names = []
    for cls_idx in class_indices:
        class_names.append(index_to_class[cls_idx])
        if len(index_to_class[cls_idx]) > max_str_len:
            max_str_len = len(index_to_class[cls_idx])
    
    print('Top 5 classes:')
    for cls_idx in class_indices:
        output_string = '\t{} : {}'.format(cls_idx, index_to_class[cls_idx])
        output_string += ' ' * (max_str_len - len(index_to_class[cls_idx])) + '\t\t'
        output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
        print(output_string)

  Visualizing the relevance of image patches for particular predictions in a class-specific manner.

Perform inference on the image to get the class index of the objects in the image.

In [15]:
image = Image.open('samples/dogcat2.png')
dog_cat_image = transform(image)

output = model(dog_cat_image.unsqueeze(0).cpu())
print_top_classes(output)


Top 5 classes:
	207 : golden retriever  		value = 6.523	 prob = 35.7%
	208 : Labrador retriever		value = 4.288	 prob = 3.8%
	285 : Egyptian cat      		value = 3.641	 prob = 2.0%
	222 : kuvasz            		value = 3.422	 prob = 1.6%
	281 : tabby             		value = 2.778	 prob = 0.8%


Visualize the relevance of image patches for different classes.

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(21,7))
axs[0].imshow(image);
axs[0].axis('off');

# dog - generate visualization for class 243: 'bull mastiff' - the predicted class
#by default the predicted  class is visualized
dog = generate_visualization(dog_cat_image)
axs[1].imshow(dog);
axs[1].axis('off');

# cat - generate visualization for class 282 : 'tiger cat'
#cat = generate_visualization(dog_cat_image, class_index=285)
#axs[2].imshow(cat);
#axs[2].axis('off');

In [None]:
image = Image.open('samples/adver.png').convert('RGB')
dog_cat_image = transform(image)

output = model(dog_cat_image.unsqueeze(0).cpu())
print_top_classes(output)

In [None]:
fig, axs = plt.subplots(1, 3,figsize=(21,7))
axs[0].imshow(image);
axs[0].axis('off');
# golden retriever - the predicted class
dog = generate_visualization(dog_cat_image)
axs[1].imshow(dog);
axs[1].axis('off');

# generate visualization for class 285: 'Egyptian cat'
#cat = generate_visualization(dog_cat_image, class_index=285)
#axs[2].imshow(cat);
#axs[2].axis('off');

In [None]:
image = Image.open('samples/normal.png').convert('RGB')
tusker_zebra_image = transform(image)

output = model(tusker_zebra_image.unsqueeze(0).cpu())
print_top_classes(output)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(21,7))
axs[0].imshow(image);
axs[0].axis('off');

# zebra - the predicted class
zebra = generate_visualization(tusker_zebra_image)
axs[1].imshow(zebra);
axs[1].axis('off');

# tusker  - generate visualization for class 101 : tusker
tusker = generate_visualization(tusker_zebra_image, class_index=530)
axs[2].imshow(tusker);
axs[2].axis('off');