### Set GPU ID

In [1]:
print('hi')

import torch
torch.cuda.set_device(7)

hi


  from .autonotebook import tqdm as notebook_tqdm


### Load HANS Dataset

In [2]:
from datasets import load_dataset
import json
import os
from tqdm import tqdm

dataset = load_dataset('hans')

Found cached dataset hans (/home/cwkang/.cache/huggingface/datasets/hans/plain_text/1.0.0/452e93cf5383f5ae39088254215b517d0da98ccaaf0af8f7ab04d8f23f67dbd9)
100%|██████████| 2/2 [00:00<00:00, 453.12it/s]


### Preprocess Function

In [3]:
def preprocess_function(examples, tokenizer):
    max_seq_length = min(128, tokenizer.model_max_length)
    sentence1_key, sentence2_key = ("premise", "hypothesis")
    label_to_id = {0:0, 1:1, 2:2}

    # Tokenize the texts
    args = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )
    result = tokenizer(*args, padding=False, max_length=max_seq_length, truncation=True)

    # Map labels to IDs (not necessary for GLUE tasks)
    if label_to_id is not None and "label" in examples:
        result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
    # for k, v in result.items():
    #     result[k] = torch.tensor(v).to('cuda')
    return result

### Load Model

In [9]:
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)

# model_name_or_path = "results/mnli/bert-base-uncased"
model_name_or_path = "ishan/bert-base-uncased-mnli"

config = AutoConfig.from_pretrained(
    model_name_or_path,
    num_labels=3,
    finetuning_task='mnli',
    use_auth_token=None,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    use_auth_token=None,
)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name_or_path,
    from_tf=bool(".ckpt" in model_name_or_path),
    config=config,
    use_auth_token=None,
    ignore_mismatched_sizes=False,
)

# import torch
# if hasattr(torch, 'compile'):
#     model = torch.compile(model)
model = model.to('cuda')

### Prepare Dataset

In [10]:
from functools import partial

eval_dataset = dataset["validation"].map(
    partial(preprocess_function, tokenizer=tokenizer),
    batched=True,
    # remove_columns=dataset["train"].column_names,
    desc="Running tokenizer on dataset",
)

Loading cached processed dataset at /home/cwkang/.cache/huggingface/datasets/hans/plain_text/1.0.0/452e93cf5383f5ae39088254215b517d0da98ccaaf0af8f7ab04d8f23f67dbd9/cache-e02ed34140c2f999.arrow


### Run Model on HANS Dataset

In [12]:
from collections import defaultdict

# mnli_label_dict = ['entailment', 'neutral', 'contradiction']
mnli_label_dict = ["contradiction", "entailment", "neutral"]
hans_label_dict = ['entailment', 'nonentailment']

e_correct = 0
n_correct = 0
e_total = 0
n_total = 0

attention_dict = {}
key_count = defaultdict(int)
for inputs in tqdm(eval_dataset):
    sample = tokenizer.pad(
        inputs,
        padding="max_length",
        max_length=min(128, tokenizer.model_max_length),
        pad_to_multiple_of=None,
        return_tensors=None,
    )
    if "label" in sample:
        sample["labels"] = sample["label"]
        del sample["label"]

    inputs = {k: torch.tensor([v]).to("cuda") for k, v in sample.items() if k in ["input_ids", "token_type_ids", "attention_mask"]}
    output = model(**inputs, output_attentions=True)
    predictions = output.logits.argmax(dim=-1)
    predictions_ = mnli_label_dict[predictions]
    references = sample['labels']
    references_ = hans_label_dict[references]

    key = sample["template"] + "_" + predictions_
    if key not in attention_dict:
        attention_dict[key] = []
        for layer_idx, weights in enumerate(output.attentions):
            attention_dict[key].append(weights.detach().cpu())
    else:
        for layer_idx, weights in enumerate(output.attentions):
            attention_dict[key][layer_idx] = attention_dict[key][layer_idx] + weights.detach().cpu()
    key_count[key] += 1

    if references_ == "entailment":
        e_total += 1
        if predictions_ == "entailment":
            e_correct += 1
    else:
        n_total += 1
        if predictions_ != "entailment":
            n_correct += 1

    # print(sample, '\n')
    # print('pred:', predictions.item(), predictions_)
    # print('answer:', references, references_)

    # print(torch.tensor(output.attentions).shape)
    # print(output.attentions[0].shape)
    # print(len(output.attentions))

for key in attention_dict:
    for layer_idx, weights in enumerate(output.attentions):
        attention_dict[key][layer_idx] /= key_count[key]

100%|██████████| 30000/30000 [07:27<00:00, 67.00it/s]


In [13]:
print(e_correct, e_total)
print(n_correct, n_total)

14850 15000
2165 15000


### Attention intervention with hook

In [14]:
from intervention_hook.intervention import bert_attention_intervention

In [15]:
for key in attention_dict:
    for layer_idx, weights in enumerate(output.attentions):
        attention_dict[key][layer_idx] = attention_dict[key][layer_idx].to("cuda")

In [16]:
intervention_e_correct = 0
intervention_n_correct = 0
intervention_e_total = 0
intervention_n_total = 0

intervention_count = 0
for inputs in tqdm(eval_dataset):
    sample = tokenizer.pad(
        inputs,
        padding="max_length",
        max_length=min(128, tokenizer.model_max_length),
        pad_to_multiple_of=None,
        return_tensors=None,
    )
    if "label" in sample:
        sample["labels"] = sample["label"]
        del sample["label"]

    inputs = {k: torch.tensor([v]).to("cuda") for k, v in sample.items() if k in ["input_ids", "token_type_ids", "attention_mask"]}

    references = sample['labels']
    references_ = hans_label_dict[references]
    keys = []
    if references_ == "nonentailment":
        if sample["template"] + "_contradiction" in attention_dict:
            keys.append(sample["template"] + "_contradiction")
        if sample["template"] + "_neutral" in attention_dict:
            keys.append(sample["template"] + "_neutral")
    elif sample["template"] + "_entailment" in attention_dict:
        keys.append(sample["template"] + "_entailment")
    key = keys[0] if len(keys) > 0 else None

    if key is not None:
        output = bert_attention_intervention(model, inputs, attention_dict[key])
        intervention_count += 1
    else:
        output = model(**inputs, output_attentions=True)

    predictions = output.logits.argmax(dim=-1)
    predictions_ = mnli_label_dict[predictions]

    if references_ == "entailment":
        intervention_e_total += 1
        if predictions_ == "entailment":
            intervention_e_correct += 1
    else:
        intervention_n_total += 1
        if predictions_ != "entailment":
            intervention_n_correct += 1

    # print(sample, '\n')
    # print('pred:', predictions.item(), predictions_)
    # print('answer:', references, references_)

    # print(torch.tensor(output.attentions).shape)
    # print(output.attentions[0].shape)
    # print(len(output.attentions))


100%|██████████| 30000/30000 [06:54<00:00, 72.31it/s]


In [17]:
print(intervention_count)

29034


In [18]:
print(intervention_e_correct, intervention_e_total)
print(intervention_n_correct, intervention_n_total)

14440 15000
6517 15000
