# Notebook for performance evaluation of techniques

**We have four methods:**

- Raw Attention (seems to be only mean of weights from input embeddings on first layer).
- Attention flow (best in final layer).
- Attention Rollout (best in final layer).
- Gradient Attention Rollout (best in final layer (most likely)).

**We have two possible methods for metric evaluation:**

- Blank-out -> generate black spots and look at drop in accuracy.
- Gradient input -> **(?)**.

**Pipeline for evaluating models:**

- Select $N$ images for study.
- For each image (also can be done in multiple layers, but time constraint): 
    - Compute its raw attention, attention flow, rollout, and gradient rollout.
    - Compute the blank-out and gradient input.
    - Each will yield a vector of size $n_{embedding}$.
    - Compute Spearman rank correlation between them and store it

We should end up with a table of size $n_{images} \times n_{methods}$. It suffices to compute mean and standard deviation for each method.

In [20]:
import torch
from PIL import Image
import numpy
import sys
from torchvision import transforms
import numpy as np
import cv2
import matplotlib.pyplot as plt

from vit_rollout import VITAttentionRollout
from vit_grad_rollout import VITAttentionGradRollout
from vit_flow import VITAttentionFlow

import copy
import requests
import torch.nn.functional as F

from tqdm import tqdm

In [21]:
DEVICE = 'cpu'

In [22]:
model = torch.hub.load('facebookresearch/deit:main', 
        'deit_tiny_patch16_224', pretrained=True)
model.eval()
model.to(DEVICE)
print()




Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


In [23]:
def preprocess_image(image_path, transform):
    img = Image.open(image_path)
    input_tensor = transform(img).unsqueeze(0)
    return input_tensor.to(DEVICE)

def get_prediction(scores):
    '''Gets the index of max prob and the prob
    '''
    h_x = F.softmax(scores, dim=1).data.squeeze()
    probs, idx = h_x.sort(0, True)
    # output the prediction
    return idx[0].item(), probs[0].item()

# idx, prob = get_prediction(scores)

# Generating Attention Rollout masks

In [24]:
file_attention_rollout = 'attention_rollout.txt'
file_attention_grad_rollout = 'attention_grad_rollout.txt'
file_attention_flow = 'attention_flow.txt'

with open(file_attention_rollout, 'w') as f:
    pass  # Opening in 'w' mode clears the file

with open(file_attention_grad_rollout, 'w') as f:
    pass  # Opening in 'w' mode clears the file

with open(file_attention_flow, 'w') as f:
    pass  # Opening in 'w' mode clears the file

In [25]:
path_prefix = 'images/ILSVRC2012_val_00000'
path_suffix = '.JPEG'
discard_ratio = 0.9
image_size = 224

def convert_number(number):
    if number < 10:
        return '00'+str(number)
    if number < 100:
        return '0'+str(number)
    else:
        return str(number)

# img = Image.open(path_prefix + image_number_converted + path_suffix)

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

In [27]:
for image_number in range(35,61):
    print(f"Current image being treated: {image_number}")

    image_number_converted = convert_number(image_number)
    image_path = path_prefix + image_number_converted + path_suffix
    input_tensor  = preprocess_image(image_path, transform)

    # Getting idx
    model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
    model.eval()
    model.to(DEVICE)
    scores = model(input_tensor)
    category_index, _ = get_prediction(scores)

    # Attention Rollout
    model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
    model.eval()
    attention_rollout = VITAttentionRollout(model, discard_ratio=discard_ratio)
    mask_attention_rollout = attention_rollout.get_attention_mask(input_tensor).numpy()
    mask_attention_rollout = mask_attention_rollout / np.max(mask_attention_rollout)
    
    with open(file_attention_rollout, 'a') as f:
        np.savetxt(f, [mask_attention_rollout], fmt='%.3f', delimiter=',')  # Saving as float

    # Gradient Attention Rollout
    model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
    model.eval()
    attention_grad_rollout = VITAttentionGradRollout(model, discard_ratio=discard_ratio)
    mask_attention_grad_rollout = attention_grad_rollout.get_attention_mask(input_tensor, category_index=category_index).numpy()
    mask_attention_grad_rollout = mask_attention_grad_rollout / np.max(mask_attention_grad_rollout)
    
    with open(file_attention_grad_rollout, 'a') as f:
        np.savetxt(f, [mask_attention_grad_rollout], fmt='%.3f', delimiter=',')  # Adjust format as needed

    # Attention Flow
    model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
    model.eval()
    attention_flow = VITAttentionFlow(model, discard_ratio=discard_ratio)
    mask_attention_flow = attention_flow.get_attention_mask(input_tensor).numpy()
    mask_attention_flow = mask_attention_flow / np.max(mask_attention_flow)
    
    with open(file_attention_flow, 'a') as f:
        np.savetxt(f, [mask_attention_flow], fmt='%.3f', delimiter=',')  # Adjust format as needed
    

Current image being treated: 35


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [10:26<00:00,  3.18s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 36


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [11:34<00:00,  3.52s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 37


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [10:10<00:00,  3.10s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 38


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [12:48<00:00,  3.90s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 39


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [17:56<00:00,  5.46s/it]  
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 40


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [10:33<00:00,  3.21s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 41


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [10:36<00:00,  3.23s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 42


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [11:03<00:00,  3.37s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 43


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [13:57<00:00,  4.25s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 44


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [27:17<00:00,  8.31s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 45


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [15:31<00:00,  4.73s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 46


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [29:32<00:00,  9.00s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 47


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [06:01<00:00,  1.84s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 48


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [06:11<00:00,  1.88s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 49


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [06:03<00:00,  1.85s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 50


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [16:02<00:00,  4.88s/it]   
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 51


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [10:51<00:00,  3.31s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 52


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [11:05<00:00,  3.38s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 53


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [18:48<00:00,  5.73s/it] 
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 54


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [10:27<00:00,  3.18s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 55


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [10:29<00:00,  3.20s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 56


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [08:23<00:00,  2.56s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 57


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [08:08<00:00,  2.48s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 58


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [06:01<00:00,  1.84s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 59


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [06:01<00:00,  1.84s/it]
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main


Current image being treated: 60


Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
Using cache found in C:\Users\mouha/.cache\torch\hub\facebookresearch_deit_main
100%|██████████| 197/197 [06:01<00:00,  1.84s/it]


In [None]:
# first train_1_to-33
#second 35 to end

In [None]:
# import requests
# import torch.nn.functional as F

# LABELS_URL = 'https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json'
# classes = {int(key): value for (key, value) in requests.get(LABELS_URL).json().items()}

In [None]:
# input_tensor  = preprocess_image("examples/input.png", transform)
# scores = model(input_tensor)

# def print_preds(scores):
#     # print the predictions with their 'probabilities' from the scores
#     h_x = F.softmax(scores, dim=1).data.squeeze()
#     probs, idx = h_x.sort(0, True)
#     probs = probs.numpy()
#     idx = idx.numpy()
#     # output the prediction
#     for i in range(0, 5):
#         print('{:.3f} -> {}'.format(probs[i], classes[idx[i]]))
#     return idx

# def get_prediction(scores):
#     '''Gets the index of max prob and the prob
#     '''
#     h_x = F.softmax(scores, dim=1).data.squeeze()
#     probs, idx = h_x.sort(0, True)
#     # output the prediction
#     return idx[0].item(), probs[0].item()

# idx, prob = get_prediction(scores)