# Multitask

Finetune a pretrained BERT model to do both named entity recognition and sequence level classification in a single forward pass. 

In [None]:
# if running in a google colab, make sure the relevant packages are trained
!pip install datasets evaluate transformers[sentencepiece]
!pip install accelerate
# To run the training on TPU, you will need to uncomment the following line:
# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!apt install git-lfs
!pip install seqeval

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.2).
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.


In [None]:
# import packages
from datasets import load_dataset
import evaluate
import numpy as np
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import BertTokenizer
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer
import evaluate
from sklearn.metrics import accuracy_score, f1_score
from transformers import get_scheduler

In [None]:
# function to align labels with tokens, taken from NER.ipynb
def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            # If the label is B-XXX we change it to I-XXX
            if label % 2 == 1:
                label += 1
            new_labels.append(label)

    return new_labels

In [None]:
metric = evaluate.load("seqeval")

# function to evaluate perfomance metrics (precision, recall, f1-score, accuracy) for given predictions and ground-truth labels
def compute_metrics(predictions, labels):

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
# specify multitask model
class MultiTaskModel(nn.Module):
    def __init__(self, model_name="bert-base-uncased", num_seq_labels=5, num_token_labels=9):
        super(MultiTaskModel, self).__init__()
        # load pretrained BERT model
        self.bert = BertModel.from_pretrained(model_name)
        # size of the hidden layer of the BERT model
        hidden_size = self.bert.config.hidden_size
        # here: learnable parameters to combine the losses of both tasks
        self.log_sigma1 = nn.Parameter(torch.tensor(0.0))  # For sequence loss
        self.log_sigma2 = nn.Parameter(torch.tensor(-1.4))  # For token loss

        # Sequence classification head
        self.seq_classifier = nn.Sequential(
            #nn.Linear(hidden_size, hidden_size // 2),
            #nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_seq_labels)
        )

        # Token classification head
        self.token_classifier = nn.Sequential(
            #nn.Linear(hidden_size, hidden_size // 2),
            #nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_token_labels)
        )

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        # output of BERT model
        outputs = self.bert(input_ids, attention_mask, token_type_ids)
        sequence_output = outputs.last_hidden_state  # Shape: (batch_size, seq_length, hidden_size)
        # [CLS] token used for classification of the while sequence
        cls_output = outputs.pooler_output           # Shape: (batch_size, hidden_size)

        # sequence classification (using [CLS] token)
        seq_logits = self.seq_classifier(cls_output)  # Shape: (batch_size, num_seq_labels)

        # token classification (per token)
        token_logits = self.token_classifier(sequence_output)  # Shape: (batch_size, seq_length, num_token_labels)

        return seq_logits, token_logits

In [None]:
# create custom huggingface dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, texts, seq_labels, token_labels, tokenizer, max_length):
        self.texts = texts
        self.seq_labels = seq_labels
        self.token_labels = token_labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'seq_label': torch.tensor(self.seq_labels[idx], dtype=torch.long),
            'token_labels': torch.tensor(self.token_labels[idx], dtype=torch.long)
        }

# data collator
# make sure all vectors are correctly padded and have the right dimension
def collate_fn(batch):
    input_ids = pad_sequence([item['input_ids'] for item in batch], batch_first=True, padding_value=0)
    attention_mask = pad_sequence([item['attention_mask'] for item in batch], batch_first=True, padding_value=0)
    seq_labels = torch.stack([item['seq_label'] for item in batch])
    token_labels = pad_sequence([item['token_labels'] for item in batch], batch_first=True, padding_value=-100)  # Ignore padding tokens
    return input_ids, attention_mask, seq_labels, token_labels

# training Loop
def train_model(model, training_dataloader, validation_dataloader, optimizer, lr_scheduler, accelerator, device, epochs=3, alpha=0.5, beta=0.5):
    model.to(device)
    model.train()

    # iterate over epochs
    for epoch in range(epochs):
        # initialize losses
        total_seq_loss = 0
        total_token_loss = 0

        # loop over batches from dataloader
        loop = tqdm(training_dataloader, leave=True)
        for batch in loop:
            # get samples and labels for both tokens and sequences
            input_ids, attention_mask, seq_labels, token_labels = [x.to(device) for x in batch]

            # set gradient to zero
            optimizer.zero_grad()

            # forward pass
            seq_logits, token_logits = model(input_ids, attention_mask)

            # loss computation
            seq_loss = F.cross_entropy(seq_logits, seq_labels)
            token_loss = F.cross_entropy(token_logits.view(-1, token_logits.size(-1)),
                                         F.pad(token_labels, (0, token_logits.size(1) - token_labels.size(1)), "constant", -100).view(-1), ignore_index=-100)

            # combine both losses using the trainable parameters
            weighted_seq_loss = 1/(2 * torch.exp(model.log_sigma1)) * seq_loss + model.log_sigma1
            weighted_token_loss = 1/(torch.exp(model.log_sigma2)) * token_loss + model.log_sigma2
            loss = weighted_seq_loss + weighted_token_loss

            # backward pass
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()

            # update metrics
            total_seq_loss += seq_loss.item()
            total_token_loss += token_loss.item()

            # update progress bar
            loop.set_description(f'Epoch {epoch + 1}')
            loop.set_postfix(seq_loss=seq_loss.item(), token_loss=token_loss.item())

        # compute validation loss
        with torch.no_grad():
          
          # initialize sequnce loss and token loss to zero
          total_seq_loss_val = 0
          total_token_loss_val = 0
          loop = tqdm(validation_dataloader, leave=True)

          all_seq_preds = []
          all_seq_labels = []

          all_token_preds = []
          all_token_labels = []

          # loop over batches
          for batch in loop:
              # get samples and labels for both tokens and sequences
              input_ids, attention_mask, seq_labels, token_labels = [x.to(device) for x in batch]

              # forward pass
              seq_logits, token_logits = model(input_ids, attention_mask)

              all_seq_preds = np.append(all_seq_preds, seq_logits.argmax(-1).cpu())
              all_seq_labels = np.append(all_seq_labels, seq_labels.cpu())

              all_token_preds.append(token_logits.argmax(-1).cpu())
              all_token_labels.append(F.pad(token_labels.cpu(), (0, token_logits.size(1) - token_labels.size(1)), "constant", -100))

              # loss computation
              seq_loss = F.cross_entropy(seq_logits, seq_labels)
              token_loss = F.cross_entropy(token_logits.view(-1, token_logits.size(-1)),
                                            F.pad(token_labels, (0, token_logits.size(1) - token_labels.size(1)), "constant", -100).view(-1), ignore_index=-100)

              # update validation metrics
              total_seq_loss_val += seq_loss.item()
              total_token_loss_val += token_loss.item()

          # print training and validation losses
          print(f"Epoch {epoch+1}/{epochs}, Training Sequence Loss: {total_seq_loss/len(training_dataloader)}, Training Token Loss: {total_token_loss/len(training_dataloader)}")
          print(f"\t Validation Sequence Loss: {total_seq_loss_val/len(validation_dataloader)}, Validation Token Loss: {total_token_loss_val/len(validation_dataloader)}")

          # compute and print validation performance metrics
          f1_seq = f1_score(all_seq_labels, all_seq_preds, average="weighted")
          acc_seq = accuracy_score(all_seq_labels, all_seq_preds)
          print(f"\t Validation Sequence F_1: {f1_seq}, Accuracy: {acc_seq}")
          all_token_preds = np.vstack(all_token_preds)
          all_token_labels = np.vstack(all_token_labels)
          metrics_token = compute_metrics(all_token_preds, all_token_labels)
          print(f"\t Validation Token F_1: {metrics_token['f1']}, Accuracy: {metrics_token['accuracy']}")

In [None]:
# function to form a sequence from a list of strings
# from Classification.ipynb
def smart_join(strings):
    result = []
    for i, s in enumerate(strings):
        # Add a space before appending if the previous item is not a special character
        if i > 0 and not strings[i-1].endswith(('(', ' ')) and not s.startswith(('.', ',',';', ':', '%', """'""", '''"''', ')', '!', '?')):
            result.append(' ')

        result.append(s)

    return ''.join(result)

# load datasets
ds_train = load_dataset("conll2003", split="train")
ds_valid = load_dataset("conll2003", split="validation")
ds_test = load_dataset("conll2003", split="test")

# load NER labels
ner_feature = ds_train.features["ner_tags"]
label_names = ner_feature.feature.names

# training dataset
training_classification = ds_train.shuffle(seed=23).select(range(1000))
training_tokens = training_classification['tokens']

training_sentences = []
for t in training_tokens:
  sentence = smart_join(t)
  training_sentences.append(sentence)

# sequence level labels generated in Classification.ipynb
training_sentence_labels = [2, 4, 4, 1, 4, 4, 1, 1, 1, 1, 1, 2, 1, 0, 2, 1, 1, 0, 2, 4, 4, 1, 4, 4, 4, 2, 1, 1, 0, 4, 1, 2, 1, 0, 4, 4, 1, 4, 1, 0, 0, 1, 1, 0, 1, 0, 4, 1, 1, 1, 2, 1, 4, 4, 0, 1, 0, 2, 4, 4, 4, 1, 2, 0, 1, 1, 1, 0, 4, 1, 1, 2, 4, 4, 0, 1, 2, 1, 4, 1, 4, 1, 2, 0, 0, 1, 1, 4, 1, 2, 4, 0, 1, 1, 4, 4, 0, 2, 1, 4, 4, 0, 1, 0, 2, 4, 1, 1, 1, 2, 1, 1, 4, 0, 1, 1, 0, 4, 4, 1, 0, 2, 1, 2, 1, 0, 0, 1, 1, 1, 2, 4, 1, 1, 2, 2, 1, 0, 1, 1, 4, 2, 2, 1, 4, 2, 3, 4, 4, 0, 2, 4, 1, 0, 4, 1, 4, 4, 0, 4, 4, 1, 2, 1, 1, 1, 0, 4, 0, 1, 2, 2, 1, 4, 2, 2, 2, 2, 1, 4, 1, 2, 4, 0, 2, 4, 1, 2, 2, 4, 2, 4, 4, 0, 2, 0, 4, 4, 4, 4, 2, 1, 4, 1, 1, 2, 1, 4, 0, 1, 4, 2, 2, 4, 0, 2, 0, 1, 1, 1, 4, 4, 0, 2, 0, 0, 1, 2, 1, 2, 1, 1, 1, 0, 1, 4, 4, 4, 1, 1, 1, 1, 2, 0, 1, 0, 1, 1, 0, 0, 1, 1, 4, 4, 1, 0, 0, 2, 0, 0, 1, 0, 1, 0, 2, 1, 1, 1, 1, 1, 4, 4, 2, 4, 2, 4, 0, 4, 1, 0, 1, 4, 2, 4, 4, 4, 4, 4, 0, 1, 1, 4, 1, 4, 0, 1, 1, 2, 1, 0, 0, 0, 4, 4, 1, 4, 2, 4, 1, 0, 4, 4, 0, 0, 0, 2, 4, 1, 2, 2, 1, 3, 4, 4, 2, 0, 0, 0, 4, 1, 0, 1, 4, 2, 2, 1, 4, 0, 0, 4, 0, 1, 4, 1, 0, 1, 4, 4, 4, 4, 0, 1, 0, 4, 1, 1, 0, 4, 0, 4, 1, 4, 2, 4, 2, 1, 4, 1, 4, 0, 1, 4, 0, 4, 0, 1, 0, 1, 4, 1, 2, 2, 4, 2, 2, 0, 4, 1, 4, 1, 2, 1, 4, 1, 2, 2, 1, 4, 4, 4, 2, 1, 4, 4, 0, 4, 4, 1, 1, 2, 2, 2, 1, 4, 1, 4, 1, 4, 0, 2, 1, 2, 4, 4, 0, 1, 2, 0, 0, 2, 0, 4, 2, 4, 1, 0, 2, 0, 2, 4, 2, 0, 0, 4, 0, 4, 1, 4, 1, 0, 1, 4, 0, 4, 4, 1, 0, 0, 0, 2, 1, 1, 4, 0, 2, 0, 1, 1, 4, 4, 0, 1, 1, 4, 1, 0, 1, 1, 1, 1, 0, 4, 1, 4, 4, 1, 2, 4, 2, 1, 1, 1, 1, 4, 4, 1, 2, 0, 1, 1, 1, 1, 1, 1, 4, 1, 2, 4, 0, 4, 1, 4, 1, 1, 0, 0, 2, 1, 4, 1, 4, 2, 0, 4, 0, 0, 1, 1, 4, 4, 2, 1, 1, 1, 0, 4, 2, 2, 0, 1, 4, 4, 1, 4, 4, 2, 4, 4, 2, 1, 1, 0, 2, 4, 0, 2, 1, 4, 0, 4, 0, 1, 0, 4, 4, 4, 0, 4, 1, 1, 4, 2, 1, 4, 1, 4, 2, 1, 4, 1, 4, 1, 4, 4, 2, 4, 4, 0, 1, 0, 2, 2, 4, 4, 4, 0, 1, 2, 1, 4, 1, 1, 2, 1, 4, 2, 4, 4, 1, 1, 1, 0, 3, 4, 0, 4, 1, 4, 0, 4, 1, 2, 1, 1, 4, 0, 0, 4, 2, 4, 1, 2, 4, 0, 0, 4, 4, 1, 4, 4, 0, 2, 1, 1, 1, 4, 0, 0, 4, 1, 4, 0, 4, 4, 1, 4, 1, 2, 1, 2, 4, 0, 1, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 0, 0, 0, 4, 0, 4, 4, 4, 2, 1, 2, 0, 2, 4, 0, 4, 4, 1, 1, 4, 4, 2, 2, 1, 1, 4, 1, 2, 4, 1, 1, 0, 1, 0, 2, 0, 4, 4, 0, 2, 0, 1, 0, 1, 1, 2, 1, 2, 4, 4, 1, 1, 1, 1, 2, 4, 4, 4, 2, 2, 4, 0, 4, 4, 4, 4, 4, 0, 4, 1, 0, 4, 1, 4, 4, 0, 0, 1, 0, 1, 0, 1, 4, 1, 4, 1, 1, 2, 1, 0, 1, 4, 2, 0, 4, 0, 2, 4, 0, 4, 1, 1, 4, 1, 0, 0, 1, 1, 4, 4, 0, 1, 0, 1, 2, 1, 2, 0, 0, 2, 0, 4, 2, 4, 2, 0, 2, 1, 4, 2, 2, 0, 4, 1, 0, 4, 1, 1, 1, 1, 4, 4, 2, 4, 2, 2, 1, 4, 1, 2, 4, 4, 4, 4, 4, 4, 0, 4, 1, 4, 2, 0, 4, 4, 1, 2, 1, 2, 0, 4, 4, 1, 0, 4, 4, 1, 1, 1, 1, 2, 4, 0, 1, 4, 1, 1, 0, 2, 0, 4, 1, 1, 0, 4, 4, 1, 1, 1, 4, 4, 1, 1, 1, 4, 1, 0, 1, 4, 0, 1, 1, 4, 4, 2, 2, 2, 0, 0, 2, 1, 4, 4, 1, 1, 2, 1, 1, 1, 0, 0, 4, 4, 0, 1, 0, 1, 4, 0, 1, 0, 4, 2, 0, 4, 1, 2, 2, 0, 1, 0, 1, 4, 2, 4, 0, 4, 2, 1, 4, 1, 0, 4, 4, 1, 1, 0, 4, 1, 0, 2, 4, 0, 1, 4, 0, 0, 4, 1, 1, 1, 2, 0, 4, 4, 0, 0, 4, 1, 4, 4, 1, 4, 0, 4, 4, 1, 1, 4, 1, 0, 4, 2, 0, 4, 1, 2, 0, 0, 4, 1, 4, 0, 0, 4, 1, 1, 1, 1, 4, 4, 1, 0, 4, 4, 4, 2, 4]

training_token_labels = training_classification['ner_tags']


# validation dataset
validation_classification = ds_valid.shuffle(seed=23).select(range(200))
validation_tokens = validation_classification['tokens']

validation_sentences = []
for t in validation_tokens:
  sentence = smart_join(t)
  validation_sentences.append(sentence)

# sequence level labels generated in Classification.ipynb
validation_sentence_labels = [0, 0, 1, 1, 4, 4, 1, 2, 1, 4, 2, 1, 0, 1, 0, 2, 1, 1, 1, 3, 1, 0, 4, 0, 4, 4, 3, 0, 1, 2, 1, 2, 0, 0, 1, 2, 1, 4, 1, 1, 4, 1, 1, 1, 1, 0, 4, 4, 2, 4, 2, 2, 1, 2, 1, 1, 2, 4, 1, 1, 4, 1, 1, 4, 4, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 4, 0, 4, 1, 0, 1, 0, 2, 4, 1, 0, 1, 1, 1, 1, 0, 1, 1, 4, 2, 1, 0, 4, 4, 4, 1, 4, 1, 0, 4, 1, 1, 0, 1, 0, 2, 1, 1, 4, 0, 1, 0, 2, 0, 0, 2, 1, 2, 2, 2, 1, 1, 1, 4, 4, 4, 4, 0, 0, 2, 1, 4, 1, 1, 4, 0, 1, 4, 4, 1, 0, 1, 1, 0, 0, 4, 2, 4, 4, 4, 1, 4, 1, 4, 1, 0, 2, 4, 4, 2, 1, 4, 2, 4, 2, 1, 0, 0, 4, 1, 1, 0, 1, 4, 1, 2, 1, 2, 0, 1, 1, 4, 4, 4, 1, 4, 0, 0, 0, 0, 1, 1, 1]

validation_token_labels = validation_classification['ner_tags']


# test dataset
test_classification = ds_test.shuffle(seed=23).select(range(200))
test_tokens = test_classification['tokens']

test_sentences = []
for t in test_tokens:
  sentence = smart_join(t)
  test_sentences.append(sentence)

# sequence level labels generated in Classification.ipynb
test_sentence_labels = [1, 1, 2, 2, 1, 4, 1, 4, 1, 1, 1, 2, 1, 4, 4, 2, 1, 0, 1, 4, 1, 4, 4, 1, 2, 1, 1, 0, 1, 1, 1, 2, 4, 1, 4, 0, 1, 0, 2, 4, 1, 2, 4, 2, 0, 1, 1, 2, 4, 1, 4, 2, 2, 4, 4, 2, 4, 1, 4, 1, 2, 0, 1, 4, 1, 2, 0, 4, 0, 2, 0, 2, 2, 1, 4, 4, 2, 4, 2, 1, 1, 1, 4, 2, 1, 1, 1, 1, 4, 4, 4, 1, 2, 1, 2, 1, 1, 1, 4, 2, 4, 1, 0, 1, 4, 2, 1, 1, 4, 0, 1, 2, 2, 4, 4, 1, 4, 4, 0, 1, 1, 2, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 0, 1, 4, 4, 2, 1, 0, 4, 4, 4, 2, 1, 4, 0, 1, 1, 1, 1, 1, 1, 2, 4, 4, 1, 1, 4, 4, 2, 0, 4, 2, 1, 4, 4, 1, 0, 0, 4, 0, 4, 2, 2, 4, 2, 1, 1, 1, 4, 1, 0, 1, 1, 1, 1, 0, 2, 4, 1, 4, 4, 4, 4, 1, 4, 3, 4, 2, 1]

test_token_labels = test_classification['ner_tags']

# batch size
batch_size = 64
max_length = max([len(l) for l in np.concatenate([training_sentences, validation_sentences, test_sentences])])

# load BERT tokenize
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# align labels with tokens
# use original word ids, before joining tokens into sentences
training_token_labels = [align_labels_with_tokens(t, tokenizer(training_classification[i]['tokens'], is_split_into_words=True).word_ids(0) ) for i, t in enumerate(training_token_labels)]
validation_token_labels = [align_labels_with_tokens(t, tokenizer(validation_classification[i]['tokens'], is_split_into_words=True).word_ids(0) ) for i, t in enumerate(validation_token_labels)]
test_token_labels = [align_labels_with_tokens(t, tokenizer(test_classification[i]['tokens'], is_split_into_words=True).word_ids(0) ) for i, t in enumerate(test_token_labels)]

# create datasets and dataloaders
training_dataset = CustomDataset(training_sentences, training_sentence_labels, training_token_labels, tokenizer, max_length)
training_dataloader = DataLoader(training_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

validation_dataset = CustomDataset(validation_sentences, validation_sentence_labels, validation_token_labels, tokenizer, max_length)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

test_dataset = CustomDataset(test_sentences, test_sentence_labels, test_token_labels, tokenizer, max_length)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)



In [None]:
from accelerate import Accelerator

# Hyperparameters
learning_rate = 5e-5
weight_decay = 0.01
epochs = 14

# Initialize model and optimizer
model = MultiTaskModel()
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

num_update_steps_per_epoch = len(training_dataloader)
num_training_steps = epochs * num_update_steps_per_epoch

# linear weight decay
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

accelerator = Accelerator()

model, optimizer, training_dataloader, lr_scheduler = accelerator.prepare(
     model, optimizer, training_dataloader, lr_scheduler
)

# start training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_model(model, training_dataloader, validation_dataloader, optimizer, lr_scheduler, accelerator, device, epochs, alpha=1, beta=7.5)

Epoch 1: 100%|██████████| 16/16 [00:44<00:00,  2.77s/it, seq_loss=1.12, token_loss=0.733]
100%|██████████| 4/4 [00:03<00:00,  1.23it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/14, Training Sequence Loss: 1.420078121125698, Training Token Loss: 1.07689917832613
	 Validation Sequence Loss: 1.201205462217331, Validation Token Loss: 0.6254969537258148
	 Validation Sequence F_1: 0.44054083343557027, Accuracy: 0.515
	 Validation Token F_1: 0.008810572687224669, Accuracy: 0.7961732660111613


Epoch 2: 100%|██████████| 16/16 [00:43<00:00,  2.71s/it, seq_loss=0.815, token_loss=0.339]
100%|██████████| 4/4 [00:03<00:00,  1.25it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 2/14, Training Sequence Loss: 1.076890666037798, Training Token Loss: 0.4318919759243727
	 Validation Sequence Loss: 0.8945852220058441, Validation Token Loss: 0.34693189710378647
	 Validation Sequence F_1: 0.6430362071649988, Accuracy: 0.695
	 Validation Token F_1: 0.5472636815920399, Accuracy: 0.9019399415360085


Epoch 3: 100%|██████████| 16/16 [00:43<00:00,  2.73s/it, seq_loss=0.6, token_loss=0.159]
100%|██████████| 4/4 [00:03<00:00,  1.23it/s]


Epoch 3/14, Training Sequence Loss: 0.7959059812128544, Training Token Loss: 0.22457178216427565
	 Validation Sequence Loss: 0.7253954857587814, Validation Token Loss: 0.21813618391752243
	 Validation Sequence F_1: 0.7058842058110888, Accuracy: 0.735
	 Validation Token F_1: 0.6952141057934509, Accuracy: 0.9431304809992027


Epoch 4: 100%|██████████| 16/16 [00:43<00:00,  2.73s/it, seq_loss=0.561, token_loss=0.1]
100%|██████████| 4/4 [00:03<00:00,  1.25it/s]


Epoch 4/14, Training Sequence Loss: 0.6369500085711479, Training Token Loss: 0.12317322427406907
	 Validation Sequence Loss: 0.5312941074371338, Validation Token Loss: 0.20217426493763924
	 Validation Sequence F_1: 0.7973087443432294, Accuracy: 0.805
	 Validation Token F_1: 0.7586206896551725, Accuracy: 0.9492426255647091


Epoch 5: 100%|██████████| 16/16 [00:43<00:00,  2.73s/it, seq_loss=0.25, token_loss=0.0572]
100%|██████████| 4/4 [00:03<00:00,  1.18it/s]


Epoch 5/14, Training Sequence Loss: 0.44866914488375187, Training Token Loss: 0.07945937383919954
	 Validation Sequence Loss: 0.5153326615691185, Validation Token Loss: 0.1460735034197569
	 Validation Sequence F_1: 0.7925536020714619, Accuracy: 0.795
	 Validation Token F_1: 0.8025806451612904, Accuracy: 0.9646558596864204


Epoch 6: 100%|██████████| 16/16 [00:44<00:00,  2.75s/it, seq_loss=0.183, token_loss=0.0457]
100%|██████████| 4/4 [00:03<00:00,  1.25it/s]


Epoch 6/14, Training Sequence Loss: 0.377736022695899, Training Token Loss: 0.05204859480727464
	 Validation Sequence Loss: 0.38438572362065315, Validation Token Loss: 0.16176610626280308
	 Validation Sequence F_1: 0.8718205404112865, Accuracy: 0.875
	 Validation Token F_1: 0.8167539267015708, Accuracy: 0.9670475684294446


Epoch 7: 100%|██████████| 16/16 [00:43<00:00,  2.74s/it, seq_loss=0.152, token_loss=0.0324]
100%|██████████| 4/4 [00:03<00:00,  1.25it/s]


Epoch 7/14, Training Sequence Loss: 0.2841995498165488, Training Token Loss: 0.036941458587534726
	 Validation Sequence Loss: 0.3669581972062588, Validation Token Loss: 0.13343499787151814
	 Validation Sequence F_1: 0.8575756703160995, Accuracy: 0.86
	 Validation Token F_1: 0.8270481144343304, Accuracy: 0.966781823013553


Epoch 8: 100%|██████████| 16/16 [00:43<00:00,  2.73s/it, seq_loss=0.0711, token_loss=0.0273]
100%|██████████| 4/4 [00:03<00:00,  1.25it/s]


Epoch 8/14, Training Sequence Loss: 0.21502772439271212, Training Token Loss: 0.030915238487068564
	 Validation Sequence Loss: 0.3694503456354141, Validation Token Loss: 0.12621749378740788
	 Validation Sequence F_1: 0.863428058765532, Accuracy: 0.865
	 Validation Token F_1: 0.8441558441558441, Accuracy: 0.9720967313313845


Epoch 9: 100%|██████████| 16/16 [00:43<00:00,  2.73s/it, seq_loss=0.0637, token_loss=0.0231]
100%|██████████| 4/4 [00:03<00:00,  1.26it/s]


Epoch 9/14, Training Sequence Loss: 0.18199307518079877, Training Token Loss: 0.02274439320899546
	 Validation Sequence Loss: 0.3297599144279957, Validation Token Loss: 0.13528584502637386
	 Validation Sequence F_1: 0.8777772260840782, Accuracy: 0.88
	 Validation Token F_1: 0.8426527958387515, Accuracy: 0.9691735317565772


Epoch 10: 100%|██████████| 16/16 [00:43<00:00,  2.73s/it, seq_loss=0.0632, token_loss=0.0213]
100%|██████████| 4/4 [00:03<00:00,  1.22it/s]


Epoch 10/14, Training Sequence Loss: 0.14171697618439794, Training Token Loss: 0.019768410187680274
	 Validation Sequence Loss: 0.37890270724892616, Validation Token Loss: 0.11835752241313457
	 Validation Sequence F_1: 0.8357888395351166, Accuracy: 0.84
	 Validation Token F_1: 0.8470588235294119, Accuracy: 0.9712994950837098


Epoch 11: 100%|██████████| 16/16 [00:43<00:00,  2.72s/it, seq_loss=0.0395, token_loss=0.02]
100%|██████████| 4/4 [00:03<00:00,  1.25it/s]


Epoch 11/14, Training Sequence Loss: 0.12068530544638634, Training Token Loss: 0.016409783391281962
	 Validation Sequence Loss: 0.39519762992858887, Validation Token Loss: 0.12327207252383232
	 Validation Sequence F_1: 0.8832088960312316, Accuracy: 0.885
	 Validation Token F_1: 0.8597640891218873, Accuracy: 0.9731597129949509


Epoch 12: 100%|██████████| 16/16 [00:43<00:00,  2.72s/it, seq_loss=0.0438, token_loss=0.0177]
100%|██████████| 4/4 [00:03<00:00,  1.23it/s]


Epoch 12/14, Training Sequence Loss: 0.10599148599430919, Training Token Loss: 0.01431423905887641
	 Validation Sequence Loss: 0.3199357446283102, Validation Token Loss: 0.11794183775782585
	 Validation Sequence F_1: 0.8811153552330023, Accuracy: 0.885
	 Validation Token F_1: 0.8341968911917099, Accuracy: 0.9702365134201435


Epoch 13: 100%|██████████| 16/16 [00:43<00:00,  2.73s/it, seq_loss=0.0585, token_loss=0.0173]
100%|██████████| 4/4 [00:03<00:00,  1.25it/s]


Epoch 13/14, Training Sequence Loss: 0.09256090852431953, Training Token Loss: 0.01303378288866952
	 Validation Sequence Loss: 0.31463278643786907, Validation Token Loss: 0.1389664225280285
	 Validation Sequence F_1: 0.896296712696389, Accuracy: 0.9
	 Validation Token F_1: 0.8229166666666667, Accuracy: 0.968110550093011


Epoch 14: 100%|██████████| 16/16 [00:43<00:00,  2.72s/it, seq_loss=0.0326, token_loss=0.0165]
100%|██████████| 4/4 [00:03<00:00,  1.24it/s]


Epoch 14/14, Training Sequence Loss: 0.08656661119312048, Training Token Loss: 0.013622149359434843
	 Validation Sequence Loss: 0.27501474507153034, Validation Token Loss: 0.15115107223391533
	 Validation Sequence F_1: 0.9062809164250882, Accuracy: 0.91
	 Validation Token F_1: 0.8322496749024707, Accuracy: 0.9683762955089025


In [None]:
model.eval();
metric = evaluate.load("seqeval")

# function to evaluate perfomance metrics (precision, recall, f1-score, accuracy) for given predictions and ground-truth labels
def compute_metrics(predictions, labels):

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

# compute performance on test set
with torch.no_grad():

  total_seq_loss_test = 0
  total_token_loss_test = 0
  loop = tqdm(test_dataloader, leave=True)

  all_seq_preds = []
  all_seq_labels = []

  all_token_preds = []
  all_token_labels = []

  for batch in loop:
      input_ids, attention_mask, seq_labels, token_labels = [x.to(device) for x in batch]

      # forward pass
      seq_logits, token_logits = model(input_ids, attention_mask)

      all_seq_preds = np.append(all_seq_preds, seq_logits.argmax(-1).cpu())
      all_seq_labels = np.append(all_seq_labels, seq_labels.cpu())

      all_token_preds.append(token_logits.argmax(-1).cpu())
      all_token_labels.append(F.pad(token_labels.cpu(), (0, token_logits.size(1) - token_labels.size(1)), "constant", -100))

      # loss computation
      seq_loss = F.cross_entropy(seq_logits, seq_labels)
      token_loss = F.cross_entropy(token_logits.view(-1, token_logits.size(-1)),
                                    F.pad(token_labels, (0, token_logits.size(1) - token_labels.size(1)), "constant", -100).view(-1), ignore_index=-100)

      # update metrics
      total_seq_loss_test += seq_loss.item()
      total_token_loss_test += token_loss.item()

  # compute and print test performance metrics  
  print(f"\nTest \t Sequence Loss: {total_seq_loss_test/len(test_dataloader)}, Token Loss: {total_token_loss_test/len(test_dataloader)}")
  f1_seq = f1_score(all_seq_labels, all_seq_preds, average="weighted")
  acc_seq = accuracy_score(all_seq_labels, all_seq_preds)
  print(f"\t Sequence F_1: {f1_seq}, Accuracy: {acc_seq}")
  all_token_preds = np.vstack(all_token_preds)
  all_token_labels = np.vstack(all_token_labels)
  metrics_token = compute_metrics(all_token_preds, all_token_labels)
  print(f"\t Token F_1: {metrics_token['f1']}, Accuracy: {metrics_token['accuracy']}")

100%|██████████| 4/4 [00:03<00:00,  1.27it/s]



Test 	 Sequence Loss: 0.5435006022453308, Token Loss: 0.17799033969640732
	 Sequence F_1: 0.8779372528317265, Accuracy: 0.88
	 Token F_1: 0.834575260804769, Accuracy: 0.9630832841110455


In [None]:
lr = 5e-5 epochs = 14 alpha = 1 beta = 7.5

Test 	 Sequence Loss: 0.525999091565609, Token Loss: 0.16994309797883034
	 Sequence F_1: 0.869773786585689, Accuracy: 0.87
	 Token F_1: 0.849772382397572, Accuracy: 0.9657412876550502