In [1]:
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch.optim import lr_scheduler
# from torch.utils.data.sampler import RandomSampler
import numpy as np
import collections
import json
from tqdm.auto import tqdm, trange
from transformers import AutoConfig, AutoTokenizer, BertModel, RobertaModel

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
with open('tasks.json', 'r') as f:
    tasks = json.load(f)

In [4]:
tasks

{'CrowdFlower': 13,
 'DailyDialog': 7,
 'EmoBank_Valence': 1,
 'EmoBank_Arousal': 1,
 'EmoBank_Dominance': 1,
 'HateOffensive': 3,
 'PASTEL_age': 8,
 'PASTEL_country': 2,
 'PASTEL_education': 10,
 'PASTEL_ethnic': 10,
 'PASTEL_gender': 3,
 'PASTEL_politics': 3,
 'PASTEL_tod': 5,
 'SARC': 2,
 'SarcasmGhosh': 2,
 'SentiTreeBank': 1,
 'ShortHumor': 2,
 'ShortJokeKaggle': 2,
 'ShortRomance': 2,
 'StanfordPoliteness': 1,
 'TroFi': 2,
 'VUA': 2}

# multi-task dataloader

In [5]:
class MyDataset(Dataset): 
    # currently it's a Mapping-style dataset. Not sure if a Iterable-style dataset will be better
    def __init__(self, tsv_file):
        self.df = pd.read_csv(tsv_file, sep='\t')
        if self.df['label'].dtype == 'float64':
            self.df['label'] = self.df['label'].astype('float32')
            
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        dataslice = self.df.iloc[idx]
        sample = {'text':dataslice['text'], 'label':dataslice['label']}
        return sample


In [6]:
class MultiTaskTrainDataLoader():
    '''
    Each time, a random integer selects a dataset and load a batch of data {text, label} from it. Return i_task and data
    
    Known issue: large dataset may have not iterate once, small datasets may have been iterated many times
    '''
    
    def __init__(self, tasks, batch_size, shuffle, num_workers):
        self.tasks = tasks
        self.split = 'train'
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        
        self.num_tasks = len(tasks)
        self.datasets = []
        self.dataloaders = []
        self.dataloaderiters = []
        self.len = 0
        for task in tasks:
            self.datasets.append(MyDataset('./processed/'+self.split+'/'+task+'.tsv'))
            self.dataloaders.append(DataLoader(self.datasets[-1], batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers)) 
            self.dataloaderiters.append(iter(self.dataloaders[-1]))
            self.len += len(self.dataloaders[-1])
    def __len__(self):   
        return self.len

    def __iter__(self):
        self.n = 0
        return self
    
    def __next__(self):
        i_task = np.random.randint(self.num_tasks)
        if self.n < self.len:
            self.n += 1
        else:
            raise StopIteration
            
        try: 
            dataloaderiter = self.dataloaderiters[i_task]
            batch = next(dataloaderiter)
        except StopIteration:
            self.dataloaderiters[i_task] = iter(self.dataloaders[i_task])
            dataloaderiter = self.dataloaderiters[i_task]
            batch = next(dataloaderiter)
        return i_task, batch

In [7]:
class MultiTaskTestDataLoader():
    '''
    TODO
    '''
    
    def __init__(self, tasks, split, batch_size, shuffle, num_workers):
        assert split in ['dev', 'test'], 'not implemented'
        self.tasks = tasks
        self.split = split
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        
        self.num_tasks = len(tasks)
        self.datasets = []
        self.dataloaders = []
        self.dataloaderiters = []
        self.len = 0
        for task in tasks:
            self.datasets.append(MyDataset('./processed/'+self.split+'/'+task+'.tsv'))
            self.dataloaders.append(DataLoader(self.datasets[-1], batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers)) 
            self.dataloaderiters.append(iter(self.dataloaders[-1]))
            self.len += len(self.dataloaders[-1])
    def __len__(self):   
        return self.len

    def __iter__(self):
        self.n = 0
        return self
    
    def __next__(self):
        i_task = np.random.randint(self.num_tasks)
        if self.n < self.len:
            self.n += 1
        else:
            raise StopIteration
            
        try: 
            dataloaderiter = self.dataloaderiters[i_task]
            batch = next(dataloaderiter)
        except StopIteration:
            self.dataloaderiters[i_task] = iter(self.dataloaders[i_task])
            dataloaderiter = self.dataloaderiters[i_task]
            batch = next(dataloaderiter)
        return i_task, batch

# multi-task model

In [8]:
from transformers.models.bert.modeling_bert import BertPreTrainedModel
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel

In [9]:
class RegressionHead(nn.Module):
    def __init__(self, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.hidden1 = nn.Linear(embedding_dim, hidden_dim)
        self.hidden2 = nn.Linear(hidden_dim, 1)
        
        self.loss_fn = nn.MSELoss()
    def forward(self, sent_emb, label):
        output = self.hidden1(sent_emb)
        output = self.hidden2(output).squeeze(1)
        
        loss = self.loss_fn(output, label)
        return output, loss

In [10]:
class ClassificationHead(nn.Module):
    def __init__(self, num_labels, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.hidden1 = nn.Linear(embedding_dim, hidden_dim)
        self.hidden2 = nn.Linear(hidden_dim, num_labels)
        self.activation = nn.Tanh()
        
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, sent_emb, label):
        output = self.hidden1(sent_emb)
        output = self.hidden2(output)
        output = self.activation(output)
        
        loss = self.loss_fn(output, label)
        return output, loss

In [11]:
class MultiTaskBert(BertPreTrainedModel):
    def __init__(self, config, tasks, use_pooler=True):
        super().__init__(config)
        self.use_pooler = use_pooler
        self.basemodel = BertModel(config)
        self.style_heads = nn.ModuleList()
        for task in tasks:
            if tasks[task] == 1:
                self.style_heads.append(RegressionHead())
            else:
                self.style_heads.append(ClassificationHead(tasks[task]))
    def forward(self, i_task, input_ids, token_type_ids, attention_mask, label):
        output = self.basemodel(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        if self.use_pooler:
            sent_emb = output['pooler_output']
        else:
            sent_emb = output['last_hidden_state'][:,0,:]
        head = self.style_heads[i_task]
        output, loss = head(sent_emb, label)
        return output, loss

In [12]:
class MultiTaskRoberta(RobertaPreTrainedModel):
    def __init__(self, config, tasks, use_pooler=True):
        super().__init__(config)
        self.use_pooler = use_pooler
        self.basemodel = RobertaModel(config)
        self.style_heads = nn.ModuleList()
        for task in tasks:
            if tasks[task] == 1:
                self.style_heads.append(RegressionHead())
            else:
                self.style_heads.append(ClassificationHead(tasks[task]))
    def forward(self, i_task, input_ids, attention_mask, label):
        output = self.basemodel(input_ids=input_ids, attention_mask=attention_mask)
        if self.use_pooler:
            sent_emb = output['pooler_output']
        else:
            sent_emb = output['last_hidden_state'][:,0,:]
        head = self.style_heads[i_task]
        output, loss = head(sent_emb, label)
        return output, loss

# train

In [13]:
# batch_size 32 will lead to memory issue
mt_dataloader = MultiTaskTrainDataLoader(tasks, batch_size = 16, shuffle = True, num_workers = 6)

In [14]:
base_model = "bert-base-uncased"
# base_model = 'roberta-base'

config = AutoConfig.from_pretrained(base_model)
tokenizer = AutoTokenizer.from_pretrained(base_model)

mt_model = MultiTaskBert(config, tasks).to(device)
# mt_model = MultiTaskRoberta(config, tasks).to(device)


In [15]:
optimizer = optim.AdamW(mt_model.parameters(), lr=0.03)
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.05, total_steps=len(mt_dataloader)) 

In [16]:
def print_loss(losses):
    for k in losses:
        print(f'{losses[k][-1]:4.2f}', end=' ')
    print('')

In [None]:
losses = collections.defaultdict(list)
for i_iter, data in enumerate(tqdm(mt_dataloader)):    
    i_task, batch = data
    optimizer.zero_grad()
    label = batch['label'].to(device)
    del batch['label']
    tokens = tokenizer(**batch, return_tensors='pt', padding=True, truncation=True, max_length=256).to(device)
    output, loss = mt_model(**tokens, i_task=i_task,  label=label)
    loss.backward()
    optimizer.step()
    scheduler.step()
    losses[i_task].append(loss.item())
    if i_iter%1000 == 0 and i_iter != 0:
        print(f'##### iter {i_iter}/{len(mt_dataloader)}')
        print_loss(losses)

  0%|          | 0/85054 [00:00<?, ?it/s]

In [None]:
PATH = './mt_model_runs/mt_1.bin'
torch.save(mt_model.state_dict(), PATH)