<a href="https://colab.research.google.com/github/mohsenfayyaz/edge-probe/blob/main/Edge_Probing_Full.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installations & Imports

In [None]:
! nproc
! lscpu
! nvidia-smi

In [None]:
from pydrive.auth import GoogleAuth
from google.colab import auth

# Authenticate and create the PyDrive client.
# auth.authenticate_user()
! gdown --id ***
! tar -xzf ontonotes_data.tar.gz

In [None]:
! git clone https://github.com/mohsenfayyaz/edge-probing-datasets.git
! pip install datasets
! pip install transformers
! pip install sentencepiece
# ! pip install wandb

In [None]:
from tqdm.notebook import tqdm
import pandas as pd
from IPython.display import display
import torch
import numpy as np
import shutil
import os
import datasets
import json
import gc
import datetime
import torch.nn as nn
from abc import ABC, abstractmethod
import torch.optim as optim
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from scipy.special import softmax
from sklearn.metrics import f1_score
import psutil  # RAM usage
# import wandb
# wandb.init()
print(torch.__version__)

# Configs

In [None]:
class Dataset_info:
    def __init__(self, dataset_name, num_of_spans, max_span_length=5, ignore_classes=[]):
        self.dataset_name = dataset_name
        self.num_of_spans = num_of_spans
        self.ignore_classes = ignore_classes  # ignore other class in rel (semeval)

In [None]:
# model_checkpoint = 'xlnet-base-cased'
# model_checkpoint = "xlnet-large-cased"
# model_checkpoint = "distilbert-base-cased"
# model_checkpoint = "bert-base-cased"
model_checkpoint = "bert-base-uncased"
# model_checkpoint = "bert-large-uncased"
# model_checkpoint = "albert-base-v2"
# model_checkpoint = "albert-large-v2"
# model_checkpoint = "albert-xxlarge-v2"
# model_checkpoint = "t5-small"
# model_checkpoint = "t5-large"
# model_checkpoint = "roberta-large"
# model_checkpoint = "google/electra-large-discriminator"


# model_checkpoint = "mohsenfayyaz/toxicity-classifier"
# model_checkpoint = "mohsenfayyaz/bert-base-uncased-toxicity"
# model_checkpoint = "textattack/xlnet-base-cased-SST-2"

# model_checkpoint = "mrm8488/albert-base-v2-finetuned-mnli-pabee"


my_dataset_info = Dataset_info("ud", num_of_spans=2)  # Dependency Labeling
# my_dataset_info = Dataset_info("ner", num_of_spans=1)  # Named Entity Labeling
# my_dataset_info = Dataset_info("srl", num_of_spans=2)  # Semantic Role Labeling
# my_dataset_info = Dataset_info("coref", num_of_spans=2)  # Coreference Ontonotes
# my_dataset_info = Dataset_info("dpr", num_of_spans=2)  # Coreference Winograd
# my_dataset_info = Dataset_info("semeval", num_of_spans=2, ignore_classes=["Other"])  # Relation Classification



POOL_METHOD = "attn"  # 'max', 'attn'
BATCH_SIZE = 32
DEVICE = 'cuda' if torch.cuda.is_available() else "cpu"

print(DEVICE)

# Prepare Dataset & Spans

In [None]:
from transformers import AutoTokenizer, AutoModel
  
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)

model = AutoModel.from_pretrained(model_checkpoint)

In [None]:
# model.save_pretrained(model_checkpoint)
# tokenizer.save_pretrained(model_checkpoint)

In [None]:
class Utils:
    def one_hot(idx, length):
        import numpy as np
        o = np.zeros(length, dtype=np.int8)
        o[idx] = 1
        return o

In [None]:
class Dataset_handler:
    def __init__(self, dataset_info: Dataset_info):
        self.dataset = datasets.DatasetDict()
        self.tokenized_dataset = None
        self.dataset_info = dataset_info
        self.labels_list = None

        if dataset_info.dataset_name == "dpr":
            self.json_to_dataset('./edge-probing-datasets/data/dpr_data/train.json', data_type="train")
            self.json_to_dataset('./edge-probing-datasets/data/dpr_data/dev.json', data_type="dev")
            self.json_to_dataset('./edge-probing-datasets/data/dpr_data/test.json', data_type="test")
        elif dataset_info.dataset_name == "ud":
            frac = 1
            self.json_to_dataset('./edge-probing-datasets/data/ud_data/en_ewt-ud-train.json', data_type="train", fraction = frac)
            self.json_to_dataset('./edge-probing-datasets/data/ud_data/en_ewt-ud-dev.json', data_type="dev", fraction = frac)
            self.json_to_dataset('./edge-probing-datasets/data/ud_data/en_ewt-ud-test.json', data_type="test", fraction = frac)
        elif dataset_info.dataset_name == "semeval":
            frac = 1
            self.json_to_dataset('./edge-probing-datasets/data/semeval_data/train.all.json', data_type="train", fraction = frac, ignore_classes = self.dataset_info.ignore_classes)
            self.json_to_dataset('./edge-probing-datasets/data/semeval_data/test.json', data_type="dev", fraction = 0.01, ignore_classes = self.dataset_info.ignore_classes)
            self.json_to_dataset('./edge-probing-datasets/data/semeval_data/test.json', data_type="test", fraction = frac, ignore_classes = self.dataset_info.ignore_classes)
        elif dataset_info.dataset_name == "srl":
            frac = 1
            self.json_to_dataset('./ontonotes_data/srl/train.json', data_type="train", fraction = frac)
            self.json_to_dataset('./ontonotes_data/srl/test.json', data_type="dev", fraction = frac)
            self.json_to_dataset('./ontonotes_data/srl/conll-2012-test.json', data_type="test", fraction = frac)
        elif dataset_info.dataset_name == "ner":
            frac = 1
            self.json_to_dataset('./ontonotes_data/ner/train.json', data_type="train", fraction = frac)
            self.json_to_dataset('./ontonotes_data/ner/conll-2012-test.json', data_type="dev", fraction = frac)
            self.json_to_dataset('./ontonotes_data/ner/test.json', data_type="test", fraction = frac)
        elif dataset_info.dataset_name == "coref":
            frac = 1
            self.json_to_dataset('./ontonotes_data/coref/train.json', data_type="train", fraction = frac)
            self.json_to_dataset('./ontonotes_data/coref/development.json', data_type="dev", fraction = frac)
            self.json_to_dataset('./ontonotes_data/coref/test.json', data_type="test", fraction = frac)
        else:
            throw("Error: Unkown dataset name!")

        print("⌛ Tokenizing Dataset and Adding One Hot Representation of Labels")
        self.tokenized_dataset = self.tokenize_input_and_one_hot_labels(self.dataset)
        # self.tokenized_dataset = self.tokenize_dataset(self.dataset)
        # print("⌛ Adding One Hot Representation of Labels")
        # self.tokenized_dataset = self.one_hot_dataset_labels(self.tokenized_dataset)
        

    # Public:
    def json_to_dataset(self, json_path, data_type="train", fraction=1, ignore_classes=[]):
        data_df = self.json_to_df(json_path)
        data_df = data_df[~data_df["label"].isin(ignore_classes)]
        # print(data_type, "text max length:", data_df["text"].str.len().max())  # max length of texts
        if fraction != 1:
            data_df = data_df.sample(frac=fraction, random_state=1).sort_index().reset_index(drop=True)
        self.dataset[data_type] = datasets.Dataset.from_pandas(data_df)
        return self.dataset
    
    def tokenize_input_and_one_hot_labels(self, dataset):
        train_df = pd.DataFrame(dataset["train"]["label"], columns=['label'])
        dev_df = pd.DataFrame(dataset["dev"]["label"], columns=['label'])
        test_df = pd.DataFrame(dataset["test"]["label"], columns=['label'])
        self.labels_list = list(set(train_df["label"].unique()).union
                               (set(dev_df["label"].unique())).union
                               (set(test_df["label"].unique())))
        self.label_to_index = dict()
        for idx, l in enumerate(self.labels_list):
            self.label_to_index[l] = idx
        tokenized_one_hot_dataset = dataset.map(tokenize_and_one_hot,
                                                fn_kwargs={"label_to_index": self.label_to_index,
                                                           "labels_len": len(self.label_to_index),
                                                           "tokenizer": tokenizer,
                                                           "one_hot_func": Utils.one_hot,
                                                           "num_of_spans": self.dataset_info.num_of_spans
                                                           },
                                                batched=False,
                                                num_proc=None)
        return tokenized_one_hot_dataset

    # Private:
    def json_to_df(self, json_path):
        with open(json_path, encoding='utf-8') as file:
            data_list = list()
            for line in file:
                instance = json.loads(line)
                for target in instance["targets"]:
                    if self.dataset_info.num_of_spans == 2:
                        data_list.append({"text": instance["text"],
                                        "span1": target["span1"],
                                        "span2": target["span2"],
                                        "label": target["label"]})
                    elif self.dataset_info.num_of_spans == 1:
                        data_list.append({"text": instance["text"],
                                        "span1": target["span1"],
                                        "label": target["label"]})
        return pd.DataFrame.from_dict(data_list)

def tokenize_and_one_hot(examples, **fn_kwargs):
    # tokenize and align spans
    thread_tokenizer = fn_kwargs["tokenizer"]
    one_hot_func = fn_kwargs["one_hot_func"]
    num_of_spans = fn_kwargs["num_of_spans"]
    tokenized_inputs = thread_tokenizer(examples["text"].split(), is_split_into_words=True)  # Must be splitted for tokenizer to word_ids works fine. (test e-mail!)
    # tokenized_inputs = tokenizer(examples["text"], truncation=True, is_split_into_words=True, padding="max_length", max_length=210)
    def align_span(word_ids, start_word_id, end_word_id):
        span = [0, 0]
        span[0] = word_ids.index(start_word_id)  # First occurance
        span[1] = len(word_ids) - 1 - word_ids[::-1].index(end_word_id - 1) + 1  # Last occurance (+1 for open range)
        return span

    # tokenized_inputs["span1"] = [0, 0]
    # tokenized_inputs["span1"][0] = word_ids.index(examples["span1"][0])  # First occurance
    # tokenized_inputs["span1"][1] = len(word_ids) - 1 - word_ids[::-1].index(examples["span1"][1] - 1) + 1  # Last occurance (+1 for open range)
    word_ids = tokenized_inputs.word_ids()
    tokenized_inputs["span1"] = align_span(word_ids, examples["span1"][0], examples["span1"][1])
    tokenized_inputs["span1_len"] = tokenized_inputs["span1"][1] - tokenized_inputs["span1"][0]
    if num_of_spans == 2:
        # tokenized_inputs["span2"] = [0, 0]
        # tokenized_inputs["span2"][0] = word_ids.index(examples["span2"][0])  # First occurance
        # tokenized_inputs["span2"][1] = len(word_ids) - 1 - word_ids[::-1].index(examples["span2"][1] - 1) + 1  # Last occurance
        tokenized_inputs["span2"] = align_span(word_ids, examples["span2"][0], examples["span2"][1])
        tokenized_inputs["span2_len"] = tokenized_inputs["span2"][1] - tokenized_inputs["span2"][0]
    # One hot
    label_to_index = fn_kwargs["label_to_index"]
    labels_len = fn_kwargs["labels_len"]
    tokenized_inputs["one_hot_label"] = one_hot_func(label_to_index[examples["label"]], labels_len)
    return tokenized_inputs

In [None]:
my_dataset_handler = Dataset_handler(my_dataset_info);

In [None]:
# Check
rnd_idx = np.random.randint(100)
# rnd_idx = 17
part = "test"

display(pd.DataFrame(my_dataset_handler.tokenized_dataset[part][0:3]))
print("idx =", rnd_idx)
print(my_dataset_handler.tokenized_dataset)
print("Original Spans:", my_dataset_handler.dataset[part][rnd_idx])
print("Tokenized Spans:", my_dataset_handler.tokenized_dataset[part][rnd_idx])
test_tokens = tokenizer.convert_ids_to_tokens(my_dataset_handler.tokenized_dataset[part][rnd_idx]["input_ids"])
print(test_tokens)

s10, s11 = my_dataset_handler.tokenized_dataset[part][rnd_idx]["span1"][0], my_dataset_handler.tokenized_dataset[part][rnd_idx]["span1"][-1]
print("span1:", s10, s11, test_tokens[s10:s11])
if my_dataset_info.num_of_spans == 2:
    s20, s21 = my_dataset_handler.tokenized_dataset[part][rnd_idx]["span2"][0], my_dataset_handler.tokenized_dataset[part][rnd_idx]["span2"][-1]
    print("span2:", s20, s21, test_tokens[s20:s21])
print("label:", my_dataset_handler.tokenized_dataset[part][rnd_idx]["label"])

pd.DataFrame(my_dataset_handler.tokenized_dataset[part]["label"], columns=['label'])["label"].value_counts().plot(kind='barh', color="green", figsize=(10, 9));

#  What if Google Morphed Into GoogleOS ? cc 
# span1: 11 12 ['and']
# span2: 13 14 ['e']
# label: cc
# BUG? 17

# Edge Probe

In [None]:
class SpanRepr(ABC, nn.Module):
    """Abstract class describing span representation."""

    def __init__(self, input_dim, use_proj=False, proj_dim=256):
        super(SpanRepr, self).__init__()
        self.input_dim = input_dim  # embedding dim or proj dim
        self.proj_dim = proj_dim
        self.use_proj = use_proj

    @abstractmethod
    def forward(self, spans, attention_mask):
        """ 
        input:
            spans: [batch_size, layers, span_max_len, proj_dim/embedding_dim] ~ [32, 13, 4, 256]
            attention_mask: [batch_size, span_max_len] ~ [32, 4]
        returns:
            [32, 13, 256]
        """
        raise NotImplementedError

    def get_input_dim(self):
        return self.input_dim

class MaxSpanRepr(SpanRepr, nn.Module):
    """Class implementing the max-pool span representation."""

    def forward(self, spans, attention_mask):
        # TODO: Vectorize this
        # for i in range(len(attention_mask)):
        #     for j in range(len(attention_mask[i])):
        #         if attention_mask[i][j] == 0:
        #             spans[i, :, j, :] = -1e10

        span_masks_shape = attention_mask.shape
        span_masks = attention_mask.reshape(
            span_masks_shape[0],
            1,
            span_masks_shape[1],
            1
        ).expand_as(spans)
        attention_spans = spans * span_masks - 1e10 * (1 - span_masks)

        max_span_repr, max_idxs = torch.max(attention_spans, dim=-2)
        # print(max_span_repr.shape)
        return max_span_repr

class AttnSpanRepr(SpanRepr, nn.Module):
    """Class implementing the attention-based span representation."""

    def __init__(self, input_dim, use_proj=False, proj_dim=256, use_endpoints=False):
        """If use_endpoints is true then concatenate the end points to attention-pooled span repr.
        Otherwise just return the attention pooled term. (use_endpoints Not Implemented)
        """
        super(AttnSpanRepr, self).__init__(input_dim, use_proj=use_proj, proj_dim=proj_dim)
        self.use_endpoints = use_endpoints
        # input_dim is embedding_dim or proj dim
        # print("input_dim", input_dim)
        self.attention_params = nn.Linear(input_dim, 1)  # Learn a weight for each token: z(k)i = W(k)att e(k)i
        # Initialize weight to zero weight
        # self.attention_params.weight.data.fill_(0)
        # self.attention_params.bias.data.fill_(0)

    def forward(self, spans, attention_mask):
        """ 
        input:
            spans: [batch_size, layers, span_max_len, proj_dim/embedding_dim] ~ [32, 13, 4, 256]
            attention_mask: [batch_size, span_max_len] ~ [32, 4]
        returns:
            [32, 13, 256]
        """
        if self.use_proj:
            encoded_input = self.proj(encoded_input)

        # span_mask = get_span_mask(start_ids, end_ids, encoded_input.shape[1])
        # attn_mask = torch.zeros(spans.shape, device=DEVICE)
        # print(datetime.datetime.now().time(), "a1")
        # print(attention_mask.shape)
        # for i in range(len(attention_mask)):
        #     for j in range(len(attention_mask[i])):
        #         if attention_mask[i][j] == 0:
        #             attn_mask[i, :, j, :] = -1e10

        span_masks_shape = attention_mask.shape
        span_masks = attention_mask.reshape(
            span_masks_shape[0],
            1,
            span_masks_shape[1],
            1
        ).expand_as(spans)
        attn_mask = - 1e10 * (1 - span_masks)
        
        # print(datetime.datetime.now().time(), "a2")

        # attn_mask = (1 - span_mask) * (-1e10)
        attn_logits = self.attention_params(spans) + attn_mask  # Decreasing the attention of padded spans by -1e10
        attention_wts = nn.functional.softmax(attn_logits, dim=-2)
        attention_term = torch.sum(attention_wts * spans, dim=-2)
        
        # if self.use_endpoints:
        #     batch_size = encoded_input.shape[0]
        #     h_start = encoded_input[torch.arange(batch_size), start_ids, :]
        #     h_end = encoded_input[torch.arange(batch_size), end_ids, :]
        #     return torch.cat([h_start, h_end, attention_term], dim=1)
        # else:
        #     return attention_term

        # print(spans.shape, attn_mask.shape)
        # print("attn_mask", attn_mask.shape)
        # print(attn_mask[sidx, :, :, 0:2])
        # print("attn_logits", attn_logits.shape)
        # print(attn_logits[sidx])
        # print("attention_wts", attention_wts.shape)
        # print(attention_wts[sidx, :, :, 0:2])
        # print("attention_term", attention_term.shape)
        # print(attention_term[sidx, :, 0:2])
        return attention_term.float()

def get_span_module(input_dim, method="max", use_proj=False, proj_dim=256):
    """Initializes the appropriate span representation class and returns the object.
    """
    if method == "avg":
        return AvgSpanRepr(input_dim, use_proj=use_proj, proj_dim=proj_dim)
    elif method == "max":
        return MaxSpanRepr(input_dim, use_proj=use_proj, proj_dim=proj_dim)
    elif method == "diff":
        return DiffSpanRepr(input_dim, use_proj=use_proj, proj_dim=proj_dim)
    elif method == "diff_sum":
        return DiffSumSpanRepr(input_dim, use_proj=use_proj, proj_dim=proj_dim)
    elif method == "endpoint":
        return EndPointRepr(input_dim, use_proj=use_proj, proj_dim=proj_dim)
    elif method == "coherent":
        return CoherentSpanRepr(input_dim, use_proj=use_proj, proj_dim=proj_dim)
    elif method == "coherent_original":
        return CoherentOrigSpanRepr(input_dim, use_proj=use_proj, proj_dim=proj_dim)
    elif method == "attn":
        return AttnSpanRepr(input_dim, use_proj=use_proj, proj_dim=proj_dim)
    elif method == "coref":
        return AttnSpanRepr(input_dim, use_proj=use_proj, proj_dim=proj_dim, use_endpoints=True)
    else:
        raise NotImplementedError

In [None]:
class Edge_probe_model(nn.Module):
    def __init__(self, num_of_spans, num_layers, input_span_len, embedding_dim, 
                 num_classes, pool_method='max', use_proj=True, proj_dim=256, 
                 hidden_dim=256, device='cuda', normalize_layers=False):
        super(Edge_probe_model, self).__init__()
        self.device = device
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.num_of_spans = num_of_spans
        self.weighing_params = nn.Parameter(torch.ones(self.num_layers))
        self.input_dim = embedding_dim * num_of_spans
        self.use_proj = use_proj
        self.proj_dim = proj_dim
        self.normalize_layers = normalize_layers

        ## Projection
        if use_proj:
            # Apply a projection layer to output of pretrained models
            # print(embedding_dim, num_layers, proj_dim)
            self.proj1 = nn.Linear(embedding_dim, proj_dim)
            if self.num_of_spans == 2:
                self.proj2 = nn.Linear(embedding_dim, proj_dim)
            # Update the input_dim
            self.input_dim = proj_dim * num_of_spans

        ## Pooling
        self.pool_method = pool_method
        input_dim = proj_dim if use_proj else embedding_dim
        self.span1_pooling_net = get_span_module(input_dim, method=pool_method).to(device)
        if self.num_of_spans == 2:
            self.span2_pooling_net = get_span_module(input_dim, method=pool_method).to(device)

        ## Classification
        self.label_net = nn.Sequential(
            nn.Linear(self.input_dim, hidden_dim),
            nn.Tanh(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, self.num_classes),
            nn.Sigmoid()
        )
        self.training_criterion = nn.BCELoss()
        self.optimizer = optim.Adam(self.parameters(), lr=5e-4, weight_decay=0)

    def forward(self, spans_torch_dict):
        span1_reprs = spans_torch_dict["span1"]
        span1_attention_mask = spans_torch_dict["span1_attention_mask"]
        if self.num_of_spans == 2:
            span2_reprs = spans_torch_dict["span2"]
            span2_attention_mask = spans_torch_dict["span2_attention_mask"]
        # print(span1_reprs.shape)
        
        ## Projection
        if self.use_proj:
            span1_reprs = self.proj1(span1_reprs)
            if self.num_of_spans == 2:
                span2_reprs = self.proj2(span2_reprs)
        
        ## Pooling
        pooled_span1 = self.span1_pooling_net(span1_reprs, span1_attention_mask)
        if self.num_of_spans == 2:
            pooled_span2 = self.span2_pooling_net(span2_reprs, span2_attention_mask)

        # print(my_dataset_handler.tokenized_dataset["train"][0])
        # print("SPAN1", span1_reprs[2, :, :, 0:5])
        # print("SPAN2", span2_reprs[2, :, :, 0:5])
        # print("MAX1", pooled_span1[2, :, 0:5])
        # print("MAX2", pooled_span2[2, :, 0:5])
        # raise "E"
        if self.normalize_layers:
            pooled_span1 = torch.nn.functional.normalize(pooled_span1, dim=-1)
            if self.num_of_spans == 2:
                pooled_span2 = torch.nn.functional.normalize(pooled_span2, dim=-1)

        if self.num_of_spans == 2:
            output = torch.cat((pooled_span1, pooled_span2), dim=-1)
        elif self.num_of_spans == 1:
            output = pooled_span1
        # print(output.shape)  # torch.Size([32, 13, 512])

        ## Mixing Weights
        wtd_encoded_repr = 0
        soft_weight = nn.functional.softmax(self.weighing_params, dim=0)
        for i in range(self.num_layers):
            # print(i, output[:, i, :].shape, torch.norm(output[:, i, :]), torch.norm(s1))
            # print(output[:, i, :][0, 0:10])
            # print(s1[0, 0:10])
            wtd_encoded_repr += soft_weight[i] * output[:, i, :]
        # wtd_encoded_repr += soft_weight[-1] * encoded_layers[:, -1, :]
        output = wtd_encoded_repr

        ## Classification
        pred_label = self.label_net(output)
        pred_label = torch.squeeze(pred_label, dim=-1)
        return pred_label

    def summary(self):
        print(self)
        pytorch_total_params = sum(p.numel() for p in self.parameters())
        pytorch_total_params_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("Total Parameters:    ", pytorch_total_params)
        print("Trainable Parameters:", pytorch_total_params_trainable)
        print("Pool Method:", self.pool_method)
        print("Projection:", self.use_proj, self.proj_dim)

# Edge Probe Trainer

In [None]:
class Trainer(ABC):
    """ Abstract Trainer Class """
    @abstractmethod
    def __init__(self, language_model, dataset_handler: Dataset_handler, 
                 verbose=True, device='cuda', edge_probe_model_checkpoint=None, 
                 pool_method="max", start_eval = False, 
                 history_checkpoint=None, up_to_layer=-1, normalize_layers=False):
        raise NotImplementedError

class Edge_probe_trainer:
    # Public:
    def __init__(self, language_model, dataset_handler: Dataset_handler, 
                 verbose=True, device='cuda', edge_probe_model_checkpoint=None, 
                 pool_method="max", start_eval = False, 
                 history_checkpoint=None, up_to_layer=-1, normalize_layers=False):
        self.dataset_handler = dataset_handler
        self.num_of_spans = self.dataset_handler.dataset_info.num_of_spans
        self.up_to_layer = up_to_layer
        self.language_model = language_model
        self.language_model.config.output_hidden_states = True
        self.device = device
        self.verbose = verbose
        self.start_eval = start_eval
        def vprint(text):
            if verbose:
                print(datetime.datetime.now().time(), text)
        self.vprint = vprint

        self.current_hidden_states = None
        self.last_input_ids = None
        self.extracted_batch_embeddings = {}

        self.vprint("Moving to device")
        for param in self.language_model.parameters():
            param.requires_grad = False
        self.language_model.eval()
        self.language_model.to(self.device)
        num_layers, input_span_len, embedding_dim, num_classes = self.get_language_model_properties()
        self.MLP_device = self.device
        if edge_probe_model_checkpoint == None:
            print("Creating New EPM")
            self.edge_probe_model = Edge_probe_model(
                num_of_spans = self.num_of_spans,
                num_layers = num_layers,
                input_span_len = input_span_len,
                embedding_dim = embedding_dim, 
                num_classes = num_classes,
                device = self.MLP_device,
                pool_method = pool_method,
                normalize_layers = normalize_layers
            )
        else:
            print("Starting From a Pretrained EPM")
            self.edge_probe_model = edge_probe_model_checkpoint
        

        if history_checkpoint is None:
            self.history = {"loss": {"train": [], "dev": [], "test": []}, 
                            "metrics": 
                            {"micro_f1": {"dev": [], "test": []}},
                            "layers_weights": []
                            }
            print("Creating New History")
        else:
            print("Using History Checkpoint")
            self.history = history_checkpoint
    
    def train(self, batch_size, epochs=3):
        tokenized_dataset = self.dataset_handler.tokenized_dataset["train"]
        tokenized_dataset_dev = self.dataset_handler.tokenized_dataset["dev"]
        tokenized_dataset_test = self.dataset_handler.tokenized_dataset["test"]

        # self.edge_probe_model.to(self.device)
        self.edge_probe_model.to(self.MLP_device)
        # self.vprint("Counting dataset rows")
        dataset_len = len(tokenized_dataset)
        dev_dataset_len = len(tokenized_dataset_dev)
        test_dataset_len = len(tokenized_dataset_test)
        print(f"Train on {dataset_len} samples, validate on {dev_dataset_len} samples, test on {test_dataset_len} samples")
        # dataset_len = 60
        if self.start_eval:
            self.update_history(epoch = 0)
        for epoch in range(epochs):
            running_loss = 0.0
            steps = 0
            self.draw_weights(epoch)
            print("----------------\n")
            self.edge_probe_model.train()
            for i in tqdm(range(0, dataset_len, batch_size), desc=f"[Epoch {epoch + 1}/{epochs}]"):
                # if int(i / batch_size) % 1000 == 0:
                #     print("memory:", psutil.virtual_memory().percent)
                self.vprint("Start")
                step = batch_size
                if i + batch_size > dataset_len:
                    step = dataset_len - i
                # print(f"WWW[{i}, {i+step})")
                
                self.vprint("Extracting")
                # self.vprint("prepare")
                spans_torch_dict = self.prepare_batch_data(tokenized_dataset, i, i + step, pad=True)
                labels = spans_torch_dict["one_hot_labels"]
                # zero the parameter gradients
                self.edge_probe_model.optimizer.zero_grad()
    
                # forward + backward + optimize
                self.vprint("Forward MLP")
                # self.vprint("epm")
                outputs = self.edge_probe_model(spans_torch_dict)
                self.vprint("Loss")
                loss = self.edge_probe_model.training_criterion(outputs.to(self.device), labels.float().to(self.device))
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.edge_probe_model.parameters(), 5.0)
                self.edge_probe_model.optimizer.step()
    
                running_loss += loss.item()
                steps += 1
                self.vprint("Done")
                # print(f"loss: {running_loss / steps}")

            self.update_history(epoch + 1, train_loss = running_loss / steps)
            

    def calc_loss(self, tokenized_dataset, batch_size=16, print_metrics=False, just_micro=False, desc=""):
        self.edge_probe_model.eval()
        with torch.no_grad():
            running_loss = 0
            dataset_len = len(tokenized_dataset["input_ids"])
            steps = 0
            preds = None
            for i in tqdm(range(0, dataset_len, batch_size), desc=desc):
                # if int(i / batch_size) % 100 == 0:
                #     print("memory:", psutil.virtual_memory().percent, gc.collect(), psutil.virtual_memory().percent)
                step = batch_size
                if i + batch_size > dataset_len:
                    step = dataset_len - i

                spans_torch_dict = self.prepare_batch_data(tokenized_dataset, i, i + step, pad=True)
                labels = spans_torch_dict["one_hot_labels"]
                # forward
                outputs = self.edge_probe_model(spans_torch_dict)
                
                preds = outputs if i == 0 else torch.cat((preds, outputs), 0)
                loss = self.edge_probe_model.training_criterion(outputs.to(self.device), labels.float().to(self.device))
                running_loss += loss.item()
                steps += 1

        preds = preds.cpu().argmax(-1)
        y_true = np.array(tokenized_dataset["one_hot_label"]).argmax(-1)
        print(preds[0:9])
        print(y_true[0:9])
        micro_f1 = f1_score(y_true, preds, average='micro')
        
        if print_metrics:
            labels_list = self.dataset_handler.labels_list
            if not just_micro:
                print(classification_report(y_true, preds, target_names=labels_list, labels=range(len(labels_list))))
            print("MICRO F1:", micro_f1)
        return running_loss / steps, micro_f1

    # Private:
    def update_history(self, epoch, train_loss = None):
        if train_loss is None:
            train_loss, train_f1 = self.calc_loss(self.dataset_handler.tokenized_dataset["train"], print_metrics=True, desc="Train Loss")
        dev_loss, dev_f1 = self.calc_loss(self.dataset_handler.tokenized_dataset["dev"], print_metrics=True, desc="Dev Loss")
        test_loss, test_f1 = self.calc_loss(self.dataset_handler.tokenized_dataset["test"], print_metrics=True, desc="Test Loss")
        self.history["loss"]["train"].append(train_loss)
        self.history["loss"]["dev"].append(dev_loss)
        self.history["loss"]["test"].append(test_loss)
        self.history["metrics"]["micro_f1"]["dev"].append(dev_f1)
        self.history["metrics"]["micro_f1"]["test"].append(test_f1)
        self.history["layers_weights"].append(self.edge_probe_model.weighing_params.tolist())
        print('[%d] loss: %.4f, val_loss: %.4f, test_loss: %.4f' % (epoch, self.history["loss"]["train"][-1], self.history["loss"]["dev"][-1], self.history["loss"]["test"][-1]))

    def draw_weights(self, epoch=0):
        if(epoch % 1 == 0):
            w = self.edge_probe_model.weighing_params.tolist()
            print(w)
            print(self.history)
            plt.bar(np.arange(len(w), dtype=int), w)
            plt.ylabel('Weight')
            plt.xlabel('Layer');
            plt.show()

            wsoft = nn.functional.softmax(self.edge_probe_model.weighing_params)
            print("CG", sum(idx*val for idx, val in enumerate(wsoft)))

            print("Loss History")
            loss_history = self.history["loss"]
            x = range(len(loss_history["train"]))
            plt.plot(x, loss_history["train"])
            plt.plot(x, loss_history["dev"])
            plt.plot(x, loss_history["test"])
            plt.legend(['Train', 'Dev', 'Test'], loc='lower left')
            plt.show()

            print("Micro f1 History")
            f1_history = self.history["metrics"]["micro_f1"]
            x = range(len(f1_history["dev"]))
            plt.plot(x, f1_history["dev"])
            plt.plot(x, f1_history["test"])
            plt.legend(['Dev', 'Test'], loc='upper left')
            plt.show()

    def prepare_batch_data(self, tokenized_dataset, start_idx, end_idx, pad=False):
        # self.vprint("Extracting From Model")
        span_representations_dict = self.extract_embeddings(tokenized_dataset, start_idx, end_idx, pad=True)
        # self.vprint("To Device")
        span1_torch = torch.stack(span_representations_dict["span1"]).float().to(self.MLP_device)  # (batch_size, #layers, max_span_len, embd_dim)
        span1_attention_mask_torch = torch.stack(span_representations_dict["span1_attention_mask"])
        one_hot_labels_torch = torch.tensor(np.array(span_representations_dict["one_hot_label"]))
        if self.num_of_spans == 2:
            span2_torch = torch.stack(span_representations_dict["span2"]).float().to(self.MLP_device)
            span2_attention_mask_torch = torch.stack(span_representations_dict["span1_attention_mask"])
            spans_torch_dict = {"span1": span1_torch, 
                                "span2": span2_torch, 
                                "span1_attention_mask": span1_attention_mask_torch, 
                                "span2_attention_mask": span2_attention_mask_torch, 
                                "one_hot_labels": one_hot_labels_torch}
        elif self.num_of_spans == 1:
            spans_torch_dict = {"span1": span1_torch, 
                                "span1_attention_mask": span1_attention_mask_torch, 
                                "one_hot_labels": one_hot_labels_torch}

        return spans_torch_dict

    def get_language_model_properties(self):
        span_representations_dict = self.extract_embeddings(self.dataset_handler.tokenized_dataset["train"], 0, 3, pad=True)
        for i in span_representations_dict["span1"]:
            print(i.shape)
        span1_torch = span_representations_dict["span1"]
        num_layers = span1_torch[0].shape[0]
        span_len = span1_torch[0].shape[1]
        embedding_dim = span1_torch[0].shape[2]
        # if self.verbose:
        #     display(pd.DataFrame(span_representations_dict))
        return num_layers, span_len, embedding_dim, len(self.dataset_handler.labels_list)

    def pad_span(self, span_repr, max_len):
        """ pad spans in embeddings to max_len 
        input:
            span_representation: df with shape (#layers, span_len, embedding_dim)
        returns:
            padded_spans: np with shape (batch_len, num_layers, max_len, embedding_dim)
            attention_mask: np with shape (max_len), values = 1: data, 0: padding
        """
        shape = span_repr.shape
        num_layers = shape[0]
        span_original_len = shape[1]
        embedding_dim = shape[2]
        # padded_span_repr = np.zeros((num_layers, max_len, embedding_dim))
        # if span_original_len > max_len:
        #     raise Exception(f"Error: {span_original_len} is more than max_span_len {max_len}\n{span_repr.shape}")
        attention_mask = torch.tensor(np.array([1] * span_original_len + [0] * (max_len - span_original_len)), dtype=torch.int8, device=self.device)
        padded_span_repr = torch.cat((span_repr, torch.zeros((num_layers, max_len - span_original_len, embedding_dim), device=self.device)), axis=1)
        # assert attention_mask.shape == (max_len, ), f"{attention_mask}, {attention_mask.shape} != ({max_len}, )"
        # assert padded_span_repr.shape == (num_layers, max_len, embedding_dim)
        return padded_span_repr, attention_mask

    def init_span_dict(self, num_of_spans, pad):
        if num_of_spans == 2:
            span_repr = {"span1": [], "span2": [], "label": [], "one_hot_label": []}
        else:
            span_repr = {"span1": [], "label": [], "one_hot_label": []}
        
        if pad:
            span_repr["span1_attention_mask"] = []
            span_repr["span2_attention_mask"] = []
        return span_repr

    def extract_batch(self, tokenized_dataset, idx, unique_batch_size=32):
        # print(idx)
        self.vprint("e1")
        dataset_len = len(tokenized_dataset)
        unique_texts_in_batch = []
        i = idx
        while len(unique_texts_in_batch) < unique_batch_size and i < dataset_len:
            # print(i)
            text = tokenized_dataset[i]["text"]
            if not text in unique_texts_in_batch:
                unique_texts_in_batch.append(text)
            i += 1
        tokenizer.padding_side = 'right'  # Important: lef will change the span indices
        tokenized_batch = tokenizer(unique_texts_in_batch, padding=True, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.language_model(**tokenized_batch)
        # torch.cuda.synchronize()
        # current_hidden_states = np.asarray([val.detach().cpu().numpy() for val in outputs.hidden_states])
        current_hidden_states = torch.stack([val.detach() for val in outputs.hidden_states])  # TODO: use only stack, no list 
        # self.vprint(current_hidden_states.shape)  # (13, 16, 34, 768)
        
        extracted_batch_embeddings = {}
        for i, unique_text in enumerate(unique_texts_in_batch):
            hashable_input = repr(unique_text)
            if self.up_to_layer == -1:
                extracted_batch_embeddings[hashable_input] = current_hidden_states[:, i, :, :]
            else:
                extracted_batch_embeddings[hashable_input] = current_hidden_states[:self.up_to_layer+1, i, :, :]
        self.vprint("e2")
        return extracted_batch_embeddings
    
    def pad_sequence(list_of_torch, pad_len, pad_value=0):
        shape = list_of_torch[0].shape
        num_layers = shape[0]
        span_original_len = shape[1]
        embedding_dim = shape[2]
        output = torch.zeros()

    def extract_embeddings(self, tokenized_dataset, start_idx, end_idx, pad=True):
        """ Extract raw embeddings for [start_idx, end_idx) of tokenized_dataset from language_model 
            
        Returns:
            extract_embeddings: DataFrame with cols (span1, span2?, label) and span shape is (range_len, (#layers, span_len, embedding_dim))
        """
        num_of_spans = self.dataset_handler.dataset_info.num_of_spans
        
        if num_of_spans == 2:
            max_span_len_in_batch = max(max(tokenized_dataset[start_idx:end_idx]["span1_len"]), max(tokenized_dataset[start_idx:end_idx]["span2_len"]))
        elif num_of_spans == 1:
            max_span_len_in_batch = max(tokenized_dataset[start_idx:end_idx]["span1_len"])
        # print("max_span_len_in_batch", max_span_len_in_batch)
        

        span_repr = self.init_span_dict(num_of_spans, pad)
        self.vprint("f1")
        for i in range(start_idx, end_idx):
            hashable_input = repr(tokenized_dataset[i]["text"])
            
            if hashable_input not in self.extracted_batch_embeddings:
                self.extracted_batch_embeddings = self.extract_batch(tokenized_dataset, i)
                            
            self.current_hidden_states = self.extracted_batch_embeddings[hashable_input]
            
            row = tokenized_dataset[i]
            span1_hidden_states = self.current_hidden_states[:, row["span1"][0]:row["span1"][1], :]  # (#layer, span_len, embd_dim)
            if pad:
                s1, a1 = self.pad_span(span1_hidden_states, max_span_len_in_batch)
                span_repr["span1"].append(s1)
                span_repr["span1_attention_mask"].append(a1)
            else:
                span_repr["span1"].append(span1_hidden_states)
            if num_of_spans == 2:
                span2_hidden_states = self.current_hidden_states[:, row["span2"][0]:row["span2"][1], :]
                if pad:
                    s2, a2 = self.pad_span(span2_hidden_states, max_span_len_in_batch)
                    span_repr["span2"].append(s2)
                    span_repr["span2_attention_mask"].append(a2)
                else:
                    span_repr["span2"].append(span2_hidden_states)
            span_repr["one_hot_label"].append(row["one_hot_label"])
            span_repr["label"].append(row["label"])
        self.vprint("f2")
        return span_repr

In [None]:
my_edge_probe_trainer = None
edge_probe_model_checkpoint = None
history = None
gc.collect()
torch.cuda.empty_cache()

In [None]:
try:
    edge_probe_model_checkpoint = my_edge_probe_trainer.edge_probe_model
except:
    edge_probe_model_checkpoint = None
my_edge_probe_trainer = Edge_probe_trainer(model,
                                           my_dataset_handler, 
                                           device=DEVICE,
                                           pool_method=POOL_METHOD,
                                           edge_probe_model_checkpoint=edge_probe_model_checkpoint,
                                           history_checkpoint=history,
                                           up_to_layer = 6,
                                           normalize_layers=True,
                                           verbose=False)

In [None]:
print("Model:", model_checkpoint)
print("Dataset:", my_dataset_info.dataset_name)
print(f"Batch Size: {BATCH_SIZE}")
my_edge_probe_trainer.edge_probe_model.summary()

In [None]:
my_edge_probe_trainer.train(batch_size = BATCH_SIZE, epochs=40)

In [None]:
torch.save(my_edge_probe_trainer.edge_probe_model.state_dict(), "EPM_xlnet-large-cased-attn_epoch3")

In [None]:
history = my_edge_probe_trainer.history
print(my_edge_probe_trainer.history)

In [None]:
print("Loss History")
loss_history = my_edge_probe_trainer.history["loss"]
print(loss_history)
print("Train Loss:", loss_history["train"])
print("Dev Loss:", loss_history["dev"])
print("Test Loss:", loss_history["test"])

x = range(len(loss_history["train"]))
plt.plot(x, loss_history["train"])
plt.plot(x, loss_history["dev"])
plt.plot(x, loss_history["test"])
plt.legend(['Train', 'Dev', 'Test'], loc='lower left')
plt.show()
print(".")

print("Micro f1 History")
f1_history = my_edge_probe_trainer.history["metrics"]["micro_f1"]
print(f1_history)
print("Dev f1:", f1_history["dev"])
print("Test f1:", f1_history["test"])



x = range(len(f1_history["dev"]))
plt.plot(x, f1_history["dev"])
plt.plot(x, f1_history["test"])
plt.legend(['Dev', 'Test'], loc='upper left')
plt.show()
print(".")

# Diagnostic Probe Trainer

In [None]:
class Diagnostic_probe_trainer:
    # Public:
    def __init__(self, language_model, dataset_handler: Dataset_handler, 
                 verbose=True, device='cuda',
                 pool_method="max", start_eval = False, normalize_layers=False):
        self.dataset_handler = dataset_handler
        self.num_of_spans = self.dataset_handler.dataset_info.num_of_spans
        self.language_model = language_model
        self.language_model.config.output_hidden_states = True
        self.device = device
        self.verbose = verbose
        self.start_eval = start_eval
        def vprint(text):
            if verbose:
                print(datetime.datetime.now().time(), text)
        self.vprint = vprint

        self.current_hidden_states = None
        self.last_input_ids = None
        self.extracted_batch_embeddings = {}

        self.vprint("Moving to device")
        for param in self.language_model.parameters():
            param.requires_grad = False
        self.language_model.eval()
        self.language_model.to(self.device)
        num_layers, input_span_len, embedding_dim, num_classes = self.get_language_model_properties()
        print(num_layers)
        self.num_layers = num_layers
        self.MLP_device = self.device
        
        print("Creating New EPM")
        self.edge_probe_models = []
        for i in range(num_layers):
            edge_probe_model = Edge_probe_model(
                num_of_spans = self.num_of_spans,
                num_layers = 1,
                input_span_len = input_span_len,
                embedding_dim = embedding_dim, 
                num_classes = num_classes,
                device = self.MLP_device,
                pool_method = pool_method,
                normalize_layers = normalize_layers
            )
            self.edge_probe_models.append(edge_probe_model)
        
        self.history = {"loss": {"train": [], "dev": [], "test": []}, 
                        "metrics": 
                        {"micro_f1": {"dev": [], "test": []}},
                        "layers_weights": []
                        }
        print("Creating New History")

    def train(self, batch_size, epochs=3):
        tokenized_dataset = self.dataset_handler.tokenized_dataset["train"]
        tokenized_dataset_dev = self.dataset_handler.tokenized_dataset["dev"]
        tokenized_dataset_test = self.dataset_handler.tokenized_dataset["test"]

        # self.edge_probe_model.to(self.device)
        for edge_probe_model in self.edge_probe_models:
            edge_probe_model.to(self.MLP_device)
        # self.vprint("Counting dataset rows")
        dataset_len = len(tokenized_dataset)
        dev_dataset_len = len(tokenized_dataset_dev)
        test_dataset_len = len(tokenized_dataset_test)
        print(f"Train on {dataset_len} samples, validate on {dev_dataset_len} samples, test on {test_dataset_len} samples")
        # dataset_len = 60
        if self.start_eval:
            self.update_history(epoch = 0)
        for epoch in range(epochs):
            running_loss = 0.0
            steps = 0
            self.draw_weights(epoch)
            print("----------------\n")
            for edge_probe_model in self.edge_probe_models:
                edge_probe_model.train()
            for i in tqdm(range(0, dataset_len, batch_size), desc=f"[Epoch {epoch + 1}/{epochs}]"):
                # if int(i / batch_size) % 1000 == 0:
                #     print("memory:", psutil.virtual_memory().percent)
                self.vprint("Start")
                step = batch_size
                if i + batch_size > dataset_len:
                    step = dataset_len - i
                # print(f"WWW[{i}, {i+step})")
                
                self.vprint("Extracting")
                # self.vprint("prepare")
                spans_torch_dict = self.prepare_batch_data(tokenized_dataset, i, i + step, pad=True)
                # print(spans_torch_dict["span1"].shape, spans_torch_dict["span1_attention_mask"].shape)
                labels = spans_torch_dict["one_hot_labels"]
                labels = labels.float().to(self.device)
                
                for epm_idx, edge_probe_model in enumerate(self.edge_probe_models):
                    # zero the parameter gradients
                    # for param_tensor in edge_probe_model.state_dict():
                    #     print(epm_idx, param_tensor, "\t", edge_probe_model.state_dict()[param_tensor].size(), torch.norm(edge_probe_model.state_dict()[param_tensor]))
                    # print(epm_idx, edge_probe_model.state_dict()["label_net.4.bias"])
                    edge_probe_model.optimizer.zero_grad()
        
                    self.vprint("dict")
                    # print(spans_torch_dict["span1"].shape) # torch.Size([32, 13, 9, 768])
                    if self.num_of_spans == 2:
                        span_torch_dict = {"span1": spans_torch_dict["span1"][:, epm_idx:epm_idx+1, :, :], 
                                           "span1_attention_mask": spans_torch_dict["span1_attention_mask"],
                                           "span2": spans_torch_dict["span2"][:, epm_idx:epm_idx+1, :, :],
                                           "span2_attention_mask": spans_torch_dict["span2_attention_mask"],
                                           }
                    else:
                        span_torch_dict = {"span1": spans_torch_dict["span1"][:, epm_idx:epm_idx+1, :, :], 
                                           "span1_attention_mask": spans_torch_dict["span1_attention_mask"]}
                    
                    # forward + backward + optimize
                    self.vprint("Forward MLP")
                    outputs = edge_probe_model(span_torch_dict)
                    self.vprint("Loss")
                    loss = edge_probe_model.training_criterion(outputs.to(self.device), labels)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(edge_probe_model.parameters(), 5.0)
                    edge_probe_model.optimizer.step()
        
                    running_loss += loss.item()
                    steps += 1
                self.vprint("Done")
                # print(f"loss: {running_loss / steps}")

            self.update_history(epoch + 1, train_loss = running_loss / steps)
            

    def calc_loss(self, tokenized_dataset, batch_size=16, print_metrics=False, just_micro=False, desc=""):
        for edge_probe_model in self.edge_probe_models:
            edge_probe_model.eval()
        with torch.no_grad():
            running_loss = 0
            dataset_len = len(tokenized_dataset["input_ids"])
            steps = 0
            preds = [None] * self.num_layers
            micro_f1 = [None] * self.num_layers
            for i in tqdm(range(0, dataset_len, batch_size), desc=desc):
                # if int(i / batch_size) % 100 == 0:
                #     print("memory:", psutil.virtual_memory().percent, gc.collect(), psutil.virtual_memory().percent)
                step = batch_size
                if i + batch_size > dataset_len:
                    step = dataset_len - i

                spans_torch_dict = self.prepare_batch_data(tokenized_dataset, i, i + step, pad=True)
                labels = spans_torch_dict["one_hot_labels"]
                labels = labels.float().to(self.device)

                for epm_idx, edge_probe_model in enumerate(self.edge_probe_models):
                    if self.num_of_spans == 2:
                        span_torch_dict = {"span1": spans_torch_dict["span1"][:, epm_idx:epm_idx+1, :, :], 
                                           "span1_attention_mask": spans_torch_dict["span1_attention_mask"],
                                           "span2": spans_torch_dict["span2"][:, epm_idx:epm_idx+1, :, :],
                                           "span2_attention_mask": spans_torch_dict["span2_attention_mask"],
                                           }
                    else:
                        span_torch_dict = {"span1": spans_torch_dict["span1"][:, epm_idx:epm_idx+1, :, :], 
                                           "span1_attention_mask": spans_torch_dict["span1_attention_mask"]}

                    # forward
                    outputs = edge_probe_model(span_torch_dict)
                    
                    preds[epm_idx] = outputs if i == 0 else torch.cat((preds[epm_idx], outputs), 0)
                    loss = edge_probe_model.training_criterion(outputs.to(self.device), labels)
                    running_loss += loss.item()
                    steps += 1

        y_true = np.array(tokenized_dataset["one_hot_label"]).argmax(-1)
        for idx, pred in enumerate(preds): 
            pred = pred.cpu().argmax(-1)
            micro_f1[idx] = f1_score(y_true, pred, average='micro')
        
        if print_metrics:
            # labels_list = self.dataset_handler.labels_list
            # if not just_micro:
            #     print(classification_report(y_true, preds, target_names=labels_list, labels=range(len(labels_list))))
            print("MICRO F1:", micro_f1)
        return running_loss / steps, micro_f1

    # Private:
    def update_history(self, epoch, train_loss = None):
        if train_loss is None:
            train_loss, train_f1 = self.calc_loss(self.dataset_handler.tokenized_dataset["train"], print_metrics=True, desc="Train Loss")
        dev_loss, dev_f1 = self.calc_loss(self.dataset_handler.tokenized_dataset["dev"], print_metrics=True, desc="Dev Loss")
        test_loss, test_f1 = self.calc_loss(self.dataset_handler.tokenized_dataset["test"], print_metrics=True, desc="Test Loss")
        self.history["loss"]["train"].append(train_loss)
        self.history["loss"]["dev"].append(dev_loss)
        self.history["loss"]["test"].append(test_loss)
        self.history["metrics"]["micro_f1"]["dev"].append(dev_f1)
        self.history["metrics"]["micro_f1"]["test"].append(test_f1)
        # self.history["layers_weights"].append(self.edge_probe_model.weighing_params.tolist())
        print('[%d] loss:' % (epoch))
        print("Train Loss:", self.history["loss"]["train"][-1])
        print("Dev Loss:", self.history["loss"]["dev"][-1])
        print("Test Loss:", self.history["loss"]["test"][-1])
        # print('[%d] loss: %.4f, val_loss: %.4f, test_loss: %.4f' % (epoch, self.history["loss"]["train"][-1], self.history["loss"]["dev"][-1], self.history["loss"]["test"][-1]))

    def draw_weights(self, epoch=0):
        if(epoch > 0):
            # w = self.edge_probe_models.weighing_params.tolist()
            # print(w)
            w = self.history["metrics"]["micro_f1"]["test"][-1]
            print(self.history)
            plt.bar(np.arange(len(w), dtype=int), w)
            plt.ylabel('f1')
            plt.xlabel('Layer');
            plt.show()

            # wsoft = nn.functional.softmax(self.edge_probe_model.weighing_params)
            # print("CG", sum(idx*val for idx, val in enumerate(wsoft)))

            print("Loss History")
            loss_history = self.history["loss"]
            x = range(len(loss_history["train"]))
            plt.plot(x, loss_history["train"])
            plt.plot(x, loss_history["dev"])
            plt.plot(x, loss_history["test"])
            plt.legend(['Train', 'Dev', 'Test'], loc='lower left')
            plt.show()

            # print("Micro f1 History")
            # f1_history = self.history["metrics"]["micro_f1"]
            # x = range(len(f1_history["dev"]))
            # plt.plot(x, f1_history["dev"])
            # plt.plot(x, f1_history["test"])
            # plt.legend(['Dev', 'Test'], loc='upper left')
            # plt.show()

    def prepare_batch_data(self, tokenized_dataset, start_idx, end_idx, pad=False):
        # self.vprint("Extracting From Model")
        span_representations_dict = self.extract_embeddings(tokenized_dataset, start_idx, end_idx, pad=True)
        # self.vprint("To Device")
        span1_torch = torch.stack(span_representations_dict["span1"]).float().to(self.MLP_device)  # (batch_size, #layers, max_span_len, embd_dim)
        span1_attention_mask_torch = torch.stack(span_representations_dict["span1_attention_mask"])
        one_hot_labels_torch = torch.tensor(np.array(span_representations_dict["one_hot_label"]))
        if self.num_of_spans == 2:
            span2_torch = torch.stack(span_representations_dict["span2"]).float().to(self.MLP_device)
            span2_attention_mask_torch = torch.stack(span_representations_dict["span1_attention_mask"])
            spans_torch_dict = {"span1": span1_torch, 
                                "span2": span2_torch, 
                                "span1_attention_mask": span1_attention_mask_torch, 
                                "span2_attention_mask": span2_attention_mask_torch, 
                                "one_hot_labels": one_hot_labels_torch}
        elif self.num_of_spans == 1:
            spans_torch_dict = {"span1": span1_torch, 
                                "span1_attention_mask": span1_attention_mask_torch, 
                                "one_hot_labels": one_hot_labels_torch}

        return spans_torch_dict

    def get_language_model_properties(self):
        span_representations_dict = self.extract_embeddings(self.dataset_handler.tokenized_dataset["train"], 0, 3, pad=True)
        for i in span_representations_dict["span1"]:
            print(i.shape)
        span1_torch = span_representations_dict["span1"]
        num_layers = span1_torch[0].shape[0]
        span_len = span1_torch[0].shape[1]
        embedding_dim = span1_torch[0].shape[2]
        # if self.verbose:
        #     display(pd.DataFrame(span_representations_dict))
        return num_layers, span_len, embedding_dim, len(self.dataset_handler.labels_list)

    def pad_span(self, span_repr, max_len):
        """ pad spans in embeddings to max_len 
        input:
            span_representation: df with shape (#layers, span_len, embedding_dim)
        returns:
            padded_spans: np with shape (batch_len, num_layers, max_len, embedding_dim)
            attention_mask: np with shape (max_len), values = 1: data, 0: padding
        """
        shape = span_repr.shape
        num_layers = shape[0]
        span_original_len = shape[1]
        embedding_dim = shape[2]
        # padded_span_repr = np.zeros((num_layers, max_len, embedding_dim))
        # if span_original_len > max_len:
        #     raise Exception(f"Error: {span_original_len} is more than max_span_len {max_len}\n{span_repr.shape}")
        attention_mask = torch.tensor(np.array([1] * span_original_len + [0] * (max_len - span_original_len)), dtype=torch.int8, device=self.device)
        padded_span_repr = torch.cat((span_repr, torch.zeros((num_layers, max_len - span_original_len, embedding_dim), device=self.device)), axis=1)
        # assert attention_mask.shape == (max_len, ), f"{attention_mask}, {attention_mask.shape} != ({max_len}, )"
        # assert padded_span_repr.shape == (num_layers, max_len, embedding_dim)
        return padded_span_repr, attention_mask

    def init_span_dict(self, num_of_spans, pad):
        if num_of_spans == 2:
            span_repr = {"span1": [], "span2": [], "label": [], "one_hot_label": []}
        else:
            span_repr = {"span1": [], "label": [], "one_hot_label": []}
        
        if pad:
            span_repr["span1_attention_mask"] = []
            span_repr["span2_attention_mask"] = []
        return span_repr

    def extract_batch(self, tokenized_dataset, idx, unique_batch_size=32):
        # print(idx)
        self.vprint("e1")
        dataset_len = len(tokenized_dataset)
        unique_texts_in_batch = []
        i = idx
        while len(unique_texts_in_batch) < unique_batch_size and i < dataset_len:
            # print(i)
            text = tokenized_dataset[i]["text"]
            if not text in unique_texts_in_batch:
                unique_texts_in_batch.append(text)
            i += 1
        tokenizer.padding_side = 'right'  # Important: lef will change the span indices
        tokenized_batch = tokenizer(unique_texts_in_batch, padding=True, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.language_model(**tokenized_batch)
        # torch.cuda.synchronize()
        # current_hidden_states = np.asarray([val.detach().cpu().numpy() for val in outputs.hidden_states])
        current_hidden_states = torch.stack([val.detach() for val in outputs.hidden_states])  # TODO: use only stack, no list 
        # self.vprint(current_hidden_states.shape)  # (13, 16, 34, 768)
        
        extracted_batch_embeddings = {}
        for i, unique_text in enumerate(unique_texts_in_batch):
            hashable_input = repr(unique_text)
            extracted_batch_embeddings[hashable_input] = current_hidden_states[:, i, :, :]
        self.vprint("e2")
        return extracted_batch_embeddings
    
    def pad_sequence(list_of_torch, pad_len, pad_value=0):
        shape = list_of_torch[0].shape
        num_layers = shape[0]
        span_original_len = shape[1]
        embedding_dim = shape[2]
        output = torch.zeros()

    def extract_embeddings(self, tokenized_dataset, start_idx, end_idx, pad=True):
        """ Extract raw embeddings for [start_idx, end_idx) of tokenized_dataset from language_model 
            
        Returns:
            extract_embeddings: DataFrame with cols (span1, span2?, label) and span shape is (range_len, (#layers, span_len, embedding_dim))
        """
        num_of_spans = self.dataset_handler.dataset_info.num_of_spans
        
        if num_of_spans == 2:
            max_span_len_in_batch = max(max(tokenized_dataset[start_idx:end_idx]["span1_len"]), max(tokenized_dataset[start_idx:end_idx]["span2_len"]))
        elif num_of_spans == 1:
            max_span_len_in_batch = max(tokenized_dataset[start_idx:end_idx]["span1_len"])
        # print("max_span_len_in_batch", max_span_len_in_batch)
        

        span_repr = self.init_span_dict(num_of_spans, pad)
        self.vprint("f1")
        for i in range(start_idx, end_idx):
            hashable_input = repr(tokenized_dataset[i]["text"])
            
            if hashable_input not in self.extracted_batch_embeddings:
                self.extracted_batch_embeddings = self.extract_batch(tokenized_dataset, i)
                            
            self.current_hidden_states = self.extracted_batch_embeddings[hashable_input]
            
            row = tokenized_dataset[i]
            span1_hidden_states = self.current_hidden_states[:, row["span1"][0]:row["span1"][1], :]  # (#layer, span_len, embd_dim)
            if pad:
                s1, a1 = self.pad_span(span1_hidden_states, max_span_len_in_batch)
                span_repr["span1"].append(s1)
                span_repr["span1_attention_mask"].append(a1)
            else:
                span_repr["span1"].append(span1_hidden_states)
            if num_of_spans == 2:
                span2_hidden_states = self.current_hidden_states[:, row["span2"][0]:row["span2"][1], :]
                if pad:
                    s2, a2 = self.pad_span(span2_hidden_states, max_span_len_in_batch)
                    span_repr["span2"].append(s2)
                    span_repr["span2_attention_mask"].append(a2)
                else:
                    span_repr["span2"].append(span2_hidden_states)
            span_repr["one_hot_label"].append(row["one_hot_label"])
            span_repr["label"].append(row["label"])
        self.vprint("f2")
        return span_repr

In [None]:
my_diagnostic_probe_trainer = Diagnostic_probe_trainer(model,
                                           my_dataset_handler, 
                                           device=DEVICE,
                                           pool_method=POOL_METHOD,
                                           normalize_layers=False,
                                           verbose=False)

In [None]:
print("Model:", model_checkpoint)
print("Dataset:", my_dataset_info.dataset_name)
print(f"Batch Size: {BATCH_SIZE}")
my_diagnostic_probe_trainer.edge_probe_models[0].summary()

In [None]:
my_diagnostic_probe_trainer.train(batch_size = BATCH_SIZE, epochs=30)

In [None]:
history = my_diagnostic_probe_trainer.history
print(my_diagnostic_probe_trainer.history)

In [None]:
print("Loss History")
loss_history = my_diagnostic_probe_trainer.history["loss"]
print(loss_history)
print("Train Loss:", loss_history["train"])
print("Dev Loss:", loss_history["dev"])
print("Test Loss:", loss_history["test"])

x = range(len(loss_history["train"]))
plt.plot(x, loss_history["train"])
plt.plot(x, loss_history["dev"])
plt.plot(x, loss_history["test"])
plt.legend(['Train', 'Dev', 'Test'], loc='lower left')
plt.show()
print(".")

print("Micro f1 History")
f1_history = my_diagnostic_probe_trainer.history["metrics"]["micro_f1"]
print(f1_history)
print("Dev f1:", f1_history["dev"])
print("Test f1:", f1_history["test"])



x = range(len(f1_history["dev"][-1]))
plt.plot(x, f1_history["dev"][-1])
plt.plot(x, f1_history["test"][-1])
plt.legend(['Dev', 'Test'], loc='upper left')
plt.show()
print(".")