In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import argparse
import torch

import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange

from diffmask.models.question_answering_squad import BertQuestionAnsweringSquad

from diffmask.attributions.integrated_gradient import squad_bert_integrated_gradient as integrated_gradient
from diffmask.attributions.schulz import bert_hidden_states_statistics, squad_bert_schulz_explainer
from diffmask.attributions.guan import squad_bert_guan_explainer

from diffmask.utils.plot import print_attributions

In [3]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpus", type=str, default="0")
    parser.add_argument("--model", type=str, default="bert-large-uncased-whole-word-masking-finetuned-squad")
    parser.add_argument(
        "--train_filename",
        type=str,
        default="./datasets/squad/train-v1.1_bert-large-uncased-whole-word-masking-finetuned-squad.json",
    )
    parser.add_argument(
        "--val_filename",
        type=str,
        default="./datasets/squad/dev-v1.1_bert-large-uncased-whole-word-masking-finetuned-squad.json",
    )
    parser.add_argument("--batch_size", type=int, default=12)
    parser.add_argument("--seed", type=int, default=0)

    hparams, _ = parser.parse_known_args()
    
    torch.manual_seed(hparams.seed)
    
    os.environ["CUDA_VISIBLE_DEVICES"] = hparams.gpus

In [4]:
device = "cuda:0"
model = BertQuestionAnsweringSquad(hparams).train().to(device)
model.prepare_data()
model.freeze()

HBox(children=(FloatProgress(value=0.0, max=87599.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))




In [5]:
question = "Where did the Broncos practice for the Super Bowl ?"
context = """The Panthers used the San Jose State practice facility and \
stayed at the San Jose Marriott . The Broncos practiced at Stanford University \
and stayed at the Santa Clara Marriott ."""

inputs_dict = {
    k: v.to(device)
    for k, v in model.tokenizer.encode_plus(
        question,
        context,
        max_length=384,
        pad_to_max_length=True,
        return_tensors="pt",
    ).items()
}
inputs_dict["mask"] = inputs_dict["attention_mask"]
inputs_dict["start_positions"] = torch.tensor([33], device=device)
inputs_dict["end_positions"] = torch.tensor([34], device=device)
del inputs_dict["attention_mask"]

question = model.tokenizer.tokenize(question)
context = model.tokenizer.tokenize(context)
tokens = ["[CLS]"]  + question + ["[SEP]"]  + context + ["[SEP]"]

In [6]:
i = 410
inputs_dict = {
    "input_ids": model.val_dataset[i][0].unsqueeze(0).to(device),
    "mask": model.val_dataset[i][1].unsqueeze(0).to(device),
    "token_type_ids": model.val_dataset[i][2].unsqueeze(0).to(device),
    "start_positions": (model.val_dataset[i][3].unsqueeze(0).to(device))[...,0],
    "end_positions": (model.val_dataset[i][3].unsqueeze(0).to(device) + 1)[...,0],
}
tokens = (
    ["[CLS]"] 
    + list(model.val_dataset_orig.values())[i]["question"] 
    + ["[SEP]"] 
    + list(model.val_dataset_orig.values())[i]["context"] 
    + ["[SEP]"]
 )
tokens = """
[CLS] Where did the Broncos practice for the Super Bowl ? 
[SEP] The Panthers used the San Jose State practice 
facility and stayed at the San Jose Marriott . The 
Broncos practiced at Stanford University and 
stayed at the Santa Clara Marriott . [SEP]
""".replace("\n", "").split(" ")

In [7]:
attributions_ig = integrated_gradient(
    model,
    inputs_dict,
    hidden_state_idx=0,
    steps=500,
)
attributions_ig = attributions_ig.sum(-1).abs()[0,:inputs_dict["mask"].sum()].cpu()

HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))




In [8]:
print_attributions(
    tokens, attributions_ig, special=False
)

[48;2;152;153;0mWhere[49m [48;2;244;244;0mdid[49m [48;2;255;255;239mthe[49m [48;2;255;255;168mBroncos[49m [48;2;254;255;142mpractice[49m [48;2;254;255;137mfor[49m [48;2;254;254;204mthe[49m [48;2;255;255;168mSuper[49m [48;2;254;255;132mBowl[49m [48;2;158;158;0m?[49m [48;2;255;255;234mThe[49m [48;2;254;254;76mPanthers[49m [48;2;254;254;214mused[49m [48;2;255;255;255mthe[49m [48;2;255;255;188mSan[49m [48;2;255;255;188mJose[49m [48;2;255;255;168mState[49m [48;2;254;255;137mpractice[49m [48;2;255;255;224mfacility[49m [48;2;255;255;234mand[49m [48;2;254;255;122mstayed[49m [48;2;254;255;153mat[49m [48;2;255;255;239mthe[49m [48;2;255;255;178mSan[49m [48;2;255;255;229mJose[49m [48;2;254;254;214mMarriott[49m [48;2;255;255;234m.[49m [48;2;255;255;239mThe[49m [48;2;193;193;0mBroncos[49m [48;2;254;255;112mpracticed[49m [48;2;254;254;66mat[49m [48;2;254;254;173mStanford[49m [48;2;255;255;209mUniversity[49m [48;2;255;255;249mand[49m 

In [9]:
all_q_z_loc, all_q_z_scale = bert_hidden_states_statistics(model, input_only=True)

HBox(children=(FloatProgress(value=0.0, max=7226.0), HTML(value='')))




In [10]:
attributions_schulz = squad_bert_schulz_explainer(
    model,
    inputs_dict,
    all_q_z_loc=all_q_z_loc,
    all_q_z_scale=all_q_z_scale,
    hidden_state_idx=0,
    steps=500,
    lr=1e-1,
    la=10,
)
attributions_schulz = attributions_schulz[0,:inputs_dict["mask"].sum()].cpu()

HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))




In [11]:
print_attributions(
    tokens, attributions_schulz, special=False
)

[48;2;254;255;20mWhere[49m [48;2;254;255;15mdid[49m [48;2;255;255;244mthe[49m [48;2;224;224;0mBroncos[49m [48;2;188;188;0mpractice[49m [48;2;254;255;96mfor[49m [48;2;254;255;81mthe[49m [48;2;255;255;224mSuper[49m [48;2;255;255;219mBowl[49m [48;2;224;224;0m?[49m [48;2;255;255;224mThe[49m [48;2;255;255;224mPanthers[49m [48;2;255;255;244mused[49m [48;2;255;255;239mthe[49m [48;2;255;255;234mSan[49m [48;2;255;255;219mJose[49m [48;2;255;255;229mState[49m [48;2;255;255;239mpractice[49m [48;2;255;255;239mfacility[49m [48;2;255;255;229mand[49m [48;2;255;255;229mstayed[49m [48;2;255;255;224mat[49m [48;2;255;255;229mthe[49m [48;2;254;254;214mSan[49m [48;2;255;255;229mJose[49m [48;2;255;255;234mMarriott[49m [48;2;255;255;209m.[49m [48;2;255;255;188mThe[49m [48;2;219;219;0mBroncos[49m [48;2;254;255;153mpracticed[49m [48;2;244;244;0mat[49m [48;2;219;219;0mStanford[49m [48;2;178;178;0mUniversity[49m [48;2;254;254;35mand[49m [48;2;15

In [12]:
attributions_guan = squad_bert_guan_explainer(
    model,
    inputs_dict,
    hidden_state_idx=0,
    steps=500,
    lr=1e-1,
    la=5
)
attributions_guan = 1 / attributions_guan[0,:inputs_dict["mask"].sum()].cpu()

HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))




In [13]:
print_attributions(
    tokens, attributions_guan, special=False
)

[48;2;255;255;255mWhere[49m [48;2;254;255;25mdid[49m [48;2;254;255;137mthe[49m [48;2;254;255;112mBroncos[49m [48;2;254;255;71mpractice[49m [48;2;255;255;255mfor[49m [48;2;254;255;132mthe[49m [48;2;255;255;255mSuper[49m [48;2;255;255;255mBowl[49m [48;2;254;255;122m?[49m [48;2;254;254;76mThe[49m [48;2;254;255;5mPanthers[49m [48;2;255;255;255mused[49m [48;2;255;255;219mthe[49m [48;2;254;255;142mSan[49m [48;2;255;255;255mJose[49m [48;2;254;255;158mState[49m [48;2;254;255;101mpractice[49m [48;2;254;255;112mfacility[49m [48;2;255;255;255mand[49m [48;2;255;255;255mstayed[49m [48;2;254;255;132mat[49m [48;2;254;254;183mthe[49m [48;2;255;255;255mSan[49m [48;2;255;255;255mJose[49m [48;2;255;255;255mMarriott[49m [48;2;255;255;255m.[49m [48;2;229;229;0mThe[49m [48;2;254;255;81mBroncos[49m [48;2;244;244;0mpracticed[49m [48;2;254;255;127mat[49m [48;2;254;255;142mStanford[49m [48;2;255;255;147mUniversity[49m [48;2;254;255;112mand[49m 

In [14]:
def gen_attention_masks(n, max_length):
    if n == 0:
        return [[0] * max_length]
    elif n == max_length:
        return [[1] * max_length]
    else:
        return [
            [0] + e for e in gen_attention_masks(n=n, max_length=max_length - 1)] + [
            [1] + e for e in gen_attention_masks(n=n - 1, max_length=max_length - 1)
        ]

In [15]:
logits_start_orig, logits_end_orig = model(**inputs_dict)
answer = [e.argmax().item() for e in model(**inputs_dict)]

for n in range(answer[1] - answer[0] + 1, 5):
    print("Trying N={}".format(n))
    all_mask = torch.tensor(gen_attention_masks(n=n, max_length=44))
    all_mask = all_mask[all_mask[:,answer[0]:answer[1] + 1].bool().all(-1)]
    
    solutions = []
    with torch.no_grad():
        t = tqdm(all_mask.split(128))
        for mask in t:

            logits_start, logits_end = model.net(**{
                "inputs_embeds": model.net.bert.embeddings.word_embeddings(
                    inputs_dict["input_ids"]
                ) * torch.nn.functional.pad(mask.to(device), (0, 384 - 44, 0, 0)).unsqueeze(-1),
                "attention_mask": inputs_dict["mask"],
                "token_type_ids": inputs_dict["token_type_ids"]
            })
            
            logits_start, logits_end, logits_start_orig, logits_end_orig = tuple(
                torch.where(inputs_dict['mask'].bool(), e, torch.full_like(e, -float("inf")))
                for e in (logits_start, logits_end, logits_start_orig, logits_end_orig)
            )

            loss_c = (
                logits_start.argmax(-1) == logits_start_orig.argmax(-1)
            ) & (
                logits_end.argmax(-1) == logits_end_orig.argmax(-1)
            )
            
            solutions += [e for e in mask[loss_c]]
            t.set_postfix(solutions=len(solutions))
    
    if len(solutions) == 0:
        print("No solutions with N={}".format(n))
    else:
        break

Trying N=2


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


No solutions with N=2
Trying N=3


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


No solutions with N=3
Trying N=4


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))




In [18]:
for e in solutions:
    print_attributions(tokens, e[:inputs_dict["mask"].sum()].cpu(), special=False)

[48;2;255;255;255mWhere[49m [48;2;255;255;255mdid[49m [48;2;255;255;255mthe[49m [48;2;255;255;255mBroncos[49m [48;2;255;255;255mpractice[49m [48;2;255;255;255mfor[49m [48;2;255;255;255mthe[49m [48;2;255;255;255mSuper[49m [48;2;255;255;255mBowl[49m [48;2;255;255;255m?[49m [48;2;255;255;255mThe[49m [48;2;255;255;255mPanthers[49m [48;2;255;255;255mused[49m [48;2;255;255;255mthe[49m [48;2;255;255;255mSan[49m [48;2;255;255;255mJose[49m [48;2;255;255;255mState[49m [48;2;255;255;255mpractice[49m [48;2;255;255;255mfacility[49m [48;2;255;255;255mand[49m [48;2;255;255;255mstayed[49m [48;2;255;255;255mat[49m [48;2;255;255;255mthe[49m [48;2;255;255;255mSan[49m [48;2;255;255;255mJose[49m [48;2;255;255;255mMarriott[49m [48;2;255;255;255m.[49m [48;2;255;255;255mThe[49m [48;2;255;255;255mBroncos[49m [48;2;255;255;255mpracticed[49m [48;2;255;255;255mat[49m [48;2;152;153;0mStanford[49m [48;2;152;153;0mUniversity[49m [48;2;255;255;255mand