In [None]:
from pcgrad import PCGrad

import sys
import os
import collections
import json
from itertools import cycle
from ast import literal_eval
from dataclasses import dataclass, asdict
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm, trange

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch import optim
from torch.optim import lr_scheduler
import torchmetrics
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from transformers import *
from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput

import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
os.environ['WANDB_NOTEBOOK_NAME'] = 'PASTEL all together with PCgrad.ipynb'
import wandb


# definitions

In [None]:
result_folder = os.environ["scratch_result_folder"] if "scratch_result_folder" in os.environ else '../result'
scratch_data_folder = os.environ["scratch_data_folder"] if "scratch_data_folder" in os.environ else None
repo_folder = os.environ["style_models_repo_folder"] if "style_models_repo_folder" in os.environ else None
data_folder = f"{repo_folder}/data" if repo_folder else '../../data'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# https://github.com/huggingface/transformers/issues/5486
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [None]:
# Dictionary: task_name: number_of_labels
with open(f'{data_folder}/pastel/pastel_tasks2labels.json', 'r') as f:
    tasks2labels = json.load(f)
# Dictionary: task_name: task index
tasks2idx = {k:i for i,k in enumerate(tasks2labels)}

In [None]:
@dataclass
class MyTrainingArgs:
    # training args
    selected_tasks: List
    base_model_name: str 
    freeze_bert: bool
    use_pooler: bool
    num_epoch: int
    lr: float = 5e-5
    num_warmup_steps = 500
    warmup_ratio = 0.1
    model_folder: str = None # if None, this will be inferred based on tasks
    model_name: str = None # if provide, use to name model_folder, otherwise use style to name model_folder
        
    # data loader args
    batch_size: int = 32
    max_length: int = 64
    shuffle: bool = False
    num_workers: int = 4
    data_limit: int = None # if not None, truncate dataset to keep only top {data_limit} rows
    
    # post training args
    save_best: bool = True
    load_best_at_end: bool = True
    
    def __post_init__(self):
        excute_time = datetime.now() 
        model_name = self.model_name if self.model_name else '+'.join(self.selected_tasks)
        model_folder = f"{result_folder}/{model_name}/{excute_time.now().strftime('%Y%m%d-%H:%M:%S')}"
        self.model_folder = model_folder

In [None]:
class MyDataset(Dataset): 
    # currently it's a Mapping-style dataset. Not sure if a Iterable-style dataset will be better
    # this works for standard class indices and also class probilities
    # limit: use to truncate dataset. This will drop rows after certain index. May influence label distribution.
    def __init__(self, training_args, split, label_prefix = None):
        self.tasks = training_args.selected_tasks
        self.max_length = training_args.max_length
        self.split = split
        self.label_prefix = label_prefix
        self.tokenizer = AutoTokenizer.from_pretrained(training_args.base_model_name)
        self.df = pd.read_csv(f"{data_folder}/pastel/processed/{self.split}/{self.tasks[0] if len(self.tasks)==1 else 'pastel'}.csv")
        self.df = self.df.dropna()
        self.df = self.df.reset_index(drop=True)
        
        # for distill model, logits that written to files need eval to be correctly recognized
        # also apply softmax on logits
        for task in self.tasks:
            if self.label_prefix is not None:
                task = self.label_prefix + task
            if isinstance(self.df[task][0], str):
                self.df[task] = torch.tensor(self.df[task].apply(literal_eval)).softmax(dim=1).numpy().tolist()

        if training_args.data_limit:
            self.df = self.df.iloc[:training_args.data_limit]
            
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        dataslice = self.df.iloc[idx]
#         item = {k: v for k, v in self.tokenizer(dataslice['output.sentences'], truncation=True, padding=True, max_length=self.max_length).items()}
        item = {'text':dataslice['output.sentences']}
        item.update({task: dataslice[task] if self.label_prefix is None else dataslice[self.label_prefix+task] for task in self.tasks}) 
        return item


In [None]:
class RegressionHead(nn.Module):
    def __init__(self, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.dropout = nn.Dropout(0.1)
        self.hidden = nn.Linear(embedding_dim, 1)
        
        self.loss_fn = nn.MSELoss()
    def forward(self, sent_emb, label):
        batchsize = sent_emb.shape[0]
        output = self.hidden(self.dropout(sent_emb)).squeeze(1)

        loss = self.loss_fn(output, label.view(batchsize, -1).squeeze(-1))
        return output, loss

In [None]:
class ClassificationHead(nn.Module):
    def __init__(self, num_labels, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.num_labels = num_labels
        self.dropout = nn.Dropout(0.1)
        self.hidden = nn.Linear(embedding_dim, self.num_labels)
        
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, sent_emb, label):
        batchsize = sent_emb.shape[0]
        output = self.hidden(self.dropout(sent_emb))
        
        loss = self.loss_fn(output.view(-1, self.num_labels), label.view(batchsize, -1).squeeze(-1))
        return output, loss

In [None]:
@dataclass
class MultiTaskOutput(ModelOutput):
    total_loss: torch.FloatTensor = None
    losses: List[torch.FloatTensor] = None
    sent_emb: torch.FloatTensor = None
    all_logits: Optional[Dict[str, torch.FloatTensor]] = None
    bert_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    bert_attentions: Optional[Tuple[torch.FloatTensor]] = None

In [None]:
class MultiTaskBert(PreTrainedModel):
    def __init__(self, config, training_args):
        super().__init__(config)
#         self.training_args = training_args
        self.tasks = training_args.selected_tasks
        self.use_pooler = training_args.use_pooler
        self.basemodel = AutoModel.from_pretrained(training_args.base_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(training_args.base_model_name)
        self.style_heads = nn.ModuleList()
        
        for task in self.tasks:
            if tasks2labels[task] == 1:
                self.style_heads.append(RegressionHead())
            else:
                self.style_heads.append(ClassificationHead(tasks2labels[task]))
                
    def forward(self, input_ids, token_type_ids, attention_mask, return_logits=False, return_sent_emb=True, **kwargs):
        output = self.basemodel(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        if self.use_pooler and ('pooler_output' in output):
            sent_emb = output['pooler_output']
        else:
            sent_emb = output['last_hidden_state'][:,0,:]
        
        total_loss = None
        losses = []
        all_logits = None
        if return_logits:
            all_logits = {}
        all_logits = {}
        for task in kwargs:
            i_task = self.tasks.index(task)
            logits, loss = self.style_heads[i_task](sent_emb, kwargs[task]) 
            losses.append(loss)
            if total_loss is None:
                total_loss = loss
            else:
                total_loss += loss
            if return_logits:
                all_logits[task] = logits.detach()
        return MultiTaskOutput(total_loss=total_loss, losses=losses, sent_emb=sent_emb, all_logits=all_logits, bert_hidden_states=output.hidden_states, bert_attentions=output.attentions)
    
    

In [None]:
def init_model(training_args):
    config = AutoConfig.from_pretrained(training_args.base_model_name) 
    model = MultiTaskBert(config, training_args).to(device)
    return model

In [None]:
def freeze_model(model, freeze_bert):
    '''
    if freeze_bert == True, freeze all layer. 
    if freeze_bert is a positive integer, freeze the bottom {freeze_bert} attention layers
    negative integer should also work
    '''
    if freeze_bert==True:
        for param in model.basemodel.parameters():
            param.requires_grad = False
    elif isinstance(freeze_bert, (int, np.int32, np.int64, torch.int32, torch.int64)):
        for param in model.basemodel.embeddings.parameters():
            param.requires_grad = False  
        for layer in model.basemodel.encoder.layer[:freeze_bert]: 
            for param in layer.parameters():
                param.requires_grad = False  
    return model

In [None]:
def load_best_model(model, model_folder):
    model.load_state_dict(torch.load(f"{model_folder}/pytorch_model.bin"))
    return model

In [None]:
def do_eval(model, cycle_valid_loader, num_valid_steps):
    model.eval()
    tasks = model.tasks
    task_f1s = {task: torchmetrics.F1Score(num_classes=tasks2labels[task], average='macro') for task in tasks}
    
    for i_step in trange(num_valid_steps, leave=False):
        batch = next(cycle_valid_loader)
        output = model(**batch, return_logits=True)
        logits = output['all_logits']

        for task in tasks:
            task_f1s[task].update(logits[task].detach().cpu().argmax(-1), batch[task].detach().cpu())
    evaluation = {'f1_'+task: task_f1s[task].compute().item() for task in tasks}
    evaluation.update({'f1_avg':np.mean(list(evaluation.values()))})
    return evaluation
    

In [None]:
def collate_fn(batch):
    batch_out = collections.defaultdict(list)
    for item in batch:
        for col in item:
            batch_out[col].append(item[col])
    for col in batch_out:
        if col != 'text':
            batch_out[col] = torch.tensor(batch_out[col], dtype=torch.int64).to(device)
            
    batch_out.update({k:v for k,v in tokenizer(text = batch_out['text'], return_tensors='pt', padding=True, truncation=True, max_length=64).to(device).items()})
    
    del batch_out['text']
    return batch_out
    

# training starts here

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [None]:
my_training_args = MyTrainingArgs(selected_tasks=list(tasks2labels.keys()),
                                  model_name=f'PASTEL all together with PCgrad',
                                  base_model_name='bert-base-uncased',
                                  freeze_bert=False,
                                  use_pooler=False,
                                  num_epoch=10,
                                  batch_size=64,
                                 )

In [None]:
num_epochs = 10
batch_size = 64
save_best_only = True # if not save the best, save the last

# some parameters to sweep
LR = [3e-5, 5e-5, 7e-5]
FREEZE_BERT = [9, 11]
WARMUP_RATIO = [0.1, 0.15, 0.2]

LR, FREEZE_BERT, WARMUP_RATIO = np.meshgrid(LR,FREEZE_BERT, WARMUP_RATIO)
LR, FREEZE_BERT, WARMUP_RATIO = LR.flatten(), FREEZE_BERT.flatten(), WARMUP_RATIO.flatten()
num_runs = len(LR)

# create dataset
train_data = MyDataset(my_training_args, 'train')
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
num_training_steps = len(train_loader)
cycle_train_loader = cycle(iter(train_loader))

valid_data = MyDataset(my_training_args, 'valid')
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
num_valid_steps = len(valid_loader)
cycle_valid_loader = cycle(iter(valid_loader))

# start runs
for i_run in range(num_runs):
    model_folder = f"{result_folder}/pastel_pcgrad/run_{i_run}"
    Path(model_folder).mkdir(parents=True, exist_ok=True)
    
    lr = LR[i_run]
    freeze_bert = FREEZE_BERT[i_run]
    warmup_ratio = WARMUP_RATIO[i_run]
    
    my_training_args.lr = lr
    my_training_args.freeze_bert = freeze_bert
    my_training_args.warmup_ratio = warmup_ratio
    
    # use wandb to track experiments
    wconfig = {}
    wconfig['lr'] = lr
    wconfig['freeze_bert'] = freeze_bert
    wconfig['warmup_ratio'] = warmup_ratio

    run = wandb.init(project="PASTEL all together with PCgrad", 
                     entity="fsu-dsc-cil", 
                     dir='/scratch/data_jz17d/wandb_tmp/', 
                     config=wconfig,
                     name=f'run {i_run}',
                     reinit=True)

    model = init_model(my_training_args)
    model = freeze_model(model, freeze_bert)
    
    wandb.watch(model, log="all", log_freq=1000, log_graph=True)

    optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad==True], lr=lr)
    optimizer = PCGrad(optimizer)

    scheduler = get_scheduler("linear",
                            optimizer=optimizer._optim,
                            num_warmup_steps=int(warmup_ratio*num_epochs*num_training_steps),
                            num_training_steps=num_epochs*num_training_steps)

    # start training and logging
    best_metric = 0.0
    df = pd.DataFrame(columns=['global_step'])
    pbar = trange(num_epochs*num_training_steps)
    for i_epoch in range(num_epochs):
        model.train()
        for i_step in range(num_training_steps):

            batch = next(cycle_train_loader)
            optimizer.zero_grad()
            output = model(**batch)
            losses = output['losses']
            optimizer.pc_backward(losses)
            optimizer.step()
            pbar.update(1)

        model.eval()
        evaluation = do_eval(model, cycle_valid_loader, num_valid_steps)
        wandb.log(evaluation, step=pbar.n)
        evaluation.update({'global_step':pbar.n})
        df = df.append(evaluation, ignore_index=True)
        
        # save best model
        if save_best_only and (best_metric < evaluation['f1_avg']):
            best_metric = evaluation['f1_avg']
            torch.save(model.state_dict(), f"{model_folder}/pytorch_model.bin")
            torch.save(optimizer._optim.state_dict(), f"{model_folder}/optimizer.pt")
            torch.save(scheduler.state_dict(), f"{model_folder}/scheduler.pt")
    
    # if not save best, save the last
    if not save_best_only:
        torch.save(model.state_dict(), f"{model_folder}/pytorch_model.bin")
        torch.save(optimizer._optim.state_dict(), f"{model_folder}/optimizer.pt")
        torch.save(scheduler.state_dict(), f"{model_folder}/scheduler.pt")
    with open(f"{model_folder}/training_args.json", "w") as outfile:
        json.dump(asdict(my_training_args), outfile)
    df.to_csv(f"{model_folder}/evaluation.csv", index=False)
    run.finish()


[34m[1mwandb[0m: Currently logged in as: [33mcpuyyp[0m ([33mfsu-dsc-cil[0m). Use [1m`wandb login --relogin`[0m to force relogin


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


  0%|          | 0/4970 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
evaluation

{'f1_country': 0.015550041571259499,
 'f1_politics': 0.2742529511451721,
 'f1_tod': 0.09492255747318268,
 'f1_age': 0.07916268706321716,
 'f1_education': 0.06269937008619308,
 'f1_ethnic': 0.01124968659132719,
 'f1_gender': 0.06537345051765442,
 'f1_avg': 0.08617296349257231,
 'global_step': 1988}