In [1]:
import torch
import spacy
import json
import numpy as np
from tqdm.notebook import tqdm
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit

from irproject.historical_events import (
    tag_paragraph, preprocess_data, load_data
)
from transformers import (
    BertTokenizerFast, BertForTokenClassification, BertForPreTraining, BertForMaskedLM,
    BertModel, BertForSequenceClassification, BertTokenizer, AdamW
)

spacy.prefer_gpu()

False

In [2]:
nlp = spacy.load("en_core_web_sm")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
with open(
    f"datasets/historical_events/wiki_dataset.json",
    encoding="utf-8"
) as f_in:
    dataset = json.load(f_in)

In [4]:
texts, tags, labels = load_data(dataset, nlp)

In [5]:
train_texts, test_texts, train_tags, test_tags, train_labels, test_labels = train_test_split(
    texts, tags, labels, test_size=0.2, random_state=42, stratify=labels
)

In [6]:
train_texts, valid_texts, train_tags, valid_tags, train_labels, valid_labels = train_test_split(
    train_texts, train_tags, train_labels, test_size=0.2, random_state=42, stratify=train_labels
)

In [7]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")

In [8]:
encodings, tokens_labels, labels, tag2idx, idx2tag = preprocess_data(
    train_texts.tolist(), train_tags.tolist(), 
    train_labels.tolist(), tokenizer, padding="max_length"
)

HBox(children=(HTML(value='Adjusting tags to encodings'), FloatProgress(value=1.0, bar_style='info', layout=La…

In [10]:
torch.save(encodings.input_ids, "datasets/historical_events/train/input_ids.pkl")
torch.save(encodings.attention_mask, "datasets/historical_events/train/attention_mask.pkl")
torch.save(tokens_labels, "datasets/historical_events/train/tokens_labels.pkl")
torch.save(labels, "datasets/historical_events/train/labels.pkl")
torch.save(tag2idx, "datasets/historical_events/train/tag2idx.pkl")
torch.save(idx2tag, "datasets/historical_events/train/idx2tag.pkl")

In [12]:
input_ids = torch.load("datasets/historical_events/train/input_ids.pkl")
tokens_labels = torch.load("datasets/historical_events/train/tokens_labels.pkl")

In [None]:
train_data = TensorDataset(
    torch.tensor(encodings.input_ids, device=device), 
    torch.tensor(encodings.attention_mask, device=device), 
    torch.tensor(tokens_labels, device=device),
    torch.tensor(labels, device=device, dtype=torch.float32)
)
train_loader = DataLoader(
    train_data, shuffle=True, batch_size=1
)

In [86]:
train_loader = DataLoader(
    train_data, shuffle=True, batch_size=128
)

In [87]:
torch.save(train_loader, "datasets/historical_events/train.pkl")

In [3]:
train_loader = torch.load("datasets/historical_events/train.pkl")
valid_loader = torch.load("datasets/historical_events/valid.pkl")

In [11]:
encodings, tokens_labels, labels, tag2idx, idx2tag = preprocess_data(
    valid_texts.tolist(), valid_tags.tolist(), 
    valid_labels.tolist(), tokenizer, padding="max_length"
)

HBox(children=(HTML(value='Adjusting tags to encodings'), FloatProgress(value=1.0, bar_style='info', layout=La…

In [12]:
torch.save(encodings.input_ids, "datasets/historical_events/valid/input_ids.pkl")
torch.save(encodings.attention_mask, "datasets/historical_events/valid/attention_mask.pkl")
torch.save(tokens_labels, "datasets/historical_events/valid/tokens_labels.pkl")
torch.save(labels, "datasets/historical_events/valid/labels.pkl")
torch.save(tag2idx, "datasets/historical_events/valid/tag2idx.pkl")
torch.save(idx2tag, "datasets/historical_events/valid/idx2tag.pkl")

In [None]:
valid_data = TensorDataset(
    torch.tensor(encodings.input_ids, device=device), 
    torch.tensor(encodings.attention_mask, device=device), 
    torch.tensor(tokens_labels, device=device),
    torch.tensor(labels, device=device, dtype=torch.float32)
)

In [88]:
valid_loader = DataLoader(
    valid_data, shuffle=True, batch_size=128
)

In [89]:
torch.save(valid_loader, "datasets/historical_events/valid.pkl")

In [13]:
encodings, tokens_labels, labels, tag2idx, idx2tag = preprocess_data(
    test_texts.tolist(), test_tags.tolist(), 
    test_labels.tolist(), tokenizer, padding="max_length"
)


HBox(children=(HTML(value='Adjusting tags to encodings'), FloatProgress(value=1.0, bar_style='info', layout=La…

In [14]:
torch.save(encodings.input_ids, "datasets/historical_events/test/input_ids.pkl")
torch.save(encodings.attention_mask, "datasets/historical_events/test/attention_mask.pkl")
torch.save(tokens_labels, "datasets/historical_events/test/tokens_labels.pkl")
torch.save(labels, "datasets/historical_events/test/labels.pkl")
torch.save(tag2idx, "datasets/historical_events/test/tag2idx.pkl")
torch.save(idx2tag, "datasets/historical_events/test/idx2tag.pkl")

In [None]:
test_data = TensorDataset(
    torch.tensor(encodings.input_ids, device=device), 
    torch.tensor(encodings.attention_mask, device=device), 
    torch.tensor(tokens_labels, device=device),
    torch.tensor(labels, device=device, dtype=torch.float32)
)

In [90]:
test_loader = DataLoader(
    test_data, shuffle=True, batch_size=128
)

In [91]:
torch.save(test_loader, "datasets/historical_events/test.pkl")

In [4]:
bert = BertModel.from_pretrained("bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
import torch.nn.functional as F
from torch import nn

class MultiTaskLearningModel(nn.Module):

    def __init__(self, base_model, dropout_rates, hidden_size, num_labels):
        super(MultiTaskLearningModel, self).__init__()
        self.base_model = base_model

        # We could avoid sigmoid here and use the
        # BCEWithLogitsLoss, which computes both the sigmoid
        # and the BCE with a trick for numerical stability.
        self.seq_clf = nn.Sequential(
            nn.Dropout(p=dropout_rates[0]),
            nn.Linear(in_features=768, out_features=hidden_size),
            nn.ReLU(),
            nn.Dropout(p=dropout_rates[1]),
            nn.Linear(in_features=hidden_size, out_features=1),
            nn.Sigmoid()
        )

        self.tokens_clf = nn.Sequential(
            nn.Dropout(p=dropout_rates[2]),
            nn.Linear(hidden_size, num_labels)
            # nn.LogSoftmax(dim=1)
        )

        for p in self.seq_clf.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        for p in self.tokens_clf.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        
    def forward(self, input_ids, attention_mask):
        output = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        last_hidden_state = output.last_hidden_state

        # TODO Since we are using mean, maybe it's better not to pad to max?
        seq_clf_out = self.seq_clf(torch.mean(last_hidden_state, dim=1))
        tokens_clf_out = self.tokens_clf(last_hidden_state)

        return seq_clf_out, tokens_clf_out

    def freeze_base(self):
        for param in self.base_model.named_parameters():
            param[1].requires_grad=False

    def unfreeze_base(self):
        for param in self.base_model.named_parameters():
            param[1].requires_grad=True


In [63]:
from torch.nn import BCELoss, CrossEntropyLoss
from torch import nn

class MultiTaskLoss(nn.Module):
    """Uncertainty weighted loss for multi-task learning
    proposed in Kendall et al., Multi-Task Learning Using 
    Uncertainty to Weigh Losses for Scene Geometry and 
    Semantics, arXiv:1705.07115v3.

    Possible alternative implementations are available
    at:
        - https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example-pytorch.ipynb
        - https://github.com/lorenmt/mtan/blob/master/im2im_pred/utils.py
    """

    def __init__(self, losses_num: int = 2, num_tokens_labels: int = 5):
        super(MultiTaskLoss, self).__init__()
        self.losses_num = losses_num
        self.num_tokens_labels = num_tokens_labels
        self.log_vars = nn.Parameter(torch.zeros((losses_num)))

    def forward(self, seq_clf_out, tokens_clf_out, labels, tokens_labels, attention_mask):

        loss_ce = CrossEntropyLoss()
        loss_bce = BCELoss()

        if attention_mask is not None:
            active_loss = attention_mask.view(-1) == 1
            active_logits = tokens_clf_out.view(-1, self.num_tokens_labels)
            active_labels = torch.where(
                active_loss, 
                tokens_labels.view(-1), 
                torch.tensor(loss_ce.ignore_index).type_as(tokens_labels)
            )
            loss0 = loss_ce(active_logits, active_labels)
        else:
            loss0 = loss_ce(
                tokens_clf_out.view(-1, self.num_tokens_labels), 
                tokens_labels.view(-1)
            )

        loss1 = loss_bce(seq_clf_out.view(-1), labels.view(-1))

        losses = [loss0, loss1]

        loss = sum(
            torch.exp(-self.log_vars[i]) * losses[i] + (self.log_vars[i] / 2)
            for i in range(self.losses_num)
        )
        
        return loss

In [8]:
# WE NEED TO SET -100 on all "O"

In [10]:
num_epochs = 3
max_grad_norm = 1.0

In [11]:


def tensor_accuracy(y_true, y_preds):

    # The model outputs probabilities for each token in the
    # sentence. So the output for the tokens will be of size
    # (#batches, #tokens, #classes). Method torch.max returns
    # the max value and the index of the max value. So we can get
    # the class predicted for each token in the batch.

    # _, y_preds = torch.max(y_preds_probas, -1)
    y_correct = (y_preds == y_true).sum().detach()
    acc = y_correct / y_true.size(0)

    return acc
    

def tensor_binary_accuracy(y_true, y_preds_probas):
    y_preds = (y_preds_probas > 0.5).view(1, -1)
    y_correct = (y_preds == y_true).sum().detach()
    acc = y_correct / y_true.size(0)

    return acc

    preds_batch_np = np.round(probas.cpu().detach().numpy())
    y_batch_np = y_batch.cpu().detach().numpy()
    acc = accuracy_score(y_true=y_batch_np, y_pred=preds_batch_np)
    f1 = f1_score(y_true=y_batch_np, y_pred=preds_batch_np, average='weighted')
    return acc, f1

In [14]:
from torch.optim import lr_scheduler
from sklearn.metrics import accuracy_score, f1_score

model = MultiTaskLearningModel(bert, [0.1, 0.3, 0.1], 768, 6)
model.to(device)
model.train()

optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, patience=2)
criterion = MultiTaskLoss(2, 6).to(device)

train_loss_values, validation_loss_values = [], []
freeze_epochs = [0, 1, 2]
unfreeze_epochs = [3, 4]

for epoch in tqdm(
    range(num_epochs),
    desc="Epoch",
    leave=False
):
    print("=== Epoch: ", epoch, " / ", num_epochs, " ===")
    print("LR:", optimizer.state_dict()["param_groups"][0]["lr"])
    if epoch in freeze_epochs:
        print("Freeze base model")
        model.freeze_base()
            
    if epoch in unfreeze_epochs:
        print("Unfreeze base model")
        model.unfreeze_base()

    train_loss = 0
    valid_loss = 0
    best_f1 = 0
    history = []

    # Seqc = classification for the whole sequence.
    # Tokc = token classification (NER).
    train_seqc_labels, train_seqc_preds = [], []
    train_tokc_labels, train_tokc_preds = [], []

    valid_seqc_labels, valid_seqc_preds = [], []
    valid_tokc_labels, valid_tokc_preds = [], []

    # Training.
    model.train()
    for step, batch in tqdm(
        enumerate(train_loader),
        desc="Training step",
        leave=False
    ):
        if step == 3:
            break

        model.zero_grad()

        input_ids, attention_mask, tokens_labels, labels = batch

        seqc_out, tokc_out = model(input_ids, attention_mask)
        loss = criterion(seqc_out, tokc_out, labels, tokens_labels, attention_mask)
        loss.backward()
        train_loss += loss.detach()

        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(),
            max_norm=max_grad_norm
        )

        optimizer.step()

        train_seqc_labels.extend(labels)
        train_tokc_labels.extend(tokens_labels)
        train_seqc_preds.extend(seqc_out)
        train_tokc_preds.extend(torch.argmax(tokc_out, dim=-1).view(-1))

    train_seqc_labels = torch.tensor(train_seqc_labels, device=device)
    train_tokc_labels = torch.cat(train_tokc_labels)
    train_seqc_preds = (torch.tensor(train_seqc_preds, device=device) > 0.5)
    train_tokc_preds = torch.tensor(train_tokc_preds, device=device)

    # Validation.
    model.eval()
    for step, batch in tqdm(
        enumerate(valid_loader),
        desc="Validation step",
        leave=False
    ):
        if step == 3:
            break

        input_ids, attention_mask, tokens_labels, labels = batch

        with torch.no_grad():
            seqc_out, tokc_out = model(input_ids, attention_mask)
            loss = criterion(seqc_out, tokc_out, labels, tokens_labels, attention_mask)
            valid_loss += loss.detach()

            valid_seqc_labels.extend(labels)
            valid_tokc_labels.extend(tokens_labels)
            valid_seqc_preds.extend(seqc_out)
            valid_tokc_preds.extend(torch.argmax(tokc_out, dim=-1).view(-1))

    valid_seqc_labels = torch.tensor(valid_seqc_labels, device=device)
    valid_tokc_labels = torch.cat(valid_tokc_labels)
    valid_seqc_preds = (torch.tensor(valid_seqc_preds, device=device) > 0.5)
    valid_tokc_preds = torch.tensor(valid_tokc_preds, device=device)

    train_avg_loss = train_loss / len(train_loader)
    valid_avg_loss = valid_loss / len(valid_loader)
    train_seqc_acc = tensor_binary_accuracy(
        train_seqc_labels, train_seqc_preds
    )
    train_tokc_acc = tensor_accuracy(train_tokc_labels, train_tokc_preds)
    train_seqc_f1 = f1_score(train_seqc_labels.cpu(), train_seqc_preds.cpu())
    train_tokc_f1 = f1_score(train_tokc_labels.cpu(), train_tokc_preds.cpu(), average="weighted")

    valid_seqc_acc = tensor_binary_accuracy(
        valid_seqc_labels, valid_seqc_preds
    )
    valid_tokc_acc = tensor_accuracy(valid_tokc_labels, valid_tokc_preds)
    valid_seqc_f1 = f1_score(valid_seqc_labels.cpu(), valid_seqc_preds.cpu())
    valid_tokc_f1 = f1_score(valid_tokc_labels.cpu(), valid_tokc_preds.cpu(), average="weighted")

    print("Train avg loss", train_avg_loss)
    print("Valid avg loss", valid_avg_loss)
    print("Train seqc accuracy:", train_seqc_acc)
    print("Train tokc accuracy:", train_tokc_acc)
    print("Train seqc f1-score:", train_seqc_f1)
    print("Train tokc f1-score:", train_tokc_f1)
    print("Valid seqc accuracy:", valid_seqc_acc)
    print("Valid tokc accuracy:", valid_tokc_acc)
    print("Valid seqc f1-score:", valid_seqc_f1)
    print("Valid tokc f1-score:", valid_tokc_f1)
    if valid_seqc_f1 > best_f1:
        best_f1 = valid_seqc_f1
        torch.save(model, "best_mtl_model.pt")

    history.append(
        {
            "train_seqc_acc": train_seqc_acc, 
            "train_tokc_acc": train_tokc_acc,
            "train_seqc_f1": train_seqc_f1, 
            "train_tokc_f1": train_tokc_f1, 
            "valid_seqc_acc": valid_seqc_acc, 
            "valid_tokc_acc": valid_tokc_acc,
            "valid_seqc_f1": valid_seqc_f1, 
            "valid_tokc_f1": valid_tokc_f1, 
        }
    )

    scheduler.step(valid_avg_loss)
    break


HBox(children=(HTML(value='Epoch'), FloatProgress(value=0.0, max=3.0), HTML(value='')))

=== Epoch:  0  /  3  ===
LR: 2e-05
Freeze base model


HBox(children=(HTML(value='Training step'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20…

HBox(children=(HTML(value='Validation step'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='…

Train avg loss tensor(0.0380, device='cuda:0')
Valid avg loss tensor(0.1503, device='cuda:0')
Train seqc accuracy: tensor(0.6120, device='cuda:0')
Train tokc accuracy: tensor(0.0056, device='cuda:0')
Train seqc f1-score: 0.2512562814070351
Train tokc f1-score: 0.0005162976379357603
Valid seqc accuracy: tensor(0.5729, device='cuda:0')
Valid tokc accuracy: tensor(0.0064, device='cuda:0')
Valid seqc f1-score: 0.0
Valid tokc f1-score: 0.0006589554572518321


In [159]:
train_seqc_labels

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0.,
        1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
        0., 1., 1., 1., 1., 1., 0., 1., 

In [160]:
train_seqc_preds

tensor([0.3689, 0.1840, 0.5953, 0.5562, 0.3302, 0.3291, 0.3773, 0.2022, 0.4021,
        0.6044, 0.7476, 0.4082, 0.5453, 0.4340, 0.7084, 0.6595, 0.6811, 0.7641,
        0.7952, 0.7075, 0.8425, 0.5931, 0.8157, 0.7328, 0.6031, 0.8657, 0.7726,
        0.8644, 0.8393, 0.7330, 0.7311, 0.8303, 0.7542, 0.7391, 0.8119, 0.8245,
        0.8378, 0.8597, 0.7686, 0.8109, 0.8461, 0.9022, 0.9361, 0.8982, 0.7767,
        0.9475, 0.8419, 0.8587, 0.8516, 0.8036, 0.9068, 0.7920, 0.8010, 0.7871,
        0.9416, 0.8206, 0.7776, 0.8981, 0.8875, 0.8977, 0.8617, 0.9382, 0.9478,
        0.9169, 0.8798, 0.9268, 0.9399, 0.8918, 0.9196, 0.9167, 0.9597, 0.9372,
        0.8534, 0.8761, 0.9129, 0.8624, 0.9420, 0.8994, 0.8853, 0.9656, 0.9527,
        0.9619, 0.9463, 0.9455, 0.9462, 0.8733, 0.8719, 0.8756, 0.8060, 0.9374,
        0.9641, 0.9117, 0.9266, 0.8690, 0.8798, 0.8781, 0.9080, 0.9262, 0.9679,
        0.9401, 0.9263, 0.9388, 0.9554, 0.9245, 0.8737, 0.9601, 0.9595, 0.9547,
        0.9087, 0.9168, 0.9068, 0.9027, 

In [144]:
from pytorch_lightning.core.lightning import LightningModule
from torchmetrics import Accuracy, F1
from torch.optim import lr_scheduler

class MultiTaskLearningModel(LightningModule):
    def __init__(
        self, base_model = None, dropout_rate: float = 0.1, 
        hidden_size: int = 768, 
        num_tokens_labels: int = 5
    ):
        super(MultiTaskLearningModel, self).__init__()
        if base_model is None:
            self.base_model = BertModel.from_pretrained("bert-base-cased")
        else:
            self.base_model = base_model

        self.num_tokens_labels = num_tokens_labels

        # We could avoid sigmoid here and use the
        # BCEWithLogitsLoss, which computes both the sigmoid
        # and the BCE with a trick for numerical stability.
        self.seq_clf = nn.Sequential(
            nn.Dropout(p=dropout_rate),
            nn.Linear(in_features=768, out_features=hidden_size),
            nn.ReLU(),
            nn.Linear(in_features=hidden_size, out_features=1),
            nn.Sigmoid()
        )

        self.tokens_clf = nn.Sequential(
            nn.Dropout(p=dropout_rate),
            nn.Linear(hidden_size, num_tokens_labels)
            # nn.LogSoftmax(dim=1)
        )

        self.multi_loss = MultiTaskLoss(2, num_tokens_labels)
        self.seqc_accuracy = Accuracy()
        self.tokc_accuracy = Accuracy(
            ignore_index=(num_tokens_labels - 1)
        )
        self.seqc_f1 = F1()
        self.tokc_f1 = F1(
            average="macro", 
            num_classes=num_tokens_labels, 
            ignore_index=(num_tokens_labels - 1)
        )

        for param in self.seq_clf.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

        for param in self.tokens_clf.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)
        
    def forward(self, input_ids, attention_mask):
        output = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        last_hidden_state = output.last_hidden_state

        # TODO Since we are using mean, maybe it's better not to pad to max?
        seq_clf_out = self.seq_clf(torch.mean(last_hidden_state, dim=1))
        tokens_clf_out = self.tokens_clf(last_hidden_state)

        return seq_clf_out, tokens_clf_out

    def configure_optimizers(self):
        optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, patience=2)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "monitor": "val_loss"
            }
        }

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, tokens_labels, labels = batch
        seqc_out, tokc_out = self(input_ids, attention_mask)
        loss = self.multi_loss(seqc_out, tokc_out, labels, tokens_labels, attention_mask)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, tokens_labels, labels = batch
        seqc_out, tokc_out = self(input_ids, attention_mask)
        loss = self.multi_loss(seqc_out, tokc_out, labels, tokens_labels, attention_mask)

        tokens_labels = torch.where(
            tokens_labels == -100,
            self.num_tokens_labels - 1, 
            tokens_labels
        )

        torch.save(tokc_out, "tokc_out.pkl")
        torch.save(tokens_labels, "tokens_labels.pkl")

        tokc_preds = torch.argmax(tokc_out, -1)

        #seqc_acc = self.seqc_accuracy(seqc_out, labels.int())
        tokc_acc = self.tokc_accuracy(tokc_preds, tokens_labels)
        #seqc_f1 = self.seqc_f1(seqc_out, labels.int())
        tokc_f1 = self.tokc_f1(tokc_preds.view(-1), tokens_labels.view(-1))

        self.log("val_loss", loss, prog_bar=True)
        #self.log("val_seqc_acc", seqc_acc, prog_bar=True, on_step=True, on_epoch=True)
        self.log("val_tokc_acc", tokc_acc, prog_bar=True, on_step=True, on_epoch=True)
        #self.log("val_seqc_f1", seqc_f1, prog_bar=True, on_step=True, on_epoch=True)
        self.log("val_tokc_f1", tokc_f1, prog_bar=True, on_step=True, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        input_ids, attention_mask, tokens_labels, labels = batch
        seqc_out, tokc_out = self(input_ids, attention_mask)
        loss = self.multi_loss(seqc_out, tokc_out, labels, tokens_labels, attention_mask)
        return loss

    def freeze_base(self):
        for param in self.base_model.named_parameters():
            param[1].requires_grad=False

    def unfreeze_base(self):
        for param in self.base_model.named_parameters():
            param[1].requires_grad=True

In [85]:
from pytorch_lightning import LightningDataModule

class WikiDataModule(LightningDataModule):

    def __init__(
        self, data_dir: str = "datasets/historical_events",
        batch_size: int = 128, num_workers: int = 0, 
        num_tokens_labels: int = 5
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.dims = 512
        self.num_tokens_labels = num_tokens_labels

        self.train_dict, self.valid_dict, self.test_dict = {}, {}, {}

    def data_load_util(self, data_split: str):
        return {
            "input_ids": torch.load(f"{self.data_dir}/{data_split}/input_ids.pkl"),
            "attention_mask": torch.load(f"{self.data_dir}/{data_split}/attention_mask.pkl"),
            "tokens_labels": torch.load(f"{self.data_dir}/{data_split}/tokens_labels.pkl"),
            "labels": torch.load(f"{self.data_dir}/{data_split}/labels.pkl"),
            "tag2idx": torch.load(f"{self.data_dir}/{data_split}/tag2idx.pkl"),
            "idx2tag": torch.load(f"{self.data_dir}/{data_split}/idx2tag.pkl")   
        }

    def prepare_data(self):
        self.train_dict = self.data_load_util("train")
        self.valid_dict = self.data_load_util("valid")
        self.test_dict = self.data_load_util("test")

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            self.wiki_train = TensorDataset(
                torch.tensor(self.train_dict["input_ids"]), 
                torch.tensor(self.train_dict["attention_mask"]),
                torch.tensor(self.train_dict["tokens_labels"]),
                torch.tensor(self.train_dict["labels"], dtype=torch.float32)
            )
            self.wiki_valid = TensorDataset(
                torch.tensor(self.valid_dict["input_ids"]), 
                torch.tensor(self.valid_dict["attention_mask"]),
                torch.tensor(self.valid_dict["tokens_labels"]),
                torch.tensor(self.valid_dict["labels"], dtype=torch.float32)
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.wiki_test = TensorDataset(
                torch.tensor(self.test_dict["input_ids"]), 
                torch.tensor(self.test_dict["attention_mask"]),
                torch.tensor(self.test_dict["tokens_labels"]),
                torch.tensor(self.test_dict["labels"], dtype=torch.float32)
            )

    def train_dataloader(self):
        return DataLoader(
            self.wiki_train, batch_size=self.batch_size, 
            num_workers=self.num_workers
        )

    def val_dataloader(self):
        return DataLoader(
            self.wiki_valid, batch_size=self.batch_size,
            num_workers=self.num_workers
        )

    def test_dataloader(self):
        return DataLoader(
            self.wiki_test, batch_size=self.batch_size,
            num_workers=self.num_workers
        )

In [130]:
from torchmetrics import Accuracy, F1
target = torch.tensor([0, 1, 2, 1])
preds = torch.tensor([
    [0.8, 0.1, 0.1], 
    [0.1, 0.8, 0.1],
    [0.1, 0.1, 0.8],
    [0.1, 0.8, 0.1]
])
accuracy = Accuracy()
f1 = F1()
accuracy(torch.argmax(preds, -1), target), f1(preds, target)

(tensor(1.), tensor(1.))

In [51]:
from torchmetrics import Accuracy, F1
target = torch.tensor([0, 1, 2, -100])
target = map(lambda x: torch.tensor(3) if x == -100 else x, target)
target = torch.tensor(list(target))
preds = torch.tensor([
    [0.8, 0.1, 0.1, 0], 
    [0.1, 0.8, 0.1, 0],
    [0.1, 0.1, 0.8, 0],
    [0.1, 0.8, 0.1, 0.1]
])
accuracy = Accuracy(ignore_index=3)
f1 = F1(average="macro", num_classes=4, ignore_index=3)
accuracy(preds, target), f1(preds, target)

(tensor(1.), tensor(0.8889))

In [54]:
from torchmetrics import Accuracy, F1
target = torch.tensor([0, 1, 1, 0])
target = map(lambda x: torch.tensor(3) if x == -100 else x, target)
target = torch.tensor(list(target))
preds = torch.tensor([0.1, 0.8, 0.8, 0.1])
accuracy = Accuracy()
f1 = F1()
accuracy(preds, target), f1(preds, target)

(tensor(1.), tensor(1.))

In [110]:
b = torch.tensor([
    [0, 1, 2],
    [0, 2, -100],
    [-100, 2, 0]
])
b

tensor([[   0,    1,    2],
        [   0,    2, -100],
        [-100,    2,    0]])

In [104]:
c = torch.empty(b.shape).fill_(3)
c

tensor([[3., 3., 3.],
        [3., 3., 3.],
        [3., 3., 3.]])

In [112]:
torch.where(b == -100, 3, b)

tensor([[0, 1, 2],
        [0, 2, 3],
        [3, 2, 0]])

In [15]:
valid_idx2tag = torch.load("datasets/historical_events/valid/idx2tag.pkl")
train_idx2tag = torch.load("datasets/historical_events/train/idx2tag.pkl")
test_idx2tag = torch.load("datasets/historical_events/test/idx2tag.pkl")

In [16]:
test_idx2tag, valid_idx2tag, train_idx2tag

({0: 'B-hist', 1: 'B-not-hist', 2: 'I-hist', 3: 'I-not-hist', -100: 'MASK'},
 {0: 'B-hist', 1: 'B-not-hist', 2: 'I-hist', 3: 'I-not-hist', -100: 'MASK'},
 {0: 'B-hist', 1: 'B-not-hist', 2: 'I-hist', 3: 'I-not-hist', -100: 'MASK'})

In [78]:
a = torch.tensor(0.1)
a

tensor(0.1000)

In [83]:
a = a.int()
a

tensor(0, dtype=torch.int32)

In [145]:
from pytorch_lightning import Trainer

dm = WikiDataModule()
model = MultiTaskLearningModel()
trainer = Trainer()
trainer.fit(model, dm)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(





  | Name          | Type          | Params
------------------------------------------------
0 | base_model    | BertModel     | 108 M 
1 | seq_clf       | Sequential    | 591 K 
2 | tokens_clf    | Sequential    | 3.8 K 
3 | multi_loss    | MultiTaskLoss | 2     
4 | seqc_accuracy | Accuracy      | 0     
5 | tokc_accuracy | Accuracy      | 0     
6 | seqc_f1       | F1            | 0     
7 | tokc_f1       | F1            | 0     
------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.622   Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


tokc_shape torch.Size([128, 512, 5]) tensor([[[-0.9654, -0.8255, -0.1501,  0.4268, -0.3652],
         [-1.2315, -0.2665,  0.0677,  0.5974,  0.2710],
         [ 0.6101, -0.0409,  0.4318,  1.9059,  0.6774],
         ...,
         [ 0.2781,  0.2673, -0.7821,  0.3398,  0.0379],
         [ 0.1401,  0.2081, -0.8498,  0.2469, -0.0970],
         [ 0.1687,  0.3741, -0.8724,  0.2815, -0.3217]],

        [[-0.9416, -0.9621, -0.1640,  0.7597, -0.4565],
         [-1.3807, -0.4784,  0.1636,  0.4127, -0.0963],
         [-0.1872, -0.4662,  0.2936,  0.5350,  0.2583],
         ...,
         [ 0.1968,  0.3323,  0.2181,  0.6491, -0.2826],
         [-0.2214,  0.5467, -0.1261,  0.4248, -0.4629],
         [-0.0019,  0.4358, -0.0523,  0.5346, -0.3377]],

        [[-1.0146, -0.5266, -0.5399,  0.3677, -0.3474],
         [-1.0747, -0.6257, -0.4775,  0.3459, -0.2881],
         [-0.4422,  0.4452,  0.1168,  0.3480, -0.3033],
         ...,
         [-0.0811,  0.6340, -0.1366,  0.5074,  0.2369],
         [-0.0967,  0

In [133]:
from torchmetrics import Accuracy, F1

tokc_out = torch.load("tokc_out.pkl")
tokens_labels = torch.load("tokens_labels.pkl")

In [127]:
tokc_out.shape, tokens_labels.shape

(torch.Size([128, 512, 5]), torch.Size([128, 512]))

In [140]:
torch.argmax(tokc_out, -1).shape, tokens_labels.shape

(torch.Size([128, 512]), torch.Size([128, 512]))

In [141]:
acc = Accuracy()
f1 = F1()
f1(torch.argmax(tokc_out, -1).view(-1), tokens_labels.view(-1))

tensor(0.5858)

In [143]:
targets = torch.tensor([
    [1, 2, 0],
    [0, 1, 1],
    [0, 0, 2]   
]
)

preds = torch.tensor([
    [1, 2, 0],
    [0, 1, 1],
    [0, 0, 2]   
])

f1(preds.view(-1), targets.view(-1))

tensor(1.)