In [3]:
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch.optim import lr_scheduler
import numpy as np
import collections
import json
from tqdm.auto import tqdm, trange
from transformers import AutoConfig, AutoTokenizer, BertModel, RobertaModel

In [4]:
# https://github.com/huggingface/transformers/issues/5486
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
with open('tasks.json', 'r') as f:
    tasks = json.load(f)

In [7]:
tasks

{'CrowdFlower': 13,
 'DailyDialog': 7,
 'EmoBank_Valence': 1,
 'EmoBank_Arousal': 1,
 'EmoBank_Dominance': 1,
 'HateOffensive': 3,
 'PASTEL_age': 8,
 'PASTEL_country': 2,
 'PASTEL_education': 10,
 'PASTEL_ethnic': 10,
 'PASTEL_gender': 3,
 'PASTEL_politics': 3,
 'PASTEL_tod': 5,
 'SARC': 2,
 'SarcasmGhosh': 2,
 'SentiTreeBank': 1,
 'ShortHumor': 2,
 'ShortJokeKaggle': 2,
 'ShortRomance': 2,
 'StanfordPoliteness': 1,
 'TroFi': 2,
 'VUA': 2}

# multi-task dataloader

In [8]:
class MyDataset(Dataset): 
    # currently it's a Mapping-style dataset. Not sure if a Iterable-style dataset will be better
    def __init__(self, tsv_file):
        self.df = pd.read_csv(tsv_file, sep='\t')
        self.df = self.df.dropna()
        self.df = self.df.reset_index(drop=True)
        if self.df['label'].dtype == 'float64':
            self.df['label'] = self.df['label'].astype('float32')
            
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        dataslice = self.df.iloc[idx]
        sample = {'text':dataslice['text'], 'label':dataslice['label']}
        return sample


In [9]:
class MultiTaskTrainDataLoader():
    '''
    Each time, a random integer selects a dataset and load a batch of data {text, label} from it. Return i_task and data
    
    Known issue: large dataset may have not iterate once, small datasets may have been iterated many times
    '''
    
    def __init__(self, tasks, batch_size, shuffle, num_workers):
        self.tasks = tasks
        self.split = 'train'
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        
        self.num_tasks = len(tasks)
        self.datasets = []
        self.dataloaders = []
        self.dataloaderiters = []
        self.len = 0
        for task in tasks:
            self.datasets.append(MyDataset('./processed/'+self.split+'/'+task+'.tsv'))
            self.dataloaders.append(DataLoader(self.datasets[-1], batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers)) 
            self.dataloaderiters.append(iter(self.dataloaders[-1]))
            self.len += len(self.dataloaders[-1])
    def __len__(self):   
        return self.len

    def __iter__(self):
        self.n = 0
        return self
    
    def __next__(self):
        i_task = np.random.randint(self.num_tasks)
        if self.n < self.len:
            self.n += 1
        else:
            raise StopIteration
            
        try: 
            dataloaderiter = self.dataloaderiters[i_task]
            batch = next(dataloaderiter)
        except StopIteration:
            self.dataloaderiters[i_task] = iter(self.dataloaders[i_task])
            dataloaderiter = self.dataloaderiters[i_task]
            batch = next(dataloaderiter)
        return i_task, batch

In [26]:
class MultiTaskTestDataLoader():
    '''
    For dev and test
    '''
    
    def __init__(self, tasks, split, batch_size, shuffle, num_workers):
        assert split in ['dev', 'test'], 'not implemented'
        self.tasks = tasks
        self.split = split
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        
        self.num_tasks = len(tasks)
        self.datasets = []
        self.dataloaders = []
        self.len = 0
        for task in tasks:
            self.datasets.append(MyDataset('./processed/'+self.split+'/'+task+'.tsv'))
            self.dataloaders.append(DataLoader(self.datasets[-1], batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers))
            self.len += len(self.dataloaders[-1])
        self.i_task = 0
    def __len__(self):   
        return self.len

    def __iter__(self):
        return self
    
    def __next__(self):
        if self.i_task < self.num_tasks:
            dataloader = self.dataloaders[self.i_task]
        else:
            self.i_task = 0
            raise StopIteration
        for batch in dataloader:
            return self.i_task, batch
        self.i_task += 1
        

# multi-task model

In [9]:
from transformers.models.bert.modeling_bert import BertPreTrainedModel
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel

In [10]:
class RegressionHead(nn.Module):
    def __init__(self, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.hidden1 = nn.Linear(embedding_dim, hidden_dim)
        self.hidden2 = nn.Linear(hidden_dim, 1)
        
        self.loss_fn = nn.MSELoss()
    def forward(self, sent_emb, label):
        output = self.hidden2(self.hidden1(sent_emb)).squeeze(1)
        
        loss = self.loss_fn(output, label)
        return output, loss

In [11]:
class ClassificationHead(nn.Module):
    def __init__(self, num_labels, embedding_dim = 768, hidden_dim = 128):
        super().__init__()
        self.hidden1 = nn.Linear(embedding_dim, hidden_dim)
        self.hidden2 = nn.Linear(hidden_dim, num_labels)
        self.activation = nn.Tanh()
        
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, sent_emb, label):
        output = self.activation(self.hidden2(self.hidden1(sent_emb)))
        
        loss = self.loss_fn(output, label)
        return output, loss

In [12]:
class MultiTaskBert(BertPreTrainedModel):
    def __init__(self, config, tasks, use_pooler=True):
        super().__init__(config)
        self.use_pooler = use_pooler
        self.basemodel = BertModel(config)
        self.style_heads = nn.ModuleList()
        for task in tasks:
            if tasks[task] == 1:
                self.style_heads.append(RegressionHead())
            else:
                self.style_heads.append(ClassificationHead(tasks[task]))
    def forward(self, i_task, input_ids, token_type_ids, attention_mask, label):
        output = self.basemodel(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        if self.use_pooler:
            sent_emb = output['pooler_output']
        else:
            sent_emb = output['last_hidden_state'][:,0,:]
        output, loss = self.style_heads[i_task](sent_emb, label)
        return output, loss

In [13]:
class MultiTaskRoberta(RobertaPreTrainedModel):
    def __init__(self, config, tasks, use_pooler=True):
        super().__init__(config)
        self.use_pooler = use_pooler
        self.basemodel = RobertaModel(config)
        self.style_heads = nn.ModuleList()
        for task in tasks:
            if tasks[task] == 1:
                self.style_heads.append(RegressionHead())
            else:
                self.style_heads.append(ClassificationHead(tasks[task]))
    def forward(self, i_task, input_ids, attention_mask, label):
        output = self.basemodel(input_ids=input_ids, attention_mask=attention_mask)
        if self.use_pooler:
            sent_emb = output['pooler_output']
        else:
            sent_emb = output['last_hidden_state'][:,0,:]
        head = self.style_heads[i_task]
        output, loss = head(sent_emb, label)
        return output, loss

# train

In [14]:
def validate(mt_val_dataloader):
    val_loss = collections.defaultdict(float)
    val_size = collections.defaultdict(int)
    mt_model.eval()
    for data in tqdm(mt_val_dataloader):  
        i_task, batch = data
        label = batch['label'].to(device)
        size = len(label)
        del batch['label']
        tokens = tokenizer(**batch, return_tensors='pt', padding=True, truncation=True, max_length=64).to(device)
        output, loss = mt_model(**tokens, i_task=i_task,  label=label)
        tokens = None
        output = None
        val_loss[i_task] += loss.detach().item()*size
        val_size[i_task] += size
    for i_task in val_loss:
        val_loss[i_task] /= val_size[i_task]
    mt_model.train()

    return val_loss


In [15]:
def print_loss(losses):
    for k in losses:
        print(f'{losses[k]:4.2f}', end=' ')
    print('')

In [16]:
# larger batch_size will definitely lead to memory issue
mt_dataloader = MultiTaskTrainDataLoader(tasks, batch_size = 16, shuffle = True, num_workers = 6)
mt_dev_dataloader = MultiTaskTestDataLoader(tasks, split='dev', batch_size = 16, shuffle = True, num_workers = 6)
mt_test_dataloader = MultiTaskTestDataLoader(tasks, split='test', batch_size = 16, shuffle = True, num_workers = 6)

Process Process-391:
Process Process-393:
Process Process-392:
Process Process-323:
Process Process-395:
Process Process-352:
Process Process-388:
Process Process-384:
Process Process-390:
Process Process-381:
Process Process-385:
Process Process-383:
Traceback (most recent call last):
Process Process-389:
Process Process-347:
Process Process-386:
Traceback (most recent call last):
Process Process-377:
Process Process-382:
Process Process-372:
Traceback (most recent call last):
Process Process-353:
Traceback (most recent call last):
Process Process-319:
Process Process-380:
Process Process-378:
Traceback (most recent call last):
Process Process-318:
Process Process-369:
Process Process-371:
Traceback (most recent call last):
Traceback (most recent call last):
Process Process-370:
Traceback (most recent call last):
Process Process-305:
Process Process-311:
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function(

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiproce

  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing

  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocess

  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/queues.py", line 201, in _finalize_join
    thread.join()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1053, in join
    self._wait_for_tstate_lock()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/mu

  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/queues.py", line 201, in _finalize_join
    thread.join()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/jz17d/an

  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1053, in join
    self._wait_for_tstate_lock()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/queues.py", line 201, in _finalize_join
    thread.join()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/queues.py", line 201, in _finalize_join
    thread.join()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1053, in join
    self._wait_for_tstate_lock()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/jz17d/anacond

  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1053, in join
    self._wait_for_tstate_lock()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1053, in join
    self._wait_for_tstate_lock()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1053, in join
    self._wait_for_tstate_lock()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/queues.py", line 201, in _finalize_join
    thread.join()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1069, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/multiprocessing/queues.py", line 201, in _finalize_join
    thread.join()
KeyboardInterrupt
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1069, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py"

  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1069, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1069, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
KeyboardInterrupt
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1069, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1053, in join
    self._wait_for_tstate_lock()
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1069, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1069, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
KeyboardInterrupt
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/threading.py", line 1053, in join
    self._wait_for_tstate_lo

In [17]:
base_model = "bert-base-uncased"
# base_model = 'roberta-base'

config = AutoConfig.from_pretrained(base_model)
tokenizer = AutoTokenizer.from_pretrained(base_model)

mt_model = MultiTaskBert(config, tasks).to(device)
# mt_model = MultiTaskRoberta(config, tasks).to(device)


In [18]:
optimizer = optim.AdamW(mt_model.parameters(), lr=0.03)
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.05, total_steps=len(mt_dataloader)) 

In [19]:
losses = collections.defaultdict(list)
df_dev = pd.DataFrame(columns=np.arange(0,len(tasks)))
for i_iter, data in enumerate(tqdm(mt_dataloader)):    
    i_task, batch = data
    optimizer.zero_grad()
    label = batch['label'].to(device)
    del batch['label']
    tokens = tokenizer(**batch, return_tensors='pt', padding=True, truncation=True, max_length=64).to(device)
    output, loss = mt_model(**tokens, i_task=i_task,  label=label)
    loss.backward()
    optimizer.step()
    scheduler.step()
    losses[i_task].append(loss.detach().item())
    tokens = None
    output = None
    
#     if i_iter%500 == 0 and i_iter != 0:
#         dev_loss = validate(mt_dev_dataloader)
#         df_dev = df_dev.append(dev_loss , ignore_index=True)
#         print(f'#####training iter {i_iter}/{len(mt_dataloader)}')
#         print_loss(dev_loss)

  0%|          | 0/85054 [00:00<?, ?it/s]

  0%|          | 0/7688 [00:00<?, ?it/s]

#####training iter 500/85054
2.37 1.05 0.05 0.01 0.01 0.70 1.73 0.17 1.98 1.29 1.02 1.07 1.58 1.13 0.69 0.08 1.12 1.13 1.13 0.02 0.69 0.67 


  0%|          | 0/7688 [00:00<?, ?it/s]

IndexError: list index out of range

In [None]:
losses

In [17]:
tasks

{'CrowdFlower': 13,
 'DailyDialog': 7,
 'EmoBank_Valence': 1,
 'EmoBank_Arousal': 1,
 'EmoBank_Dominance': 1,
 'HateOffensive': 3,
 'PASTEL_age': 8,
 'PASTEL_country': 2,
 'PASTEL_education': 10,
 'PASTEL_ethnic': 10,
 'PASTEL_gender': 3,
 'PASTEL_politics': 3,
 'PASTEL_tod': 5,
 'SARC': 2,
 'SarcasmGhosh': 2,
 'SentiTreeBank': 1,
 'ShortHumor': 2,
 'ShortJokeKaggle': 2,
 'ShortRomance': 2,
 'StanfordPoliteness': 1,
 'TroFi': 2,
 'VUA': 2}

In [18]:
tasks2 = {'CrowdFlower': 13}

In [27]:
mt_dev_dataloader = MultiTaskTestDataLoader(tasks2, split='dev', batch_size = 16, shuffle = True, num_workers = 6)


In [28]:
for batch in tqdm(mt_dev_dataloader):
    pass

  0%|          | 0/123 [00:00<?, ?it/s]

Exception ignored in: <function _releaseLock at 0x7f15003e65e0>
Traceback (most recent call last):
  File "/home/jz17d/anaconda3/envs/torch/lib/python3.9/logging/__init__.py", line 227, in _releaseLock
    def _releaseLock():
KeyboardInterrupt: 


RuntimeError: DataLoader worker (pid(s) 2088407) exited unexpectedly

In [25]:
batch

<generator object MultiTaskTestDataLoader.__next__ at 0x7f13ccf42e40>

In [None]:
test_loss = validate(mt_test_dataloader)

In [None]:
PATH = './mt_model_runs/mt_2.bin'
torch.save(mt_model.state_dict(), PATH)

In [None]:
test_loss