In [None]:
import os, sys

sys.path.append(os.path.realpath('./pytorch-vqa'))
sys.path.append(os.path.realpath('./pytorch-resnet'))

import threading
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

import resnet

import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.colors import LinearSegmentedColormap

from model import Net, apply_attention, tile_2d_over_nd
from utils import get_transform

from captum.attr import IntegratedGradients, LayerConductance, Saliency, NoiseTunnel
from captum.attr import InterpretableEmbeddingBase, TokenReferenceBase
from captum.attr import visualization, configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class ResNetLayer4(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.r_model = resnet.resnet152(pretrained=True)
        self.r_model.eval()
        self.r_model.to(device)

        self.buffer = {}
        lock = threading.Lock()

        def save_output(module, input, output):
            with lock:
                self.buffer[output.device] = output

        self.r_model.layer4.register_forward_hook(save_output)

    def forward(self, x):
        self.r_model(x)          
        return self.buffer[x.device]

In [None]:
class VQA_Resnet_Model(Net):
    def __init__(self, embedding_tokens):
        super().__init__(embedding_tokens)
        self.resnet_layer4 = ResNetLayer4()
    
    def forward(self, v, q, q_len):
        q = self.text(q, list(q_len.data))
        v = self.resnet_layer4(v)

        v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8)

        a = self.attention(v, q)
        v = apply_attention(v, a)

        combined = torch.cat([v, q], dim=1)
        answer = self.classifier(combined)
        return answer

##

In [None]:
def image_to_features(img):
    img_transformed = transform(img)
    img_batch = img_transformed.unsqueeze(0).to(device)
    return img_batch


##

In [None]:
def vqa_resnet_interpret_ig(image_filename, questions, targets):
    img = Image.open(image_filename).convert('RGB')
    original_image = transforms.Compose([transforms.Scale(int(image_size / central_fraction)),
                                   transforms.CenterCrop(image_size), transforms.ToTensor()])(img) 
    
    image_features = image_to_features(img).requires_grad_().to(device)
    for question, target in zip(questions, targets):
        q, q_len = encode_question(question)
        q_input_embedding = interpretable_embedding.indices_to_embeddings(q).unsqueeze(0)

        # Making prediction. The output of prediction will be visualized later
        ans = vqa_resnet(image_features, q_input_embedding, q_len.unsqueeze(0))
        pred, answer_idx = F.softmax(ans, dim=1).data.cpu().max(dim=1)

        # generate reference for each sample
        q_reference_indices = token_reference.generate_reference(q_len.item(), 
                                                                 device=device).unsqueeze(0)
        q_reference = interpretable_embedding.indices_to_embeddings(q_reference_indices).to(device)
        attributions = ig.attribute(inputs=(image_features, q_input_embedding),
                                    baselines=(image_features * 0.0, q_reference),
                                    target=answer_idx,
                                    additional_forward_args=q_len.unsqueeze(0),
                                    n_steps=30)
        # Visualize text attributions
        text_attributions_norm = attributions[1].sum(dim=2).squeeze(0).norm()
        vis_data_records = [visualization.VisualizationDataRecord(
                                attributions[1].sum(dim=2).squeeze(0) / text_attributions_norm,
                                pred[0].item(),
                                answer_words[ answer_idx ],
                                answer_words[ answer_idx ],
                                target,
                                attributions[1].sum(),       
                                question.split(),
                                0.0)]
        visualization.visualize_text(vis_data_records)

        # visualize image attributions
        original_im_mat = np.transpose(original_image.cpu().detach().numpy(), (1, 2, 0))
        attr = np.transpose(attributions[0].squeeze(0).cpu().detach().numpy(), (1, 2, 0))
        
        visualization.visualize_image_attr_multiple(attr, original_im_mat, 
                                                    ["original_image", "heat_map"], ["all", "absolute_value"], 
                                                    titles=["Original Image", "Attribution Magnitude"],
                                                    cmap=default_cmap,
                                                    show_colorbar=True)
        print('Text Contributions: ', attributions[1].sum().item())
        print('Image Contributions: ', attributions[0].sum().item())
        print('Total Contribution: ', attributions[0].sum().item() + attributions[1].sum().item())
        


In [None]:
def vqa_resnet_interpret_saliency(image_filename, questions, targets):
    img = Image.open(image_filename).convert('RGB')
    original_image = transforms.Compose([transforms.Scale(int(image_size / central_fraction)),
                                   transforms.CenterCrop(image_size), transforms.ToTensor()])(img) 
    
    image_features = image_to_features(img).requires_grad_().to(device)
    for question, target in zip(questions, targets):
        q, q_len = encode_question(question)
        q_input_embedding = interpretable_embedding.indices_to_embeddings(q).unsqueeze(0)

        # Making prediction. The output of prediction will be visualized later
        ans = vqa_resnet(image_features, q_input_embedding, q_len.unsqueeze(0))
        pred, answer_idx = F.softmax(ans, dim=1).data.cpu().max(dim=1)

        # generate reference for each sample
        q_reference_indices = token_reference.generate_reference(q_len.item(), 
                                                                 device=device).unsqueeze(0)
        q_reference = interpretable_embedding.indices_to_embeddings(q_reference_indices).to(device)
        attributions = saliency.attribute(inputs=(image_features, q_input_embedding),
                                    baselines=(image_features * 0.0, q_reference),
                                    target=answer_idx,
                                    additional_forward_args=q_len.unsqueeze(0),
                                    n_steps=30)
        # Visualize text attributions
        text_attributions_norm = attributions[1].sum(dim=2).squeeze(0).norm()
        vis_data_records = [visualization.VisualizationDataRecord(
                                attributions[1].sum(dim=2).squeeze(0) / text_attributions_norm,
                                pred[0].item(),
                                answer_words[ answer_idx ],
                                answer_words[ answer_idx ],
                                target,
                                attributions[1].sum(),       
                                question.split(),
                                0.0)]
        visualization.visualize_text(vis_data_records)

        # visualize image attributions
        original_im_mat = np.transpose(original_image.cpu().detach().numpy(), (1, 2, 0))
        attr = np.transpose(attributions[0].squeeze(0).cpu().detach().numpy(), (1, 2, 0))
        
        visualization.visualize_image_attr_multiple(attr, original_im_mat, 
                                                    ["original_image", "heat_map"], ["all", "absolute_value"], 
                                                    titles=["Original Image", "Attribution Magnitude"],
                                                    cmap=default_cmap,
                                                    show_colorbar=True)
        print('Text Contributions: ', attributions[1].sum().item())
        print('Image Contributions: ', attributions[0].sum().item())
        print('Total Contribution: ', attributions[0].sum().item() + attributions[1].sum().item())
        



In [None]:
def vqa_resnet_interpret_sg(image_filename, questions, targets):
    img = Image.open(image_filename).convert('RGB')
    original_image = transforms.Compose([transforms.Scale(int(image_size / central_fraction)),
                                   transforms.CenterCrop(image_size), transforms.ToTensor()])(img) 
    
    image_features = image_to_features(img).requires_grad_().to(device)
    for question, target in zip(questions, targets):
        q, q_len = encode_question(question)
        q_input_embedding = interpretable_embedding.indices_to_embeddings(q).unsqueeze(0)

        # Making prediction. The output of prediction will be visualized later
        ans = vqa_resnet(image_features, q_input_embedding, q_len.unsqueeze(0))
        pred, answer_idx = F.softmax(ans, dim=1).data.cpu().max(dim=1)

        # generate reference for each sample
        q_reference_indices = token_reference.generate_reference(q_len.item(), 
                                                                 device=device).unsqueeze(0)
        q_reference = interpretable_embedding.indices_to_embeddings(q_reference_indices).to(device)
        attributions = sg.attribute(inputs=(image_features, q_input_embedding),
                                    target=answer_idx,
                                    additional_forward_args=q_len.unsqueeze(0),
                                    nt_type='smoothgrad',
                                    n_samples=10)
        # Visualize text attributions
        text_attributions_norm = attributions[1].sum(dim=2).squeeze(0).norm()
        vis_data_records = [visualization.VisualizationDataRecord(
                                attributions[1].sum(dim=2).squeeze(0) / text_attributions_norm,
                                pred[0].item(),
                                answer_words[ answer_idx ],
                                answer_words[ answer_idx ],
                                target,
                                attributions[1].sum(),       
                                question.split(),
                                0.0)]
        visualization.visualize_text(vis_data_records)

        # visualize image attributions
        original_im_mat = np.transpose(original_image.cpu().detach().numpy(), (1, 2, 0))
        attr = np.transpose(attributions[0].squeeze(0).cpu().detach().numpy(), (1, 2, 0))
        
        visualization.visualize_image_attr_multiple(attr, original_im_mat, 
                                                    ["original_image", "heat_map"], ["all", "absolute_value"], 
                                                    titles=["Original Image", "Attribution Magnitude"],
                                                    cmap=default_cmap,
                                                    show_colorbar=True)
        print('Text Contributions: ', attributions[1].sum().item())
        print('Image Contributions: ', attributions[0].sum().item())
        print('Total Contribution: ', attributions[0].sum().item() + attributions[1].sum().item())
        



In [None]:
saved_state = torch.load('./2017-08-04_00.55.19.pth', map_location=device)

vocab = saved_state['vocab']
token_to_index = vocab['question']
answer_to_index = vocab['answer']
num_tokens = len(token_to_index) + 1

answer_words = ['unk'] * len(answer_to_index)
for w, idx in answer_to_index.items():
    answer_words[idx]=w

In [None]:
vqa_net = torch.nn.DataParallel(Net(num_tokens))
vqa_net.load_state_dict(saved_state['weights'])
vqa_net.to(device)
vqa_net.eval()

In [None]:
def encode_question(question):
    question_arr = question.lower().split()
    vec = torch.zeros(len(question_arr), device=device).long()
    for i, token in enumerate(question_arr):
        index = token_to_index.get(token, 0)
        vec[i] = index
    return vec, torch.tensor(len(question_arr), device=device)

In [None]:
vqa_resnet = VQA_Resnet_Model(vqa_net.module.text.embedding.num_embeddings)
vqa_resnet = torch.nn.DataParallel(vqa_resnet)
partial_dict = vqa_net.state_dict()
state = vqa_resnet.state_dict()
state.update(partial_dict)
vqa_resnet.load_state_dict(state)
vqa_resnet.to(device)
vqa_resnet.eval()

In [None]:
image_size = 448
central_fraction = 1.0
transform = get_transform(image_size, central_fraction=central_fraction)

In [None]:
interpretable_embedding = configure_interpretable_embedding_layer(vqa_resnet, 'module.text.embedding')


In [None]:
PAD_IND = token_to_index['pad']
token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)

In [None]:
torch.backends.cudnn.enabled=False

In [None]:
ig = IntegratedGradients(vqa_resnet)
saliency = Saliency(vqa_resnet)
nt = NoiseTunnel(ig)

In [None]:
default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#252b36'),
                                                  (1, '#000000')], N=256)

In [None]:
images = ['./siamese.jpg',
          './captum/tutorials/img/vqa/elephant.jpg',
          './captum/tutorials/img/vqa/zebra.jpg']

In [None]:
image_idx = 0 # cat

vqa_resnet_interpret(images[image_idx], [
    "what is on the picture",
    "what color are the cat's eyes",
    "is the animal in the picture a cat or a fox",
    "what color is the cat",
    "how many ears does the cat have",
    "where is the cat"
], ['cat', 'blue', 'cat', 'white and brown', '2', 'at the wall'])