## Interpreting PrivacyQA on Different Models with Captum
In this notebook we interpret the results of different models by visualize attention layers and embeddings.  

In [210]:
import numpy as np
import matplotlib.pyplot as plt

import torch

from transformers import AutoModelForSequenceClassification, AutoTokenizer

from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients

In [239]:
device = "cpu"

In [244]:
def visualize_embedding_attribution_for_model(model, question, text, true_label):
    # replace <PATH-TO-SAVED-MODEL> with the real path of the saved model
    model_path = f'../../../../ashankar/git/privacy-glue/runs/{model}/privacy_qa/seed_0/'

    # load model
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    model.to(device)
    model.eval()
    model.zero_grad()
    
    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # transform question and corresponding text
    question_ids = tokenizer.encode(question, add_special_tokens=False)
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    ref_token_id = tokenizer.pad_token_id # token used for generating token reference
    sep_token_id = tokenizer.sep_token_id # token used as a separator between question and text and it is also added to the end of the text.
    cls_token_id = tokenizer.cls_token_id # token used for prepending to the concatenated question-text word sequence

    # construct input token ids by concatenating question and text
    input_ids = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id]

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(question_ids) + [sep_token_id] + \
        [ref_token_id] * len(text_ids) + [sep_token_id]
    # make tensors
    input_ids, ref_input_ids = torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device)
    indices = input_ids[0].detach().tolist()

    # make prediction
    def forward_func(inputs):
        return model(inputs, attention_mask=torch.ones_like(inputs)).logits
    model_output = forward_func(input_ids)
    prediction_ind = np.argmax(model_output[0].detach().numpy()).item()
    label_itos = model.config.id2label
    print('Question: ', question)
    print('Text: ', text)
    print('Predicted Answer: ', f"{label_itos[prediction_ind]}")

    # attach the roberta embeddings if roBERTa model otherwise BERT
    if model.config.__class__.__name__.startswith("Roberta"):
        # remove these special characters representing spaces in roberta tokenizer
        all_tokens = [token.replace("Ġ","") for token in tokenizer.convert_ids_to_tokens(indices)]
        lig = LayerIntegratedGradients(forward_func, model.roberta.embeddings)
    else:
        all_tokens = tokenizer.convert_ids_to_tokens(indices)
        lig = LayerIntegratedGradients(forward_func, model.bert.embeddings)

    attributions, delta = lig.attribute(
        inputs=input_ids,
        baselines=ref_input_ids,
        target=prediction_ind,
        return_convergence_delta=True
    )
    # sum and normalize
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()
    vis_record = viz.VisualizationDataRecord(
        attributions,
        torch.max(torch.softmax(model_output[0], dim=0)),
        label_itos[prediction_ind],
        true_label,
        label_itos[prediction_ind],
        attributions.sum(),
        all_tokens,
        delta
    )
    return viz.visualize_text([vis_record])


In [245]:
question = 'are my statistics kept private?'
text = 'We will never share with or sell the information gained through the use of Apple HealthKit, such as age, weight and heart rate data, to advertisers or other agencies without your authorization.'

In [246]:
_ = visualize_embedding_attribution_for_model("bert_base_uncased", question, text, "Relevant")

Question:  are my statistics kept private?
Text:  We will never share with or sell the information gained through the use of Apple HealthKit, such as age, weight and heart rate data, to advertisers or other agencies without your authorization.
Predicted Answer:  Irrelevant


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Relevant,Irrelevant (1.00),Irrelevant,2.52,"[CLS] are my statistics kept private ? [SEP] we will never share with or sell the information gained through the use of apple health ##kit , such as age , weight and heart rate data , to ad ##vert ##iser ##s or other agencies without your authorization . [SEP]"
,,,,


In [247]:
_ = visualize_embedding_attribution_for_model("mukund_privbert", question, text, "Relevant")

Question:  are my statistics kept private?
Text:  We will never share with or sell the information gained through the use of Apple HealthKit, such as age, weight and heart rate data, to advertisers or other agencies without your authorization.
Predicted Answer:  Relevant


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Relevant,Relevant (0.99),Relevant,1.82,"#s are my statistics kept private ? #/s We will never share with or sell the information gained through the use of Apple Health Kit , such as age , weight and heart rate data , to advertisers or other agencies without your authorization . #/s"
,,,,
