In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import os
import sys
module_path = os.path.join(os.getcwd(), '..')
sys.path.append(module_path)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "deepset/roberta-large-squad2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name).to(device)

In [5]:
from lrp_engine import LRPEngine, checkpoint_hook

In [6]:
context = "Welcome to the final examination for this term's offering of CS100. Please remove all headphones and earbuds, as well as hats and hoods. Place your bag under your desk so that it does not block the aisle. You are permitted writing instruments, a clear water bottle, and any aids listed on the front of your booklet. The exam will be 150 minutes in duration. You may now begin."
question = "What is this?"

In [7]:
input_ids = tokenizer(question, context, return_tensors="pt")["input_ids"]

In [8]:
output = model(input_ids.to(device))

In [9]:
start = torch.argmax(output.start_logits)
end = torch.argmax(output.end_logits) + 1

answer_tokens = input_ids[0][start:end]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
print(answer)

 final examination


In [10]:
lrp = LRPEngine(topk=1)

In [11]:
lrp.starting_relevance=None
lrp.topk=3

In [12]:
checkpoint_vals, param_vals = lrp.run((output.start_logits, output.end_logits))

In [13]:
checkpoint_vals1, param_vals1 = lrp.run((output.start_logits, output.end_logits))

In [14]:
# Check results diff between 1st pass and 2nd pass on same input (should be ~0)
[
    ((p1 - p0)**2).sum() for p1, p0 in zip(param_vals1, param_vals)
]

[tensor(1.3388e-16, device='cuda:0'),
 tensor(3.5755e-25, device='cuda:0'),
 tensor(4.1561e-17, device='cuda:0')]

In [15]:
# Top 5 tokens from LRP
lrp_answer_ids = input_ids[0][param_vals[1].flatten().topk(k=5).indices.cpu()]
print(tokenizer.decode(lrp_answer_ids))

 examination CS100. final


In [6]:
from datasets import load_dataset

dataset = load_dataset("squad_v2")

In [12]:
from tqdm import tqdm

results = []
top1_label_hits = 0
top1_model_hits = 0
total_examples = 0
total_intersect = 0
total_union = 0

for example in tqdm(dataset["validation"]):
    question = example["question"]
    context = example["context"]
    answers = example["answers"]["text"]

    if not answers:
        continue

    input_ids = tokenizer(context + " " + question, return_tensors="pt")["input_ids"].to(device)
    if input_ids.shape[-1] > 512:
        continue

    input_ids = tokenizer(question, context, return_tensors="pt")["input_ids"]
    output = model(input_ids.to(device))

    start = torch.argmax(output.start_logits[:,1:]) + 1
    end = torch.argmax(output.end_logits[:,1:]) + 2
    if start > end:
        continue
    model_answer = tokenizer.decode(input_ids[0][start:end], skip_special_tokens=True)
    if model_answer == "":
        continue
    # output = run_flan_encode_decode(input_ids)

    lrp.topk = end - start
    if (end - start) > output.start_logits.shape[-1]:
        print(start, end)
        print(output.start_logits.shape)
        break

    lrp_input1, lrp_input2 = output.start_logits, output.end_logits
    lrp_input1[0][0] = 0
    lrp_input2[0][0] = 0
    checkpoint_vals, param_vals = lrp.run((lrp_input1, lrp_input2))

    lrp_max = param_vals[1].flatten()[1:].argmax() + 1
    lrp_top_token = tokenizer.convert_ids_to_tokens([input_ids[0][lrp_max]])[0].strip().replace(chr(9601), "")

    # Do model answer-based accuracy, i.e. is the attribution aligned with the model prediction
    if start <= lrp_max <= end:
        top1_model_hits += 1

    # Do IoU with the model prediction
    intersect = 0
    union = lrp.topk
    for top_ind in param_vals[1].flatten()[1:].topk(lrp.topk).indices:
        top_ind = top_ind + 1
        if start <= top_ind <= end:
            intersect += 1
        else:
            union += 1
    total_intersect += intersect
    total_union += union

    # Do label-based accuracy
    if any(lrp_top_token in ans for ans in answers):
        # Is the attribution aligned with the ground truth label
        top1_label_hits += 1
    total_examples += 1
    if not (total_examples % 100):
        print(top1_model_hits, top1_label_hits, total_examples)
        print(total_intersect / total_union)
    # lrp_top5 = param_vals[2].flatten().topk(k=5)
    # lrp_top5_tokens = tokenizer.decode(input_ids[0][lrp_top5.indices.cpu()])
    # results.append({
    #     "example": example,
    #     "lrp_top5_tokens": lrp_top5_tokens,
    #     "lrp_top5_relevances": lrp_top5.values.cpu(),
    #     "is_impossible": len(answers) == 0
    # })

  2%|█▍                                                                            | 219/11873 [00:16<17:15, 11.25it/s]

54 10 100
tensor(0.2482, device='cuda:0')


  4%|██▊                                                                           | 434/11873 [00:32<11:56, 15.96it/s]

113 16 200
tensor(0.2589, device='cuda:0')


  5%|████▏                                                                         | 643/11873 [00:49<14:05, 13.29it/s]

178 25 300
tensor(0.2613, device='cuda:0')


  7%|█████▍                                                                        | 835/11873 [01:05<11:49, 15.56it/s]

245 31 400
tensor(0.2959, device='cuda:0')


  9%|██████▋                                                                      | 1022/11873 [01:21<17:37, 10.26it/s]

312 41 500
tensor(0.2999, device='cuda:0')


 10%|███████▉                                                                     | 1220/11873 [01:38<11:47, 15.05it/s]

373 46 600
tensor(0.2972, device='cuda:0')


 12%|█████████▏                                                                   | 1419/11873 [01:54<11:17, 15.42it/s]

433 58 700
tensor(0.2901, device='cuda:0')


 14%|██████████▍                                                                  | 1619/11873 [02:11<11:05, 15.40it/s]

491 71 800
tensor(0.2854, device='cuda:0')


 15%|███████████▊                                                                 | 1812/11873 [02:27<10:39, 15.73it/s]

548 86 900
tensor(0.2801, device='cuda:0')


 17%|█████████████                                                                | 2022/11873 [02:43<11:10, 14.69it/s]

618 93 1000
tensor(0.2853, device='cuda:0')


 19%|██████████████▍                                                              | 2227/11873 [02:58<10:34, 15.20it/s]

684 105 1100
tensor(0.2869, device='cuda:0')


 20%|███████████████▌                                                             | 2391/11873 [03:14<13:51, 11.41it/s]

754 114 1200
tensor(0.2859, device='cuda:0')


 22%|████████████████▋                                                            | 2571/11873 [03:29<15:45,  9.84it/s]

813 125 1300
tensor(0.2863, device='cuda:0')


 24%|██████████████████                                                           | 2792/11873 [03:45<10:11, 14.85it/s]

873 133 1400
tensor(0.2819, device='cuda:0')


 25%|███████████████████▍                                                         | 2990/11873 [04:01<09:50, 15.05it/s]

942 138 1500
tensor(0.2786, device='cuda:0')


 26%|████████████████████                                                         | 3100/11873 [04:13<20:28,  7.14it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (586 > 512). Running this sequence through the model will result in indexing errors
 27%|████████████████████▌                                                        | 3178/11873 [04:18<10:17, 14.08it/s]

995 143 1600
tensor(0.2741, device='cuda:0')


 29%|██████████████████████                                                       | 3411/11873 [04:35<15:58,  8.83it/s]

1061 153 1700
tensor(0.2740, device='cuda:0')


 30%|███████████████████████                                                      | 3556/11873 [04:52<13:07, 10.56it/s]

1124 166 1800
tensor(0.2719, device='cuda:0')


 31%|████████████████████████▏                                                    | 3729/11873 [05:08<16:22,  8.29it/s]

1176 176 1900
tensor(0.2708, device='cuda:0')


 33%|█████████████████████████▌                                                   | 3933/11873 [05:23<11:09, 11.86it/s]

1227 193 2000
tensor(0.2667, device='cuda:0')


 35%|██████████████████████████▊                                                  | 4143/11873 [05:40<08:40, 14.84it/s]

1288 202 2100
tensor(0.2659, device='cuda:0')


 37%|████████████████████████████▊                                                | 4440/11873 [05:57<04:53, 25.34it/s]

1337 214 2200
tensor(0.2596, device='cuda:0')


 40%|██████████████████████████████▌                                              | 4720/11873 [06:13<06:32, 18.21it/s]

1393 223 2300
tensor(0.2537, device='cuda:0')


 42%|████████████████████████████████▏                                            | 4958/11873 [06:29<04:58, 23.15it/s]

1450 233 2400
tensor(0.2531, device='cuda:0')


 44%|██████████████████████████████████                                           | 5247/11873 [06:45<11:24,  9.68it/s]

1508 244 2500
tensor(0.2532, device='cuda:0')


 46%|███████████████████████████████████▎                                         | 5439/11873 [07:01<07:09, 14.99it/s]

1561 252 2600
tensor(0.2536, device='cuda:0')


 47%|████████████████████████████████████▎                                        | 5595/11873 [07:17<07:51, 13.31it/s]

1618 263 2700
tensor(0.2533, device='cuda:0')


 49%|█████████████████████████████████████▍                                       | 5776/11873 [07:33<06:08, 16.53it/s]

1679 264 2800
tensor(0.2523, device='cuda:0')


 51%|███████████████████████████████████████                                      | 6017/11873 [07:50<06:14, 15.65it/s]

1743 271 2900
tensor(0.2535, device='cuda:0')


 52%|████████████████████████████████████████▍                                    | 6229/11873 [08:05<05:26, 17.31it/s]

1814 283 3000
tensor(0.2552, device='cuda:0')


 54%|█████████████████████████████████████████▉                                   | 6458/11873 [08:21<05:23, 16.71it/s]

1882 300 3100
tensor(0.2565, device='cuda:0')


 56%|███████████████████████████████████████████▏                                 | 6666/11873 [08:37<06:02, 14.38it/s]

1942 304 3200
tensor(0.2569, device='cuda:0')


 58%|████████████████████████████████████████████▎                                | 6827/11873 [08:53<10:12,  8.24it/s]

2001 310 3300
tensor(0.2598, device='cuda:0')


 59%|█████████████████████████████████████████████▍                               | 6999/11873 [09:08<05:00, 16.20it/s]

2076 321 3400
tensor(0.2612, device='cuda:0')


 61%|██████████████████████████████████████████████▋                              | 7202/11873 [09:24<06:00, 12.96it/s]

2140 323 3500
tensor(0.2620, device='cuda:0')


 62%|████████████████████████████████████████████████                             | 7409/11873 [09:40<05:20, 13.94it/s]

2212 328 3600
tensor(0.2628, device='cuda:0')


 64%|█████████████████████████████████████████████████▎                           | 7611/11873 [09:56<04:27, 15.93it/s]

2270 342 3700
tensor(0.2630, device='cuda:0')


 66%|██████████████████████████████████████████████████▋                          | 7808/11873 [10:12<05:41, 11.90it/s]

2325 358 3800
tensor(0.2627, device='cuda:0')


 67%|███████████████████████████████████████████████████▉                         | 8004/11873 [10:28<05:06, 12.64it/s]

2385 374 3900
tensor(0.2622, device='cuda:0')


 69%|█████████████████████████████████████████████████████▏                       | 8210/11873 [10:43<03:53, 15.69it/s]

2441 390 4000
tensor(0.2630, device='cuda:0')


 71%|██████████████████████████████████████████████████████▋                      | 8429/11873 [10:59<04:03, 14.15it/s]

2496 406 4100
tensor(0.2621, device='cuda:0')


 73%|████████████████████████████████████████████████████████                     | 8639/11873 [11:15<03:20, 16.09it/s]

2562 416 4200
tensor(0.2619, device='cuda:0')


 75%|█████████████████████████████████████████████████████████▍                   | 8848/11873 [11:31<03:10, 15.86it/s]

2619 427 4300
tensor(0.2610, device='cuda:0')


 76%|██████████████████████████████████████████████████████████▌                  | 9037/11873 [11:47<05:29,  8.60it/s]

2676 440 4400
tensor(0.2605, device='cuda:0')


 77%|███████████████████████████████████████████████████████████▋                 | 9196/11873 [12:02<04:06, 10.84it/s]

2731 447 4500
tensor(0.2610, device='cuda:0')


 79%|████████████████████████████████████████████████████████████▊                | 9379/11873 [12:18<03:10, 13.09it/s]

2791 467 4600
tensor(0.2607, device='cuda:0')


 81%|██████████████████████████████████████████████████████████████               | 9578/11873 [12:34<02:17, 16.73it/s]

2854 474 4700
tensor(0.2608, device='cuda:0')


 82%|███████████████████████████████████████████████████████████████▎             | 9771/11873 [12:50<02:27, 14.24it/s]

2922 479 4800
tensor(0.2606, device='cuda:0')


 84%|████████████████████████████████████████████████████████████████▌            | 9960/11873 [13:06<01:55, 16.58it/s]

2994 484 4900
tensor(0.2610, device='cuda:0')


 85%|████████████████████████████████████████████████████████████████▉           | 10150/11873 [13:22<02:53,  9.93it/s]

3069 488 5000
tensor(0.2619, device='cuda:0')


 87%|██████████████████████████████████████████████████████████████████▎         | 10353/11873 [13:37<01:24, 18.04it/s]

3134 492 5100
tensor(0.2625, device='cuda:0')


 89%|███████████████████████████████████████████████████████████████████▌        | 10550/11873 [13:53<01:51, 11.83it/s]

3193 501 5200
tensor(0.2626, device='cuda:0')


 90%|████████████████████████████████████████████████████████████████████▊       | 10743/11873 [14:09<01:35, 11.89it/s]

3255 518 5300
tensor(0.2625, device='cuda:0')


 92%|██████████████████████████████████████████████████████████████████████      | 10946/11873 [14:24<01:05, 14.25it/s]

3314 524 5400
tensor(0.2623, device='cuda:0')


 94%|███████████████████████████████████████████████████████████████████████▍    | 11157/11873 [14:40<00:45, 15.87it/s]

3377 534 5500
tensor(0.2620, device='cuda:0')


 96%|████████████████████████████████████████████████████████████████████████▉   | 11393/11873 [14:56<00:26, 17.88it/s]

3424 545 5600
tensor(0.2610, device='cuda:0')


 98%|██████████████████████████████████████████████████████████████████████████▎ | 11605/11873 [15:12<00:23, 11.51it/s]

3474 552 5700
tensor(0.2603, device='cuda:0')


 99%|███████████████████████████████████████████████████████████████████████████▌| 11798/11873 [15:27<00:05, 14.01it/s]

3544 560 5800
tensor(0.2607, device='cuda:0')


100%|████████████████████████████████████████████████████████████████████████████| 11873/11873 [15:34<00:00, 12.71it/s]


In [18]:
from tqdm import tqdm

results = []

for example in tqdm(dataset["validation"].select(range(100))):
    question = example["question"]
    context = example["context"]
    answers = example["answers"]["text"]

    input_ids = tokenizer(question, context, return_tensors="pt")["input_ids"]
    output = model(input_ids.to(device))

    start = torch.argmax(output.start_logits)
    end = torch.argmax(output.end_logits) + 1
    model_answer = tokenizer.decode(input_ids[0][start:end], skip_special_tokens=True)

    checkpoint_vals, param_vals = lrp.run((output.start_logits, output.end_logits))

    lrp_top5 = param_vals[1].flatten().topk(k=5)
    lrp_top5_tokens = tokenizer.decode(input_ids[0][lrp_top5.indices.cpu()], skip_special_tokens=True)
    lrp_start_end = param_vals[1].flatten().topk(k=2).indices.cpu().sort()
    lrp_start = lrp_start_end[0][0]
    if lrp_start == 0:
        lrp_start = lrp_start_end[0][1]
    lrp_end = lrp_start_end[0][-1]
    lrp_answer_ids = input_ids[0][lrp_start:lrp_end + 1]
    lrp_answer = tokenizer.decode(lrp_answer_ids, skip_special_tokens=True)

    results.append({
        "example": example,
        "model_answer": model_answer,
        "lrp_answer": lrp_answer,
        "lrp_top5_tokens": lrp_top5_tokens,
        "lrp_top5_relevances": lrp_top5.values.cpu(),
        "is_impossible": len(answers) == 0
    })

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:17<00:00,  5.76it/s]


In [19]:
for res in results:
    print("Q: ", res["example"]["question"])
    print("A (labels): ", res["example"]["answers"]["text"])
    print("Model: ", res["model_answer"])
    print("LRP: ", res["lrp_answer"])
    print("LRP top5: ", res["lrp_top5_tokens"])
    print("LRP top5 attributions: ", res["lrp_top5_relevances"], '\n')

Q:  In what country is Normandy located?
A (labels):  ['France', 'France', 'France', 'France']
Model:   France
LRP:   France
LRP top5:   Franceimilation NormandyG
LRP top5 attributions:  tensor([8.0164e-07, 1.1496e-07, 6.0129e-08, 5.0107e-08, 4.3468e-08]) 

Q:  When were the Normans in Normandy?
A (labels):  ['10th and 11th centuries', 'in the 10th and 11th centuries', '10th and 11th centuries', '10th and 11th centuries']
Model:   10th and 11th centuries
LRP:   centuries gave their name to Normandy, a region in France. They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Franc
LRP top5:   centuries Franc centuries 10
LRP top5 attributions:  tensor([3.6011e-07, 1.4652e-07, 1.3250e-07, 7.8557e-08, 7.7170e-08]) 

Q:  From which countries did the Norse originate?
A (labels):  ['Denmark, Iceland and Norway', 'Denmark, Iceland and Norway', 'Denmark