In [78]:
import torch
from PIL import Image
import numpy
import sys
from torchvision import transforms
import numpy as np
import cv2
import matplotlib.pyplot as plt

from vit_rollout import VITAttentionRollout
from vit_grad_rollout import VITAttentionGradRollout

from tqdm import tqdm
import scipy.stats

# Notebook for evaluating impact on precision by applying blank-out

In [79]:
import requests
import torch.nn.functional as F

LABELS_URL = 'https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json'
classes = {int(key): value for (key, value) in requests.get(LABELS_URL).json().items()}

In [80]:
img = Image.open("examples/input.png")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {DEVICE}")

IMAGE_SIZE = 224
DISCARD_RATIO = 0.9

def preprocess_image(image_path, transform):
    img = Image.open(image_path)
    input_tensor = transform(img).unsqueeze(0)
    return input_tensor.to(DEVICE)

def show_mask_on_image(img, mask):
    img = np.float32(img) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

model = torch.hub.load('facebookresearch/deit:main', 
        'deit_tiny_patch16_224', pretrained=True)
model.eval()
model.to(DEVICE)
print()

Using cuda



Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main


# Example and functions

In [72]:
input_tensor  = preprocess_image("examples/input.png", transform)
scores = model(input_tensor)


In [73]:
def get_prediction(scores):
    '''Gets the index of max prob and the prob
    '''
    h_x = F.softmax(scores, dim=1).data.squeeze()
    probs, idx = h_x.sort(0, True)
    # output the prediction
    return idx[0].item(), probs[0].item()

idx, prob = get_prediction(scores)

In [74]:
import copy
import torch
import torch.nn.functional as F

def compute_blankout_importance_conv(model, input_tensor, initial_class, patch_size=16, discard_ratio=0.9):
    # Get the initial prediction
    initial_scores = model(input_tensor)
    initial_prob = F.softmax(initial_scores, dim=1)[0, initial_class].item()

    # Create a deep copy of the input tensor
    perturbed_tensor = copy.deepcopy(input_tensor)

    # Initialize the importance vector
    importance_vector = []

    # Iterate through each patch in the input tensor
    for i in range(0, input_tensor.size(2), patch_size):
        for j in range(0, input_tensor.size(3), patch_size):
            # Blank-out each patch in the input tensor
            perturbed_tensor = copy.deepcopy(input_tensor)
            perturbed_tensor[0, :, i:i+patch_size, j:j+patch_size] = 0  # Set all values to zero for each patch

            # Get the prediction after blanking-out the patch
            perturbed_scores = model(perturbed_tensor)
            perturbed_prob = F.softmax(perturbed_scores, dim=1)[0, initial_class].item()

            # Compute the impact on the probability
            impact = initial_prob - perturbed_prob
            importance_vector.append(impact)

    return importance_vector

# # Usage example:
input_tensor = preprocess_image("examples/input.png", transform)
initial_class = get_prediction(scores)[0]
importance_vector_direct = compute_blankout_importance_conv(model, input_tensor, initial_class, patch_size=16, discard_ratio=DISCARD_RATIO)

# Blank out pipeline

In [75]:
file_blank_out = 'blank_out.txt'

with open(file_blank_out, 'w') as f:
    pass  # Opening in 'w' mode clears the file

In [76]:
path_prefix = 'images/ILSVRC2012_val_00000'
path_suffix = '.JPEG'
discard_ratio = 0.9
image_size = 224

def convert_number(number):
    if number < 10:
        return '00'+str(number)
    if number < 100:
        return '0'+str(number)
    else:
        return str(number)

# img = Image.open(path_prefix + image_number_converted + path_suffix)

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

In [77]:
for image_number in tqdm(range(1,121)):
    # print(f"Current image being treated: {image_number}")
    if image_number in [34, 107, 118]:
        with open(file_blank_out, 'a') as f:
            f.write('\n') 
    else:
        image_number_converted = convert_number(image_number)
        image_path = path_prefix + image_number_converted + path_suffix
        input_tensor  = preprocess_image(image_path, transform)

        # Getting idx
        model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
        model.eval()
        model.to(DEVICE)
        scores = model(input_tensor)
        category_index, _ = get_prediction(scores)

        initial_class = get_prediction(scores)[0]
        importance_vector_direct = compute_blankout_importance_conv(model, input_tensor, category_index, patch_size=16, discard_ratio=DISCARD_RATIO)
        importance_vector_direct = scipy.stats.rankdata(importance_vector_direct, method='average')
        
        with open(file_blank_out, 'a') as f:
            np.savetxt(f, [importance_vector_direct], fmt='%.3f', delimiter=',') 
        

  0%|          | 0/120 [00:00<?, ?it/s]Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
  1%|          | 1/120 [00:01<02:08,  1.08s/it]Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
  2%|▏         | 2/120 [00:02<02:06,  1.08s/it]Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
  2%|▎         | 3/120 [00:03<02:07,  1.09s/it]Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
  3%|▎         | 4/120 [00:04<02:05,  1.08s/it]Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
  4%|▍         | 5/120 [00:05<02:04,  1.08s/it]Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
  5%|▌         | 6/120 [00:06<02:03,  1.08s/it]Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebo

In [60]:
# Your original list of values
values = [-0.5, 3, 9, 8, 2, 7, 10, 1, 4, 6]  # For example

# Get the Spearman ranks of the values
ranks = scipy.stats.rankdata(values, method='average')  # 'average' deals with ties by assigning the average rank


In [46]:
import pandas as pd

In [50]:
pd.read_csv('images_done/attention_flow.txt', header = None)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,186,187,188,189,190,191,192,193,194,195
0,0.091,0.086,0.087,0.088,1.000,0.086,0.089,0.086,0.086,0.088,...,0.087,0.090,0.452,0.091,0.090,0.083,0.090,0.091,0.089,0.088
1,0.737,0.076,0.079,0.074,0.093,0.088,0.085,0.082,0.084,0.083,...,0.087,0.077,0.084,0.086,0.083,0.076,0.318,0.076,0.385,0.747
2,1.000,0.031,0.030,0.030,0.030,0.030,0.031,0.030,0.030,0.030,...,0.030,0.030,0.030,0.030,0.030,0.030,0.030,0.031,0.030,0.030
3,0.066,0.066,0.067,0.066,0.065,0.064,0.065,0.066,0.065,0.074,...,0.135,0.197,0.066,0.067,0.062,0.121,0.063,0.062,0.065,0.070
4,0.038,0.037,0.037,0.037,0.036,0.037,0.037,0.037,0.036,0.036,...,0.046,0.037,0.038,0.037,0.038,0.038,0.038,0.041,0.038,0.038
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
112,0.073,0.787,0.072,0.466,0.073,0.072,0.076,0.160,0.074,0.072,...,0.071,0.069,0.072,0.072,0.072,0.072,0.070,0.071,0.069,0.071
113,0.032,0.032,0.029,0.029,0.029,0.029,0.092,0.255,0.029,0.028,...,0.029,0.029,0.030,0.030,0.030,0.030,0.030,0.029,0.029,0.031
114,0.067,0.068,0.070,1.000,0.754,0.799,0.068,0.067,0.444,0.190,...,0.068,0.070,0.068,0.064,0.064,0.066,0.065,0.064,0.063,0.065
115,0.047,0.044,0.046,0.044,0.045,0.044,0.045,0.045,0.045,0.045,...,0.600,0.047,0.047,0.045,0.046,0.045,0.045,0.044,0.045,0.047


In [52]:
pd.read_csv('blank_out.txt', header = None)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,246,247,248,249,250,251,252,253,254,255
0,-0.037,-0.023,-0.035,0.002,-0.020,0.010,-0.014,-0.016,-0.028,-0.016,...,-0.034,-0.024,-0.025,-0.021,-0.039,-0.027,-0.021,-0.017,-0.008,-0.031
1,-0.071,-0.003,-0.040,0.017,0.039,0.068,0.025,-0.027,-0.033,0.027,...,0.002,-0.018,-0.017,-0.002,-0.079,-0.023,-0.029,-0.024,-0.010,-0.023
2,-0.071,-0.013,-0.066,-0.016,-0.003,-0.022,-0.013,-0.068,-0.015,-0.026,...,-0.047,-0.003,-0.032,-0.005,-0.022,-0.003,0.010,0.015,0.002,0.015
3,0.003,-0.020,-0.010,-0.004,-0.010,-0.012,0.009,-0.004,-0.003,0.003,...,-0.099,-0.034,-0.041,-0.121,-0.099,-0.063,-0.065,-0.030,0.001,0.005
4,-0.059,-0.027,-0.036,-0.010,0.008,-0.001,-0.011,-0.048,-0.014,-0.016,...,0.019,0.006,-0.023,0.009,-0.005,0.014,0.004,-0.021,-0.002,-0.010
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
112,-0.111,-0.014,-0.027,-0.059,-0.031,-0.023,-0.053,-0.034,-0.042,-0.081,...,0.004,0.008,0.010,0.007,0.001,-0.003,-0.004,-0.001,0.005,0.003
113,-0.077,-0.001,0.021,0.001,-0.039,-0.052,-0.046,-0.114,-0.053,-0.040,...,-0.075,-0.134,-0.140,-0.139,-0.133,-0.107,-0.031,-0.039,0.007,-0.136
114,0.088,0.019,0.102,0.014,0.064,0.085,0.056,0.006,0.051,-0.001,...,0.012,0.003,0.045,-0.014,0.106,0.044,0.014,0.002,-0.002,0.031
115,-0.016,-0.005,0.015,-0.000,-0.001,0.013,0.016,0.010,0.002,0.010,...,-0.007,-0.006,0.009,0.009,-0.007,-0.019,-0.036,0.004,-0.009,-0.001


# IMPORTANT


MUST CHECK IF THE CORRESPONDANCE OF EMBEDDING IS THE SAME AS IN THE BLANK OUT !!!