In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from scipy.stats import pearsonr

import datasets
from tqdm.notebook import tqdm

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding
)

In [2]:
MODELS = {
    # "mnli-e0": "bert-base-uncased",
    "mnli-e1": "/home/modaresi/projects/globenc_analysis/outputs/models/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-12272",
    "mnli-e2": "/home/modaresi/projects/globenc_analysis/outputs/models/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-24544",
    "mnli-e3": "/home/modaresi/projects/globenc_analysis/outputs/models/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-36816",
    "mnli-e4": "/home/modaresi/projects/globenc_analysis/outputs/models/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-49088",
    "mnli-e5": "/home/modaresi/projects/globenc_analysis/outputs/models/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-61360",
}

TASK = "mnli"
 
SET = "train"  # train/validation/validation_matched

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cuda'

In [5]:
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]

BATCH_SIZE = 24
MAX_LENGTH = 128

actual_task = "mnli" if TASK == "mnli-mm" else TASK
dataset = datasets.load_dataset("glue", actual_task)
metric = datasets.load_metric('glue', actual_task)
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mnli-mm": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}
SENTENCE1_KEY, SENTENCE2_KEY = task_to_keys[TASK]
tokenizer = None
sel_dataset = None
dataset

Reusing dataset glue (/opt/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/5 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9832
    })
    test_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9796
    })
    test_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9847
    })
})

In [4]:
def preprocess_function_wrapped(tokenizer):
    def preprocess_function(examples):
        # Tokenize the texts
        args = (
            (examples[SENTENCE1_KEY],) if SENTENCE2_KEY is None else (examples[SENTENCE1_KEY], examples[SENTENCE2_KEY])
        )
        result = tokenizer(*args, padding=False, max_length=MAX_LENGTH, truncation=True)
        return result
    return preprocess_function

def token_id_to_tokens_mapper(tokenizer, sample):
    length = len(sample["input_ids"])
    return tokenizer.convert_ids_to_tokens(sample["input_ids"])[:length], length

from IPython.display import display, HTML
# def print_globenc(globenc, tokenized_text, discrete=False, prefix=""): 
#     if len(globenc.shape) == 2: 
#         globenc = np.expand_dims(globenc, axis=0) 
#     # norm_cls = globenc
#     # norm_cls = np.flip(norm_cls, axis=0) 
#     # row_sums = norm_cls.max(axis=1) 
#     # norm_cls = norm_cls / row_sums[:, np.newaxis] 
#     html = prefix 
#     if discrete: 
#         cls_attention = np.argsort(np.argsort(norm_cls[0, :])) / len(norm_cls[0, :]) 
#     else: 
#         cls_attention = globenc / globenc.max()
#     for i in range(len(tokenized_text)): 
#         html += (f"<span style='background-color: rgba(10, {cls_attention[i]*255}, 10, {cls_attention[i] / 1.5}); " 
#                  f"font-size: {int(cls_attention[i]*18 + 1)}px; " 
#                  f"font-weight: {int(cls_attention[i]*900)};'>") 
#         html += tokenized_text[i] 
#         html += "</span> " 
#     display(HTML(html))

def print_globenc(globenc, tokenized_text, discrete=False, prefix="", no_cls=False):
    # if len(globenc.shape) == 2:
    #     globenc = np.expand_dims(globenc, axis=0)
    # norm_cls = globenc[:, 0, :]
    # norm_cls = np.flip(norm_cls, axis=0)
    # row_sums = norm_cls.max(axis=1)
    # norm_cls = norm_cls / row_sums[:, np.newaxis]
    html = prefix
    if discrete:
        cls_attention = np.argsort(np.argsort(norm_cls[0, :])) / len(norm_cls[0, :])
    else:
        cls_attention = globenc / globenc.max()
    for i in range(len(tokenized_text)):
        html += (f"<span style='"
                 f"background-color: rgba({cls_attention[i]*255}, {cls_attention[i]*255}, 0, {cls_attention[i] / 1.5}); "
#                  f"background-color: rgba(200, {cls_attention[i]*255}, 10, 1.0); "
#                  f"font-size: {int(cls_attention[i]*18 + 1)}px; "
#                  f"font-weight: {int(cls_attention[i]*900)};"
                 f"font-weight: {int(800)};"
                 "'>")
        html += tokenized_text[i]
        html += "</span> "
    display(HTML(html))

In [49]:
# tokenizer = None
# sel_dataset = None
idx = 314093
prev_sals = None
for name, path in tqdm(MODELS.items(), desc="Models"):
    model = AutoModelForSequenceClassification.from_pretrained(path)
    tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True, max_length=128) if tokenizer is None else tokenizer

    sel_dataset = dataset[SET].map(preprocess_function_wrapped(tokenizer), batched=True, batch_size=1024) if sel_dataset is None else sel_dataset
    dataset_size = len(sel_dataset)

    tokens, length = token_id_to_tokens_mapper(tokenizer, sel_dataset[idx])
    inputs = {
        'input_ids': torch.tensor([sel_dataset[idx]['input_ids']], dtype=torch.int32),
        'attention_mask': torch.tensor([sel_dataset[idx]['attention_mask']], dtype=torch.int32),
        'token_type_ids': torch.tensor([sel_dataset[idx]['token_type_ids']], dtype=torch.int32),
    }
    labels = torch.tensor([sel_dataset[idx]['label']])
    output = model(**inputs, output_hidden_states=True)

    output.hidden_states[0].retain_grad()
    logits = output.logits
    target_class_l_sum = torch.gather(logits, 1, labels.unsqueeze(-1)).sum()
    target_class_l_sum.backward()
    
    inputXgradient = output.hidden_states[0].grad * output.hidden_states[0]
    saliencies = torch.norm(inputXgradient, dim=-1).detach()
    model.zero_grad()
    if prev_sals is not None:
        print("Corr:", pearsonr(prev_sals, saliencies[0].numpy())[0])
    prev_sals = saliencies[0].numpy()
    print_globenc(saliencies[0].numpy(), tokens, prefix=name + " ")


Models:   0%|          | 0/5 [00:00<?, ?it/s]

Corr: 0.9636886075299456


Corr: 0.9864787814821108


Corr: 0.9864836155935834


Corr: 0.9950280429647553


In [50]:
# tokenizer = None
# sel_dataset = None
idx = 81139
prev_sals = None
for name, path in tqdm(MODELS.items(), desc="Models"):
    model = AutoModelForSequenceClassification.from_pretrained(path)
    tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True, max_length=128) if tokenizer is None else tokenizer

    sel_dataset = dataset[SET].map(preprocess_function_wrapped(tokenizer), batched=True, batch_size=1024) if sel_dataset is None else sel_dataset
    dataset_size = len(sel_dataset)

    tokens, length = token_id_to_tokens_mapper(tokenizer, sel_dataset[idx])
    inputs = {
        'input_ids': torch.tensor([sel_dataset[idx]['input_ids']], dtype=torch.int32),
        'attention_mask': torch.tensor([sel_dataset[idx]['attention_mask']], dtype=torch.int32),
        'token_type_ids': torch.tensor([sel_dataset[idx]['token_type_ids']], dtype=torch.int32),
    }
    labels = torch.tensor([sel_dataset[idx]['label']])
    output = model(**inputs, output_hidden_states=True)

    output.hidden_states[0].retain_grad()
    logits = output.logits
    target_class_l_sum = torch.gather(logits, 1, labels.unsqueeze(-1)).sum()
    target_class_l_sum.backward()
    
    inputXgradient = output.hidden_states[0].grad * output.hidden_states[0]
    saliencies = torch.norm(inputXgradient, dim=-1).detach()
    model.zero_grad()
    if prev_sals is not None:
        print("Corr:", pearsonr(prev_sals, saliencies[0].numpy())[0])
    prev_sals = saliencies[0].numpy()
    print_globenc(saliencies[0].numpy(), tokens, prefix=name + " ")


Models:   0%|          | 0/5 [00:00<?, ?it/s]

Corr: 0.6708251144252246


Corr: 0.93518511826332


Corr: 0.7228107813981285


Corr: 0.991289830362293


In [52]:
# tokenizer = None
# sel_dataset = None
idx = 364274
prev_sals = None
for name, path in tqdm(MODELS.items(), desc="Models"):
    model = AutoModelForSequenceClassification.from_pretrained(path)
    tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True, max_length=128) if tokenizer is None else tokenizer

    sel_dataset = dataset[SET].map(preprocess_function_wrapped(tokenizer), batched=True, batch_size=1024) if sel_dataset is None else sel_dataset
    dataset_size = len(sel_dataset)

    tokens, length = token_id_to_tokens_mapper(tokenizer, sel_dataset[idx])
    inputs = {
        'input_ids': torch.tensor([sel_dataset[idx]['input_ids']], dtype=torch.int32),
        'attention_mask': torch.tensor([sel_dataset[idx]['attention_mask']], dtype=torch.int32),
        'token_type_ids': torch.tensor([sel_dataset[idx]['token_type_ids']], dtype=torch.int32),
    }
    labels = torch.tensor([sel_dataset[idx]['label']])
    output = model(**inputs, output_hidden_states=True)

    output.hidden_states[0].retain_grad()
    logits = output.logits
    target_class_l_sum = torch.gather(logits, 1, labels.unsqueeze(-1)).sum()
    target_class_l_sum.backward()
    
    inputXgradient = output.hidden_states[0].grad * output.hidden_states[0]
    saliencies = torch.norm(inputXgradient, dim=-1).detach()
    model.zero_grad()
    if prev_sals is not None:
        print("Corr:", pearsonr(prev_sals, saliencies[0].numpy())[0])
    prev_sals = saliencies[0].numpy()
    print_globenc(saliencies[0].numpy(), tokens, prefix=name + " ")


Models:   0%|          | 0/5 [00:00<?, ?it/s]

Corr: 0.9426896125433669


Corr: 0.5417381184122058


Corr: 0.9659932744600644


Corr: 0.9537608660033299


In [54]:
# tokenizer = None
# sel_dataset = None
idx = 24638
prev_sals = None
print(f"### {idx} ###")
for name, path in MODELS.items():
    model = AutoModelForSequenceClassification.from_pretrained(path)
    tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True, max_length=128) if tokenizer is None else tokenizer

    sel_dataset = dataset[SET].map(preprocess_function_wrapped(tokenizer), batched=True, batch_size=1024) if sel_dataset is None else sel_dataset
    dataset_size = len(sel_dataset)

    tokens, length = token_id_to_tokens_mapper(tokenizer, sel_dataset[idx])
    inputs = {
        'input_ids': torch.tensor([sel_dataset[idx]['input_ids']], dtype=torch.int32),
        'attention_mask': torch.tensor([sel_dataset[idx]['attention_mask']], dtype=torch.int32),
        'token_type_ids': torch.tensor([sel_dataset[idx]['token_type_ids']], dtype=torch.int32),
    }
    labels = torch.tensor([sel_dataset[idx]['label']])
    output = model(**inputs, output_hidden_states=True)

    output.hidden_states[0].retain_grad()
    logits = output.logits
    target_class_l_sum = torch.gather(logits, 1, labels.unsqueeze(-1)).sum()
    target_class_l_sum.backward()
    
    inputXgradient = output.hidden_states[0].grad * output.hidden_states[0]
    saliencies = torch.norm(inputXgradient, dim=-1).detach()
    model.zero_grad()
    if prev_sals is not None:
        print("Corr:", pearsonr(prev_sals, saliencies[0].numpy())[0])
    prev_sals = saliencies[0].numpy()
    print_globenc(saliencies[0].numpy(), tokens, prefix=name + " ")


### 24638 ###


Corr: 0.7877146852355619


Corr: 0.884317937354894


Corr: 0.941260779392162


Corr: 0.9842848461106477


In [62]:
# tokenizer = None
# sel_dataset = None
idx = 192454
prev_sals = None
print(f"### {idx} ###")
for name, path in MODELS.items():
    model = AutoModelForSequenceClassification.from_pretrained(path)
    tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True, max_length=128) if tokenizer is None else tokenizer

    sel_dataset = dataset[SET].map(preprocess_function_wrapped(tokenizer), batched=True, batch_size=1024) if sel_dataset is None else sel_dataset
    dataset_size = len(sel_dataset)

    tokens, length = token_id_to_tokens_mapper(tokenizer, sel_dataset[idx])
    inputs = {
        'input_ids': torch.tensor([sel_dataset[idx]['input_ids']], dtype=torch.int32),
        'attention_mask': torch.tensor([sel_dataset[idx]['attention_mask']], dtype=torch.int32),
        'token_type_ids': torch.tensor([sel_dataset[idx]['token_type_ids']], dtype=torch.int32),
    }
    labels = torch.tensor([sel_dataset[idx]['label']])
    output = model(**inputs, output_hidden_states=True)

    output.hidden_states[0].retain_grad()
    logits = output.logits
    target_class_l_sum = torch.gather(logits, 1, labels.unsqueeze(-1)).sum()
    target_class_l_sum.backward()
    
    inputXgradient = output.hidden_states[0].grad * output.hidden_states[0]
    saliencies = torch.norm(inputXgradient, dim=-1).detach()
    model.zero_grad()
    if prev_sals is not None:
        print("Corr:", pearsonr(prev_sals, saliencies[0].numpy())[0])
    prev_sals = saliencies[0].numpy()
    print_globenc(saliencies[0].numpy(), tokens, prefix=name + " ")


### 192454 ###


Corr: 0.9374757943435119


Corr: 0.9692294978046911


Corr: 0.9869057326385663


Corr: 0.9977108173607421


In [6]:
### HTA
idx = 314093
prev_sals = None
for name, path in tqdm(MODELS.items(), desc="Models"):
    model = AutoModelForSequenceClassification.from_pretrained(path)
    model.to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True, max_length=128) if tokenizer is None else tokenizer

    sel_dataset = dataset[SET].map(preprocess_function_wrapped(tokenizer), batched=True, batch_size=1024) if sel_dataset is None else sel_dataset
    dataset_size = len(sel_dataset)

    tokens, length = token_id_to_tokens_mapper(tokenizer, sel_dataset[idx])
    inputs = {
        'input_ids': torch.tensor([sel_dataset[idx]['input_ids']], dtype=torch.int32).to(DEVICE),
        'attention_mask': torch.tensor([sel_dataset[idx]['attention_mask']], dtype=torch.int32).to(DEVICE),
        'token_type_ids': torch.tensor([sel_dataset[idx]['token_type_ids']], dtype=torch.int32).to(DEVICE),
    }
    labels = torch.tensor([sel_dataset[idx]['label']]).to(DEVICE)
    output = model(**inputs, output_hidden_states=True)

    output.hidden_states[0].retain_grad()
    # logits = output.logits
    # target_class_l_sum = torch.gather(logits, 1, labels.unsqueeze(-1)).sum()
    # target_class_l_sum.backward()
    
    # inputXgradient = output.hidden_states[0].grad * output.hidden_states[0]
    # saliencies = torch.norm(inputXgradient, dim=-1).detach()
    # model.zero_grad()
    # if prev_sals is not None:
    #     print("Corr:", pearsonr(prev_sals, saliencies[0].numpy())[0])
    # prev_sals = saliencies[0].numpy()
    # print_globenc(saliencies[0].numpy(), tokens, prefix=name + " ")


Models:   0%|          | 0/5 [00:00<?, ?it/s]



  0%|          | 0/384 [00:00<?, ?ba/s]