In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import datasets
from tqdm.notebook import tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding
)
import pickle
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, HTML

def preprocess_function_wrapped(tokenizer):
    def preprocess_function(examples):
        # Tokenize the texts
        args = (
            (examples[SENTENCE1_KEY],) if SENTENCE2_KEY is None else (examples[SENTENCE1_KEY], examples[SENTENCE2_KEY])
        )
        result = tokenizer(*args, padding=False, max_length=MAX_LENGTH, truncation=True)
        return result
    return preprocess_function

def token_id_to_tokens_mapper(tokenizer, sample):
    length = len(sample["input_ids"])
    return tokenizer.convert_ids_to_tokens(sample["input_ids"])[:length], length

def load_globenc(path, no_cls=False, no_sep=False):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    if no_cls:
        data["tokens"] = [d[1:] for d in data["tokens"]]
        data["globenc"] = [np.array(d)[:, 1:] for d in data["globenc"]]
    if no_sep:
        data["tokens"] = [d[:-1] for d in data["tokens"]]
        data["globenc"] = [np.array(d)[:, :-1] for d in data["globenc"]]
    data = pd.DataFrame(data)
    before_size = len(data)
    data = data[data["tokens"].map(len) > 1]
    after_size = len(data)
    print(f"Read {path}: {before_size}->{after_size} ")
    return data.to_dict(orient="list"), data.index

def load_pickle(path, no_cls=False, no_sep=False):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    return data

def print_globenc(globenc, tokenized_text, discrete=False, prefix="", del_ratio=0.0):
    if len(globenc.shape) == 2:
        globenc = np.expand_dims(globenc, axis=0)
    norm_cls = globenc[:, 0, :]
    norm_cls = np.flip(norm_cls, axis=0)
    row_sums = norm_cls.max(axis=1)
    norm_cls = norm_cls / row_sums[:, np.newaxis]
    html = prefix
    if discrete:
        cls_attention = np.argsort(np.argsort(norm_cls[0, :])) / len(norm_cls[0, :])
    else:
        cls_attention = norm_cls[0, :]
    for i in range(len(tokenized_text)):
        del_count = np.floor(len(norm_cls[0, :]) * del_ratio)
        ranks = np.argsort(np.argsort(norm_cls[0, :]))
        if len(ranks) - ranks[i] > del_count:
            color = f"background-color: rgba({cls_attention[i]*255}, {cls_attention[i]*255}, 0, {cls_attention[i] / 1.5}); "
        else:
            color = f"background-color: rgba({cls_attention[i]*255}, 0, 0, {cls_attention[i] / 1.5}); "
        html += (f"<span style='"
                 f"{color}"
#                  f"background-color: rgba(200, {cls_attention[i]*255}, 10, 1.0); "
#                  f"font-size: {int(cls_attention[i]*18 + 1)}px; "
#                  f"font-weight: {int(cls_attention[i]*900)};"
                 f"font-weight: {int(800)};"
                 "'>")
        html += tokenized_text[i]
        html += "</span> "
    display(HTML(html))

In [None]:
MODEL_PATH =  "/home/modaresi/projects/globenc_analysis/outputs/models/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-61360"
TASK = "mnli"
SET = "validation_matched"  # train/validation/validation_matched

# MODEL_PATH = "/home/modaresi/projects/globenc_analysis/outputs/models/output_sst2_bert-large-uncased_0001_SEED0042/checkpoint-10525"
# TASK = "sst2"
# SET = "validation"

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

In [None]:
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]

BATCH_SIZE = 32
MAX_LENGTH = 128

actual_task = "mnli" if TASK == "mnli-mm" else TASK
dataset = datasets.load_dataset("glue", actual_task)
metric = datasets.load_metric('glue', actual_task)
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mnli-mm": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}
SENTENCE1_KEY, SENTENCE2_KEY = task_to_keys[TASK]
dataset

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model.to(torch.device("cuda:0"))
model.eval()

def preprocess(e):
    e["premise"] = e["premise"].replace(".", "")
    e["hypothesis"] = e["hypothesis"].replace(".", "")
    return e

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=True, max_length=MAX_LENGTH)

sel_dataset = dataset[SET].map(preprocess)
sel_dataset = sel_dataset.map(preprocess_function_wrapped(tokenizer), batched=True, batch_size=1024)
dataset_size = len(sel_dataset)
print(dataset_size)

dataset[SET][10]

In [None]:
print(sel_dataset[10])

# Compute Faithfulness

In [None]:
HTAs_dir = "/home/modaresi/projects/globenc_analysis/outputs/HTAs"
globencs_v2_dir = "/home/modaresi/projects/globenc_analysis/outputs/globencs_elementwise"
globencs_dir = "/home/modaresi/projects/globenc_analysis/outputs/globencs"
saliencies_dir = "/home/modaresi/projects/globenc_analysis/outputs/saliencies"
configs = {
    "mnli-val": {
        "hta_path": lambda epoch : f"{HTAs_dir}/mnli_validation_matched_bert-base-uncased_0001_SEED0042_checkpoint-{epoch*12272}.pkl",
        "globenc_path": lambda epoch : f"{globencs_dir}/mnli-e{epoch}_validation_matched_bert-base-uncased.pickle",
        "globenc_v2_path": lambda epoch : f"{globencs_v2_dir}/mnli-e{epoch}_validation_matched_bert-base-uncased.pickle",
        "saliency_path": lambda epoch : f"{saliencies_dir}/mnli_bert-base-uncased_0001_SEED0042_checkpoint-{epoch*12272}.npy",
        "hf_ds": "mnli",
    },
    "sst2-val": {
        "hta_path": lambda epoch : f"{HTAs_dir}/sst2_validation_bert-base-uncased_0001_SEED0042_checkpoint-{epoch*2105}.pkl",
        "globenc_path": lambda epoch : f"{globencs_dir}/sst2-e{epoch}_validation_bert-base-uncased.pickle",
        "globenc_v2_path": lambda epoch : f"{globencs_v2_dir}/sst2-e{epoch}_validation_bert-base-uncased.pickle",
        "saliency_path": lambda epoch : f"{saliencies_dir}/sst2_bert-base-uncased_0001_SEED0042_checkpoint-{epoch*2105}.npy",
        "hf_ds": "sst2",
    }
}

CONFIG_NAME = "mnli-val"
# CONFIG_NAME = "sst2-val"
CONFIG = configs[CONFIG_NAME]
EPOCH = 5

globencs, DATASET_KEEP_IDX = load_globenc(CONFIG["globenc_path"](EPOCH), no_cls=False, no_sep=False)
globencs_v2, DATASET_KEEP_IDX = load_globenc(CONFIG["globenc_v2_path"](EPOCH), no_cls=False, no_sep=False)
htas = load_pickle(CONFIG["hta_path"](EPOCH))
saliencies = np.load(CONFIG["saliency_path"](EPOCH))

In [None]:
results = {}
INVERTED = False  # True=MaskMax, False=MaskMin
MASK = True
for exp_type in ["globencV1", "globencV2", "salsNorm"]:
    print("Masking based on", exp_type)
    results[exp_type] = dict()
    for i in [0, 7]:
        results[exp_type][f"{i*10}%"] = {"preds": [], "correct": [], "logits": []}
        # Masks i*10% of the tokens -- based on their attribution metric value
        def mapping_masks(example):
            length = np.sum(example["attention_mask"])
            if exp_type == "globencV1":
                sal_rank = globencs['globenc'][example["idx"]][0][:length].argsort()
            elif exp_type == "globencV2":
                sal_rank = globencs_v2['globenc'][example["idx"]][0][:length].argsort()
            elif exp_type == "salsNorm":
                sal_rank = saliencies[example["idx"]][:length].argsort()
            elif exp_type == "hta":
                sal_rank = htas["HTAs"][example["idx"]][0][:length].argsort()
            # Exclude CLS and SEPs
            sal_rank = sal_rank[~np.in1d(sal_rank, np.argwhere(np.array(example["input_ids"]) < 103).flatten())]
            mask_count = int(np.floor(len(sal_rank) * i / 10.0))
            if mask_count == 0:
                masks = []
            else:
                if INVERTED:
                    masks = sal_rank[-mask_count:]
                else:
                    masks = sal_rank[:mask_count]
            replacement_token = tokenizer.mask_token_id if MASK else tokenizer.pad_token_id
            if not MASK:
                example["attention_mask"] = [0 if j in masks else example["attention_mask"][j] for j in range(length)]
            example["input_ids"] = [replacement_token if j in masks else example["input_ids"][j] for j in range(length)]
            return example
        
        modified_set = sel_dataset.map(mapping_masks)

        modified_set.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
        collator = DataCollatorWithPadding(tokenizer, True, MAX_LENGTH, return_tensors="pt")
        dataloader = torch.utils.data.DataLoader(modified_set, batch_size=BATCH_SIZE, collate_fn=collator, shuffle=False)
        steps = int(np.ceil(dataset_size / BATCH_SIZE))
        num_labels = len(set(modified_set['label']))

        it = iter(dataloader)
        y_preds = torch.zeros(size=(dataset_size,)).cuda()
        y_trues = torch.zeros(size=(dataset_size,), dtype=torch.int32).cuda()
        y_logits = list()
        with torch.no_grad():
            for j in tqdm(range(steps), desc=f"{exp_type}-{i*10}%"):
                batch = next(it)
                batch = {k: v.to(torch.device('cuda:0')) for k, v in batch.items()}
                inputs = {
                    'input_ids': batch['input_ids'],
                    'attention_mask': batch['attention_mask'],
                    'token_type_ids': batch['token_type_ids'],
                }
                y_trues[j*BATCH_SIZE:(j+1)*BATCH_SIZE] = batch['labels']
                output = model(**batch)
                y_preds[j*BATCH_SIZE:(j+1)*BATCH_SIZE] = torch.argmax(output.logits, dim=-1)
                y_logits.extend(output.logits.cpu().numpy())
        
        results[exp_type][f"{i*10}%"]["correct"] = (y_trues == y_preds).cpu().numpy()
        results[exp_type][f"{i*10}%"]["preds"] = y_preds.cpu().numpy()
        results[exp_type][f"{i*10}%"]["logits"] = np.array(y_logits)
        replacement_token = tokenizer.mask_token_id if MASK else tokenizer.pad_token_id
        results[exp_type][f"{i*10}%"]["modified_set"] = modified_set["input_ids"]
        # print(results[exp_type][-1])

In [None]:
P = "70%"
labels = sel_dataset["label"]
for method in ["globencV1", "globencV2", "salsNorm"]:
    idxs = (labels == results[method]["0%"]["preds"])
    print(method, np.mean(results[method]["0%"]["correct"][idxs]))
    print(method, np.mean(results[method][P]["correct"][idxs]))

In [None]:
assert (results["globencV1"]["0%"]["correct"] == results["globencV2"]["0%"]["correct"]).all()

diff_indices = results["globencV1"][P]["correct"] != results["globencV2"][P]["correct"]
diff_indices_correct_preds = diff_indices & (labels == results["globencV1"]["0%"]["preds"])
v1_error_idxs = np.where(diff_indices_correct_preds & (results["globencV1"][P]["correct"] == False))[0]
v2_error_idxs = np.where(diff_indices_correct_preds & (results["globencV2"][P]["correct"] == False))[0]
sal_error_idxs = np.where(diff_indices_correct_preds & (results["salsNorm"][P]["correct"] == False))[0]

In [None]:
for idx in v1_error_idxs[:]:
    if len(globencs["tokens"][idx]) > 20 or len(globencs["tokens"][idx]) < 18:
        continue
#     idx = 5705
    print("idx:", idx)
    print("LABEL:", sel_dataset["label"][idx])
    print(f"G_V1 Pred: {results['globencV1']['0%']['preds'][idx]}->{results['globencV1'][P]['preds'][idx]} {results['globencV1']['0%']['logits'][idx]}->{results['globencV1'][P]['logits'][idx]}")
    print(f"G_V2 Pred: {results['globencV2']['0%']['preds'][idx]}->{results['globencV2'][P]['preds'][idx]} {results['globencV2']['0%']['logits'][idx]}->{results['globencV2'][P]['logits'][idx]}")
    print(f"salN Pred: {results['salsNorm']['0%']['preds'][idx]}->{results['salsNorm'][P]['preds'][idx]} {results['salsNorm']['0%']['logits'][idx]}->{results['salsNorm'][P]['logits'][idx]}")
    # print_globenc(globencs["globenc"][idx], globencs["tokens"][idx], prefix=f"g_v1: ")
    # print_globenc(globencs_v2["globenc"][idx], globencs["tokens"][idx], prefix=f"g_v2: ")
    length = len(globencs["tokens"][idx][1:-1])
    print_globenc(globencs["globenc"][idx][:, 1:-1], globencs["tokens"][idx][1:-1], prefix=f"G_V1: ", discrete=True, del_ratio=0.33)
    print_globenc(globencs_v2["globenc"][idx][:, 1:-1], globencs_v2["tokens"][idx][1:-1], prefix=f"G_V2: ", discrete=True, del_ratio=0.33)
    print_globenc(np.expand_dims(saliencies[idx][1:length+1], 0), globencs_v2["tokens"][idx][1:-1], prefix=f"S_V0: ", discrete=True, del_ratio=0.33)
    ### REAL MASKED VERSIONS:
    print_globenc(globencs["globenc"][idx][:, 1:-1], tokenizer.convert_ids_to_tokens(results["globencV1"][P]["modified_set"][idx][1:-1]), prefix=f"G_V1: ", discrete=True, del_ratio=0.33)
    print_globenc(globencs_v2["globenc"][idx][:, 1:-1], tokenizer.convert_ids_to_tokens(results["globencV2"][P]["modified_set"][idx][1:-1]), prefix=f"G_V2: ", discrete=True, del_ratio=0.33)
    print_globenc(np.expand_dims(saliencies[idx][1:length+1], 0), tokenizer.convert_ids_to_tokens(results["salsNorm"][P]["modified_set"][idx][1:-1]), prefix=f"S_V0: ", discrete=True, del_ratio=0.33)

In [None]:
idx = 5705
print(tokenizer.convert_ids_to_tokens(results["globencV1"]["30%"]["modified_set"][idx]))
print(tokenizer.convert_ids_to_tokens(results["globencV2"]["30%"]["modified_set"][idx]))

In [None]:
results["globencV1"]["30%"]["modified_set"][idx]

In [None]:
results["globencV2"]["30%"]["modified_set"][idx]