In [1]:
%%capture
!pip install transformers
!pip install wandb

# MTL approach
- 1 output neuron for hl and 1 for sl
- Use Soft and Hard Labels (Alternate: Add hard loss every 2nd batch)
- Hard:
  - Encoding: ~~1_hot~~ / One label 
  - Loss: ~~Cross_entropy~~/ ~~Negative_Log_Like~~/ BCE_Loss
- Soft
  - Loss: ~~Cross_entropy~~ / ~~KL_div~~ / BCE_Loss

- Model output layer:
  - 1 Layer with 2 outputs (loss with [:1] & [1:])~~
  - ~~2 seperate headers~~
  - ~~1 header and alternating loss (different loss)~~

- Layers to train
 - ~~ Whole model~~
 - output
 - Specific number of last layers: 3

- Regularization:
  - ~~Dropout~~



**Results**: 
- https://wandb.ai/capture_disagreement/MTL/runs/1mxyze9u/overview?workspace=user-sheuschk
- Codalab: CE Average ~0.43

In [None]:
#TODO: Fit training to 2 loss functions

In [2]:
import wandb
wandb.login()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 

··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
from drive.MyDrive.cicl_data.helpers import read_data
# from drive.MyDrive.cicl_data.code import CustomLabelDataset

In [4]:
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
# from datasets import Dataset
import torch.nn.functional as Fun
from torch.utils.data import Dataset, random_split, DataLoader
from torch.optim import AdamW
import torch.nn as nn

from tqdm.notebook import tqdm


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
data_dict = read_data()
df_all = pd.concat([data_dict[k] for k in data_dict.keys()])

In [7]:
def extract_soft_labels(row):
  return row[1]

In [8]:
df_all["sl_1s"] = df_all["soft_list"].apply(extract_soft_labels)

### Pretrained model

In [9]:
# Maybe load from wandb in future
tokenizer = AutoTokenizer.from_pretrained("lanwuwei/GigaBERT-v4-Arabic-and-English", do_lower_case=True)

Downloading:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/578 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/458k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/157 [00:00<?, ?B/s]

In [10]:
from transformers import BertModel
base_model = BertModel.from_pretrained("lanwuwei/GigaBERT-v4-Arabic-and-English")
# class_model = AutoModelForSequenceClassification.from_pretrained("lanwuwei/GigaBERT-v4-Arabic-and-English", num_labels=2)

# output of model: https://huggingface.co/docs/transformers/main_classes/output#transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions

Downloading:   0%|          | 0.00/500M [00:00<?, ?B/s]

Some weights of the model checkpoint at lanwuwei/GigaBERT-v4-Arabic-and-English were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
class MTLModel(nn.Module):
  def __init__(self, base_model):
    super().__init__()
    self.bert = base_model
    self.linear = nn.Linear(768, 2)
    self.sigmoid = nn.Sigmoid()

  def forward(self, input_ids, attention_mask, token_type_ids):
    """a linear layer on top of the pooled output (https://huggingface.co/transformers/v3.0.2/model_doc/bert.html#bertforsequenceclassification)"""

    x = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    x = self.linear(x.pooler_output)
    x = self.sigmoid(x)
    return x[:,0].to(dtype=torch.float64), x[:,1].to(dtype=torch.float64)

In [12]:
model = MTLModel(base_model).to(device)

In [13]:
# Deactivate weights for backprob
for name, param in model.named_parameters():
  if "bert.embeddings" in name:
    param.requires_grad = False
  for i in range(0, 9):  # BERT Model has 11 Bert Layers
    if str(i) in name:
      param.requires_grad = False


## Prepare Data

In [None]:
# from drive.MyDrive.cicl_data.helpers import CustomLabelDataset
# return: input, hard_labels_1hot, soft_labels, hard_labels

In [14]:
class CustomLabelDataset(Dataset):
    def __init__(self, df_all):
        self.text = list(map(self.tokenize_func, df_all["text"]))
        self.soft_labels = df_all["soft_list"] 
        self.hard_labels = df_all["hard_label"]
        self.hard_labels_1h = Fun.one_hot(torch.tensor(df_all['hard_label'].values))
        self.soft_labels_1s = df_all["sl_1s"] # 0.33 of soft labels like {"1": 0.33, "0": 0.67}

    def __len__(self):
        return len(self.text)
      
    def tokenize_func(self, text):
        return tokenizer(text, padding="max_length", truncation=True, max_length=240)

    def __getitem__(self, idx):
        input = {"attention_mask": torch.tensor(self.text[idx]["attention_mask"]),
                 "token_type_ids": torch.tensor(self.text[idx]["token_type_ids"]),
                 "input_ids": torch.tensor(self.text[idx]["input_ids"])}
        return input, self.hard_labels_1h[idx], torch.tensor(self.soft_labels[idx]), torch.tensor(self.hard_labels[idx]), torch.tensor(self.soft_labels_1s[idx])


In [15]:
# Init dataset
dataset = CustomLabelDataset(df_all)
batch_size = 4

train_size = len(dataset)

train_dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True)


In [18]:
# initialize Evaluation dataset
data_dict_dev = read_data("dev")
df_dev = pd.concat([data_dict_dev[k] for k in data_dict_dev.keys()])

df_dev["sl_1s"] = df_dev["soft_list"].apply(extract_soft_labels)

dev_dataset = CustomLabelDataset(df_dev)
dev_batch_size = 4
dev_size = len(dev_dataset)

dev_dataloader = DataLoader(
    dev_dataset,
    batch_size=dev_batch_size)

## Optimization

In [19]:
# Optimizer
num_epochs = 10

num_training_steps = num_epochs * len(train_dataloader)
optimizer = AdamW(model.parameters())
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [20]:
# Loss
bce_loss = nn.BCELoss()

## Training

In [54]:
run = wandb.init(
    project="MTL",
    config={
        "epochs": num_epochs,
        "batch_size": batch_size,
        "device": device,
        "Run": "Fine tune transformer"
        },
      save_code = True,
      tags = ["bert_arabic_english", "MTL", "3_last_layers", "1_head", "2_Neurons", "BCE_Loss"]
      )
wandb.watch(model, log_freq=100)

[34m[1mwandb[0m: Currently logged in as: [33msheuschk[0m ([33mcapture_disagreement[0m). Use [1m`wandb login --relogin`[0m to force relogin


[]

In [52]:
# from drive.MyDrive.cicl_data.helpers import ce_eval_func
def ce_eval_func(model, eval_dataloader, eval_size, epsilon=1e-12, device="cuda"):
  model.eval()
  cross_error = 0

  for i, batch in enumerate(tqdm(eval_dataloader, 0)):
    input_ids = batch[0]["input_ids"].to(device, dtype = torch.long)
    attention_mask = batch[0]["attention_mask"].to(device, dtype = torch.long)
    token_type_ids = batch[0]["token_type_ids"].to(device, dtype = torch.long)
    soft_labels = batch[2].to(device)

    with torch.no_grad():
      _, pred = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    pred = pred.reshape(len(pred), 1)
    probabilities = torch.cat((1-pred, pred), dim=-1)
    predictions = torch.clip(probabilities, epsilon, 1. - epsilon)
    cross_error += -torch.sum(soft_labels * torch.log(predictions + 1e-9))

  return cross_error / eval_size



In [53]:
ce_eval_func(model, dev_dataloader, dev_size, device=device)

  0%|          | 0/557 [00:00<?, ?it/s]

tensor(0.7122, device='cuda:0', dtype=torch.float64)

CE before Training: 0.7122

In [55]:
# Train
last_ce = 10
smallest_ce = 10

for e in range(num_epochs):
  model.train()
  loss_batches = 0
  epoch_loss = 0
  epoch_len = len(train_dataloader)

  eval_counter = False

  for i, batch in enumerate(train_dataloader):
    input_ids = batch[0]["input_ids"].to(device, dtype=torch.long)
    attention_mask = batch[0]["attention_mask"].to(device, dtype=torch.long)
    token_type_ids = batch[0]["token_type_ids"].to(device, dtype=torch.long)
    soft_labels_1 = batch[4].to(device, dtype=torch.float64)
    hard_label = batch[3].to(device, dtype=torch.float64)

    # predict
    optimizer.zero_grad()
    pred_hl, pred_sl = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

    # Loss
    loss = bce_loss(pred_sl, soft_labels_1)
    if i % 2 == 0:
      loss += bce_loss(pred_hl, hard_label)

    loss.backward()
    optimizer.step()
    lr_scheduler.step()

    # Log
    loss_batches += loss.item()
    epoch_loss += loss.item()

    log_n_batches = 200
    if i % log_n_batches == 0:
      if i != 0:
        print(f"{e+1}: Last {log_n_batches} batches avg loss: {loss_batches/log_n_batches:>7f}  [{i}/{epoch_len}]")
        wandb.log({"train/loss_over_batches": loss_batches/log_n_batches})
        wandb.log({"train/epochs": e})
      loss_batches = 0
  
  epoch_loss /= i  # Not completely correct (Loss per batch but not every batch has same size)
  print(f"Epoch [{e+1}/{num_epochs}] mean loss: {epoch_loss:>6f}")
  wandb.log({"train/epoch_loss": epoch_loss})

  # Eval error
  ce = ce_eval_func(model, dev_dataloader, dev_size, device=device)
  print(f"Epoch [{e+1}/{num_epochs}] Eval CE  : {ce:>6f}")
  wandb.log({"eval/epoch_ce": ce})

  # Stop after Eval CE raises 2 times in a row (Simple early stopping)
  if ce > last_ce:
    if eval_counter is True:
      print("Interrupt: Eval Error is raising")
      break;
    eval_counter = True
  elif ce < smallest_ce:
    torch.save(model.state_dict(), 'model.pt')
    eval_counter = False
    smallest_ce = ce
  else:
    eval_counter = False
  
  last_ce = ce



1: Last 200 batches avg loss: 0.948368  [200/2608]
1: Last 200 batches avg loss: 0.909511  [400/2608]
1: Last 200 batches avg loss: 0.876303  [600/2608]
1: Last 200 batches avg loss: 0.827733  [800/2608]
1: Last 200 batches avg loss: 0.846425  [1000/2608]
1: Last 200 batches avg loss: 0.840334  [1200/2608]
1: Last 200 batches avg loss: 0.874573  [1400/2608]
1: Last 200 batches avg loss: 0.770077  [1600/2608]
1: Last 200 batches avg loss: 0.865336  [1800/2608]
1: Last 200 batches avg loss: 0.770354  [2000/2608]
1: Last 200 batches avg loss: 0.861430  [2200/2608]
1: Last 200 batches avg loss: 0.738943  [2400/2608]
1: Last 200 batches avg loss: 0.740746  [2600/2608]
Epoch [1/10] mean loss: 0.836461


  0%|          | 0/557 [00:00<?, ?it/s]

Epoch [1/10] Eval CE  : 0.473010
2: Last 200 batches avg loss: 0.767667  [200/2608]
2: Last 200 batches avg loss: 0.758180  [400/2608]
2: Last 200 batches avg loss: 0.774382  [600/2608]
2: Last 200 batches avg loss: 0.785679  [800/2608]
2: Last 200 batches avg loss: 0.714244  [1000/2608]
2: Last 200 batches avg loss: 0.725501  [1200/2608]
2: Last 200 batches avg loss: 0.738744  [1400/2608]
2: Last 200 batches avg loss: 0.760331  [1600/2608]
2: Last 200 batches avg loss: 0.787470  [1800/2608]
2: Last 200 batches avg loss: 0.705587  [2000/2608]
2: Last 200 batches avg loss: 0.738554  [2200/2608]
2: Last 200 batches avg loss: 0.742831  [2400/2608]
2: Last 200 batches avg loss: 0.734456  [2600/2608]
Epoch [2/10] mean loss: 0.749041


  0%|          | 0/557 [00:00<?, ?it/s]

Epoch [2/10] Eval CE  : 0.471134
3: Last 200 batches avg loss: 0.665461  [200/2608]
3: Last 200 batches avg loss: 0.730899  [400/2608]
3: Last 200 batches avg loss: 0.735598  [600/2608]
3: Last 200 batches avg loss: 0.701214  [800/2608]
3: Last 200 batches avg loss: 0.709773  [1000/2608]
3: Last 200 batches avg loss: 0.761297  [1200/2608]
3: Last 200 batches avg loss: 0.695142  [1400/2608]
3: Last 200 batches avg loss: 0.782070  [1600/2608]
3: Last 200 batches avg loss: 0.710887  [1800/2608]
3: Last 200 batches avg loss: 0.692325  [2000/2608]
3: Last 200 batches avg loss: 0.735591  [2200/2608]
3: Last 200 batches avg loss: 0.707999  [2400/2608]
3: Last 200 batches avg loss: 0.732890  [2600/2608]
Epoch [3/10] mean loss: 0.720801


  0%|          | 0/557 [00:00<?, ?it/s]

Epoch [3/10] Eval CE  : 0.466182
4: Last 200 batches avg loss: 0.693232  [200/2608]
4: Last 200 batches avg loss: 0.700020  [400/2608]
4: Last 200 batches avg loss: 0.674966  [600/2608]
4: Last 200 batches avg loss: 0.662925  [800/2608]
4: Last 200 batches avg loss: 0.707368  [1000/2608]
4: Last 200 batches avg loss: 0.702544  [1200/2608]
4: Last 200 batches avg loss: 0.694884  [1400/2608]
4: Last 200 batches avg loss: 0.652637  [1600/2608]
4: Last 200 batches avg loss: 0.633196  [1800/2608]
4: Last 200 batches avg loss: 0.684950  [2000/2608]
4: Last 200 batches avg loss: 0.718163  [2200/2608]
4: Last 200 batches avg loss: 0.693939  [2400/2608]
4: Last 200 batches avg loss: 0.714901  [2600/2608]
Epoch [4/10] mean loss: 0.687351


  0%|          | 0/557 [00:00<?, ?it/s]

Epoch [4/10] Eval CE  : 0.436345
5: Last 200 batches avg loss: 0.664598  [200/2608]
5: Last 200 batches avg loss: 0.660834  [400/2608]
5: Last 200 batches avg loss: 0.666818  [600/2608]
5: Last 200 batches avg loss: 0.630307  [800/2608]
5: Last 200 batches avg loss: 0.652617  [1000/2608]
5: Last 200 batches avg loss: 0.611367  [1200/2608]
5: Last 200 batches avg loss: 0.673986  [1400/2608]
5: Last 200 batches avg loss: 0.703544  [1600/2608]
5: Last 200 batches avg loss: 0.609465  [1800/2608]
5: Last 200 batches avg loss: 0.678497  [2000/2608]
5: Last 200 batches avg loss: 0.639217  [2200/2608]
5: Last 200 batches avg loss: 0.651493  [2400/2608]
5: Last 200 batches avg loss: 0.709024  [2600/2608]
Epoch [5/10] mean loss: 0.657184


  0%|          | 0/557 [00:00<?, ?it/s]

Epoch [5/10] Eval CE  : 0.449067
6: Last 200 batches avg loss: 0.584579  [200/2608]
6: Last 200 batches avg loss: 0.669171  [400/2608]
6: Last 200 batches avg loss: 0.633315  [600/2608]
6: Last 200 batches avg loss: 0.623668  [800/2608]
6: Last 200 batches avg loss: 0.587126  [1000/2608]
6: Last 200 batches avg loss: 0.663677  [1200/2608]
6: Last 200 batches avg loss: 0.638324  [1400/2608]
6: Last 200 batches avg loss: 0.669152  [1600/2608]
6: Last 200 batches avg loss: 0.630231  [1800/2608]
6: Last 200 batches avg loss: 0.611019  [2000/2608]
6: Last 200 batches avg loss: 0.619998  [2200/2608]
6: Last 200 batches avg loss: 0.647421  [2400/2608]
6: Last 200 batches avg loss: 0.639130  [2600/2608]
Epoch [6/10] mean loss: 0.633009


  0%|          | 0/557 [00:00<?, ?it/s]

Epoch [6/10] Eval CE  : 0.426625
7: Last 200 batches avg loss: 0.574778  [200/2608]
7: Last 200 batches avg loss: 0.606541  [400/2608]
7: Last 200 batches avg loss: 0.583655  [600/2608]
7: Last 200 batches avg loss: 0.625756  [800/2608]
7: Last 200 batches avg loss: 0.623919  [1000/2608]
7: Last 200 batches avg loss: 0.600113  [1200/2608]
7: Last 200 batches avg loss: 0.647340  [1400/2608]
7: Last 200 batches avg loss: 0.607402  [1600/2608]
7: Last 200 batches avg loss: 0.605404  [1800/2608]
7: Last 200 batches avg loss: 0.587983  [2000/2608]
7: Last 200 batches avg loss: 0.580345  [2200/2608]
7: Last 200 batches avg loss: 0.598003  [2400/2608]
7: Last 200 batches avg loss: 0.590923  [2600/2608]
Epoch [7/10] mean loss: 0.602676


  0%|          | 0/557 [00:00<?, ?it/s]

Epoch [7/10] Eval CE  : 0.424363
8: Last 200 batches avg loss: 0.554471  [200/2608]
8: Last 200 batches avg loss: 0.591779  [400/2608]
8: Last 200 batches avg loss: 0.578492  [600/2608]
8: Last 200 batches avg loss: 0.580108  [800/2608]
8: Last 200 batches avg loss: 0.581063  [1000/2608]
8: Last 200 batches avg loss: 0.526981  [1200/2608]
8: Last 200 batches avg loss: 0.636331  [1400/2608]
8: Last 200 batches avg loss: 0.594544  [1600/2608]
8: Last 200 batches avg loss: 0.589190  [1800/2608]
8: Last 200 batches avg loss: 0.638781  [2000/2608]
8: Last 200 batches avg loss: 0.601068  [2200/2608]
8: Last 200 batches avg loss: 0.561530  [2400/2608]
8: Last 200 batches avg loss: 0.560250  [2600/2608]
Epoch [8/10] mean loss: 0.584471


  0%|          | 0/557 [00:00<?, ?it/s]

Epoch [8/10] Eval CE  : 0.420388
9: Last 200 batches avg loss: 0.549006  [200/2608]
9: Last 200 batches avg loss: 0.567034  [400/2608]
9: Last 200 batches avg loss: 0.616433  [600/2608]
9: Last 200 batches avg loss: 0.566993  [800/2608]
9: Last 200 batches avg loss: 0.554144  [1000/2608]
9: Last 200 batches avg loss: 0.597888  [1200/2608]
9: Last 200 batches avg loss: 0.510797  [1400/2608]
9: Last 200 batches avg loss: 0.568511  [1600/2608]
9: Last 200 batches avg loss: 0.552991  [1800/2608]
9: Last 200 batches avg loss: 0.578847  [2000/2608]
9: Last 200 batches avg loss: 0.551563  [2200/2608]
9: Last 200 batches avg loss: 0.577245  [2400/2608]
9: Last 200 batches avg loss: 0.599034  [2600/2608]
Epoch [9/10] mean loss: 0.568288


  0%|          | 0/557 [00:00<?, ?it/s]

Epoch [9/10] Eval CE  : 0.427228
10: Last 200 batches avg loss: 0.531172  [200/2608]
10: Last 200 batches avg loss: 0.532908  [400/2608]
10: Last 200 batches avg loss: 0.588315  [600/2608]
10: Last 200 batches avg loss: 0.548708  [800/2608]
10: Last 200 batches avg loss: 0.528386  [1000/2608]
10: Last 200 batches avg loss: 0.556531  [1200/2608]
10: Last 200 batches avg loss: 0.538802  [1400/2608]
10: Last 200 batches avg loss: 0.546768  [1600/2608]
10: Last 200 batches avg loss: 0.554375  [1800/2608]
10: Last 200 batches avg loss: 0.527474  [2000/2608]
10: Last 200 batches avg loss: 0.568435  [2200/2608]
10: Last 200 batches avg loss: 0.542919  [2400/2608]
10: Last 200 batches avg loss: 0.542142  [2600/2608]
Epoch [10/10] mean loss: 0.547531


  0%|          | 0/557 [00:00<?, ?it/s]

Epoch [10/10] Eval CE  : 0.425183


Model dependent improvements:
- Balance between hard and soft loss
- Maybe first just hl, than just soft afterwards (only one output neuron, or two for distribution)



## Evaluation

In [56]:
# Final Cross Entropy Error
cross_error = ce_eval_func(model, dev_dataloader, dev_size, device=device)
print(f"CE error: {cross_error}")
wandb.log({"dev/ce": cross_error})

  0%|          | 0/557 [00:00<?, ?it/s]

CE error: 0.42518259605304065


In [58]:
model_best = MTLModel(base_model)
model_best.load_state_dict(torch.load('model.pt'))
model_best = model_best.to(device)

In [59]:
# Final Cross Entropy Error
cross_error = ce_eval_func(model_best, dev_dataloader, dev_size, device=device)
print(f"CE error: {cross_error}")
wandb.log({"dev/ce": cross_error})

  0%|          | 0/557 [00:00<?, ?it/s]

CE error: 0.42518259605304065


In [60]:
# from drive.MyDrive.cicl_data.helpers import f1_eval_func
from sklearn.metrics import f1_score

def f1_eval_func(model, eval_dataloader, eval_size, device="cuda"):
  model.eval()
  f1_error = 0

  for i, batch in enumerate(tqdm(eval_dataloader, 0)):
    input_ids = batch[0]["input_ids"].to(device, dtype = torch.long)
    attention_mask = batch[0]["attention_mask"].to(device, dtype = torch.long)
    token_type_ids = batch[0]["token_type_ids"].to(device, dtype = torch.long)
    hard_labels = batch[3].to(device)

    with torch.no_grad():
      pred, _ = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    f1_error += f1_score(hard_labels.cpu(), round(pred).cpu(), average='micro')

  return f1_error / eval_size

In [61]:
# F1 micro Error
f1 = f1_eval_func(model, dev_dataloader, dev_size, device)
print(f"F1 error: {f1}")
wandb.log({"dev/f1": f1})

  0%|          | 0/557 [00:00<?, ?it/s]

TypeError: ignored

In [62]:
model.eval()
cross_error = 0
epsilon = 1e-12
for i, batch in enumerate(dev_dataloader):
  input_ids = batch[0]["input_ids"].to(device, dtype = torch.long)
  attention_mask = batch[0]["attention_mask"].to(device, dtype = torch.long)
  token_type_ids = batch[0]["token_type_ids"].to(device, dtype = torch.long)
  soft_labels = batch[2].to(device)

  with torch.no_grad():
    _, pred = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
  pred = pred.reshape(len(pred), 1)
  probabilities = torch.cat((1-pred, pred), dim=-1)
  predictions = torch.clip(probabilities, epsilon, 1. - epsilon)
  cross_error += -torch.sum(soft_labels * torch.log(predictions + 1e-9))
  break



In [63]:
print(predictions)
print(soft_labels)

tensor([[0.6692, 0.3308],
        [0.0753, 0.9247],
        [0.7356, 0.2644],
        [0.4792, 0.5208]], device='cuda:0', dtype=torch.float64)
tensor([[0.6700, 0.3300],
        [0.0000, 1.0000],
        [0.3300, 0.6700],
        [0.3300, 0.6700]], device='cuda:0')


### Finish

In [64]:
torch.save(model.state_dict(), 'model.pt')
# model.load_state_dict(torch.load(PATH), strict=False)
artifact = wandb.Artifact(name='model_param', type='model')
artifact.add_file(local_path="model.pt")
run.log_artifact(artifact);

In [65]:
wandb.finish()

0,1
dev/ce,▁▁
eval/epoch_ce,██▇▃▅▂▂▁▂▂
train/epoch_loss,█▆▅▄▄▃▂▂▂▁
train/epochs,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
train/loss_over_batches,█▆▅▅▅▄▅▅▅▄▆▄▄▄▄▄▃▃▄▃▂▂▃▃▂▂▂▂▁▂▂▂▂▂▂▂▂▁▁▁

0,1
dev/ce,0.42518
eval/epoch_ce,0.42518
train/epoch_loss,0.54753
train/epochs,9.0
train/loss_over_batches,0.54214


## TSV files

In [66]:
import os
import csv

In [68]:
filepaths = ["/content/ArMIS_results.tsv", "/content/ConvAbuse_results.tsv", "/content/HS-Brexit_results.tsv", "/content/MD-Agreement_results.tsv"]
epsilon = 1e-12

for fp in filepaths:
  if os.path.exists(fp):
    os.remove(fp)

for key in data_dict_dev.keys():
  data_dict_dev[key]["sl_1s"] = data_dict_dev[key]["soft_list"].apply(extract_soft_labels)
  tsv_dataset = CustomLabelDataset(data_dict_dev[key])
  tsv_dataloader = DataLoader(tsv_dataset, shuffle=False, batch_size=1)
  filepath_write = f"/content/{key}_results.tsv"

  with open(filepath_write, 'w', newline='') as tsvfile:
      writer = csv.writer(tsvfile, delimiter='\t', lineterminator='\n')
      for i, batch in enumerate(tqdm(tsv_dataloader, 0)):
        input_ids = batch[0]["input_ids"].to(device, dtype = torch.long)
        attention_mask = batch[0]["attention_mask"].to(device, dtype = torch.long)
        token_type_ids = batch[0]["token_type_ids"].to(device, dtype = torch.long)

        with torch.no_grad():
          _ , pred = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # logits = pred.logits
        # probability = torch.softmax(pred, axis=-1)
        pred = pred.reshape(len(pred), 1)
        probability = torch.cat((1-pred, pred), dim=-1)
        prediction = torch.argmax(probability, dim=-1)
        probability = torch.clip(probability, epsilon, 1. - epsilon) # Really necessary?
        writer.writerow([prediction[0].item(), probability[0][0].item(), probability[0][1].item()])


  0%|          | 0/141 [00:00<?, ?it/s]

  0%|          | 0/1104 [00:00<?, ?it/s]

  0%|          | 0/812 [00:00<?, ?it/s]

  0%|          | 0/168 [00:00<?, ?it/s]

In [69]:
from zipfile import ZipFile

filepath = "res.zip" 

if os.path.exists(filepath):
    os.remove(filepath)

#loop over filepath names throws an string index out of range for whatever reason(also can't use content here, not sure why)
with ZipFile(filepath, 'w') as zipObj:
  zipObj.write("MD-Agreement_results.tsv")
  zipObj.write("ArMIS_results.tsv")
  zipObj.write("HS-Brexit_results.tsv")
  zipObj.write("ConvAbuse_results.tsv")

In [70]:
from google.colab import files
files.download("res.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>