In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
VOCAB, EMBEDDING_MATRIX = np.load("data/vocab_glove_100d.npy", allow_pickle=True).item(), np.load("data/embedding_matrix_glove_100d.npy", allow_pickle=True)

In [4]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch import nn
import pytorch_lightning as pl
import torchmetrics
import torch.nn.functional as F

from relation_modeling_utils import MaxPool, AvgPool

class SWEMMultiClassifier(pl.LightningModule):
    def __init__(self, num_classes=3, pooling="max", freeze_emb=True, learning_rate=1e-3):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=EMBEDDING_MATRIX.shape[0],
                                      embedding_dim=EMBEDDING_MATRIX.shape[1]).from_pretrained(torch.tensor(EMBEDDING_MATRIX, dtype=torch.float32), freeze=freeze_emb)
        self.pool = MaxPool() if pooling == "max" else AvgPool()
        self.linear = nn.Linear(EMBEDDING_MATRIX.shape[1], num_classes)
        self.model = nn.Sequential(self.embedding, self.pool, self.linear)
        self.criterion = nn.BCEWithLogitsLoss()
        self.learning_rate = learning_rate
        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy = torchmetrics.Accuracy()
        self.train_precision = torchmetrics.Precision(num_classes=3, average='weighted')
        self.val_precision = torchmetrics.Precision(num_classes=3, average='weighted')
        self.train_recall = torchmetrics.Recall(num_classes=3, average='weighted')
        self.val_recall = torchmetrics.Recall(num_classes=3, average='weighted')
        self.train_f1 = torchmetrics.F1Score(num_classes=3, average='weighted')
        self.val_f1 = torchmetrics.F1Score(num_classes=3, average='weighted')
        self.save_hyperparameters()
    
    def forward(self, X):
        outputs = self.model(X)
        probs = F.sigmoid(outputs)
        return probs
    
    def training_step(self, batch, batch_idx):
        X, y = batch
        outputs = self.model(X)
        train_loss = self.criterion(outputs, y.float())
        preds = self.forward(X)
        self.train_accuracy(preds, y)
        self.train_precision(preds, y)
        self.train_recall(preds, y)
        self.train_f1(preds, y)
        self.log("train_loss", train_loss, on_epoch=True)
        self.log('train_accuracy', self.train_accuracy, on_epoch=True)
        self.log('train_precision', self.train_precision, on_epoch=True)
        self.log('train_recall', self.train_recall, on_epoch=True)
        self.log('train_f1', self.train_f1, on_epoch=True)
        return train_loss
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
        outputs = self.model(X)
        val_loss = self.criterion(outputs, y.float())
        preds = self.forward(X)
        self.val_accuracy(preds, y)
        self.val_precision(preds, y)
        self.val_recall(preds, y)
        self.val_f1(preds, y)
        self.log("val_loss", val_loss, on_epoch=True)
        self.log('val_accuracy', self.val_accuracy, on_epoch=True)
        self.log('val_precision', self.val_precision, on_epoch=True)
        self.log('val_recall', self.val_recall, on_epoch=True)
        self.log('val_f1', self.val_f1, on_epoch=True)
        return val_loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [7]:
# from relation_modeling_utils import load_data, HeadDataset

# train_df = load_data("data/atomic2020_data-feb2021/train.tsv", multi_label=True)
# dev_df = load_data("data/atomic2020_data-feb2021/dev.tsv", multi_label=True)
# train_data = HeadDataset(train_df, vocab=VOCAB)
# val_data = HeadDataset(dev_df, vocab=VOCAB)

In [8]:
# torch.save(train_data, "data/head_train_multi_with_pad.pt")
# torch.save(val_data, "data/head_val_multi_with_pad.pt")

In [7]:
EMBEDDING_MATRIX.shape

(400002, 100)

In [4]:
train_data = torch.load("data/head_train_multi_with_pad.pt")
val_data = torch.load("data/head_val_multi_with_pad.pt")

In [5]:
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=128)

In [7]:
from pytorch_lightning.loggers import WandbLogger
import wandb

wandb_logger = WandbLogger(project="kogito-relation-matcher", name="swem_multi_label_finetune_max")
model = SWEMMultiClassifier(pooling="max", freeze_emb=False, learning_rate=1e-4)
trainer = pl.Trainer(max_epochs=20, logger=wandb_logger, accelerator="gpu", devices=1)
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
wandb.finish()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]



   | Name            | Type              | Params
-------------------------------------------------------
0  | embedding       | Embedding         | 40.0 M
1  | pool            | MaxPool           | 0     
2  | linear          | Linear            | 303   
3  | model           | Sequential        | 40.0 M
4  | criterion       | BCEWithLogitsLoss | 0     
5  | train_accuracy  | Accuracy          | 0     
6  | val_accuracy    | Accuracy          | 0     
7  | train_precision | Precision         | 0     
8  | val_precision   | Precision         | 0     
9  | train_recall    | Recall            | 0     
10 | val_recall      | Recall            | 0     
11 | train_f1        | F1Score           | 0     
12 | val_f1          | F1Score           | 0     
-------------------------------------------------------
40.0 M    Trainable params
0         Non-trainable params
40.0 M    Total params
160.002   Total estimated model params size (MB)


                                                              

  rank_zero_warn(
  rank_zero_warn(


Epoch 19: 100%|██████████| 313/313 [00:09<00:00, 31.38it/s, loss=0.197, v_num=3hsv]



0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
train_accuracy_epoch,▁▂▄▅▆▇▇▇████████████
train_accuracy_step,▁▁▂▃▃▃▄▆▆▆▇▇▇▇▇▇██▇██▇█▇███▇▇████▇███▇█▇
train_f1_epoch,▁▃▄▆▆▇▇█████████████
train_f1_step,▁▂▂▄▄▄▄▇▆▇▇▇██▇▇███████████▇█████▇█████▇
train_loss_epoch,█▇▇▆▅▄▄▃▃▂▂▂▂▂▁▁▁▁▁▁
train_loss_step,██▇▇▇▇▆▅▅▅▄▄▃▃▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▂▁▁▂▁▃
train_precision_epoch,▁▃▄▅▆▇▇▇████████████
train_precision_step,▁▃▃▄▄▄▄▇▆▇▇▇▇▇▇▇██▇███▇██▇▇▇▇███▇▇▇▇█▇█▆
train_recall_epoch,▁▄▅▆▆▇▇▇▇███████████

0,1
epoch,19.0
train_accuracy_epoch,0.92917
train_accuracy_step,0.86979
train_f1_epoch,0.93067
train_f1_step,0.87121
train_loss_epoch,0.20589
train_loss_step,0.31938
train_precision_epoch,0.90333
train_precision_step,0.84809
train_recall_epoch,0.961


In [8]:
from pytorch_lightning.loggers import WandbLogger
import wandb

wandb_logger = WandbLogger(project="kogito-relation-matcher", name="swem_multi_label_finetune")
model = SWEMMultiClassifier(pooling="avg", freeze_emb=False, learning_rate=1e-4)
trainer = pl.Trainer(max_epochs=20, logger=wandb_logger, accelerator="gpu", devices=1)
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
wandb.finish()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]



   | Name            | Type              | Params
-------------------------------------------------------
0  | embedding       | Embedding         | 40.0 M
1  | pool            | AvgPool           | 0     
2  | linear          | Linear            | 303   
3  | model           | Sequential        | 40.0 M
4  | criterion       | BCEWithLogitsLoss | 0     
5  | train_accuracy  | Accuracy          | 0     
6  | val_accuracy    | Accuracy          | 0     
7  | train_precision | Precision         | 0     
8  | val_precision   | Precision         | 0     
9  | train_recall    | Recall            | 0     
10 | val_recall      | Recall            | 0     
11 | train_f1        | F1Score           | 0     
12 | val_f1          | F1Score           | 0     
-------------------------------------------------------
40.0 M    Trainable params
0         Non-trainable params
40.0 M    Total params
160.002   Total estimated model params size (MB)


                                                              

  rank_zero_warn(
  rank_zero_warn(


Epoch 19: 100%|██████████| 313/313 [00:09<00:00, 31.90it/s, loss=0.254, v_num=gvzf]



0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
train_accuracy_epoch,▁▄▆▇▇▇▇▇▇███████████
train_accuracy_step,▁▄▃▄▆▇▆▇▇▇▇▇▇▇█▇▇▇▇▇█▇▇▇▇██▇█▇█▇█▇███▇█▇
train_f1_epoch,▁▁▆▇▇▇▇▇▇▇██████████
train_f1_step,▂▄▁▂▅▇▆▇▆▇▇▇▇▇█▇▇▇▇▇█▇▇▇▇██▇▇▇███▇███▇█▇
train_loss_epoch,█▇▇▆▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁
train_loss_step,███▇▇▆▆▆▅▅▅▄▄▄▃▄▄▃▃▄▂▃▃▂▂▂▂▂▂▂▁▂▁▂▁▂▁▂▁▁
train_precision_epoch,▁▆██████████████████
train_precision_step,▁▃▅▆▇█▆▇▇▇▇▇███▇▇█▇▆█▇▇▇▇██▇█▇███▇█████▇
train_recall_epoch,▂▁▄▅▅▆▆▆▇▇▇▇▇███████

0,1
epoch,19.0
train_accuracy_epoch,0.92056
train_accuracy_step,0.90885
train_f1_epoch,0.92108
train_f1_step,0.91215
train_loss_epoch,0.24995
train_loss_step,0.24224
train_precision_epoch,0.89904
train_precision_step,0.88313
train_recall_epoch,0.94461


In [None]:
torch.save(model, "models/swem_multi_label_finetune_model.bin")

In [8]:
from relation_modeling_utils import load_data, HeadDataset
from torch.utils.data import DataLoader

test_df = load_data("data/atomic2020_data-feb2021/test.tsv", multi_label=True)
test_data = HeadDataset(test_df, vocab=VOCAB)
test_dataloader = DataLoader(test_data, batch_size=len(test_data))

In [9]:
test_df.head()

Unnamed: 0,text,label
0,PersonX abuses PersonX's power,"[0, 1, 1]"
1,PersonX accepts PersonY's apology,"[0, 1, 1]"
2,PersonX accepts ___ in payment,"[0, 1, 1]"
3,PersonX accidentally kicked,"[0, 1, 1]"
4,PersonX accidentally kicked ___,"[0, 1, 1]"


In [10]:
len(test_data.texts), len(test_df)

(6569, 6569)

In [12]:
len(test_df[test_df.label.apply(lambda l: l[0]) == 0])

4668

In [13]:
import torch
model = torch.load('models/swem_multi_label_finetune_model.bin')

In [6]:
torch.save(model.state_dict(), "models/swem_multi_label_finetune_state_dict.pth")

In [14]:
X, y = next(iter(test_dataloader))
preds = model.forward(X)



In [15]:
import torchmetrics
test_accuracy = torchmetrics.Accuracy()
test_precision = torchmetrics.Precision(num_classes=3, average="weighted")
test_recall = torchmetrics.Recall(num_classes=3, average="weighted")
test_f1 = torchmetrics.F1Score(num_classes=3, average="weighted")
print(f'Test accurayc={test_accuracy(preds, y).item():.3f}, precision={test_precision(preds, y).item():.3f}, recall={test_recall(preds, y).item():.3f}, f1={test_f1(preds, y).item():.3f}')

Test accurayc=0.860, precision=0.829, recall=0.961, f1=0.878


In [16]:
test_confusion = torchmetrics.ConfusionMatrix(num_classes=3, multilabel=True)
confusion_matrix = test_confusion(preds, y)
confusion_matrix

tensor([[[4564,  104],
         [  27, 1874]],

        [[1896,  254],
         [ 230, 4189]],

        [[1935, 2061],
         [  87, 2486]]])

In [17]:
import pandas as pd
pred_df = pd.DataFrame({'texts': test_data.texts, 'labels': test_data.labels.tolist(), 'probs': preds.detach().tolist()})
pred_df['preds'] = pred_df.probs.apply(lambda p: (np.array(p) >= 0.5).astype(int).tolist())

In [18]:
pred_df.head()

Unnamed: 0,texts,labels,probs,preds
0,PersonX abuses PersonX's power,"[0, 1, 1]","[0.007279202342033386, 0.8754659295082092, 0.9...","[0, 1, 1]"
1,PersonX accepts PersonY's apology,"[0, 1, 1]","[0.004542194306850433, 0.9169185161590576, 0.9...","[0, 1, 1]"
2,PersonX accepts ___ in payment,"[0, 1, 1]","[0.006188894622027874, 0.6897541880607605, 0.9...","[0, 1, 1]"
3,PersonX accidentally kicked,"[0, 1, 1]","[0.3949142396450043, 0.4427388310432434, 0.498...","[0, 0, 0]"
4,PersonX accidentally kicked ___,"[0, 1, 1]","[0.01818978600203991, 0.6005459427833557, 0.95...","[0, 1, 1]"


In [19]:
pred_df['matches'] = pred_df.apply(lambda row: (np.array(row.labels) * np.array(row.preds)).sum().tolist(), axis=1)

In [23]:
pred_df['label_0'] = pred_df.labels.apply(lambda l: l[0])
pred_df['label_1'] = pred_df.labels.apply(lambda l: l[1])
pred_df['label_2'] = pred_df.labels.apply(lambda l: l[2])
pred_df['pred_0'] = pred_df.preds.apply(lambda p: p[0])
pred_df['pred_1'] = pred_df.preds.apply(lambda p: p[1])
pred_df['pred_2'] = pred_df.preds.apply(lambda p: p[2])

In [24]:
pred_df.head()

Unnamed: 0,texts,labels,probs,preds,matches,label_0,label_1,label_2,pred_0,pred_1,pred_2
0,PersonX abuses PersonX's power,"[0, 1, 1]","[0.007279202342033386, 0.8754659295082092, 0.9...","[0, 1, 1]",2,0,1,1,0,1,1
1,PersonX accepts PersonY's apology,"[0, 1, 1]","[0.004542194306850433, 0.9169185161590576, 0.9...","[0, 1, 1]",2,0,1,1,0,1,1
2,PersonX accepts ___ in payment,"[0, 1, 1]","[0.006188894622027874, 0.6897541880607605, 0.9...","[0, 1, 1]",2,0,1,1,0,1,1
3,PersonX accidentally kicked,"[0, 1, 1]","[0.3949142396450043, 0.4427388310432434, 0.498...","[0, 0, 0]",0,0,1,1,0,0,0
4,PersonX accidentally kicked ___,"[0, 1, 1]","[0.01818978600203991, 0.6005459427833557, 0.95...","[0, 1, 1]",2,0,1,1,0,1,1


In [27]:
# Percentage of confusions between class 2 and class 3
len(pred_df[(pred_df.label_2 == 0) & (pred_df.pred_2 == 1) & (pred_df.label_1 == 1)]) / len(pred_df[(pred_df.label_2 == 0) & (pred_df.pred_2 == 1)])

0.9936923823386705

In [28]:
# Percentage of confusions between class 1 and class 3
len(pred_df[(pred_df.label_2 == 0) & (pred_df.pred_2 == 1) & (pred_df.label_0 == 1)]) / len(pred_df[(pred_df.label_2 == 0) & (pred_df.pred_2 == 1)])

0.006307617661329452

In [29]:
# Percentage of confusions between class 1 and class 2
len(pred_df[(pred_df.label_1 == 0) & (pred_df.pred_1 == 1) & (pred_df.label_0 == 1)]) / len(pred_df[(pred_df.label_1 == 0) & (pred_df.pred_1 == 1)])

0.01968503937007874

In [15]:
# Number of total mistakes
len(pred_df[pred_df['matches'] < 1])

201

In [16]:
pred_df[pred_df['matches'] < 1].head()

Unnamed: 0,texts,labels,probs,preds,matches
3,PersonX accidentally kicked,"[0, 1, 1]","[0.3949142396450043, 0.4427388310432434, 0.498...","[0, 0, 0]",0
5,PersonX accidentally poured,"[0, 1, 1]","[0.40726250410079956, 0.43522584438323975, 0.4...","[0, 0, 0]",0
126,PersonX blows bubbles,"[0, 1, 1]","[0.4730721414089203, 0.4032168984413147, 0.456...","[0, 0, 0]",0
419,PersonX donates plasma,"[0, 1, 1]","[0.5411490201950073, 0.3266826868057251, 0.371...","[1, 0, 0]",0
953,PersonX hats cats,"[0, 1, 1]","[0.42976856231689453, 0.3377617299556732, 0.45...","[0, 0, 0]",0


In [19]:
# Number of no predictions
len(pred_df[pred_df['preds'].apply(lambda p: np.sum(p)) == 0])

44

In [27]:
# Percentage of cases with 2 labels but less than 2 predictions
len(pred_df[(pred_df.labels.apply(lambda l: np.sum(l)) == 2) & (pred_df.matches < 2)]) / len(pred_df[pred_df.labels.apply(lambda l: np.sum(l)) == 2])

0.05584415584415584

In [28]:
# Percentage of cases with 2 labels but 1 prediction
len(pred_df[(pred_df.labels.apply(lambda l: np.sum(l)) == 2) & (pred_df.matches == 1)]) / len(pred_df[pred_df.labels.apply(lambda l: np.sum(l)) == 2])

0.03506493506493506

In [26]:
len(pred_df[pred_df.labels.apply(lambda l: np.sum(l)) == 3])

7