<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#functions" data-toc-modified-id="functions-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>functions</a></span><ul class="toc-item"><li><span><a href="#multi-task-dataset" data-toc-modified-id="multi-task-dataset-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>multi-task dataset</a></span></li><li><span><a href="#multi-task-model" data-toc-modified-id="multi-task-model-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>multi-task model</a></span></li><li><span><a href="#trainer" data-toc-modified-id="trainer-1.3"><span class="toc-item-num">1.3&nbsp;&nbsp;</span>trainer</a></span></li></ul></li></ul></div>

# functions

In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
import sys
import os
import collections
import json
from ast import literal_eval
from dataclasses import dataclass, asdict
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm, trange

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch.optim import lr_scheduler
import torchmetrics
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from transformers import *
from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput

import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
result_folder = os.environ["scratch_result_folder"] if "scratch_result_folder" in os.environ else '../result'
scratch_data_folder = os.environ["scratch_data_folder"] if "scratch_data_folder" in os.environ else None
repo_folder = os.environ["style_models_repo_folder"] if "style_models_repo_folder" in os.environ else None
data_folder = f"{repo_folder}/data" if repo_folder else '../../data'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# https://github.com/huggingface/transformers/issues/5486
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [None]:
# Dictionary: task_name: number_of_labels
with open(f'{data_folder}/pastel/pastel_tasks2labels.json', 'r') as f:
    tasks2labels = json.load(f)
# Dictionary: task_name: task index
tasks2idx = {k:i for i,k in enumerate(tasks2labels)}

In [None]:
@dataclass
class MyTrainingArgs:
    # training args
    selected_tasks: List
    base_model_name: str 
    freeze_bert: bool
    use_pooler: bool
    num_epoch: int
    lr: float = 5e-5
    num_warmup_steps = 500
    model_folder: str = None # if None, this will be inferred based on tasks
    model_name: str = None # if provide, use to name model_folder, otherwise use style to name model_folder
        
    # data loader args
    batch_size: int = 32
    max_length: int = 64
    shuffle: bool = False
    num_workers: int = 4
    data_limit: int = None # if not None, truncate dataset to keep only top {data_limit} rows
    
    # post training args
    save_best: bool = True
    load_best_at_end: bool = True
    
    def __post_init__(self):
        excute_time = datetime.now() 
        model_name = self.model_name if self.model_name else '+'.join(self.selected_tasks)
        model_folder = f"{result_folder}/{model_name}/{excute_time.now().strftime('%Y%m%d-%H:%M:%S')}"
        self.model_folder = model_folder

## multi-task dataset

In [None]:
class MyDataset(Dataset): 
    # currently it's a Mapping-style dataset. Not sure if a Iterable-style dataset will be better
    # this works for standard class indices and also class probilities
    # limit: use to truncate dataset. This will drop rows after certain index. May influence label distribution.
    def __init__(self, training_args, split, label_prefix = None):
        self.tasks = training_args.selected_tasks
        self.max_length = training_args.max_length
        self.split = split
        self.label_prefix = label_prefix
        self.tokenizer = AutoTokenizer.from_pretrained(training_args.base_model_name)
        self.df = pd.read_csv(f"{data_folder}/pastel/processed/{self.split}/{self.tasks[0] if len(self.tasks)==1 else 'pastel'}.csv")
        self.df = self.df.dropna()
        self.df = self.df.reset_index(drop=True)
        
        # for distill model, logits that written to files need eval to be correctly recognized
        # also apply softmax on logits
        for task in self.tasks:
            if self.label_prefix is not None:
                task = self.label_prefix + task
            if isinstance(self.df[task][0], str):
                self.df[task] = torch.tensor(self.df[task].apply(literal_eval)).softmax(dim=1).numpy().tolist()

        if training_args.data_limit:
            self.df = self.df.iloc[:training_args.data_limit]
            
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        dataslice = self.df.iloc[idx]
        item = {k: v for k, v in self.tokenizer(dataslice['output.sentences'], truncation=True, padding=True, max_length=self.max_length).items()}
        item.update({task: dataslice[task] if self.label_prefix is None else dataslice[self.label_prefix+task] for task in self.tasks}) 
        return item


## multi-task model

Given selected tasks, the model will add corresponding classification heads on the top of pretrained bert/(other bert). 

In [None]:
class RegressionHead(nn.Module):
    def __init__(self, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.dropout = nn.Dropout(0.1)
        self.hidden = nn.Linear(embedding_dim, 1)
        
        self.loss_fn = nn.MSELoss()
    def forward(self, sent_emb, label):
        batchsize = sent_emb.shape[0]
        output = self.hidden(self.dropout(sent_emb)).squeeze(1)

        loss = self.loss_fn(output, label.view(batchsize, -1).squeeze(-1))
        return output, loss

In [None]:
class ClassificationHead(nn.Module):
    def __init__(self, num_labels, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.num_labels = num_labels
        self.dropout = nn.Dropout(0.1)
        self.hidden = nn.Linear(embedding_dim, self.num_labels)
        
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, sent_emb, label):
        batchsize = sent_emb.shape[0]
        output = self.hidden(self.dropout(sent_emb))
        
        loss = self.loss_fn(output.view(-1, self.num_labels), label.view(batchsize, -1).squeeze(-1))
        return output, loss

In [None]:
@dataclass
class MultiTaskOutput(ModelOutput):
    loss: torch.FloatTensor = None
    sent_emb: torch.FloatTensor = None
    all_logits: Optional[Dict[str, torch.FloatTensor]] = None
    bert_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    bert_attentions: Optional[Tuple[torch.FloatTensor]] = None

In [None]:
class MultiTaskBert(PreTrainedModel):
    def __init__(self, config, training_args):
        super().__init__(config)
#         self.training_args = training_args
        self.tasks = training_args.selected_tasks
        self.use_pooler = training_args.use_pooler
        self.basemodel = AutoModel.from_pretrained(training_args.base_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(training_args.base_model_name)
        self.style_heads = nn.ModuleList()
        
        for task in self.tasks:
            if tasks2labels[task] == 1:
                self.style_heads.append(RegressionHead())
            else:
                self.style_heads.append(ClassificationHead(tasks2labels[task]))
                
    def forward(self, input_ids, token_type_ids, attention_mask, return_logits=False, return_sent_emb=True, **kwargs):
        output = self.basemodel(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        if self.use_pooler and ('pooler_output' in output):
            sent_emb = output['pooler_output']
        else:
            sent_emb = output['last_hidden_state'][:,0,:]
        
        total_loss = None
        all_logits = None
        if return_logits:
            all_logits = {}
        all_logits = {}
        for task in kwargs:
            i_task = self.tasks.index(task)
            logits, loss = self.style_heads[i_task](sent_emb, kwargs[task]) 
            if total_loss is None:
                total_loss = loss
            else:
                total_loss += loss
            if return_logits:
                all_logits[task] = logits.detach()
        return MultiTaskOutput(loss=total_loss, sent_emb=sent_emb, all_logits=all_logits, bert_hidden_states=output.hidden_states, bert_attentions=output.attentions)
    
    

In [None]:
def init_model(training_args):
    config = AutoConfig.from_pretrained(training_args.base_model_name) 
    model = MultiTaskBert(config, training_args).to(device)
    return model

In [None]:
def freeze_model(model, freeze_bert):
    '''
    if freeze_bert == True, freeze all layer. 
    if freeze_bert is a positive integer, freeze the bottom {freeze_bert} attention layers
    negative integer should also work
    '''
    if freeze_bert==True:
        for param in model.basemodel.parameters():
            param.requires_grad = False
    elif isinstance(freeze_bert, int):
        for layer in model.basemodel.encoder.layer[:freeze_bert]: 
            for param in layer.parameters():
                param.requires_grad = False  

## trainer

In [None]:
def nested_detach(tensors):
    "Detach `tensors` (even if it's a nested list/tuple of tensors)."
    if isinstance(tensors, (list, tuple)):
        return type(tensors)(nested_detach(t) for t in tensors)
    if isinstance(tensors, dict):
        return {k:nested_detach(tensors[k]) for k in tensors}
    return tensors.detach()

In [None]:
def nested_to(tensors, device):
    if isinstance(tensors, (list, tuple)):
        return type(tensors)(nested_to(t, device) for t in tensors)
    if isinstance(tensors, dict):
        return {k: nested_to(tensors[k], device) for k in tensors}
    return tensors.to(device)

In [None]:
class MyTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)    

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        return (outputs.loss, outputs.all_logits) if return_outputs else outputs.loss
    
    def prediction_step(self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        inputs = nested_to(inputs, model.device)
        labels = {}
        for task in model.tasks:
            labels[task] = inputs[task]
        outputs = model(**inputs, return_logits=True)
        loss = outputs.loss.detach()
        
        if prediction_loss_only:
            return (loss, None, None)
        logits = nested_detach(outputs.all_logits)
        return (loss, logits, labels)    
            

In [None]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions 
    res = {}
    for task in labels:
        if tasks2labels[task] == 2:
            average = 'binary'
        else:
            average = 'macro'
        precision, recall, f1, _ = precision_recall_fscore_support(labels[task], preds[task].argmax(-1), average=average)
        acc = accuracy_score(labels[task], preds[task].argmax(-1))
        res.update({
            f'accuracy_{task}': acc,
            f'f1_{task}': f1,
            f'precision_{task}': precision,
            f'recall_{task}': recall
        })
    return res