<a href="https://colab.research.google.com/github/cicl-iscl/LeWiDi_SemEval2023/blob/main/Notebooks/MTL/MTL_Bert_2heads.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install wandb

# MTL approach
- Use 4 output nodes, 2 for hard labels, 2 for soft
- Cross Entropy error for both tasks
- Fine tune whole model - to share knowledge over tasks 

In [2]:
#TODO: Fit training to 2 loss functions

In [3]:
import wandb
wandb.login()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 

··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [4]:
from drive.MyDrive.cicl_data.code import read_data
# from drive.MyDrive.cicl_data.code import CustomLabelDataset

In [5]:
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
# from datasets import Dataset
import torch.nn.functional as Fun
from torch.utils.data import Dataset, random_split, DataLoader
from torch.optim import AdamW
import torch.nn as nn

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [7]:
data_dict = read_data()
df_all = pd.concat([data_dict[k] for k in data_dict.keys()])

### Pretrained model

In [8]:
# Maybe load from wandb in future
tokenizer = AutoTokenizer.from_pretrained("lanwuwei/GigaBERT-v4-Arabic-and-English", do_lower_case=True)
# model_class = AutoModelForSequenceClassification.from_pretrained("lanwuwei/GigaBERT-v4-Arabic-and-English", num_labels=4)

Downloading:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/578 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/458k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/157 [00:00<?, ?B/s]

In [9]:
from transformers import BertModel
model_base = BertModel.from_pretrained("lanwuwei/GigaBERT-v4-Arabic-and-English", num_labels=4)
# output of model: https://huggingface.co/docs/transformers/main_classes/output#transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions

Downloading:   0%|          | 0.00/500M [00:00<?, ?B/s]

Some weights of the model checkpoint at lanwuwei/GigaBERT-v4-Arabic-and-English were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
class MTLModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.bert = model_base
    self.linear_hl = nn.Linear(768, 2)
    self.linear_sl = nn.Linear(768, 2)
    self.softmax = nn.Softmax
    # [8, 240, 768]

  def forward(self, input_ids, attention_mask, token_type_ids):
    """a linear layer on top of the pooled output (https://huggingface.co/transformers/v3.0.2/model_doc/bert.html#bertforsequenceclassification)"""
    x = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    x_hl = self.linear_hl(x.pooler_output)
    # x_hl = self.softmax(x_hl, axis=-1)

    x_sl = self.linear_sl(x.pooler_output)
    return x_hl, x_sl

In [11]:
mtl_model = MTLModel().to(device)

## Prepare Data

In [12]:
class CustomLabelDataset(Dataset):
    def __init__(self, df_all):
        self.text = list(map(self.tokenize_func, df_all["text"]))
        self.soft_labels = df_all["soft_list"] 
        self.hard_labels = df_all["hard_label"]
        self.hard_labels_1h = Fun.one_hot(torch.tensor(df_all['hard_label'].values))

    def __len__(self):
        return len(self.text)
      
    def tokenize_func(self, text):
        return tokenizer(text, padding="max_length", truncation=True, max_length=240)

    def __getitem__(self, idx):
        input = {"attention_mask": torch.tensor(self.text[idx]["attention_mask"]),
                 "token_type_ids": torch.tensor(self.text[idx]["token_type_ids"]),
                 "input_ids": torch.tensor(self.text[idx]["input_ids"])}
        return input, self.hard_labels_1h[idx], torch.tensor(self.soft_labels[idx]), torch.tensor(self.hard_labels[idx])

In [60]:
# Init dataset
dataset = CustomLabelDataset(df_all)
batch_size = 8

train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
train_dataset, eval_dataset = random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True)

eval_dataloader = DataLoader(
    eval_dataset,
    batch_size=batch_size)

## Optimization

In [14]:
# Optimizer
num_epochs = 4

num_training_steps = num_epochs * len(train_dataloader)
optimizer = AdamW(mtl_model.parameters())
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [15]:
# Loss
loss_hard = nn.CrossEntropyLoss()
loss_soft = nn.CrossEntropyLoss()

## Training

In [16]:
run = wandb.init(
    project="mtl-1l",
    config={
        "epochs": num_epochs,
        "batch_size": batch_size,
        "device": device
        },
      save_code = True,
      tags = ["bert_arabic_english", "mtl", "2_heads"],
      )
wandb.watch(mtl_model, log_freq=100)

[34m[1mwandb[0m: Currently logged in as: [33msheuschk[0m. Use [1m`wandb login --relogin`[0m to force relogin


[]

In [21]:
# Train
# num_epochs = wandb.config.epochs

for e in range(num_epochs):

  loss_batches = 0
  epoch_loss = 0
  epoch_len = len(train_dataloader)

  for i, batch in enumerate(train_dataloader):

    input_ids, attention_mask, token_type_ids = batch[0]["input_ids"].to(device), batch[0]["attention_mask"].to(device), batch[0]["token_type_ids"].to(device)
    hard_labels, soft_labels = batch[1].to(device), batch[2].to(device)

    pred_hl, pred_sl = mtl_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

    loss_hl = loss_hard(pred_hl, hard_labels.to(float))
    loss_sl = loss_soft(pred_sl, soft_labels)

    loss = loss_sl + loss_hl

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    lr_scheduler.step()

    loss_batches += loss.item()
    epoch_loss += loss.item()
    wandb.log({"train/loss": loss.item()})

    log_n_batches = 50
    if i % log_n_batches == 0:
      if i == 0:
        log_n_batches = 1
      print(f"{e+1}: Last {log_n_batches} batches avg loss: {loss_batches/log_n_batches:>7f}  [{i}/{epoch_len}]")
      wandb.log({"train/loss_over_batches": loss_batches/log_n_batches})
      wandb.log({"train/epochs": e})
      loss_batches = 0
  
  epoch_loss /= i
  print(f"Epoch loss: {epoch_loss:>7f}  [{e+1}/{num_epochs}]")
  wandb.log({"train/epoch_loss": epoch_loss})
  wandb.log({"train/epoch": e})



Last 1 batches avg loss: 6.565152  [0/1043.0]
Last 50 batches avg loss: 1.363249  [50/1043.0]
Last 50 batches avg loss: 1.248226  [100/1043.0]
Last 50 batches avg loss: 1.182887  [150/1043.0]
Last 50 batches avg loss: 1.154746  [200/1043.0]
Last 50 batches avg loss: 1.303767  [250/1043.0]
Last 50 batches avg loss: 1.239921  [300/1043.0]
Last 50 batches avg loss: 1.296868  [350/1043.0]
Last 50 batches avg loss: 1.189413  [400/1043.0]
Last 50 batches avg loss: 1.200521  [450/1043.0]
Last 50 batches avg loss: 1.197504  [500/1043.0]
Last 50 batches avg loss: 1.263680  [550/1043.0]
Last 50 batches avg loss: 1.264297  [600/1043.0]
Last 50 batches avg loss: 1.180400  [650/1043.0]
Last 50 batches avg loss: 1.209289  [700/1043.0]
Last 50 batches avg loss: 1.273286  [750/1043.0]
Last 50 batches avg loss: 1.175780  [800/1043.0]
Last 50 batches avg loss: 1.231676  [850/1043.0]
Last 50 batches avg loss: 1.174319  [900/1043.0]
Last 50 batches avg loss: 1.245590  [950/1043.0]
Last 50 batches avg loss

In [None]:
# Notes to improve (superficial):
# - counter for batches

Model dependent iprovements:
- Choose different loss?
  - CE should be good for 1hot and soft labels, as they represent distributions. But maybe KL Error, as in paper

- Is 1hot encoded hard labels really a benefit
  - As the tasks are seperate, maybe just keep it 0 and 1 and learn more regression style task

- Number of epochs?
  - @Dennis: For output layer it overfits after one epoch
  - But train the whole model, its harder to find sweet spot
  - * [ ] Add evaluation on onseen data at the end of epoch, to check generalization error while training

- 

## Evaluation

In [22]:
# from torcheval.metrics.functional import binary_f1_score
from sklearn.metrics import f1_score

In [50]:
# Eval
def cross_entropy(targets_soft, predictions_soft, epsilon = 1e-12):
  predictions = torch.clip(predictions_soft, epsilon, 1. - epsilon)
  ce = -torch.sum(targets_soft * torch.log(predictions + 1e-9)) # / predictions.shape[0]
  return ce


cross_error = 0
f1 = 0

for i, batch in enumerate(eval_dataloader):
  input_ids, attention_mask, token_type_ids = batch[0]["input_ids"].to(device), batch[0]["attention_mask"].to(device), batch[0]["token_type_ids"].to(device)
  _, soft_labels, hard_labels = batch[1].to(device), batch[2].to(device), batch[3].to(device)

  with torch.no_grad():
    pred_hl, pred_sl = mtl_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
  pred_sl = torch.softmax(pred_sl, axis=-1)
  cross_error += cross_entropy(soft_labels, pred_sl)
  f1 += f1_score(hard_labels.cpu(), pred_hl.argmax(1).cpu(), average='micro')

f1 /= i  # as not all batches are same size, this is not correct
cross_error /= test_size

print(f"F1 error: {f1}")
print(f"CE error: {cross_error}")
wandb.log({"eval/ce": cross_error})
wandb.log({"eval/f1": f1})



F1 error: 0.7424213353798925
CE error: 2.961099147796631


Error: ignored

In [39]:
from tqdm.notebook import tqdm
total = 0
ce = 0
epsilon = 1e-12

for i, batch in enumerate(tqdm(eval_dataloader, 0)):
  input_ids, attention_mask, token_type_ids = batch[0]["input_ids"].to(device), batch[0]["attention_mask"].to(device), batch[0]["token_type_ids"].to(device)
  _, soft_labels, hard_labels = batch[1].to(device), batch[2].to(device), batch[3].to(device)

  with torch.no_grad():
    pred_hl, pred_sl = mtl_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
  probabilities = torch.softmax(pred_sl, axis=-1)
  predictions = torch.clip(probabilities, epsilon, 1. - epsilon)                                      
  N = predictions.shape[0]
  total+= N #maybe should be 1 here as well
  ce += -torch.sum(soft_labels * torch.log(predictions + 1e-9))

ce = ce/total

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in enumerate(tqdm_notebook(eval_dataloader, 0)):


  0%|          | 0/1304 [00:00<?, ?it/s]

In [40]:
ce

tensor(0.5924, device='cuda:0')

In [48]:
torch.softmax(pred_sl, axis=-1)

tensor([[0.7255, 0.2745],
        [0.7255, 0.2745],
        [0.7255, 0.2745],
        [0.7255, 0.2745],
        [0.7255, 0.2745],
        [0.7255, 0.2745],
        [0.7255, 0.2745]], device='cuda:0')

### Finish

In [None]:
# Save parameters
# artifact = wandb.Artifact(name='model_param', type='model')
# artifact.add_dir(local_path="classifier.pt")
# artifact.add_dir(local_path="bias.pt")
# artifact.add_file(local_path="model.pt")

# run.log_artifact(artifact)

In [None]:
torch.save(mtl_model.parameters, 'model.pt')
artifact = wandb.Artifact(name='model_param', type='model')
artifact.add_file(local_path="model.pt")
run.log_artifact(artifact)

In [27]:
wandb.finish()

0,1
eval/ce,▁
eval/f1,▁
train/epoch,▁▃▆█
train/epoch_loss,█▄▂▁
train/epochs,▁▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▃▆▆▆▆▆▆▆▆▆▆██████████
train/loss,▅▃▇▅▅▅▄█▆▅▆▇▂▅▆▃▄▄▄▁▄▆▅▅▇▃▅▆▅▆█▄▄▄▄▂▆▄▄▂
train/loss_over_batches,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/ce,31.14395
eval/f1,0.74242
train/epoch,3.0
train/epoch_loss,1.16648
train/epochs,3.0
train/loss,1.42576
train/loss_over_batches,1.15365
