In [None]:
# ! pip install datasets==1.18.3
# ! pip install transformers==4.18.0
! git clone https://github.com/mohsenfayyaz/GlobEnc

In [None]:
import torch
import numpy as np
import datasets
import pickle
import pathlib
import os
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

from GlobEnc.src.modeling.modeling_bert import BertForSequenceClassification
from GlobEnc.src.modeling.modeling_electra import ElectraForSequenceClassification
from GlobEnc.src.attention_rollout import AttentionRollout

In [None]:
ROOT_DIR = "./outputs_globenc"
# MODELS = {
#     "sst2-e0": "bert-base-uncased",
#     "sst2-e1": "/home/modaresi/globenc_extension/outputs/output_sst2_bert-base-uncased_0001_SEED0042/checkpoint-2105",
#     "sst2-e2": "/home/modaresi/globenc_extension/outputs/output_sst2_bert-base-uncased_0001_SEED0042/checkpoint-4210",
#     "sst2-e3": "/home/modaresi/globenc_extension/outputs/output_sst2_bert-base-uncased_0001_SEED0042/checkpoint-6315",
#     "sst2-e4": "/home/modaresi/globenc_extension/outputs/output_sst2_bert-base-uncased_0001_SEED0042/checkpoint-8420",
#     "sst2-e5": "/home/modaresi/globenc_extension/outputs/output_sst2_bert-base-uncased_0001_SEED0042/checkpoint-10525",
# }
MODELS = {
    "mnli-e0": "bert-base-uncased",
    "mnli-e1": "/home/modaresi/globenc_extension/outputs/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-12272",
    "mnli-e2": "/home/modaresi/globenc_extension/outputs/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-24544",
    "mnli-e3": "/home/modaresi/globenc_extension/outputs/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-36816",
    "mnli-e4": "/home/modaresi/globenc_extension/outputs/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-49088",
    "mnli-e5": "/home/modaresi/globenc_extension/outputs/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-61360",
}

TASK = "mnli"  # sst2/mnli
SET = "train"  # train/validation/validation_matched


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

In [None]:
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]

batch_size = 1

actual_task = "mnli" if TASK == "mnli-mm" else TASK
dataset = datasets.load_dataset("glue", actual_task)
metric = datasets.load_metric('glue', actual_task)
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mnli-mm": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}
SENTENCE1_KEY, SENTENCE2_KEY = task_to_keys[TASK]
dataset

In [None]:
def extract_globenc(model, tokenizer, data):
    if SENTENCE2_KEY is None:
        tokenized_sentence = tokenizer.encode_plus(data[SENTENCE1_KEY], return_tensors="pt")
    else:
        tokenized_sentence = tokenizer.encode_plus(data[SENTENCE1_KEY], data[SENTENCE2_KEY], return_tensors="pt")
    with torch.no_grad():
        tokenized_sentence = tokenized_sentence.to(DEVICE)
        model.to(DEVICE)
        logits, norms = model(**tokenized_sentence, output_attentions=False, output_norms=True, return_dict=False)
    num_layers = 12
    norm_nenc = torch.stack([norms[i][4] for i in range(num_layers)]).squeeze().cpu().numpy()
    globenc = AttentionRollout().compute_flows([norm_nenc], output_hidden_states=True, disable_tqdm=True)[0]
    globenc = np.array(globenc)
    tokens = tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][0])
    return globenc, tokens

def save_pickle(obj, path):
    pathlib.Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True) 
    with open(path, 'wb') as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"Saved {path}")

In [None]:
for name, path in tqdm(MODELS.items(), desc="Models"):
    model = BertForSequenceClassification.from_pretrained(path)
    tokenizer = AutoTokenizer.from_pretrained(path)
    globencs = {"globenc": [], "tokens": []}
    for data in tqdm(dataset[SET], total=len(dataset[SET])):
        globenc, tokens = extract_globenc(model, tokenizer, data)
        globencs["globenc"].append(globenc)
        globencs["tokens"].append(tokens)
    save_pickle(globencs, f"{ROOT_DIR}/{name}_{SET}.pickle")

# Batched GlobEnc

In [None]:
def extract_globenc_batch(model, tokenizer, batch):
    if SENTENCE2_KEY is None:
        tokenized_sentences = tokenizer(batch[SENTENCE1_KEY], return_tensors="pt", padding=True)
    else:
        tokenized_sentences = tokenizer(batch[SENTENCE1_KEY], batch[SENTENCE2_KEY], return_tensors="pt", padding=True)
    tokenized_sentences = tokenized_sentences.to(DEVICE)
    with torch.no_grad():
        model.to(DEVICE)
#         logits, norms = model(**tokenized_sentences, output_attentions=False, output_norms=True, return_dict=False)
        logits, norms = model(**tokenized_sentences, output_attentions=False, output_norms=False, return_dict=False, output_globenc=True)
#     num_layers = 12
#     norm_nenc = torch.stack([norms[i][4] for i in range(num_layers)]).squeeze().cpu().numpy()  # (12, batch, 78, 78)
    tokenized_len = torch.sum(tokenized_sentences['attention_mask'], dim=-1)
#     globencs, tokens = [], []
#     for idx in range(norm_nenc.shape[1]):
#         norm_nenc_idx = norm_nenc[:, idx, :tokenized_len[idx], :tokenized_len[idx]]
#         globenc = AttentionRollout().compute_flows([norm_nenc_idx], output_hidden_states=True, disable_tqdm=True)[0]
#         globencs.append(np.array(globenc))
#         tokens.append(tokenizer.convert_ids_to_tokens(tokenized_sentences["input_ids"][idx])[:tokenized_len[idx]])
    
    globenc = norms.squeeze().cpu().numpy()
    globencs, tokens = [], []
    for idx in range(len(globenc)):
        globencs.append(np.array(globenc[idx, :tokenized_len[idx], :tokenized_len[idx]]))
        tokens.append(tokenizer.convert_ids_to_tokens(tokenized_sentences["input_ids"][idx])[:tokenized_len[idx]])
    return globencs, tokens

In [None]:
for checkpoint_name, path in tqdm(MODELS.items(), desc="Models"):
    globencs = {"globenc": [], "tokens": []}
    model = BertForSequenceClassification.from_pretrained(path)
    tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True, max_length=128)
    dataloader = DataLoader(dataset[SET], batch_size=4)
    for batch in tqdm(dataloader, total=len(dataloader)):
        globenc, tokens = extract_globenc_batch(model, tokenizer, batch)
        globencs["globenc"].extend(globenc)
        globencs["tokens"].extend(tokens)
    save_pickle(globencs, f"{ROOT_DIR}/{checkpoint_name}_{SET}.pickle")

In [None]:
! nvidia-smi