In [1]:
import torch
from PIL import Image
import numpy
import sys
from torchvision import transforms
import numpy as np
import cv2
import matplotlib.pyplot as plt

from vit_rollout import VITAttentionRollout
from vit_grad_rollout import VITAttentionGradRollout

from tqdm import tqdm
import scipy.stats

  warn(


# Notebook for evaluating impact on precision by applying blank-out

In [2]:
import requests
import torch.nn.functional as F

LABELS_URL = 'https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json'
classes = {int(key): value for (key, value) in requests.get(LABELS_URL).json().items()}

In [3]:
img = Image.open("examples/input.png")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {DEVICE}")

IMAGE_SIZE = 224
DISCARD_RATIO = 0.9

def preprocess_image(image_path, transform):
    img = Image.open(image_path)
    input_tensor = transform(img).unsqueeze(0)
    return input_tensor.to(DEVICE)

def show_mask_on_image(img, mask):
    img = np.float32(img) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

model = torch.hub.load('facebookresearch/deit:main', 
        'deit_tiny_patch16_224', pretrained=True)
model.eval()
model.to(DEVICE)
print()

Using cuda


Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
  from .autonotebook import tqdm as notebook_tqdm





# Example and functions

In [4]:
input_tensor  = preprocess_image("examples/input.png", transform)
scores = model(input_tensor)

  return F.conv2d(input, weight, bias, self.stride,


In [10]:
def get_prediction(scores):
    '''Gets the index of max prob and the prob
    '''
    h_x = F.softmax(scores, dim=1).data.squeeze()
    probs, idx = h_x.sort(0, True)
    # output the prediction
    return idx[0].item(), probs[0].item()

idx, prob = get_prediction(scores)

In [11]:
import copy
import torch
import torch.nn.functional as F

def compute_blankout_importance_conv(model, input_tensor, initial_class, patch_size=16, discard_ratio=0.9):
    # Get the initial prediction
    initial_scores = model(input_tensor)
    initial_prob = F.softmax(initial_scores, dim=1)[0, initial_class].item()

    # Create a deep copy of the input tensor
    perturbed_tensor = copy.deepcopy(input_tensor)

    # Initialize the importance vector
    importance_vector = []

    # Iterate through each patch in the input tensor
    for i in range(0, input_tensor.size(2), patch_size):
        for j in range(0, input_tensor.size(3), patch_size):
            # Blank-out each patch in the input tensor
            perturbed_tensor = copy.deepcopy(input_tensor)
            perturbed_tensor[0, :, i:i+patch_size, j:j+patch_size] = 0  # Set all values to zero for each patch

            # Get the prediction after blanking-out the patch
            perturbed_scores = model(perturbed_tensor)
            perturbed_prob = F.softmax(perturbed_scores, dim=1)[0, initial_class].item()

            # Compute the impact on the probability
            impact = initial_prob - perturbed_prob
            importance_vector.append(impact)

    return importance_vector

# # Usage example:
input_tensor = preprocess_image("examples/input.png", transform)
initial_class = get_prediction(scores)[0]
importance_vector_direct = compute_blankout_importance_conv(model, input_tensor, initial_class, patch_size=16, discard_ratio=DISCARD_RATIO)

# Blank out pipeline

In [12]:
file_blank_out = 'blank_out.txt'

with open(file_blank_out, 'w') as f:
    pass  # Opening in 'w' mode clears the file

In [13]:
path_prefix = 'images/ILSVRC2012_val_00000'
path_suffix = '.JPEG'
discard_ratio = 0.9
image_size = 224

def convert_number(number):
    if number < 10:
        return '00'+str(number)
    if number < 100:
        return '0'+str(number)
    else:
        return str(number)

# img = Image.open(path_prefix + image_number_converted + path_suffix)

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

In [14]:
for image_number in tqdm(range(1,251)):
    # print(f"Current image being treated: {image_number}")
    if image_number in [34, 107, 118, 126, 141, 223]:
        with open(file_blank_out, 'a') as f:
            f.write('\n') 
    else:
        image_number_converted = convert_number(image_number)
        image_path = path_prefix + image_number_converted + path_suffix
        input_tensor  = preprocess_image(image_path, transform)

        # Getting idx
        model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
        model.eval()
        model.to(DEVICE)
        scores = model(input_tensor)
        category_index, prob = get_prediction(scores)

        # initial_class = get_prediction(scores)[0]
        # importance_vector_direct = compute_blankout_importance_conv(model, input_tensor, category_index, patch_size=16, discard_ratio=DISCARD_RATIO)
        # importance_vector_direct = scipy.stats.rankdata(importance_vector_direct, method='average')
        
        with open(file_blank_out, 'a') as f:
            # np.savetxt(f, [importance_vector_direct], fmt='%.3f', delimiter=',') 
            np.savetxt(f, [prob], fmt='%.3f', delimiter=',') 
        

  0%|          | 0/250 [00:00<?, ?it/s]Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
  1%|          | 2/250 [00:00<00:18, 13.56it/s]Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
  2%|▏         | 4/250 [00:00<00:17, 13.89it/s]Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
  2%|▏         | 6/250 [00:00<00:17, 13.97it/s]Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
Using cache found in /users/eleves-a/2018/nicolas.lopes/.cache/torch/hub/facebookresearch_deit_main
  3%|▎         | 8/2