In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from transformers_interpret import QuestionAnsweringExplainer
from captum.attr import IntegratedGradients
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


#### Lab 8: Using LIG with Bert Q/A

#### 1. Load model and do a forward pass (predictdion)

In [2]:
# Load BERT QA model (fine-tuned on SQuAD)
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
model = AutoModelForQuestionAnswering.from_pretrained(model_name, ignore_mismatched_sizes=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set model to evaluation mode
model.eval()


Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,)

In [3]:
# Define question and context
question = "What is the capital of Canada?"
context = "Canada is a country in North Americas. The capital of Canada is Ottawa, which is known as the most educated city in Canada."

# Tokenize input
inputs = tokenizer(question, context, return_tensors="pt")


In [4]:
# Get model predictions
with torch.no_grad():
    outputs = model(**inputs)
    start_logits = outputs.start_logits
    end_logits = outputs.end_logits

# Extract answer span
start_idx = torch.argmax(start_logits)
end_idx = torch.argmax(end_logits) + 1  # +1 to include last token

# Convert token IDs back to text
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start_idx:end_idx]))

print(f"Predicted Answer: {answer}")


Predicted Answer: ottawa


#### 2. Use transformers_interpret import QuestionAnsweringExplainer to explain the prediction with LIG

In [16]:
# Initialize explainer with attribution type "lig"
qa_explainer = QuestionAnsweringExplainer(model, tokenizer)

explanation = qa_explainer(question, context)
print(explanation)

qa_explainer.visualize()

{'start': [('[CLS]', 0.0), ('what', -0.11392290855452557), ('is', 0.03452253920574694), ('the', 0.2139910656517729), ('capital', 0.714033470911247), ('of', 0.059143124926587805), ('canada', -0.07868388118912466), ('?', -0.03704502063019229), ('[SEP]', -0.042119023140083235), ('canada', 0.3794967739696823), ('is', 0.0172761949051078), ('a', 0.02415347671108553), ('country', 0.18471678605127456), ('in', 0.0013970164737983945), ('north', -0.2049442967756021), ('americas', -0.004818462113978016), ('.', 0.10013573290858906), ('the', 0.10692145051581424), ('capital', -0.1613762533979355), ('of', -0.10500553710549256), ('canada', -0.08743274650250663), ('is', -0.01612008266256812), ('ottawa', -0.19308209237097976), (',', 0.0539853765185147), ('which', 0.006980638055265022), ('is', 0.05836702960244064), ('known', 0.04826204780089213), ('as', 0.037114067302668315), ('the', -0.008527785166652741), ('most', -0.15225481618405518), ('educated', 0.09353504476849521), ('city', -0.1598213930935432), (

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
ottawa (22),ottawa (6.30),ottawa (22),0.57,"[CLS] what is the capital of canada ? [SEP] canada is a country in north americas . the capital of canada is ottawa , which is known as the most educated city in canada . [SEP]"
,,,,
ottawa (22),ottawa (7.01),ottawa (22),0.84,"[CLS] what is the capital of canada ? [SEP] canada is a country in north americas . the capital of canada is ottawa , which is known as the most educated city in canada . [SEP]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
ottawa (22),ottawa (6.30),ottawa (22),0.57,"[CLS] what is the capital of canada ? [SEP] canada is a country in north americas . the capital of canada is ottawa , which is known as the most educated city in canada . [SEP]"
,,,,
ottawa (22),ottawa (7.01),ottawa (22),0.84,"[CLS] what is the capital of canada ? [SEP] canada is a country in north americas . the capital of canada is ottawa , which is known as the most educated city in canada . [SEP]"
,,,,
