<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#README-BEFORE-RUN" data-toc-modified-id="README-BEFORE-RUN-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>README BEFORE RUN</a></span></li><li><span><a href="#functions" data-toc-modified-id="functions-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>functions</a></span><ul class="toc-item"><li><span><a href="#multi-task-dataset" data-toc-modified-id="multi-task-dataset-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>multi-task dataset</a></span></li><li><span><a href="#multi-task-model" data-toc-modified-id="multi-task-model-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>multi-task model</a></span></li><li><span><a href="#trainer" data-toc-modified-id="trainer-2.3"><span class="toc-item-num">2.3&nbsp;&nbsp;</span>trainer</a></span></li></ul></li><li><span><a href="#train" data-toc-modified-id="train-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>train</a></span><ul class="toc-item"><li><span><a href="#bertology" data-toc-modified-id="bertology-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>bertology</a></span></li></ul></li></ul></div>

# README BEFORE RUN
**Before run this notebook, must find the source code of the functions below and modify them to work with dict as below**  

These functions are in \<your-env\>/site-packages/transformers/trainer_pt_utils.py. Search for the function names and replace them with the code below.

In [None]:
# !!! Must find the source code of these functions and modify them to work with dict as below

# def nested_concat(tensors, new_tensors, padding_index=-100):
#     """
#     Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
#     nested list/tuples/dict of tensors.
#     """
#     assert type(tensors) == type(
#         new_tensors
#     ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
#     if isinstance(tensors, (list, tuple)):
#         return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
#     elif isinstance(tensors, torch.Tensor):
#         return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
#     elif isinstance(tensors, np.ndarray):
#         return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
#     elif isinstance(tensors, dict): ### JOEY EDITTED
#         return {k: nested_concat(tensors[k], new_tensors[k], padding_index=padding_index) for k in tensors}
#     else:
#         raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")

# def nested_numpify(tensors):
#     "Numpify `tensors` (even if it's a nested list/tuple/dict of tensors)."
#     if isinstance(tensors, (list, tuple)):
#         return type(tensors)(nested_numpify(t) for t in tensors)
#     elif isinstance(tensors, dict): ### JOEY EDITTED
#         return {k: nested_numpify(tensors[k]) for k in tensors}
#     return tensors.cpu().numpy()

# def nested_detach(tensors):
#     "Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
#     if isinstance(tensors, (list, tuple)):
#         return type(tensors)(nested_detach(t) for t in tensors)
#     elif isinstance(tensors, dict): ### JOEY EDITTED
#         return {k:nested_detach(tensors[k]) for k in tensors}
#     return tensors.detach()

# def nested_truncate(tensors, limit):
#     "Truncate `tensors` at `limit` (even if it's a nested list/tuple of tensors)."
#     if isinstance(tensors, (list, tuple)):
#         return type(tensors)(nested_truncate(t, limit) for t in tensors)
#     elif isinstance(tensors, dict): ### JOEY EDITTED
#         return {k: nested_truncate(tensors[k], limit) for k in tensors}
#     return tensors[:limit]


# functions

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os
import collections
import json
from dataclasses import dataclass, asdict
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm, trange

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch.optim import lr_scheduler
import torchmetrics
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from transformers import *
from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput

import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

result_folder = os.environ["scratch_result_folder"] if "scratch_result_folder" in os.environ else './result'
scratch_data_folder = os.environ["scratch_data_folder"] if "scratch_data_folder" in os.environ else None
data_folder = '../data'

# https://github.com/huggingface/transformers/issues/5486
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [None]:
# Dictionary: task_name: number_of_labels
with open(f'{data_folder}/pastel/pastel_tasks2labels.json', 'r') as f:
    tasks2labels = json.load(f)
# Dictionary: task_name: task index
tasks2idx = {k:i for i,k in enumerate(tasks2labels)}

In [None]:
tasks2labels

{'country': 3,
 'politics': 3,
 'tod': 5,
 'age': 8,
 'education': 10,
 'ethnic': 10,
 'gender': 3}

In [None]:
@dataclass
class MyTrainingArgs:
    # training args
    selected_tasks: List
    base_model_name: str 
    freeze_bert: bool
    use_pooler: bool
    num_epoch: int
    lr: float = 5e-5
    num_warmup_steps = 500
    model_folder: str = None # this will be inferred based on tasks
    model_name: str = None # if provide, use to name model_folder, otherwise use style to name model_folder
        
    # data loader args
    batch_size: int = 32
    max_length: int = 64
    shuffle: bool = False
    num_workers: int = 4
    data_limit: int = None # if not None, truncate dataset to keep only top {data_limit} rows
    
    # post training args
    save_best: bool = True
    load_best_at_end: bool = True
    
    def __post_init__(self):
        excute_time = datetime.now() 
        model_name = self.model_name if self.model_name else '+'.join(self.selected_tasks)
        model_folder = f"{result_folder}/{model_name}/{excute_time.now().strftime('%Y%m%d-%H:%M:%S')}"
        self.model_folder = model_folder

## multi-task dataset

One sentence with multiple labels

In [None]:
class MyDataset(Dataset): 
    # currently it's a Mapping-style dataset. Not sure if a Iterable-style dataset will be better
    # limit: use to truncate dataset. This will drop rows after certain index. May influence label distribution.
    def __init__(self, training_args, split):
        self.max_length = training_args.max_length
        self.split = split
        self.tokenizer = AutoTokenizer.from_pretrained(training_args.base_model_name)
        self.df = pd.read_csv(f'{data_folder}/pastel/processed/{self.split}/pastel.csv')
        self.df = self.df.dropna()
        self.df = self.df.reset_index(drop=True)

        if training_args.data_limit:
            self.df = self.df.iloc[:training_args.data_limit]
            
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        dataslice = self.df.iloc[idx]
        item = {k: v for k, v in self.tokenizer(dataslice['output.sentences'], truncation=True, padding=True, max_length=self.max_length).items()}
        item.update({k: dataslice[k] for k in tasks2labels}) 
        return item


## multi-task model

Given selected tasks, the model will add corresponding classification heads on the top of pretrained bert/(other bert). 

In [None]:
class RegressionHead(nn.Module):
    def __init__(self, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.dropout = nn.Dropout(0.1)
        self.hidden = nn.Linear(embedding_dim, 1)
        
        self.loss_fn = nn.MSELoss()
    def forward(self, sent_emb, label):
        output = self.hidden(self.dropout(sent_emb)).squeeze(1)

        loss = self.loss_fn(output, label)
        return output, loss

In [None]:
class ClassificationHead(nn.Module):
    def __init__(self, num_labels, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.num_labels = num_labels
        self.dropout = nn.Dropout(0.1)
        self.hidden = nn.Linear(embedding_dim, self.num_labels)
        
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, sent_emb, label):
        output = self.hidden(self.dropout(sent_emb))
        
        loss = self.loss_fn(output.view(-1, self.num_labels), label.view(-1))
        return output, loss

In [None]:
@dataclass
class MultiTaskOutput(ModelOutput):
    loss: torch.FloatTensor = None
    sent_emb: torch.FloatTensor = None
    all_logits: Optional[Dict[str, torch.FloatTensor]] = None
    bert_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    bert_attentions: Optional[Tuple[torch.FloatTensor]] = None

In [None]:
class MultiTaskBert(PreTrainedModel):
    def __init__(self, config, training_args):
        super().__init__(config)
#         self.training_args = training_args
        self.tasks = training_args.selected_tasks
        self.use_pooler = training_args.use_pooler
        self.basemodel = AutoModel.from_pretrained(training_args.base_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(training_args.base_model_name)
        self.style_heads = nn.ModuleList()
        
        for task in self.tasks:
            if tasks2labels[task] == 1:
                self.style_heads.append(RegressionHead())
            else:
                self.style_heads.append(ClassificationHead(tasks2labels[task]))
                
    def forward(self, input_ids, token_type_ids, attention_mask, return_logits=False, return_sent_emb=True, **kwargs):
        output = self.basemodel(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        if self.use_pooler and ('pooler_output' in output):
            sent_emb = output['pooler_output']
        else:
            sent_emb = output['last_hidden_state'][:,0,:]
        
        total_loss = None
        all_logits = None
        if return_logits:
            all_logits = {}
        all_logits = {}
        for task in kwargs:
            i_task = tasks2idx[task]
            logits, loss = self.style_heads[i_task](sent_emb, kwargs[task]) 
            if total_loss is None:
                total_loss = loss
            else:
                total_loss += loss
            if return_logits:
                all_logits[task] = logits.detach()
        return MultiTaskOutput(loss=total_loss, sent_emb=sent_emb, all_logits=all_logits, bert_hidden_states=output.hidden_states, bert_attentions=output.attentions)
    
    

In [None]:
def init_model(training_args):
    config = AutoConfig.from_pretrained(training_args.base_model_name) 
    model = MultiTaskBert(config, training_args).to(device)
    return model

In [None]:
def freeze_model(model, freeze_bert):
    '''
    if freeze_bert == True, freeze all layer. 
    if freeze_bert is a positive integer, freeze the bottom {freeze_bert} attention layers
    negative integer should also work
    '''
    if freeze_bert==True:
        for param in model.basemodel.parameters():
            param.requires_grad = False
    elif isinstance(freeze_bert, int):
        for layer in model.basemodel.encoder.layer[:freeze_bert]: 
            for param in layer.parameters():
                param.requires_grad = False  

## trainer

In [None]:
def nested_detach(tensors):
    "Detach `tensors` (even if it's a nested list/tuple of tensors)."
    if isinstance(tensors, (list, tuple)):
        return type(tensors)(nested_detach(t) for t in tensors)
    if isinstance(tensors, dict):
        return {k:nested_detach(tensors[k]) for k in tensors}
    return tensors.detach()

In [None]:
def nested_to(tensors, device):
    if isinstance(tensors, (list, tuple)):
        return type(tensors)(nested_to(t, device) for t in tensors)
    if isinstance(tensors, dict):
        return {k: nested_to(tensors[k], device) for k in tensors}
    return tensors.to(device)

In [None]:
class MyTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)    

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        return (outputs.loss, outputs.all_logits) if return_outputs else outputs.loss
    
    def prediction_step(self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        inputs = nested_to(inputs, model.device)
        labels = {}
        for task in model.tasks:
            labels[task] = inputs[task]
        outputs = model(**inputs, return_logits=True)
        loss = outputs.loss.detach()
        
        if prediction_loss_only:
            return (loss, None, None)
        logits = nested_detach(outputs.all_logits)
        return (loss, logits, labels)    
            

In [None]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions # .argmax(-1)
    res = {}
    for task in labels:
        if tasks2labels[task] == 2:
            average = 'binary'
        else:
            average = 'macro'
        precision, recall, f1, _ = precision_recall_fscore_support(labels[task], preds[task].argmax(-1), average=average)
        acc = accuracy_score(labels[task], preds[task].argmax(-1))
        res.update({
            f'accuracy_{task}': acc,
            f'f1_{task}': f1,
            f'precision_{task}': precision,
            f'recall_{task}': recall
        })
    return res

# train

In [None]:
my_training_args = MyTrainingArgs(selected_tasks=list(tasks2labels.keys()),
                             base_model_name='bert-base-uncased',
                             freeze_bert=False,
                             use_pooler=False,
                             num_epoch=5,
                             data_limit=30000,
                            )

hg_training_args = TrainingArguments(
    output_dir=my_training_args.model_folder,   # output directory
    num_train_epochs=my_training_args.num_epoch,     # total number of training epochs
    per_device_train_batch_size=my_training_args.batch_size,  # batch size per device during training
    per_device_eval_batch_size=my_training_args.batch_size,   # batch size for evaluation
    warmup_steps=my_training_args.num_warmup_steps,    # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir=f"{my_training_args.model_folder}/logs",  # directory for storing logs
    logging_first_step = True, 
    evaluation_strategy="steps",     # evaluate each `logging_steps`
    save_total_limit = 1,
    save_strategy = 'epoch',
#     load_best_model_at_end=True, # decide on loss
)

model = init_model(my_training_args)
freeze_model(model, my_training_args.freeze_bert)

train_dataset = MyDataset(my_training_args, 'train')
val_dataset = MyDataset(my_training_args, 'valid')

trainer = MyTrainer(
    model=model,   # the instantiated Transformers model to be trained
    args=hg_training_args,                  # training arguments, defined above
    tokenizer=model.tokenizer, 
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,          # evaluation dataset
    compute_metrics=compute_metrics,     # the callback that computes metrics of interest
)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
trainer.train()

***** Running training *****
  Num examples = 30000
  Num Epochs = 5
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 4690


Step,Training Loss,Validation Loss,Accuracy Country,F1 Country,Precision Country,Recall Country,Accuracy Politics,F1 Politics,Precision Politics,Recall Politics,Accuracy Tod,F1 Tod,Precision Tod,Recall Tod,Accuracy Age,F1 Age,Precision Age,Recall Age,Accuracy Education,F1 Education,Precision Education,Recall Education,Accuracy Ethnic,F1 Ethnic,Precision Ethnic,Recall Ethnic,Accuracy Gender,F1 Gender,Precision Gender,Recall Gender
500,7.4657,6.71232,0.975188,0.329146,0.325063,0.333333,0.471178,0.338772,0.317624,0.37616,0.448622,0.207172,0.432388,0.256449,0.421053,0.163771,0.201873,0.175018,0.337093,0.165237,0.312272,0.186655,0.82807,0.182579,0.178509,0.188811,0.73609,0.443897,0.513352,0.44187
1000,6.6042,6.594169,0.975188,0.329146,0.325063,0.333333,0.490727,0.355664,0.390922,0.391111,0.43183,0.235156,0.234724,0.264786,0.434085,0.188127,0.240019,0.187776,0.41604,0.196126,0.251324,0.196373,0.828321,0.204089,0.296138,0.201432,0.737343,0.441228,0.523945,0.440119
1500,6.3557,6.53455,0.975188,0.329146,0.325063,0.333333,0.489724,0.383332,0.468841,0.40173,0.450376,0.218197,0.314425,0.262195,0.438847,0.204605,0.234889,0.204669,0.419048,0.218343,0.263148,0.20741,0.831328,0.21799,0.379092,0.210724,0.736842,0.455421,0.495026,0.451151
2000,6.198,6.552981,0.975188,0.329146,0.325063,0.333333,0.513283,0.427695,0.476989,0.432176,0.44812,0.239641,0.324544,0.269985,0.449875,0.217161,0.243137,0.213642,0.415789,0.225031,0.268267,0.21574,0.83183,0.221282,0.32001,0.213328,0.739599,0.462612,0.491339,0.457825
2500,5.7568,6.531356,0.975188,0.335879,0.491809,0.336684,0.515288,0.429238,0.491544,0.434767,0.452381,0.225795,0.316301,0.265657,0.453383,0.21046,0.250206,0.209222,0.429825,0.231064,0.293924,0.224936,0.829073,0.235061,0.385206,0.222489,0.746867,0.464859,0.502547,0.459645
3000,5.4215,6.76831,0.975689,0.349031,0.575305,0.343557,0.518797,0.425691,0.490425,0.43317,0.447619,0.230828,0.299303,0.266513,0.452381,0.218716,0.236233,0.216365,0.417794,0.232732,0.301486,0.230848,0.827569,0.238431,0.356211,0.223541,0.737093,0.462652,0.486474,0.458139
3500,4.9726,6.846441,0.97594,0.361321,0.563562,0.350344,0.508772,0.455003,0.463876,0.45266,0.442356,0.263891,0.357665,0.27924,0.449123,0.226067,0.233774,0.226387,0.416541,0.241904,0.289828,0.233589,0.82807,0.242996,0.338047,0.226461,0.740602,0.462699,0.492899,0.457873
4000,4.6453,7.032868,0.975188,0.34856,0.491968,0.343386,0.517544,0.451317,0.470335,0.44997,0.425815,0.264964,0.306352,0.278487,0.442857,0.222537,0.229282,0.220794,0.425564,0.241917,0.294418,0.232075,0.828571,0.252534,0.378623,0.232904,0.734837,0.464048,0.481542,0.459938
4500,4.3209,7.108984,0.975689,0.355117,0.547607,0.346908,0.511779,0.455131,0.464497,0.452666,0.433083,0.262308,0.329649,0.277303,0.443358,0.225579,0.230294,0.225131,0.415789,0.237092,0.279466,0.231582,0.82782,0.250089,0.38145,0.231965,0.733584,0.462331,0.481048,0.458213


***** Running Evaluation *****
  Num examples = 3990
  Batch size = 32
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Saving model checkpoint to /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919-15:01:59/checkpoint-938
Configuration saved in /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919-15:01:59/checkpoint-938/config.json
Model weights saved in /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919-15:01:59/checkpoint-938/pytorch_model.bin
tokenizer config file saved in /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919-15:01:59/checkpoint-938/tokenizer_config.json
Special

***** Running Evaluation *****
  Num examples = 3990
  Batch size = 32
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Saving model checkpoint to /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919-15:01:59/checkpoint-2814
Configuration saved in /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919-15:01:59/checkpoint-2814/config.json
Model weights saved in /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919-15:01:59/checkpoint-2814/pytorch_model.bin
tokenizer config file saved in /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919-15:01:59/checkpoint-2814/tokenizer_config.json
Spe

***** Running Evaluation *****
  Num examples = 3990
  Batch size = 32
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Saving model checkpoint to /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919-15:01:59/checkpoint-4690
Configuration saved in /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919-15:01:59/checkpoint-4690/config.json
Model weights saved in /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919-15:01:59/checkpoint-4690/pytorch_model.bin
tokenizer config file saved in /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919-15:01:59/checkpoint-4690/tokenizer_config.json
Special tokens file saved in /scratch/data_jz17d/result/country+politics+tod+age+education+ethnic+gender/20220919

TrainOutput(global_step=4690, training_loss=5.691822536159426, metrics={'train_runtime': 1120.7916, 'train_samples_per_second': 133.834, 'train_steps_per_second': 4.185, 'total_flos': 2206427450800320.0, 'train_loss': 5.691822536159426, 'epoch': 5.0})

## bertology

In [None]:
# def entropy(p):
#     """ Compute the entropy of a probability distribution """
#     plogp = p * torch.log(p)
#     plogp[p == 0] = 0
#     return -plogp.sum(dim=-1)

In [None]:
# def compute_heads_importance(
#     model, eval_dataloader, training_args, diagnose_per_step=False, diagnose_normalize=True, compute_entropy=True, compute_importance=True, head_mask=None, 
#     dont_normalize_importance_by_layer = True, dont_normalize_global_importance=True
# ):
#     """ This method shows how to compute:
#         - head attention entropy
#         - head importance scores according to http://arxiv.org/abs/1905.10650
#     """
#     model_folder = training_args.model_folder
    
#     # Prepare our tensors
#     n_layers, n_heads = model.basemodel.config.num_hidden_layers, model.basemodel.config.num_attention_heads
#     head_importance = torch.zeros(n_layers, n_heads).to(device)
#     attn_entropy = torch.zeros(n_layers, n_heads).to(device)

#     if head_mask is None:
#         head_mask = torch.ones(n_layers, n_heads).to(device)
#     head_mask.requires_grad_(requires_grad=True)
#     preds = None
#     labels = None
#     tot_tokens = 0.0
#     if diagnose_per_step:
#         entropy_per_step = None
#         importance_per_step = None

#     for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration")):
#         i_task, batch = batch
#         label_ids = batch['label'].to(device)
#         size = len(label_ids)
#         del batch['label']
#         batch = model.tokenizer(**batch, return_tensors='pt', padding=True, truncation=True, max_length=64).to(device)
#         input_ids, input_mask, segment_ids = batch['input_ids'], batch['attention_mask'], batch['token_type_ids']
        
#         # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
#         outputs = model(i_task=i_task,
#             input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, label=label_ids, head_mask=head_mask, 
#             output_attentions = True, 
#         )
#         loss, logits, all_attentions = (
#             outputs.loss,
#             outputs.logits,
#             outputs.attentions,
#         )  # Loss and logits are the first, attention the last
#         loss.backward()  # Backpropagate to populate the gradients in the head mask
        
#         batch_entropy = torch.zeros(n_layers, n_heads).to(device) 
#         if compute_entropy:
#             for layer, attn in enumerate(all_attentions):
#                 masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1)
#                 batch_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()
#                 attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()

#         if compute_importance:
#             batch_importance = head_mask.grad.abs().detach()
#             head_importance += batch_importance

#         # Also store our logits/labels if we want to compute metrics afterwards
#         if preds is None:
#             preds = logits.detach().cpu().numpy()
#             labels = label_ids.detach().cpu().numpy()
#         else:
#             preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
#             labels = np.append(labels, label_ids.detach().cpu().numpy(), axis=0)
        
#         batch_num_tokens = input_mask.float().detach().sum().item()
#         tot_tokens += batch_num_tokens
        
#         if diagnose_per_step:
#             if diagnose_normalize:
#                 batch_entropy = batch_entropy.detach().cpu().unsqueeze(0).numpy()/batch_num_tokens
#                 batch_importance = batch_importance.cpu().unsqueeze(0).numpy()/batch_num_tokens
                
#             else:
#                 batch_entropy = batch_entropy.detach().cpu().unsqueeze(0).numpy()
#                 batch_importance = batch_importance.detach().cpu().unsqueeze(0).numpy()
                
#             if entropy_per_step is None:
#                 entropy_per_step = batch_entropy
#             else:
#                 entropy_per_step = np.append(entropy_per_step, batch_entropy, axis=0)
#             if importance_per_step is None:
#                 importance_per_step = batch_importance
#             else:
#                 importance_per_step = np.append(importance_per_step, batch_importance, axis=0)
    
#     # Normalize
#     attn_entropy /= tot_tokens
#     head_importance /= tot_tokens
#     # Layerwise importance normalization
#     if not dont_normalize_importance_by_layer:
#         exponent = 2
#         norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
#         head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20

#     if not dont_normalize_global_importance:
#         head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())

#     # save matrices
#     np.save(os.path.join(model_folder, "attn_entropy.npy"), attn_entropy.detach().cpu().numpy())
#     np.save(os.path.join(model_folder, "head_importance.npy"), head_importance.detach().cpu().numpy())

#     head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=device)
#     head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
#         head_importance.numel(), device=device
#     )
#     head_ranks = head_ranks.view_as(head_importance)
    
#     plt.figure(figsize = (9,4))
#     plt.subplot(1,2,1)
#     plt.title('attn_entropy')
#     plt.imshow(attn_entropy.detach().cpu().numpy())
#     plt.colorbar()
#     plt.subplot(1,2,2)
#     plt.title('head_importance')
#     plt.imshow(head_importance.detach().cpu().numpy())
#     plt.colorbar()
#     plt.show()
    
#     if diagnose_per_step:
#         return attn_entropy, head_importance, preds, labels, entropy_per_step, importance_per_step
    
#     return attn_entropy, head_importance, preds, labels

In [None]:
# def imshow(torch_mat):
#     plt.imshow(torch_mat.detach().cpu().numpy())
#     plt.show()

In [None]:
# eval_dataloader = MultiTaskTestDataLoader(training_args, split='dev')
# attn_entropy, head_importance, preds, labels = compute_heads_importance(model, eval_dataloader, training_args)

# imshow(attn_entropy)
# imshow(head_importance)