<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [1]:
import collections
import bisect
import pandas as pd
import os
import sys
import json
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch.optim import lr_scheduler
# import torchmetrics

import datasets
from datasets import load_metric
from transformers import AutoConfig, AutoTokenizer, BertModel, RobertaModel, BertPreTrainedModel, BertConfig
from transformers import BertForSequenceClassification, DataCollatorWithPadding
from transformers import TrainingArguments, Trainer

from sklearn.metrics import mean_squared_error, accuracy_score, precision_recall_fscore_support


In [2]:
# https://github.com/huggingface/transformers/issues/5486
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
with open('../data/xslue/tasks.json', 'r') as f:
    tasks = json.load(f)
tasks

{'CrowdFlower': 13,
 'DailyDialog': 7,
 'EmoBank_Valence': 1,
 'EmoBank_Arousal': 1,
 'EmoBank_Dominance': 1,
 'HateOffensive': 3,
 'PASTEL_age': 8,
 'PASTEL_country': 2,
 'PASTEL_education': 10,
 'PASTEL_ethnic': 10,
 'PASTEL_gender': 3,
 'PASTEL_politics': 3,
 'PASTEL_tod': 5,
 'SARC': 2,
 'SarcasmGhosh': 2,
 'SentiTreeBank': 1,
 'ShortHumor': 2,
 'ShortJokeKaggle': 2,
 'ShortRomance': 2,
 'StanfordPoliteness': 1,
 'TroFi': 2,
 'VUA': 2}

In [5]:
tasks = {'VUA': 2}

In [6]:
class MultiTasksDatasets(Dataset): 
    # currently it's a Mapping-style dataset. Not sure if a Iterable-style dataset will be better
    def __init__(self, tasks, split):
        self.tasks = tasks
        self.split = split
        self.data_folder = f'../data/xslue/processed/{self.split}'
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.dfs = []
        self.ranges = [0]
        for task in tasks:
            tsv_file = f'{self.data_folder}/{task}.tsv'
            df = pd.read_csv(tsv_file, sep='\t')
            df = df.dropna()
            df = df.reset_index(drop=True)
            if df['label'].dtype == 'float64':
                df['label'] = df['label'].astype('float32')
            self.ranges.append(len(df))
            self.dfs.append(df)
            
        self.ranges = np.cumsum(self.ranges).tolist()  
        
#         self.encodings = self.tokenizer(self.df['text'].tolist(), truncation=True, padding=True, max_length=64)
            
    def __len__(self):
        return self.ranges[-1]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        i_task = bisect.bisect_right(self.ranges, idx) - 1 # select task/df
        idx_in_df = idx - self.ranges[i_task]
        
        item = {k: torch.tensor(v) for k, v in self.tokenizer(self.dfs[i_task]["text"].iloc[idx_in_df], truncation=True, padding=True, max_length=64).items()}
#         item = {}
        item["i_task"] = torch.LongTensor([i_task])
        item["labels"] = torch.LongTensor([self.dfs[i_task]["label"].iloc[idx_in_df]])
#         item["text"] = self.dfs[i_task]["text"].iloc[idx_in_df]
        return item


In [7]:
class MultiTasksDataCollator(DataCollatorWithPadding):
    def __init__(self, tokenizer, **kwargs):
        super().__init__(tokenizer, **kwargs)
        
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )
        if "label" in batch:
            batch["labels"] = batch["label"]
            del batch["label"]
        if "label_ids" in batch:
            batch["labels"] = batch["label_ids"]
            del batch["label_ids"]
        return batch    

In [8]:
class RegressionHead(nn.Module):
    def __init__(self, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.dropout = nn.Dropout(0.1)
        self.hidden = nn.Linear(embedding_dim, 1)
        
        self.loss_fn = nn.MSELoss()
    def forward(self, sent_emb, label):
        output = self.dropout(self.hidden(sent_emb)).squeeze(1)

        loss = self.loss_fn(output, label)
        
        return output, loss

class ClassificationHead(nn.Module):
    def __init__(self, num_labels, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.num_labels = num_labels
        self.dropout = nn.Dropout(0.1)
        self.hidden = nn.Linear(embedding_dim, self.num_labels)
        
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, sent_emb, label):
        output = self.dropout(self.hidden(sent_emb))
        
        loss = self.loss_fn(output.view(-1, self.num_labels), label.view(-1))
        
        return output, loss

In [9]:
class MultiTaskBert(BertPreTrainedModel):
    def __init__(self, config, selected_tasks):
        super().__init__(config)
        self.num_tasks = len(selected_tasks)
        self.basemodel = BertModel(config)
        self.style_heads = nn.ModuleList()
        for task in selected_tasks:
            if tasks[task] == 1:
                self.style_heads.append(RegressionHead())
            else:
                self.style_heads.append(ClassificationHead(tasks[task]))
    def forward(self, input_ids, token_type_ids, attention_mask, i_task=None, labels=None, return_loss=True):
        output = self.basemodel(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        sent_emb = output['pooler_output']
        
        return self.style_heads[0](sent_emb, labels)
        
        
        total_loss = torch.tensor(0.).to(device)
        total_pred = []
        for j_task in range(self.num_tasks):
            pred, loss = self.style_heads[j_task](sent_emb[i_task.view(-1)==j_task], labels[i_task.view(-1)==j_task])
            total_loss += loss
            total_pred.append(pred.detach().to('cpu'))
        if return_loss:
            return total_loss
        return total_pred, total_loss

In [10]:
class MyTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)    

    def compute_loss(self, model, inputs, return_outputs=False):
#         labels = inputs.pop("labels")
        outputs = model(**inputs)
        labels = inputs.pop("labels")
        
        logits = outputs[0]
        loss_fn = nn.CrossEntropyLoss()
        return (loss_fn(logits, labels.view(-1)), outputs) if return_outputs else loss_fn(logits, labels.view(-1))

In [12]:
freeze_bert = True
batch_size = 32

train_dataset = MultiTasksDatasets(tasks, 'train')
val_dataset = MultiTasksDatasets(tasks, 'dev')

bertconfig = BertConfig.from_pretrained("bert-base-uncased")
multitaskbert = MultiTaskBert(bertconfig, tasks)

if freeze_bert:
    for param in multitaskbert.basemodel.parameters():
        param.requires_grad = False
        
result_folder = '../result'
training_args = TrainingArguments(
    output_dir=f"{result_folder}/multitask/{'baseline'+'_freezed' if freeze_bert else 'baseline'}",   # output directory
    num_train_epochs=5,              # total number of training epochs
    per_device_train_batch_size=batch_size,  # batch size per device during training
    per_device_eval_batch_size=batch_size,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir=f"{result_folder}/multitask/{'baseline'+'_freezed' if freeze_bert else 'baseline'}/logs",  # directory for storing logs
    logging_first_step = True, 
#         logging_steps=500,               # log & save weights each logging_steps
#         save_steps=500,
    evaluation_strategy="epoch",     # evaluate each `logging_steps`
    save_total_limit = 1,
    save_strategy = 'epoch',
    load_best_model_at_end=True, # decide on loss
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

trainer = MyTrainer(
    model=multitaskbert,   # the instantiated Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased"), 
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,          # evaluation dataset
#     compute_metrics=compute_metrics,     # the callback that computes metrics of interest
)
trainer.train()

loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /home/joey/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab11005bcd270f3c34464dc1704b715b5d7d52b1a461abe3b9e4e
Model config BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.18.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

loading file https://huggingface.co/bert-base-uncased/resol

loading file https://huggingface.co/bert-base-uncased/resolve/main/tokenizer.json from cache at /home/joey/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4
loading file https://huggingface.co/bert-base-uncased/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/bert-base-uncased/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/bert-base-uncased/resolve/main/tokenizer_config.json from cache at /home/joey/.cache/huggingface/transformers/c1d7f0a763fb63861cc08553866f1fc3e5a6f4f07621be277452d26d71303b7e.20430bd8e10ef77a7d2977accefe796051e01bc2fc4aa146bc862997a1a15e79
loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /home/joey/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab11005bcd270f3

Epoch,Training Loss,Validation Loss
1,0.7003,0.581456
2,0.6215,0.57872
3,0.5984,0.577645
4,0.598,0.576772
5,0.5947,0.576694


***** Running Evaluation *****
  Num examples = 1638
  Batch size = 32
Saving model checkpoint to ../result/multitask/baseline_freezed/checkpoint-474
Configuration saved in ../result/multitask/baseline_freezed/checkpoint-474/config.json
Model weights saved in ../result/multitask/baseline_freezed/checkpoint-474/pytorch_model.bin
tokenizer config file saved in ../result/multitask/baseline_freezed/checkpoint-474/tokenizer_config.json
Special tokens file saved in ../result/multitask/baseline_freezed/checkpoint-474/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 1638
  Batch size = 32
Saving model checkpoint to ../result/multitask/baseline_freezed/checkpoint-948
Configuration saved in ../result/multitask/baseline_freezed/checkpoint-948/config.json
Model weights saved in ../result/multitask/baseline_freezed/checkpoint-948/pytorch_model.bin
tokenizer config file saved in ../result/multitask/baseline_freezed/checkpoint-948/tokenizer_config.json
Special tokens file saved

TrainOutput(global_step=2370, training_loss=0.6023955406518928, metrics={'train_runtime': 256.2481, 'train_samples_per_second': 295.749, 'train_steps_per_second': 9.249, 'total_flos': 2458619673482760.0, 'train_loss': 0.6023955406518928, 'epoch': 5.0})