In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import torch

In [3]:
import os
import sys
module_path = os.path.join(os.getcwd(), '..')
sys.path.append(module_path)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "deepset/roberta-large-squad2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name).to(device)

In [5]:
from lrp_engine import LRPEngine, checkpoint_hook
from lrp_engine.lrp_graph import make_graph

In [6]:
lrp = LRPEngine(topk=1, use_gamma=True)

In [7]:
context = "Welcome to the final examination for this term's offering of CS100. Please remove all headphones and earbuds, as well as hats and hoods. Place your bag under your desk so that it does not block the aisle. You are permitted writing instruments, a clear water bottle, and any aids listed on the front of your booklet. The exam will be 150 minutes in duration. You may now begin."
question = "What is this?"

In [9]:
input_ids = tokenizer(question, context, return_tensors="pt")["input_ids"]

In [10]:
output = model(input_ids.to(device))

In [13]:
g = make_graph(output.start_logits)
g[2]

{'AccumulateGrad',
 'AddBackward0',
 'CloneBackward0',
 'CopySlices',
 'EmbeddingBackward0',
 'GeluBackward0',
 'MmBackward0',
 'NativeLayerNormBackward0',
 'PermuteBackward0',
 'ScaledDotProductEfficientAttentionBackward0',
 'SplitBackward0',
 'SqueezeBackward1',
 'TBackward0',
 'TransposeBackward0',
 'ViewBackward0'}

In [11]:
start = torch.argmax(output.start_logits)
end = torch.argmax(output.end_logits) + 1

answer_tokens = input_ids[0][start:end]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
print(answer)

 final examination


# Dummy Test (Don't run if you want to run the full SQuADv2 evaluation below)

In [12]:
checkpoint_vals, param_vals = lrp.run((output.start_logits, output.end_logits))

In [13]:
checkpoint_vals1, param_vals1 = lrp.run((output.start_logits, output.end_logits))

In [14]:
# Check results diff between 1st pass and 2nd pass on same input (should be ~0)
[
    ((p1 - p0)**2).sum() for p1, p0 in zip(param_vals1, param_vals)
]

[tensor(0., device='cuda:0'),
 tensor(0., device='cuda:0'),
 tensor(0., device='cuda:0')]

In [15]:
# Top 5 tokens from LRP
lrp_answer_ids = input_ids[0][param_vals[1].flatten().topk(k=5).indices.cpu()]
print(tokenizer.decode(lrp_answer_ids))

 well under listed 150s


# SQuADv2 Evaluation

In [7]:
lrp.starting_relevance=None
lrp.topk=3

In [8]:
from datasets import load_dataset

dataset = load_dataset("squad_v2")

In [9]:
from tqdm import tqdm

results = []
top1_label_hits = 0
top1_model_hits = 0
total_examples = 0
total_intersect = 0
total_union = 0

for example in tqdm(dataset["validation"]):
    question = example["question"]
    context = example["context"]
    answers = example["answers"]["text"]

    if not answers:
        continue

    input_ids = tokenizer(context + " " + question, return_tensors="pt")["input_ids"].to(device)
    if input_ids.shape[-1] > 512:
        continue

    input_ids = tokenizer(question, context, return_tensors="pt")["input_ids"]
    output = model(input_ids.to(device))

    start = torch.argmax(output.start_logits[:,1:]) + 1
    end = torch.argmax(output.end_logits[:,1:]) + 2
    if start > end:
        continue
    model_answer = tokenizer.decode(input_ids[0][start:end], skip_special_tokens=True)
    if model_answer == "":
        continue
    # output = run_flan_encode_decode(input_ids)

    lrp.topk = end - start
    if (end - start) > output.start_logits.shape[-1]:
        print(start, end)
        print(output.start_logits.shape)
        break

    lrp_input1, lrp_input2 = output.start_logits, output.end_logits
    lrp_input1[0][0] = 0
    lrp_input2[0][0] = 0
    checkpoint_vals, param_vals = lrp.run((lrp_input1, lrp_input2))

    lrp_max = param_vals[1].flatten()[1:].argmax() + 1
    lrp_top_token = tokenizer.convert_ids_to_tokens([input_ids[0][lrp_max]])[0].strip().replace(chr(9601), "")

    # Do model answer-based accuracy, i.e. is the attribution aligned with the model prediction
    if start <= lrp_max <= end:
        top1_model_hits += 1

    # Do IoU with the model prediction
    intersect = 0
    union = lrp.topk
    for top_ind in param_vals[1].flatten()[1:].topk(lrp.topk).indices:
        top_ind = top_ind + 1
        if start <= top_ind <= end:
            intersect += 1
        else:
            union += 1
    total_intersect += intersect
    total_union += union

    # Do label-based accuracy
    if any(lrp_top_token in ans for ans in answers):
        # Is the attribution aligned with the ground truth label
        top1_label_hits += 1
    total_examples += 1
    if not (total_examples % 100):
        print(top1_model_hits, top1_label_hits, total_examples)
        print(total_intersect / total_union)
    # lrp_top5 = param_vals[2].flatten().topk(k=5)
    # lrp_top5_tokens = tokenizer.decode(input_ids[0][lrp_top5.indices.cpu()])
    # results.append({
    #     "example": example,
    #     "lrp_top5_tokens": lrp_top5_tokens,
    #     "lrp_top5_relevances": lrp_top5.values.cpu(),
    #     "is_impossible": len(answers) == 0
    # })

  2%|█▍                                                                            | 217/11873 [00:27<18:04, 10.74it/s]

81 23 100
tensor(0.5099, device='cuda:0')


  4%|██▊                                                                           | 429/11873 [00:52<20:07,  9.48it/s]

167 28 200
tensor(0.5025, device='cuda:0')


  5%|████▏                                                                         | 643/11873 [01:18<20:53,  8.96it/s]

253 41 300
tensor(0.4810, device='cuda:0')


  7%|█████▍                                                                        | 829/11873 [01:43<28:06,  6.55it/s]

345 56 400
tensor(0.5099, device='cuda:0')


  9%|██████▋                                                                      | 1022/11873 [02:09<26:42,  6.77it/s]

430 69 500
tensor(0.5234, device='cuda:0')


 10%|███████▊                                                                     | 1214/11873 [02:35<27:11,  6.53it/s]

513 78 600
tensor(0.5235, device='cuda:0')


 12%|█████████▏                                                                   | 1419/11873 [03:01<17:34,  9.91it/s]

602 97 700
tensor(0.5179, device='cuda:0')


 14%|██████████▍                                                                  | 1614/11873 [03:28<35:09,  4.86it/s]

682 119 800
tensor(0.5045, device='cuda:0')


 15%|███████████▊                                                                 | 1812/11873 [03:56<18:00,  9.31it/s]

764 137 900
tensor(0.4990, device='cuda:0')


 17%|█████████████                                                                | 2016/11873 [04:23<23:05,  7.12it/s]

855 147 1000
tensor(0.5067, device='cuda:0')


 19%|██████████████▍                                                              | 2227/11873 [04:49<16:45,  9.59it/s]

945 160 1100
tensor(0.5122, device='cuda:0')


 20%|███████████████▍                                                             | 2390/11873 [05:16<20:50,  7.59it/s]

1023 176 1200
tensor(0.5084, device='cuda:0')


 22%|████████████████▋                                                            | 2570/11873 [05:43<23:16,  6.66it/s]

1101 189 1300
tensor(0.5029, device='cuda:0')


 23%|██████████████████                                                           | 2790/11873 [06:09<18:28,  8.19it/s]

1189 201 1400
tensor(0.4985, device='cuda:0')


 25%|███████████████████▍                                                         | 2991/11873 [06:36<23:22,  6.33it/s]

1280 208 1500
tensor(0.4958, device='cuda:0')


 26%|████████████████████                                                         | 3100/11873 [07:00<47:17,  3.09it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (586 > 512). Running this sequence through the model will result in indexing errors
 27%|████████████████████▌                                                        | 3177/11873 [07:08<16:06,  9.00it/s]

1361 213 1600
tensor(0.4916, device='cuda:0')


 29%|██████████████████████                                                       | 3411/11873 [07:39<23:24,  6.02it/s]

1447 223 1700
tensor(0.4917, device='cuda:0')


 30%|███████████████████████                                                      | 3555/11873 [08:05<18:38,  7.44it/s]

1534 246 1800
tensor(0.4913, device='cuda:0')


 31%|████████████████████████▏                                                    | 3728/11873 [08:33<27:19,  4.97it/s]

1608 262 1900
tensor(0.4885, device='cuda:0')


 33%|█████████████████████████▌                                                   | 3933/11873 [09:01<25:02,  5.28it/s]

1691 286 2000
tensor(0.4824, device='cuda:0')


 35%|██████████████████████████▊                                                  | 4137/11873 [09:28<25:28,  5.06it/s]

1775 300 2100
tensor(0.4851, device='cuda:0')


 37%|████████████████████████████▊                                                | 4438/11873 [09:57<05:42, 21.71it/s]

1836 316 2200
tensor(0.4679, device='cuda:0')


 40%|██████████████████████████████▌                                              | 4714/11873 [10:24<19:39,  6.07it/s]

1906 327 2300
tensor(0.4564, device='cuda:0')


 42%|████████████████████████████████▏                                            | 4958/11873 [10:51<07:04, 16.28it/s]

1986 343 2400
tensor(0.4557, device='cuda:0')


 44%|██████████████████████████████████                                           | 5245/11873 [11:17<14:35,  7.57it/s]

2071 362 2500
tensor(0.4582, device='cuda:0')


 46%|███████████████████████████████████▎                                         | 5439/11873 [11:45<11:28,  9.35it/s]

2149 369 2600
tensor(0.4552, device='cuda:0')


 47%|████████████████████████████████████▏                                        | 5589/11873 [12:11<25:12,  4.16it/s]

2228 386 2700
tensor(0.4535, device='cuda:0')


 49%|█████████████████████████████████████▍                                       | 5776/11873 [12:40<09:02, 11.23it/s]

2301 392 2800
tensor(0.4465, device='cuda:0')


 51%|██████████████████████████████████████▉                                      | 6012/11873 [13:06<12:15,  7.97it/s]

2391 402 2900
tensor(0.4496, device='cuda:0')


 52%|████████████████████████████████████████▎                                    | 6223/11873 [13:41<26:33,  3.54it/s]

2483 422 3000
tensor(0.4522, device='cuda:0')


 54%|█████████████████████████████████████████▉                                   | 6458/11873 [14:27<13:16,  6.80it/s]

2566 444 3100
tensor(0.4547, device='cuda:0')


 56%|███████████████████████████████████████████▏                                 | 6661/11873 [15:13<26:14,  3.31it/s]

2651 454 3200
tensor(0.4558, device='cuda:0')


 57%|████████████████████████████████████████████▎                                | 6826/11873 [15:59<35:05,  2.40it/s]

2735 464 3300
tensor(0.4576, device='cuda:0')


 59%|█████████████████████████████████████████████▍                               | 7000/11873 [16:48<17:50,  4.55it/s]

2825 476 3400
tensor(0.4592, device='cuda:0')


 61%|██████████████████████████████████████████████▋                              | 7200/11873 [17:35<22:19,  3.49it/s]

2910 478 3500
tensor(0.4604, device='cuda:0')


 62%|████████████████████████████████████████████████                             | 7403/11873 [18:21<21:22,  3.48it/s]

3002 491 3600
tensor(0.4622, device='cuda:0')


 64%|█████████████████████████████████████████████████▎                           | 7611/11873 [19:08<11:39,  6.10it/s]

3090 515 3700
tensor(0.4632, device='cuda:0')


 66%|██████████████████████████████████████████████████▋                          | 7807/11873 [19:54<13:52,  4.88it/s]

3168 538 3800
tensor(0.4643, device='cuda:0')


 67%|███████████████████████████████████████████████████▉                         | 8002/11873 [20:40<15:42,  4.11it/s]

3250 561 3900
tensor(0.4638, device='cuda:0')


 69%|█████████████████████████████████████████████████████▏                       | 8210/11873 [21:26<11:16,  5.41it/s]

3332 591 4000
tensor(0.4650, device='cuda:0')


 71%|██████████████████████████████████████████████████████▋                      | 8429/11873 [22:13<09:37,  5.97it/s]

3412 618 4100
tensor(0.4664, device='cuda:0')


 73%|████████████████████████████████████████████████████████                     | 8639/11873 [22:59<08:52,  6.07it/s]

3496 634 4200
tensor(0.4661, device='cuda:0')


 75%|█████████████████████████████████████████████████████████▍                   | 8848/11873 [23:46<08:19,  6.06it/s]

3578 643 4300
tensor(0.4661, device='cuda:0')


 76%|██████████████████████████████████████████████████████████▌                  | 9037/11873 [24:32<19:41,  2.40it/s]

3656 659 4400
tensor(0.4660, device='cuda:0')


 77%|███████████████████████████████████████████████████████████▋                 | 9195/11873 [25:19<09:31,  4.68it/s]

3746 670 4500
tensor(0.4673, device='cuda:0')


 79%|████████████████████████████████████████████████████████████▊                | 9379/11873 [26:05<07:41,  5.40it/s]

3832 699 4600
tensor(0.4675, device='cuda:0')


 81%|██████████████████████████████████████████████████████████████               | 9572/11873 [26:51<11:57,  3.21it/s]

3922 709 4700
tensor(0.4670, device='cuda:0')


 82%|███████████████████████████████████████████████████████████████▎             | 9765/11873 [27:38<11:55,  2.95it/s]

4006 716 4800
tensor(0.4669, device='cuda:0')


 84%|████████████████████████████████████████████████████████████████▌            | 9960/11873 [28:24<05:11,  6.15it/s]

4095 724 4900
tensor(0.4669, device='cuda:0')


 85%|████████████████████████████████████████████████████████████████▉           | 10149/11873 [29:10<08:53,  3.23it/s]

4182 732 5000
tensor(0.4675, device='cuda:0')


 87%|██████████████████████████████████████████████████████████████████▎         | 10354/11873 [29:57<05:09,  4.91it/s]

4266 738 5100
tensor(0.4674, device='cuda:0')


 89%|███████████████████████████████████████████████████████████████████▌        | 10550/11873 [30:43<04:46,  4.61it/s]

4348 750 5200
tensor(0.4672, device='cuda:0')


 90%|████████████████████████████████████████████████████████████████████▊       | 10741/11873 [31:30<05:22,  3.51it/s]

4434 769 5300
tensor(0.4681, device='cuda:0')


 92%|██████████████████████████████████████████████████████████████████████      | 10947/11873 [32:16<03:21,  4.60it/s]

4517 783 5400
tensor(0.4687, device='cuda:0')


 94%|███████████████████████████████████████████████████████████████████████▍    | 11157/11873 [33:02<01:49,  6.53it/s]

4599 797 5500
tensor(0.4692, device='cuda:0')


 96%|████████████████████████████████████████████████████████████████████████▉   | 11394/11873 [33:48<01:32,  5.20it/s]

4672 808 5600
tensor(0.4667, device='cuda:0')


 98%|██████████████████████████████████████████████████████████████████████████▎ | 11603/11873 [34:35<01:18,  3.43it/s]

4753 821 5700
tensor(0.4653, device='cuda:0')


 99%|███████████████████████████████████████████████████████████████████████████▌| 11798/11873 [35:21<00:13,  5.67it/s]

4843 831 5800
tensor(0.4658, device='cuda:0')


100%|████████████████████████████████████████████████████████████████████████████| 11873/11873 [35:40<00:00,  5.55it/s]


In [None]:
from tqdm import tqdm

results = []

for example in tqdm(dataset["validation"].select(range(100))):
    question = example["question"]
    context = example["context"]
    answers = example["answers"]["text"]

    input_ids = tokenizer(question, context, return_tensors="pt")["input_ids"]
    output = model(input_ids.to(device))

    start = torch.argmax(output.start_logits)
    end = torch.argmax(output.end_logits) + 1
    model_answer = tokenizer.decode(input_ids[0][start:end], skip_special_tokens=True)

    checkpoint_vals, param_vals = lrp.run((output.start_logits, output.end_logits))

    lrp_top5 = param_vals[1].flatten().topk(k=5)
    lrp_top5_tokens = tokenizer.decode(input_ids[0][lrp_top5.indices.cpu()], skip_special_tokens=True)
    lrp_start_end = param_vals[1].flatten().topk(k=2).indices.cpu().sort()
    lrp_start = lrp_start_end[0][0]
    if lrp_start == 0:
        lrp_start = lrp_start_end[0][1]
    lrp_end = lrp_start_end[0][-1]
    lrp_answer_ids = input_ids[0][lrp_start:lrp_end + 1]
    lrp_answer = tokenizer.decode(lrp_answer_ids, skip_special_tokens=True)

    results.append({
        "example": example,
        "model_answer": model_answer,
        "lrp_answer": lrp_answer,
        "lrp_top5_tokens": lrp_top5_tokens,
        "lrp_top5_relevances": lrp_top5.values.cpu(),
        "is_impossible": len(answers) == 0
    })

In [None]:
for res in results:
    print("Q: ", res["example"]["question"])
    print("A (labels): ", res["example"]["answers"]["text"])
    print("Model: ", res["model_answer"])
    print("LRP: ", res["lrp_answer"])
    print("LRP top5: ", res["lrp_top5_tokens"])
    print("LRP top5 attributions: ", res["lrp_top5_relevances"], '\n')