<a href="https://colab.research.google.com/github/napsternxg/ipython-notebooks/blob/master/MultiTask_Transformer_for_Token_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install torch transformers datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


Prompt

> Write the python code for a multi task Data Loader using HuggingFace datasets. It should suport token classification datasets. Remember to include the support for mapping token labels to their corresponding subword labels.


Bing

> Here’s a sample Python code for a multi-task data loader using HuggingFace datasets that supports token classification datasets and includes support for mapping token labels to their corresponding subword labels:

> You can use this class by initializing it with the desired datasets and tokenizer name. For example:

> This will load the `conll2003` dataset for the ner task and tokenize it using the `bert-base-cased` tokenizer. The resulting data will be a dictionary with keys corresponding to the tasks and values being the tokenized and aligned datasets. Is there anything else you would like to know?

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer
from functools import partial
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Sampler
import itertools
from transformers import DataCollatorForTokenClassification
from transformers import AutoModel
from torch import nn
from transformers.modeling_outputs import TokenClassifierOutput
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.optim import Adam
from tqdm.auto import tqdm
import torch

In [3]:
class MultiTaskDataSets:
    def __init__(self, datasets, tokenizer_name):
        self.datasets = datasets
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    def tokenize_and_align_labels(self, examples, task=None):
        tokenized_inputs = self.tokenizer(examples['tokens'], truncation=True, is_split_into_words=True)

        labels = []
        dataset_name, label_key = self.datasets[task]
        for i, label in enumerate(examples[label_key]):            
            word_ids = tokenized_inputs.word_ids(batch_index=i)
            # print(f"{i=}, {word_ids=}, {label}")
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:
                if word_idx is None:
                    label_ids.append(-100)
                elif word_idx != previous_word_idx:
                    label_ids.append(label[word_idx])
                else:
                    label_ids.append(-100)
                previous_word_idx = word_idx
            labels.append(label_ids)

        tokenized_inputs['labels'] = labels
        return tokenized_inputs

    def load(self):
        data = {}
        self.label_to_id = {}
        self.task_tokenize_and_align_labels_fns = {}
        for task in self.datasets:
            dataset_name, label_key = self.datasets[task]
            dataset = load_dataset(*dataset_name)
            first_split_key = list(dataset.keys())[0]
            label_list = dataset[first_split_key].features[label_key].feature.names
            self.label_to_id[task] = {l: i for i, l in enumerate(label_list)}
            print(self.label_to_id[task])
            self.task_tokenize_and_align_labels_fns[task] = partial(self.tokenize_and_align_labels, task=task)
            data[task] = dataset.map(self.task_tokenize_and_align_labels_fns[task], batched=True)
            data[task].set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
        return data




class RoundRobinSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
        self.task_indices = {task: list(range(len(data))) for task, data in data_source.items()}
        self.task_iterators = {task: iter(indices) for task, indices in self.task_indices.items()}

    def __iter__(self):
        while True:
            for task in self.task_indices:
                try:
                    yield next(self.task_iterators[task])
                except StopIteration:
                    self.task_iterators[task] = iter(self.task_indices[task])

    def __len__(self):
        return sum(len(indices) for indices in self.task_indices.values())


class MultiTaskTokenClassificationModel(nn.Module):
    def __init__(self, model_name, task_info):
        super().__init__()
        self.shared = AutoModel.from_pretrained(model_name, return_dict=True)
        self.num_labels = {
            task: info["num_labels"]
            for task, info in task_info.items()
        }
        self.classifiers = nn.ModuleDict({
            task: nn.Linear(self.shared.config.hidden_size, info["num_labels"])
            for task, info in task_info.items()
        })

    def forward(self, task, input_ids, attention_mask=None, labels=None, return_dict=True):
        outputs = self.shared(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        logits = self.classifiers[task](sequence_output)
        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels[task]), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

class RoundRobinIterator:
  def __init__(self, data_loaders):
    self.data_loaders = data_loaders

  def __iter__(self):
    iters = {task: iter(data_loader) for task, data_loader in self.data_loaders.items()}
    finished_tasks = set()
    while len(finished_tasks) < len(iters):
      for task, data_iter in iters.items():
        if task in finished_tasks:
          continue
        try:
          batch = next(data_iter)
          yield task, batch
        except StopIteration:
          finished_tasks.add(task)
          # iters[task] = iter(self.data_loaders[task])

  def __len__(self):
    return sum(len(data_loader) for data_loader in self.data_loaders.values())



    

In [4]:
datasets = {
    'conll2003': (('conll2003',), 'ner_tags'),
    'wnut_17': (('wnut_17',), 'ner_tags'),
    # 'twitter_pos_foster': (('strombergnlp/twitter_pos', "foster"), 'pos_tags'),
    'twitter_pos_ritter': (('strombergnlp/twitter_pos', "ritter"), 'pos_tags')
}
tokenizer_name = 'bert-base-cased'
multitask_datasets = MultiTaskDataSets(datasets, tokenizer_name)
data = multitask_datasets.load()
data



  0%|          | 0/3 [00:00<?, ?it/s]



{'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}


Map:   0%|          | 0/3250 [00:00<?, ? examples/s]



  0%|          | 0/3 [00:00<?, ?it/s]



{'O': 0, 'B-corporation': 1, 'I-corporation': 2, 'B-creative-work': 3, 'I-creative-work': 4, 'B-group': 5, 'I-group': 6, 'B-location': 7, 'I-location': 8, 'B-person': 9, 'I-person': 10, 'B-product': 11, 'I-product': 12}




  0%|          | 0/3 [00:00<?, ?it/s]



{'"': 0, "''": 1, '#': 2, '%': 3, '$': 4, '(': 5, ')': 6, ',': 7, '.': 8, ':': 9, '``': 10, 'CC': 11, 'CD': 12, 'DT': 13, 'EX': 14, 'FW': 15, 'IN': 16, 'JJ': 17, 'JJR': 18, 'JJS': 19, 'LS': 20, 'MD': 21, 'NN': 22, 'NNP': 23, 'NNPS': 24, 'NNS': 25, 'NN|SYM': 26, 'PDT': 27, 'POS': 28, 'PRP': 29, 'PRP$': 30, 'RB': 31, 'RBR': 32, 'RBS': 33, 'RP': 34, 'SYM': 35, 'TO': 36, 'UH': 37, 'VB': 38, 'VBD': 39, 'VBG': 40, 'VBN': 41, 'VBP': 42, 'VBZ': 43, 'WDT': 44, 'WP': 45, 'WP$': 46, 'WRB': 47, 'RT': 48, 'HT': 49, 'USR': 50, 'URL': 51}


{'conll2003': DatasetDict({
     train: Dataset({
         features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
         num_rows: 14041
     })
     validation: Dataset({
         features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
         num_rows: 3250
     })
     test: Dataset({
         features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
         num_rows: 3453
     })
 }),
 'wnut_17': DatasetDict({
     train: Dataset({
         features: ['id', 'tokens', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
         num_rows: 3394
     })
     validation: Dataset({
         features: ['id', 'tokens', 'ner_tags', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
         num_rows: 1009
     })
     test: Dataset({
         features: ['i

Prompt

> Can you write a data sampler function which can make batches from this dataset. Make sure each batch is homogenous i.e. each batch should only contain a single task examples. Also make sure that the batches from each task are extracted in a round robin fashion.


Bing
> Sure! Here’s a sample Python code for a data sampler that can create batches from the multi-task dataset generated by the MultiTaskDataLoader class. This sampler ensures that each batch is homogenous (i.e., contains examples from only a single task) and that batches from each task are extracted in a round-robin fashion:



In [5]:
batch_size = 4
collate_fn = DataCollatorForTokenClassification(multitask_datasets.tokenizer)

# sampler = RoundRobinSampler(data)
data_loaders = {task: DataLoader(dataset_dict['train'], batch_size=batch_size, collate_fn=collate_fn) for task, dataset_dict in data.items()}

In [6]:
for task, data_loader in data_loaders.items():
  for batch in data_loader:
    print(task)
    print({k: v.shape for k,v in batch.items()})
    break  


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


conll2003
{'input_ids': torch.Size([4, 34]), 'attention_mask': torch.Size([4, 34]), 'labels': torch.Size([4, 34])}
wnut_17
{'input_ids': torch.Size([4, 35]), 'attention_mask': torch.Size([4, 35]), 'labels': torch.Size([4, 35])}
twitter_pos_ritter
{'input_ids': torch.Size([4, 87]), 'attention_mask': torch.Size([4, 87]), 'labels': torch.Size([4, 87])}


In [7]:
task_info = {
    task: dict(num_labels=len(label2ids))
    for task, label2ids in multitask_datasets.label_to_id.items()
}
task_info

{'conll2003': {'num_labels': 9},
 'wnut_17': {'num_labels': 13},
 'twitter_pos_ritter': {'num_labels': 52}}

In [8]:
model_name = "distilbert-base-uncased"
multitask_model = MultiTaskTokenClassificationModel(model_name=model_name, task_info=task_info)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
optimizer = Adam(multitask_model.parameters(), lr=0.001)

In [10]:
num_epochs = 2
for epoch in range(num_epochs):
  running_loss = 0.0
  for task, data_loader in data_loaders.items():
    for batch in data_loader:
      optimizer.zero_grad()
      output = multitask_model(task=task, **batch)
      print(task)
      print({k: v.shape for k,v in batch.items()})
      print({k:v.shape for k,v in output.items()})
      loss = output.loss
      loss.backward()
      optimizer.step()
      running_loss += loss.item()
      break

conll2003
{'input_ids': torch.Size([4, 34]), 'attention_mask': torch.Size([4, 34]), 'labels': torch.Size([4, 34])}
{'loss': torch.Size([]), 'logits': torch.Size([4, 34, 9])}
wnut_17
{'input_ids': torch.Size([4, 35]), 'attention_mask': torch.Size([4, 35]), 'labels': torch.Size([4, 35])}
{'loss': torch.Size([]), 'logits': torch.Size([4, 35, 13])}
twitter_pos_ritter
{'input_ids': torch.Size([4, 87]), 'attention_mask': torch.Size([4, 87]), 'labels': torch.Size([4, 87])}
{'loss': torch.Size([]), 'logits': torch.Size([4, 87, 52])}
conll2003
{'input_ids': torch.Size([4, 34]), 'attention_mask': torch.Size([4, 34]), 'labels': torch.Size([4, 34])}
{'loss': torch.Size([]), 'logits': torch.Size([4, 34, 9])}
wnut_17
{'input_ids': torch.Size([4, 35]), 'attention_mask': torch.Size([4, 35]), 'labels': torch.Size([4, 35])}
{'loss': torch.Size([]), 'logits': torch.Size([4, 35, 13])}
twitter_pos_ritter
{'input_ids': torch.Size([4, 87]), 'attention_mask': torch.Size([4, 87]), 'labels': torch.Size([4, 87])

In [11]:
round_robin_iterator = RoundRobinIterator(data_loaders)
len(round_robin_iterator)

4498

In [12]:
num_epochs = 2
for epoch in range(num_epochs):
  for task, batch in round_robin_iterator:
      print(task)
      print({k: v.shape for k,v in batch.items()})
      print(batch["input_ids"][:, :10])
      # print({k:v.shape for k,v in output.items()})
      break

conll2003
{'input_ids': torch.Size([4, 34]), 'attention_mask': torch.Size([4, 34]), 'labels': torch.Size([4, 34])}
tensor([[  101,  7270, 22961,  1528,  1840,  1106, 21423,  1418,  2495, 12913],
        [  101,  1943, 14428,   102,     0,     0,     0,     0,     0,     0],
        [  101, 26660, 13329, 12649, 15928,  1820,   118,  4775,   118,  1659],
        [  101,  1109,  1735,  2827,  1163,  1113,  9170,  1122, 19786,  1114]])
conll2003
{'input_ids': torch.Size([4, 34]), 'attention_mask': torch.Size([4, 34]), 'labels': torch.Size([4, 34])}
tensor([[  101,  7270, 22961,  1528,  1840,  1106, 21423,  1418,  2495, 12913],
        [  101,  1943, 14428,   102,     0,     0,     0,     0,     0,     0],
        [  101, 26660, 13329, 12649, 15928,  1820,   118,  4775,   118,  1659],
        [  101,  1109,  1735,  2827,  1163,  1113,  9170,  1122, 19786,  1114]])


In [13]:
! nvidia-smi

Mon Apr 17 02:00:14 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P8     9W /  70W |      3MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [15]:
model_name = "Twitter/twhin-bert-base"
multitask_model = MultiTaskTokenClassificationModel(model_name=model_name, task_info=task_info)


datasets = {
    'conll2003': (('conll2003',), 'ner_tags'),
    'wnut_17': (('wnut_17',), 'ner_tags'),
    # 'twitter_pos_foster': (('strombergnlp/twitter_pos', "foster"), 'pos_tags'),
    'twitter_pos_ritter': (('strombergnlp/twitter_pos', "ritter"), 'pos_tags')
}
tokenizer_name = model_name
multitask_datasets = MultiTaskDataSets(datasets, tokenizer_name)

data = multitask_datasets.load()
data

Downloading (…)lve/main/config.json:   0%|          | 0.00/632 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

Some weights of the model checkpoint at Twitter/twhin-bert-base were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)okenizer_config.json:   0%|          | 0.00/372 [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]



  0%|          | 0/3 [00:00<?, ?it/s]

{'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}


Map:   0%|          | 0/14041 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/3250 [00:00<?, ? examples/s]

Map:   0%|          | 0/3453 [00:00<?, ? examples/s]



  0%|          | 0/3 [00:00<?, ?it/s]

{'O': 0, 'B-corporation': 1, 'I-corporation': 2, 'B-creative-work': 3, 'I-creative-work': 4, 'B-group': 5, 'I-group': 6, 'B-location': 7, 'I-location': 8, 'B-person': 9, 'I-person': 10, 'B-product': 11, 'I-product': 12}


Map:   0%|          | 0/3394 [00:00<?, ? examples/s]

Map:   0%|          | 0/1009 [00:00<?, ? examples/s]

Map:   0%|          | 0/1287 [00:00<?, ? examples/s]



  0%|          | 0/3 [00:00<?, ?it/s]

{'"': 0, "''": 1, '#': 2, '%': 3, '$': 4, '(': 5, ')': 6, ',': 7, '.': 8, ':': 9, '``': 10, 'CC': 11, 'CD': 12, 'DT': 13, 'EX': 14, 'FW': 15, 'IN': 16, 'JJ': 17, 'JJR': 18, 'JJS': 19, 'LS': 20, 'MD': 21, 'NN': 22, 'NNP': 23, 'NNPS': 24, 'NNS': 25, 'NN|SYM': 26, 'PDT': 27, 'POS': 28, 'PRP': 29, 'PRP$': 30, 'RB': 31, 'RBR': 32, 'RBS': 33, 'RP': 34, 'SYM': 35, 'TO': 36, 'UH': 37, 'VB': 38, 'VBD': 39, 'VBG': 40, 'VBN': 41, 'VBP': 42, 'VBZ': 43, 'WDT': 44, 'WP': 45, 'WP$': 46, 'WRB': 47, 'RT': 48, 'HT': 49, 'USR': 50, 'URL': 51}


Map:   0%|          | 0/118 [00:00<?, ? examples/s]

Map:   0%|          | 0/118 [00:00<?, ? examples/s]

Map:   0%|          | 0/551 [00:00<?, ? examples/s]

{'conll2003': DatasetDict({
     train: Dataset({
         features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'attention_mask', 'labels'],
         num_rows: 14041
     })
     validation: Dataset({
         features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'attention_mask', 'labels'],
         num_rows: 3250
     })
     test: Dataset({
         features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags', 'input_ids', 'attention_mask', 'labels'],
         num_rows: 3453
     })
 }),
 'wnut_17': DatasetDict({
     train: Dataset({
         features: ['id', 'tokens', 'ner_tags', 'input_ids', 'attention_mask', 'labels'],
         num_rows: 3394
     })
     validation: Dataset({
         features: ['id', 'tokens', 'ner_tags', 'input_ids', 'attention_mask', 'labels'],
         num_rows: 1009
     })
     test: Dataset({
         features: ['id', 'tokens', 'ner_tags', 'input_ids', 'attention_mask', 'labels'],
         num_rows: 128

In [16]:
multitask_model = multitask_model.to(device)
optimizer = Adam(multitask_model.parameters(), lr=0.001)

In [17]:
batch_size = 32
# sampler = RoundRobinSampler(data)
collate_fn = DataCollatorForTokenClassification(multitask_datasets.tokenizer)
data_loaders = {
    task: DataLoader(dataset_dict['train'], batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
    for task, dataset_dict in data.items()
}
round_robin_iterator = RoundRobinIterator(data_loaders)
len(round_robin_iterator)

564

In [18]:
num_epochs = 2
for epoch in tqdm(range(num_epochs)):
  task_losses = {task: 0.0 for task in data_loaders}
  batch_losses = {task: float("inf") for task in data_loaders}
  task_batches = {f"{task}_batches": 0 for task in data_loaders}
  with tqdm(round_robin_iterator, desc="round_robin_iterator") as pbar:
    for task, batch in pbar:
        batch = batch.to(device)
        # print(task)
        # print({k: v.shape for k,v in batch.items()})
        # print(batch["input_ids"][:, :10])
        optimizer.zero_grad()
        output = multitask_model(task=task, **batch)
        task_batches[f"{task}_batches"] += 1
        # print({k:v.shape for k,v in output.items()})
        loss = output.loss
        loss.backward()
        optimizer.step()
        batch_losses[task] = loss.item()
        task_losses[task] += batch_losses[task]
        # print(task_losses)
        pbar.set_postfix(task=task, **batch_losses, **task_batches)
        # break
    print(task_losses)

  0%|          | 0/2 [00:00<?, ?it/s]

round_robin_iterator:   0%|          | 0/564 [00:00<?, ?it/s]

You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'conll2003': 346.3214177787304, 'wnut_17': 36.585519790649414, 'twitter_pos_ritter': 59.7191059589386}


round_robin_iterator:   0%|          | 0/564 [00:00<?, ?it/s]

{'conll2003': 343.96735060214996, 'wnut_17': 34.77236123383045, 'twitter_pos_ritter': 58.08308410644531}
