In [None]:
import os
import sys

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from transformers import BertTokenizer, BertForQuestionAnswering, BertConfig, AutoTokenizer, AutoModel

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients, Saliency, NoiseTunnel
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

##

In [None]:
def construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id):
    question_ids = tokenizer.encode(question, add_special_tokens=False)
    text_ids = tokenizer.encode(text, add_special_tokens=False)
    input_ids = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id]
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(question_ids) + [sep_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    sequence_length = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(sequence_length)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)
    
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    sequence_length = input_ids.size(1)
    position_ids = torch.arange(sequence_length, dtype=torch.long, device=device)
    ref_position_ids = torch.zeros(sequence_length, dtype=torch.long, device=device)
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def construct_bert_sub_embedding(input_ids, ref_input_ids, token_type_ids, ref_token_type_ids, position_ids, ref_position_ids):
    input_embeddings = interpretable_embedding1.indices_to_embeddings(input_ids)
    ref_input_embeddings = interpretable_embedding1.indices_to_embeddings(ref_input_ids)
    input_embeddings_token_type = interpretable_embedding2.indices_to_embeddings(token_type_ids)
    ref_input_embeddings_token_type = interpretable_embedding2.indices_to_embeddings(ref_token_type_ids)
    input_embeddings_position_ids = interpretable_embedding3.indices_to_embeddings(position_ids)
    ref_input_embeddings_position_ids = interpretable_embedding3.indices_to_embeddings(ref_position_ids)
    
    return (input_embeddings, ref_input_embeddings), (input_embeddings_token_type, ref_input_embeddings_token_type), (input_embeddings_position_ids, ref_input_embeddings_position_ids)
    
def construct_whole_bert_embeddings(input_ids, ref_input_ids, token_type_ids=None, ref_token_type_ids=None, position_ids=None, ref_position_ids=None):
    input_embeddings = interpretable_embedding.indices_to_embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = interpretable_embedding.indices_to_embeddings(ref_input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    
    return input_embeddings, ref_input_embeddings

##

In [None]:
def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
    prediction = predict(inputs, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask)
    prediction = prediction[position]
    
    return prediction.max(1).values

def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    return model(inputs, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask)

##

In [None]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    
    return attributions

def get_top_attributed_tokens(attrs, k=5):
    values, indices = torch.topk(attrs, k)
    top_tokens = [all_tokens[idx] for idx in indices]
    
    return top_tokens, values, indices

In [None]:
def process_squad(data_file):
    with open(data_file) as f:
        data = json.load(f)['data']
    rows = []
    for article in data:
        for paragraph in article['paragraphs']:
            context = paragraph['context']
            for qa in paragraph['qas']:
                id_, question, answers = qa['id'], qa['question'], qa['answers']
                answers = [a['text'] for a in answers]
                rows.append((id_, context, question, answers))
    return rows

In [None]:
dataset = process_squad("dev-v2.0.json")

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad', force_download=True)
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
model.to(device)
model.eval()
model.zero_grad()

In [None]:
for name, param in model.named_parameters(): 
    print(name)

In [None]:
model.bert.embeddings

In [None]:
ref_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id

##

In [None]:
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)

In [None]:
for example in dataset:
    text = example[1]
    question = example[2]
    answers = example[3]

    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
    attention_mask = construct_attention_mask(input_ids)

    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)

    ground_truth = answers[0]

    ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
    ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
    ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1

    start_scores, end_scores = predict(input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask)


    print('Question: ', question)
    print(start_scores)
    print(end_scores)
    print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

    attributions_start, delta_start = lig.attribute(inputs=input_ids, baselines=ref_input_ids, 
                                                    additional_forward_args=(token_type_ids, position_ids, attention_mask, 0),
                                                    return_convergence_delta=True)
    attributions_end, delta_end = lig.attribute(inputs=input_ids, baselines=ref_input_ids, 
                                                    additional_forward_args=(token_type_ids, position_ids, attention_mask, 1),
                                                    return_convergence_delta=True)
    
    attributions_start_sum = summarize_attributions(attributions_start)
    attributions_end_sum = summarize_attributions(attributions_end)


##

In [None]:
saliency = Saliency(squad_pos_forward_func)

In [None]:
for example in dataset:

    text = example[1]
    question = example[2]
    answers = example[3]

    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
    attention_mask = construct_attention_mask(input_ids)

    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)


    ground_truth = answers[0]

    ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
    ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
    ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1

    start_scores, end_scores = predict(input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask)


    print('Question: ', question)
    print(start_scores)
    print(end_scores)
    print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

    attributions_start = saliency.attribute(inputs=input_ids, 
                                            additional_forward_args=(token_type_ids, position_ids, attention_mask, 0))
    attributions_end = saliency.attribute(inputs=input_ids, baselines=ref_input_ids, 
                                            additional_forward_args=(token_type_ids, position_ids, attention_mask, 1))
    
    attributions_start_sum = summarize_attributions(attributions_start)
    attributions_end_sum = summarize_attributions(attributions_end)

##

In [None]:
nt = NoiseTunnel(ig)

In [None]:
for example in dataset:
    
    text = example[1]
    question = example[2]
    answers = example[3]

    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
    attention_mask = construct_attention_mask(input_ids)

    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)


    ground_truth = answers[0]

    ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
    ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
    ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1

    start_scores, end_scores = predict(input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask)

    print('Question: ', question)
    print(start_scores)
    print(end_scores)
    print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

    attributions_start = nt.attribute(inputs=input_ids, nt_type='smoothgrad', n_samples=5, additional_forward_args=(token_type_ids, position_ids, attention_mask, 0))
    attributions_end = nt.attribute(inputs=input_ids, nt_type='smoothgrad', n_samples=5, additional_forward_args=(token_type_ids, position_ids, attention_mask, 1))
    
    attributions_start_sum = summarize_attributions(attributions_start)
    attributions_end_sum = summarize_attributions(attributions_end)