<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#functions" data-toc-modified-id="functions-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>functions</a></span><ul class="toc-item"><li><span><a href="#multi-task-dataloader" data-toc-modified-id="multi-task-dataloader-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>multi-task dataloader</a></span></li><li><span><a href="#multi-task-model" data-toc-modified-id="multi-task-model-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>multi-task model</a></span></li><li><span><a href="#train-and-validate" data-toc-modified-id="train-and-validate-1.3"><span class="toc-item-num">1.3&nbsp;&nbsp;</span>train and validate</a></span></li><li><span><a href="#bertology" data-toc-modified-id="bertology-1.4"><span class="toc-item-num">1.4&nbsp;&nbsp;</span>bertology</a></span></li></ul></li></ul></div>

# functions

In [None]:
# %load_ext pycodestyle_magic
# %flake8_on

In [None]:
import sys
import os
import collections
import json
from dataclasses import dataclass, asdict
from typing import List
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm, trange

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch.optim import lr_scheduler
import torchmetrics
from transformers import *
from transformers.modeling_outputs import SequenceClassifierOutput

import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
# https://github.com/huggingface/transformers/issues/5486
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# Dictionary: task_name: number_of_labels
with open('../data/xslue/tasks.json', 'r') as f:
    tasks2labels = json.load(f)
tasks2labels

In [None]:
# Dictionary: task_name: task index
tasks2idx = {k:i for i,k in enumerate(tasks2labels)}
tasks2idx

In [None]:
# some binary tasks and their (train) dataset size 
selected_task = ['PASTEL_country', # 33224
#                  'SARC', # 205645
                 'SarcasmGhosh', # 39780
                 'ShortHumor', # 37801
#                  'ShortJokeKaggle', # 406682
#                  'ShortRomance', # 1902
#                  'TroFi', # 3335
                 'VUA', # 15157
                ] 


In [None]:
@dataclass
class TrainingArgs:
    # training args
    selected_tasks: List
    base_model_name: str 
    freeze_bert: bool
    use_pooler: bool
    num_epoch: int
    lr: float = 5e-5
    num_warmup_steps = 500
    model_folder: str = None # this will be inferred based on tasks
        
    # data loader args
    batch_size: int = 32
    max_length: int = 64
    shuffle: bool = False
    num_workers: int = 4
    data_limit: int = None # if not None, truncate dataset to keep only top {data_limit} rows
    
    # post training args
    save_best: bool = True
    load_best_at_end: bool = True
    
    def __post_init__(self):
        excute_time = datetime.now() 
        result_folder = '../result'
        model_folder = f"{result_folder}/{'+'.join(self.selected_tasks)}/{excute_time.now().strftime('%Y%m%d-%H:%M:%S')}"
        Path(model_folder).mkdir(parents=True, exist_ok=True)
        self.model_folder = model_folder

## multi-task dataloader

Test/validation dataloader consume dataset one by one, where as the train dataloader do it randomly. So the train dataloader is more complicated than test/validation dataloader. It must be able to reset a dataset once it is exhausted.

In [None]:
class MyDataset(Dataset): 
    # currently it's a Mapping-style dataset. Not sure if a Iterable-style dataset will be better
    # limit: use to truncate dataset. This will drop rows after certain index. May influence label distribution.
    def __init__(self, tsv_file, data_limit=None):
        self.df = pd.read_csv(tsv_file, sep='\t')
        self.df = self.df.dropna()
        self.df = self.df.reset_index(drop=True)
        if self.df['label'].dtype == 'float64':
            self.df['label'] = self.df['label'].astype('float32')
        if data_limit:
            self.df = self.df.iloc[:data_limit]
            
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        dataslice = self.df.iloc[idx]
        sample = {'text':dataslice['text'], 'label':dataslice['label']}
        return sample


In [None]:
class MultiTaskTrainDataLoader():
    '''
    Each time, a random integer selects a dataset and load a batch of data {text, label} from it. Return i_task and data
    
    a iterator
    Known issue: large dataset may have not iterate once, small datasets may have been iterated many times
    '''
    
    def __init__(self, training_args):
        self.tasks = training_args.selected_tasks
        self.split = 'train'
        self.batch_size = training_args.batch_size
        self.shuffle = training_args.shuffle
        self.num_workers = training_args.num_workers
        self.data_limit = training_args.data_limit
        
        self.num_tasks = len(self.tasks)
        self.datasets = []
        self.dataloaders = []
        self.dataloaderiters = []
        self.len = 0
        for task in self.tasks:
            self.datasets.append(MyDataset(f'../data/xslue/processed/{self.split}/{task}.tsv', data_limit=self.data_limit))
            self.dataloaders.append(DataLoader(self.datasets[-1], batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers)) 
            self.dataloaderiters.append(self.dataloaders[-1]._get_iterator())
            self.len += len(self.dataloaders[-1])
    def __len__(self):   
        return self.len

    def __iter__(self):
        self.n = 0
        return self
    
    def __next__(self):
        i_task = np.random.randint(self.num_tasks)
        if self.n < self.len:
            self.n += 1
        else:
            raise StopIteration
        dataloaderiter = self.dataloaderiters[i_task]    
        try: 
            batch = next(dataloaderiter)
        except StopIteration:
#             self.dataloaderiters[i_task]._reset(self.dataloaders[i_task])
#             dataloaderiter = self.dataloaderiters[i_task]
            self.dataloaderiters[i_task] = iter(self.dataloaders[i_task])
            dataloaderiter = self.dataloaderiters[i_task]
            batch = next(dataloaderiter)
        return i_task, batch

In [None]:
class MultiTaskTestDataLoader():
    '''
    Used for evaluation
    
    a generator
    '''
    
    def __init__(self, training_args, split):
        assert split in ['train', 'dev', 'test'], 'not implemented'
        self.tasks = training_args.selected_tasks
        self.split = split
        self.batch_size = training_args.batch_size
        self.shuffle = training_args.shuffle
        self.num_workers = training_args.num_workers
        self.data_limit = training_args.data_limit
        
        self.num_tasks = len(self.tasks)
        self.datasets = []
        self.dataloaders = []
        self.dataloaderiters = []
        self.len = 0
        for task in self.tasks:
            self.datasets.append(MyDataset(f'../data/xslue/processed/{self.split}/{task}.tsv', data_limit=self.data_limit))
            self.dataloaders.append(DataLoader(self.datasets[-1], batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers))
            self.len += len(self.dataloaders[-1])
        self.i_task = 0
    def __len__(self):   
        return self.len

    def __iter__(self):
        for i_task in range(self.num_tasks):
            dataloader = self.dataloaders[i_task]
            for batch in dataloader:
                yield i_task, batch
    
        

## multi-task model

Given selected tasks, the model will add corresponding classification heads on the top of pretrained bert/(other bert). 

In [None]:
class RegressionHead(nn.Module):
    def __init__(self, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.dropout = nn.Dropout(0.1)
        self.hidden = nn.Linear(embedding_dim, 1)
        
        self.loss_fn = nn.MSELoss()
    def forward(self, sent_emb, label):
        output = self.hidden(self.dropout(sent_emb)).squeeze(1)

        loss = self.loss_fn(output, label)
        return output, loss

In [None]:
class ClassificationHead(nn.Module):
    def __init__(self, num_labels, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.num_labels = num_labels
        self.dropout = nn.Dropout(0.1)
        self.hidden = nn.Linear(embedding_dim, self.num_labels)
        
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, sent_emb, label):
        output = self.hidden(self.dropout(sent_emb))
        
        loss = self.loss_fn(output.view(-1, self.num_labels), label.view(-1))
        return output, loss

In [None]:
class MultiTaskBert(PreTrainedModel):
    def __init__(self, config, training_args):
        super().__init__(config)
#         self.training_args = training_args
        self.use_pooler = training_args.use_pooler
        self.basemodel = AutoModel.from_pretrained(training_args.base_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(training_args.base_model_name)
        self.style_heads = nn.ModuleList()
        
        for task in training_args.selected_tasks:
            if tasks2labels[task] == 1:
                self.style_heads.append(RegressionHead())
            else:
                self.style_heads.append(ClassificationHead(tasks2labels[task]))
    def forward(self, i_task=None, label=None, **kwargs):
        output = self.basemodel(**kwargs)
        if self.use_pooler and ('pooler_output' in output):
            sent_emb = output['pooler_output']
        else:
            sent_emb = output['last_hidden_state'][:,0,:]
        
        logits, loss = self.style_heads[i_task](sent_emb, label) 
        return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=output.hidden_states, attentions=output.attentions)
    
    

## train and validate

In [None]:
def validate(model, training_args, split):
    val_loss = collections.defaultdict(float)
    val_size = collections.defaultdict(int)
    overall_acc = torchmetrics.Accuracy() 
    task_accs = [torchmetrics.Accuracy() for i in range(len(selected_task))] 
    
    mt_dataloader = MultiTaskTestDataLoader(training_args, split=split)
        
    model.eval()
    for data in tqdm(mt_dataloader, leave=False):  
        i_task, batch = data
        label = batch['label'].to(device)
        size = len(label)
        del batch['label']
        tokens = model.tokenizer(**batch, return_tensors='pt', padding=True, truncation=True, max_length=64).to(device)
        output = model(**tokens, i_task=i_task,  label=label)
        loss = output.loss
        logits = output.logits
        overall_acc.update(logits.to('cpu').detach(), label.to('cpu').detach())
        task_accs[i_task].update(logits.to('cpu').detach(), label.to('cpu').detach())
        tokens = None
        output = None
        val_loss[i_task] += loss.detach().item()*size
        val_size[i_task] += size
    
    accs = []
    for i_task in val_loss:
        val_loss[i_task] /= val_size[i_task]
        accs.append(task_accs[i_task].compute())
    model.train()
    
    return val_loss, overall_acc.compute(), accs


In [None]:
def init_model(training_args):
    config = AutoConfig.from_pretrained(training_args.base_model_name) 
    model = MultiTaskBert(config, training_args).to(device)
    return model

In [None]:
def freeze_model(model, freeze_bert):
    '''
    if freeze_bert == True, freeze all layer. 
    if freeze_bert is a positive integer, freeze the bottom {freeze_bert} attention layers
    negative integer should also work
    '''
    if freeze_bert==True:
        for param in model.basemodel.parameters():
            param.requires_grad = False
    elif isinstance(freeze_bert, int):
        for layer in model.basemodel.encoder.layer[:freeze_bert]: 
            for param in layer.parameters():
                param.requires_grad = False  

In [None]:
def load_best_model(model, model_folder):
    model.load_state_dict(torch.load(f"{model_folder}/pytorch_model.bin"))
    return model

In [None]:
def train_model(model, training_args):
    
    # these two will be used frequently
    model_folder = training_args.model_folder
    selected_tasks = training_args.selected_tasks
    
    train_dataloader = MultiTaskTrainDataLoader(training_args)
    num_training_steps = training_args.num_epoch*len(train_dataloader)
    
    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    scheduler = get_scheduler("linear",
                                optimizer=optimizer,
                                num_warmup_steps=training_args.num_warmup_steps,
                                num_training_steps=num_training_steps)

    # create dataframes for logging
    columns = ['i_epoch', 'train_loss'] + [f'train_loss_{selected_tasks[i]}' for i in range(len(selected_tasks))]
    columns += ['train_acc'] + [f'train_acc_{selected_tasks[i]}' for i in range(len(selected_tasks))]
    columns += ['val_loss'] + [f'val_loss_{selected_tasks[i]}' for i in range(len(selected_tasks))]
    columns += ['val_acc'] + [f'val_acc_{selected_tasks[i]}' for i in range(len(selected_tasks))]
    df_evaluation = pd.DataFrame(columns=columns)
    df_loss_per_step = pd.DataFrame(columns=['i_epoch', 'i_iter', 'i_task', 'task_name', 'train_loss'])
    
    best_accuracy = 0.0
    progress_bar = tqdm(range(num_training_steps))
    for i_epoch in range(training_args.num_epoch):
        for i_iter, data in enumerate(train_dataloader):  
            i_task, batch = data
            optimizer.zero_grad()
            label = batch['label'].to(device)
            del batch['label']
            tokens = model.tokenizer(**batch, return_tensors='pt', padding=True, truncation=True, max_length=training_args.max_length).to(device)
            output = model(**tokens, i_task=i_task,  label=label)
            loss = output.loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            # log per step
            step_result = {'i_epoch':i_epoch, 'i_iter':i_iter, 'i_task':i_task, 'task_name':selected_tasks[i_task], 'train_loss':loss.item(),}
            df_loss_per_step = df_loss_per_step.append(step_result , ignore_index=True)
            progress_bar.update(1)
        
        # run evaluation on train and validation set 
        train_loss, train_overall_acc, train_task_accs = validate(model, training_args, split='train')
        val_loss, val_overall_acc, val_task_accs = validate(model, training_args, split='dev')
        
        # save best model and corresponding opt and scheduler states to disk
        if training_args.save_best and val_overall_acc.item() > best_accuracy: 
            torch.save(model.state_dict(), f"{model_folder}/pytorch_model.bin")
            torch.save(optimizer.state_dict(), f"{model_folder}/optimizer.pt")
            torch.save(scheduler.state_dict(), f"{model_folder}/scheduler.pt")
            
        # collect result
        epoch_result = {'i_epoch':i_epoch, f'train_loss':sum(train_loss.values()), 'train_acc':train_overall_acc.item(), 'val_loss':sum(val_loss.values()), 'val_acc':val_overall_acc.item()}
        epoch_result.update({f'train_loss_{selected_tasks[i]}':train_loss[i] for i in train_loss})
        epoch_result.update({f'train_acc_{selected_tasks[i]}':train_task_accs[i].item() for i in range(len(train_task_accs))})   
        epoch_result.update({f'val_loss_{selected_tasks[i]}':val_loss[i] for i in val_loss})
        epoch_result.update({f'val_acc_{selected_tasks[i]}':val_task_accs[i].item() for i in range(len(val_task_accs))})   
        df_evaluation = df_evaluation.append(epoch_result , ignore_index=True)
#         print('\n'.join([f"{k}:{v:.4}" if isinstance(v, float) else f"{k}:{v}" for k,v in result.items()]))
    
    # save to disk
    if training_args.save_best:
        with open("training_args.json", "w") as outfile:
            json.dump(dataclasses.asdict(traing_args), outfile)
        config.to_json_file(f"{model_folder}/config.json")
        df_evaluation.to_csv(f"{model_folder}/evaluation.csv", index=False)
        df_loss_per_step.to_csv(f"{model_folder}/loss_per_step.csv", index=False)
    
    display(df_evaluation) 
#     display(df_loss_per_step) # this is too long, not approporate to show directly

    if training_args.save_best and training_args.load_best_at_end:
        model = load_best_model(model, model_folder)
    return df_evaluation, df_loss_per_step, model


## bertology

In [None]:
def entropy(p):
    """ Compute the entropy of a probability distribution """
    plogp = p * torch.log(p)
    plogp[p == 0] = 0
    return -plogp.sum(dim=-1)

In [None]:
def compute_heads_importance(
    model, eval_dataloader, training_args, compute_entropy=True, compute_importance=True, head_mask=None, 
    dont_normalize_importance_by_layer = True, dont_normalize_global_importance=True
):
    """ This method shows how to compute:
        - head attention entropy
        - head importance scores according to http://arxiv.org/abs/1905.10650
    """
    model_folder = training_args.model_folder
    
    # Prepare our tensors
    n_layers, n_heads = model.basemodel.config.num_hidden_layers, model.basemodel.config.num_attention_heads
    head_importance = torch.zeros(n_layers, n_heads).to(device)
    attn_entropy = torch.zeros(n_layers, n_heads).to(device)

    if head_mask is None:
        head_mask = torch.ones(n_layers, n_heads).to(device)
    head_mask.requires_grad_(requires_grad=True)
    preds = None
    labels = None
    tot_tokens = 0.0

    for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration")):
        i_task, batch = batch
        label_ids = batch['label'].to(device)
        size = len(label_ids)
        del batch['label']
        batch = model.tokenizer(**batch, return_tensors='pt', padding=True, truncation=True, max_length=64).to(device)
        input_ids, input_mask, segment_ids = batch['input_ids'], batch['attention_mask'], batch['token_type_ids']
        
        # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
        outputs = model(i_task=i_task,
            input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, label=label_ids, head_mask=head_mask, 
            output_attentions = True, 
        )
        loss, logits, all_attentions = (
            outputs.loss,
            outputs.logits,
            outputs.attentions,
        )  # Loss and logits are the first, attention the last
        loss.backward()  # Backpropagate to populate the gradients in the head mask

        if compute_entropy:
            for layer, attn in enumerate(all_attentions):
                masked_entropy = entropy(attn.detach()) * input_mask.float().unsqueeze(1)
                attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()

        if compute_importance:
            head_importance += head_mask.grad.abs().detach()

        # Also store our logits/labels if we want to compute metrics afterwards
        if preds is None:
            preds = logits.detach().cpu().numpy()
            labels = label_ids.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            labels = np.append(labels, label_ids.detach().cpu().numpy(), axis=0)

        tot_tokens += input_mask.float().detach().sum().data

    # Normalize
    attn_entropy /= tot_tokens
    head_importance /= tot_tokens
    # Layerwise importance normalization
    if not dont_normalize_importance_by_layer:
        exponent = 2
        norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
        head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20

    if not dont_normalize_global_importance:
        head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())

    # save matrices
    np.save(os.path.join(model_folder, "attn_entropy.npy"), attn_entropy.detach().cpu().numpy())
    np.save(os.path.join(model_folder, "head_importance.npy"), head_importance.detach().cpu().numpy())

    head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=device)
    head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
        head_importance.numel(), device=device
    )
    head_ranks = head_ranks.view_as(head_importance)
    
    return attn_entropy, head_importance, preds, labels

In [None]:
def imshow(torch_mat):
    plt.imshow(torch_mat.detach().cpu().numpy())
    plt.show()

In [None]:
# training_args = TrainingArgs(selected_tasks=['VUA'],
#                              base_model_name='bert-base-uncased',
#                              freeze_bert=True,
#                              use_pooler=True,
#                              num_epoch=5,
#                              data_limit=30000,
#                             )

# model = init_model(training_args)
# freeze_model(model, training_args.freeze_bert)
# df_evaluation, df_loss_per_step, model = train_model(model, training_args)

# eval_dataloader = MultiTaskTestDataLoader(training_args, split='dev')
# attn_entropy, head_importance, preds, labels = compute_heads_importance(model, eval_dataloader, training_args)

# imshow(attn_entropy)
# imshow(head_importance)