In [None]:
import torch
import numpy as np
import datasets
import pickle
import pathlib
import os
from importlib import reload
import sys
import pandas as pd
from IPython.display import display, HTML
import json
import datetime

from tqdm.auto import tqdm
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from transformers.trainer_pt_utils import LengthGroupedSampler

# For Imports
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
print("sys.path:", sys.path)

from GlobEnc.src.modeling.globenc_utils import GlobencConfig
from GlobEnc.src.modeling.modeling_bert_v3 import BertForSequenceClassification
from GlobEnc.src.modeling.modeling_roberta import RobertaForSequenceClassification

! nvidia-smi

# Configs

In [None]:
OUTPUT_DIR = "/home/modaresi/projects/globenc_analysis/outputs/globencs_v3"
MODELS_DIR = "/home/modaresi/projects/globenc_analysis/outputs/models"

MODEL_DATASET_SET = [
#     (f"{MODELS_DIR}/output_sst2_bert-base-uncased_0001_SEED0042/checkpoint-10525", "sst2", "validation"),
#     (f"{MODELS_DIR}/output_hatexplain_bert-base-uncased_0001_SEED0042/checkpoint-2405", "hatexplain", "validation"),
#     (f"{MODELS_DIR}/output_qnli_bert-base-uncased_0001_SEED0042/checkpoint-16370", "qnli", "validation"),
#     (f"{MODELS_DIR}/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-61360", "mnli", "validation_matched"),
#     (f"{MODELS_DIR}/output_cola_bert-base-uncased_0001_SEED0042/checkpoint-1340", "cola", "validation"),
#     (f"{MODELS_DIR}/output_mrpc_bert-base-uncased_0001_SEED0042/checkpoint-575", "mrpc", "validation"),
#     (f"{MODELS_DIR}/output_sst2_bert-large-uncased_0001_SEED0042/checkpoint-10525", "sst2", "validation"),
    
### TRAINING MAPS
#     (f"{MODELS_DIR}/output_sst2_bert-base-uncased_0001_SEED0042/checkpoint-10525", "sst2", "train"),
#     (f"{MODELS_DIR}/output_hatexplain_bert-base-uncased_0001_SEED0042/checkpoint-2405", "hatexplain", "train"),
#     (f"{MODELS_DIR}/output_qnli_bert-base-uncased_0001_SEED0042/checkpoint-16370", "qnli", "train"),
#     (f"{MODELS_DIR}/output_mnli_bert-base-uncased_0001_SEED0042/checkpoint-61360", "mnli", "train"),
#     (f"{MODELS_DIR}/output_cola_bert-base-uncased_0001_SEED0042/checkpoint-1340", "cola", "train"),
#     (f"{MODELS_DIR}/output_mrpc_bert-base-uncased_0001_SEED0042/checkpoint-575", "mrpc", "train"),
    
### RoBERTa
#     (f"WillHeld/roberta-base-sst2", "sst2", "validation"),
    (f"WillHeld/roberta-base-mnli", "mnli", "validation_matched"),
]

GLOBENC_CONFIGS = {
    "GlobEnc":
        GlobencConfig(
            include_biases=False,  ###
            bias_decomp_type="absdot",
            include_LN1=True,
            include_FFN=False,  ###
            FFN_approx_type="GeLU_LA",
            include_LN2=True,
            aggregation="rollout",  ###
            include_classifier_w_pooler=False,  ###
            tanh_approx_type="ZO",
            output_all_layers=True,
            output_attention=None,
            output_res1=None,
            output_LN1=None,
            output_FFN=None,
            output_res2=None,
            output_encoder=None,
            output_aggregated="norm",
            output_pooler=None,
            output_classifier=False,
        ),
#     "GlobEnc AbsDot Bias":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="absdot",
#             include_LN1=True,
#             include_FFN=False,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="rollout",
#             include_classifier_w_pooler=False,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler=None,
#             output_classifier=False,
#         ),
#     "GlobEnc AbsSim Bias":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="abssim",
#             include_LN1=True,
#             include_FFN=False,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="rollout",
#             include_classifier_w_pooler=False,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler=None,
#             output_classifier=False,
#         ),
#     "GlobEnc Equal Bias":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="equal",
#             include_LN1=True,
#             include_FFN=False,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="rollout",
#             include_classifier_w_pooler=False,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler=None,
#             output_classifier=False,
#         ),
#     "GlobEnc Norm Bias":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="norm",
#             include_LN1=True,
#             include_FFN=False,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="rollout",
#             include_classifier_w_pooler=False,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler=None,
#             output_classifier=False,
#         ),
#     "GlobEnc No Bias FFN":
#         GlobencConfig(
#             include_biases=False,
#             bias_decomp_type="absdot",
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="rollout",
#             include_classifier_w_pooler=False,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler=None,
#             output_classifier=False,
#         ),
#     "GlobEnc AbsDot Bias FFN":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="absdot",
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="rollout",
#             include_classifier_w_pooler=False,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler=None,
#             output_classifier=False,
#         ),
#     "GlobEnc Equal Bias FFN":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="equal",
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="rollout",
#             include_classifier_w_pooler=False,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler=None,
#             output_classifier=False,
#         ),
#     "GlobEnc AbsSim Bias FFN":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="abssim",
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="rollout",
#             include_classifier_w_pooler=False,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler=None,
#             output_classifier=False,
#         ),
#     "GlobEnc Norm Bias FFN":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="abssim",
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="rollout",
#             include_classifier_w_pooler=False,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler=None,
#             output_classifier=False,
#         ),
#     "Decomposition No Bias":
#         GlobencConfig(
#             include_biases=False,
#             bias_decomp_type="absdot",
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="vector",
#             include_classifier_w_pooler=True,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler="norm",
#             output_classifier=True,
#         ),
    "Decomposition AbsDot Bias":
        GlobencConfig(
            include_biases=True,
            bias_decomp_type="absdot",
            include_LN1=True,
            include_FFN=True,
            FFN_approx_type="GeLU_ZO",
            include_LN2=True,
            aggregation="vector",
            include_classifier_w_pooler=True,
            tanh_approx_type="ZO",
            output_all_layers=True,
            output_attention=None,
            output_res1=None,
            output_LN1=None,
            output_FFN=None,
            output_res2=None,
            output_encoder=None,
            output_aggregated="norm",
            output_pooler="norm",
            output_classifier=True,
        ),
#     "Decomposition AbsDot Bias No FFN":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="absdot",
#             include_LN1=True,
#             include_FFN=False,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="vector",
#             include_classifier_w_pooler=True,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler="norm",
#             output_classifier=True,
#         ),
#     "Decomposition No Bias No FFN":
#         GlobencConfig(
#             include_biases=False,
#             bias_decomp_type="absdot",
#             include_LN1=True,
#             include_FFN=False,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="vector",
#             include_classifier_w_pooler=True,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler="norm",
#             output_classifier=True,
#         ),
#     "Decomposition AbsDot Bias ReLU":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="absdot",
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="ReLU",
#             include_LN2=True,
#             aggregation="vector",
#             include_classifier_w_pooler=True,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler="norm",
#             output_classifier=True,
#         ),
#     "Decomposition Equal Bias":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="equal",
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="vector",
#             include_classifier_w_pooler=True,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler="norm",
#             output_classifier=True,
#         ),
#     "Decomposition Norm Bias":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="norm",
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="vector",
#             include_classifier_w_pooler=True,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler="norm",
#             output_classifier=True,
#         ),
#     "Decomposition AbsSim Bias":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="abssim",
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="vector",
#             include_classifier_w_pooler=True,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler="norm",
#             output_classifier=True,
#         ),
#     "Decomposition Norm Bias Token":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="norm",
#             include_bias_token=True,
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="GeLU_ZO",
#             include_LN2=True,
#             aggregation="vector",
#             include_classifier_w_pooler=True,
#             tanh_approx_type="ZO",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler="norm",
#             output_classifier=True,
#         ),
    
    
#     "GlobEnc FFN LinearApproximation":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="absdot",
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="GeLU_LA",
#             include_LN2=True,
#             aggregation="rollout",
#             include_classifier_w_pooler=False,
#             tanh_approx_type="LA",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler=None,
#             output_classifier=False,
#         ),
#     "Decomposition LinearApproximation":
#         GlobencConfig(
#             include_biases=True,
#             bias_decomp_type="absdot",
#             include_LN1=True,
#             include_FFN=True,
#             FFN_approx_type="GeLU_LA",
#             include_LN2=True,
#             aggregation="vector",
#             include_classifier_w_pooler=True,
#             tanh_approx_type="LA",
#             output_all_layers=True,
#             output_attention=None,
#             output_res1=None,
#             output_LN1=None,
#             output_FFN=None,
#             output_res2=None,
#             output_encoder=None,
#             output_aggregated="norm",
#             output_pooler="norm",
#             output_classifier=True,
#         ),
}

SAVE_CLS = False
SAVE_ALL_LAYERS = False  # 12 layers + last layer
GLOBENC_CONFIGS.keys()

# Retreive and Save
Namings are all automatically determined from the configs.
```
file_name = f"[{task_name}]_[{set_of_data}]_[{'-'.join(model_checkpoint.split('/')[-2:])}]_[{globenc_cfg_name}]"
```

In [None]:
class CheckpointToGlobenc:
    def __init__(
        self,
        model_checkpoint: str,
        globenc_config: GlobencConfig,
        task_name: str,
        set_of_data: str,
        save_cls: bool = False,
        save_all_layers: bool = False
    ) -> pd.DataFrame:
        self.model_checkpoint = model_checkpoint
        self.globenc_config = globenc_config
        self.task_name = task_name
        self.set_of_data = set_of_data
        self.save_cls = save_cls
        self.save_all_layers = save_all_layers
    
    def retrieve(self) -> pd.DataFrame:
        DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
        ### DATASET ###
        def aggregate_hatexplain(example):
            def mode(lst):
                return max(set(lst), key=lst.count)
            example["label"] = mode(example["annotators"]["label"])
            example["text"] = " ".join(example["post_tokens"])
            return example
        GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]
        BATCH_SIZE = 8
        MAX_LENGTH = 128
        actual_task = "mnli" if self.task_name == "mnli-mm" else self.task_name
        if self.task_name in GLUE_TASKS:
            dataset = datasets.load_dataset("glue", actual_task, download_config=datasets.DownloadConfig(local_files_only=True))
        elif self.task_name == "hatexplain":
            dataset = datasets.load_dataset("hatexplain", download_config=datasets.DownloadConfig(local_files_only=True)).map(aggregate_hatexplain)
            for part in ["train", "validation", "test"]:
                dataset[part] = dataset[part].add_column("idx", [i for i in range(len(dataset[part]))])
        task_to_keys = {
            "cola": ("sentence", None),
            "mnli": ("premise", "hypothesis"),
            "mnli-mm": ("premise", "hypothesis"),
            "mrpc": ("sentence1", "sentence2"),
            "qnli": ("question", "sentence"),
            "qqp": ("question1", "question2"),
            "rte": ("sentence1", "sentence2"),
            "sst2": ("sentence", None),
            "stsb": ("sentence1", "sentence2"),
            "wnli": ("sentence1", "sentence2"),
            "hatexplain": ("text", None),
        }
        SENTENCE1_KEY, SENTENCE2_KEY = task_to_keys[self.task_name]
        def preprocess_function_wrapped(tokenizer):
            def preprocess_function(examples):
                # Tokenize the texts
                args = (
                    (examples[SENTENCE1_KEY],) if SENTENCE2_KEY is None else (examples[SENTENCE1_KEY], examples[SENTENCE2_KEY])
                )
                result = tokenizer(*args, padding=False, max_length=MAX_LENGTH, truncation=True)
                return result
            return preprocess_function

        def token_id_to_tokens_mapper(tokenizer, sample):
            length = len(sample["input_ids"])
            return tokenizer.convert_ids_to_tokens(sample["input_ids"])[:length], length
        
        ### RUN ###
        if "roberta" in self.model_checkpoint:
            model = RobertaForSequenceClassification.from_pretrained(self.model_checkpoint)
        else:
            model = BertForSequenceClassification.from_pretrained(self.model_checkpoint)
        model.eval()
        tokenizer = AutoTokenizer.from_pretrained(self.model_checkpoint, use_fast=True, max_length=128)

        sel_dataset = dataset[self.set_of_data].map(preprocess_function_wrapped(tokenizer), batched=True, batch_size=1024)
        dataset_size = len(sel_dataset)
        steps = int(np.ceil(dataset_size / BATCH_SIZE))

        final_data = {
            "tokens": [],
        }
        lengths = []
        for i in tqdm(range(dataset_size), desc="Tokenize"):
            tokens, length = token_id_to_tokens_mapper(tokenizer, sel_dataset[i])
            final_data["tokens"].append(tokens)
            lengths.append(length)

        generator = torch.Generator()
        generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))

        sampler = LengthGroupedSampler(
            BATCH_SIZE,
            lengths=lengths,
            model_input_name=tokenizer.model_input_names[0],
            generator=generator,
        )

        collator = DataCollatorWithPadding(tokenizer=tokenizer)

        sel_dataset = sel_dataset.add_column("length", lengths)
        cols = ["input_ids", "attention_mask", "length", "idx"]
        cols = cols + ["token_type_ids"] if not "roberta" in self.model_checkpoint else cols
        sel_dataset.set_format(type="torch", columns=cols)

        dataloader = DataLoader(
            sel_dataset,
            batch_size=BATCH_SIZE,
            sampler=sampler,
            collate_fn=collator
        )

        model.to(DEVICE)

        it = iter(dataloader)

        idxes = []
        shuffled_globencs, shuffled_cls, shuffled_globencs_all_layers = [], [], []
        shuffled_data = {
            "importance_last_layer_aggregated": [],
            "importance_last_layer_pooler": [],
            "importance_last_layer_classifier": [],
            "importance_all_layers_aggregated": [],
            "cls": [],
            "logits": [],
            "label": [],
        }
        first_time = datetime.datetime.now()
        with torch.no_grad():
            for i in tqdm(range(steps), desc="Forward"):
                batch = next(it)
#                 print("#################")
#                 print(batch.keys())['idx', 'input_ids', 'attention_mask', 'length']
#                 print("#################")
                input_batch = {k: batch[k].to(DEVICE) for k in batch.keys() - ['idx', 'length']}
                logits, hidden_states, globenc_last_layer_outputs, globenc_all_layers_outputs = model(
                    **input_batch, 
                    output_attentions=False, 
                    return_dict=False, 
                    output_hidden_states=True, 
                    globenc_config=self.globenc_config
                )
                # globenc_last_layer_outputs.aggregated ~ (1, 8, 55, 55)
                # globenc_last_layer_outputs.pooler ~ (1, 8, 55)
                # globenc_last_layer_outputs.classifier ~ (8, 55, 2)
                # globenc_all_layers_outputs.aggregated ~ (12, 8, 55, 55)
                # logits ~ (8, 2)
                # hidden_states ~ (13, 8, 55, 768)
                
                batch_lengths = batch["length"].numpy()
                idxes.extend(batch['idx'].tolist())
                
                ### logits ~ (8, 2) ###
                shuffled_data["logits"].extend(logits.cpu().numpy())
                
                ### globenc_last_layer_outputs.aggregated ~ (1, 8, 55, 55) ###
                importance = np.array([g.squeeze().cpu().numpy() for g in globenc_last_layer_outputs.aggregated]).squeeze()  # (batch, seq_len, seq_len)
                importance = [importance[j][:batch_lengths[j],:batch_lengths[j]] for j in range(len(importance))]
                shuffled_data["importance_last_layer_aggregated"].extend(importance)
                
                ### globenc_last_layer_outputs.pooler ~ (1, 8, 55) ###
                if globenc_last_layer_outputs.pooler is not None:
                    importance = np.array([g.squeeze().cpu().numpy() for g in globenc_last_layer_outputs.pooler]).squeeze()  # (batch, seq_len)
                    importance = [importance[j][:batch_lengths[j]] for j in range(len(importance))]
                    shuffled_data["importance_last_layer_pooler"].extend(importance)
                
                ### globenc_last_layer_outputs.classifier ~ (8, 55, 2) ###
                if globenc_last_layer_outputs.classifier is not None:
                    importance = np.array([g.squeeze().cpu().numpy() for g in globenc_last_layer_outputs.classifier]).squeeze()  # (batch, seq_len, classes)
                    importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
                    shuffled_data["importance_last_layer_classifier"].extend(importance)
                
                ### globenc_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
                if self.save_all_layers:
                    importance = np.array([g.squeeze().cpu().numpy() for g in globenc_all_layers_outputs.aggregated])  # (layers, batch, seq_len, seq_len)
                    importance = np.einsum('lbij->blij', importance)  # (batch, layers, seq_len, seq_len)
                    importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
                    shuffled_data["importance_all_layers_aggregated"].extend(importance)
                
                ### hidden_states ~ (13, 8, 55, 768) ###
                if self.save_cls:
                    shuffled_data["cls"].extend(hidden_states[-1][:, 0, :].cpu().numpy())  # Last layer & only CLS -> (8, 768)
        later_time = datetime.datetime.now()
        # Reorder retrieved data
        inverse_idxes = np.argsort(idxes)
        for key in shuffled_data.keys():
            if len(shuffled_data[key]) == 0:
                shuffled_data[key] = [None for _ in range(dataset_size)]
            final_data[key] = [shuffled_data[key][inverse_idxes[i]] for i in range(dataset_size)]
            
        ### labels ###
        final_data["label"] = sel_dataset["label"]
        df = pd.DataFrame(final_data)
        df.attrs["time"] = later_time - first_time
        df.attrs["time/example"] = df.attrs["time"] / len(sel_dataset)
        return df
    #     save_pickle(globencs, f"{OUTPUT_DIR}/{name}_{SET}_{list(MODELS.values())[0].replace('/', '-')}.pickle")

In [None]:
for model_dataset_set in tqdm(MODEL_DATASET_SET, desc="Models_Dataset_Sets"):
    model_checkpoint, task_name, set_of_data = model_dataset_set
    for globenc_cfg_name, globenc_cfg in tqdm(GLOBENC_CONFIGS.items(), desc="Globenc Configs"):
        torch.cuda.empty_cache()
        try:
            file_name = f"[{task_name}]_[{set_of_data}]_[{'-'.join(model_checkpoint.split('/')[-2:])}]_[{globenc_cfg_name}]"
            print(f"### {file_name}")
            ctg = CheckpointToGlobenc(
                model_checkpoint=model_checkpoint,
                globenc_config=globenc_cfg,
                task_name=task_name,
                set_of_data=set_of_data,
                save_cls=SAVE_CLS,
                save_all_layers=SAVE_ALL_LAYERS
            )
            df = ctg.retrieve()
            display(df.head(1))
            df.attrs["model_checkpoint"] = model_checkpoint
            df.attrs["task_name"] = task_name
            df.attrs["set_of_data"] = set_of_data
            df.attrs["save_cls"] = SAVE_CLS
            df.attrs["save_all_layers"] = SAVE_ALL_LAYERS
            df.attrs["globenc_config_name"] = globenc_cfg_name
            df.attrs["globenc_config"] = globenc_cfg.__dict__
            pathlib.Path(os.path.dirname(f"{OUTPUT_DIR}/{file_name}.pkl")).mkdir(parents=True, exist_ok=True) 
            df.to_pickle(f"{OUTPUT_DIR}/{file_name}.pkl")  # Can add zip/bz2/...
            print(f"{OUTPUT_DIR}/{file_name}.pkl")
            print(df.attrs)
        except Exception as e:
            raise e
            print(e)

In [None]:
ds = datasets.load_dataset("glue", "sst2", download_config=datasets.DownloadConfig(local_files_only=True))

In [None]:
ds["train"]

In [None]:
# from datasets import load_dataset
# from sklearn.metrics import accuracy_score
# from transformers import EvalPrediction

# def mode(lst):
#     return max(set(lst), key=lst.count)

# def update_data(example):
#     example["label"] = mode(example["annotators"]["label"])
#     example["text"] = " ".join(example["post_tokens"])
#     return example

# ds = load_dataset("hatexplain").map(update_data)

In [None]:
ds = datasets.load_dataset("glue", "sst2", download_config=datasets.DownloadConfig(local_files_only=True))

In [None]:
ds