In [18]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
DATASET_TYPE = "f1"

In [20]:
import numpy as np
VOCAB, EMBEDDING_MATRIX = np.load("data/vocab_glove_100d.npy", allow_pickle=True).item(), np.load("data/embedding_matrix_glove_100d.npy", allow_pickle=True)

In [21]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch import nn
import pytorch_lightning as pl
import torchmetrics
import torch.nn.functional as F

from relation_modeling_utils import MaxPool, AvgPool, Evaluator

class SWEMClassifier(pl.LightningModule, Evaluator):
    def __init__(self, num_classes=3, pooling="avg", freeze_emb=False, learning_rate=1e-4):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=EMBEDDING_MATRIX.shape[0],
                                      embedding_dim=EMBEDDING_MATRIX.shape[1]).from_pretrained(torch.tensor(EMBEDDING_MATRIX, dtype=torch.float32), freeze=freeze_emb)
        self.pool = MaxPool() if pooling == "max" else AvgPool()
        self.linear = nn.Linear(EMBEDDING_MATRIX.shape[1], num_classes)
        self.model = nn.Sequential(self.embedding, self.pool, self.linear)
        self.criterion = nn.BCEWithLogitsLoss()
        self.learning_rate = learning_rate
        self.save_hyperparameters()
    
    def forward(self, X):
        outputs = self.model(X)
        probs = F.sigmoid(outputs)
        return probs
    
    def training_step(self, batch, batch_idx):
        X, y = batch
        outputs = self.model(X)
        train_loss = self.criterion(outputs, y.float())
        preds = self.forward(X)
        self.log("train_loss", train_loss, on_epoch=True)
        self.log_metrics(preds, y, type="train")
        return train_loss
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
        outputs = self.model(X)
        val_loss = self.criterion(outputs, y.float())
        preds = self.forward(X)
        self.log("val_loss", val_loss, on_epoch=True)
        self.log_metrics(preds, y, type="val")
        return val_loss

    def test_step(self, batch, batch_idx):
        X, y = batch
        outputs = self.model(X)
        test_loss = self.criterion(outputs, y.float())
        preds = self.forward(X)
        self.log("test_loss", test_loss, on_epoch=True)
        self.log_metrics(preds, y, type="test")
        return test_loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [22]:
from relation_modeling_utils import SWEMHeadDataset, load_fdata, load_data
import pandas as pd

train_df = load_fdata(f"data/atomic_ood/{DATASET_TYPE}/train_{DATASET_TYPE}.csv")
val_df = load_data("data/atomic2020_data-feb2021/dev.tsv", multi_label=True)
test_df = load_fdata(f"data/atomic_ood/{DATASET_TYPE}/test_{DATASET_TYPE}.csv")
train_data = SWEMHeadDataset(train_df, vocab=VOCAB)
val_data = SWEMHeadDataset(val_df, vocab=VOCAB)
test_data = SWEMHeadDataset(test_df, vocab=VOCAB)

In [24]:
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=128)
test_dataloader = DataLoader(test_data, batch_size=128)

In [25]:
from pytorch_lightning.loggers import WandbLogger
from relation_modeling_utils import get_timestamp
import wandb

NUM_EPOCHS = 20
LR_RATE = 1e-4

timestamp = get_timestamp()

wandb_logger = WandbLogger(project="kogito-relation-matcher", name=f"swem_finetune_{DATASET_TYPE}")
wandb_logger.experiment.config["epochs"] = NUM_EPOCHS
model = SWEMClassifier(pooling="avg", freeze_emb=False, learning_rate=LR_RATE)
trainer = pl.Trainer(max_epochs=NUM_EPOCHS, logger=wandb_logger, accelerator="gpu", devices=[0])
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
trainer.test(model, dataloaders=test_dataloader)
trainer.save_checkpoint(f"models/swem/swem_finetune_{DATASET_TYPE}_{timestamp}.ckpt", weights_only=True)
wandb.finish()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name            | Type              | Params
-------------------------------------------------------
0  | embedding       | Embedding         | 40.0 M
1  | pool            | AvgPool           | 0     
2  | linear          | Linear            | 303   
3  | model           | Sequential        | 40.0 M
4  | criterion       | BCEWithLogitsLoss | 0     
5  | train_accuracy  | Accuracy          | 0     
6  | val_accuracy    | Accuracy          | 0     
7  | train_precision | Precision         | 0     
8  | val_precision   | Precision         | 0     
9  | train_recall    | Recall            | 0     
10 | val_recall      | Recall            | 0     
11 | train_f1        | F1Score           | 0     
12 | val_f1          | F1Score           | 0     
13 | test_accuracy   | Accuracy          | 0     
14 | test_precision  | Precision        

Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_accuracy': 0.8215733766555786,
 'test_f1': 0.781292200088501,
 'test_loss': 0.43175986409187317,
 'test_precision': 0.763730525970459,
 'test_recall': 0.8030611872673035}
--------------------------------------------------------------------------------



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
test_accuracy,▁
test_f1,▁
test_loss,▁
test_precision,▁
test_recall,▁
train_accuracy_epoch,▁▂▅▇▇▇▇▇████████████
train_accuracy_step,▁▃▄▃▄▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇▇█▇▇██▇███▇████
train_f1_epoch,▁▁▄▇▇▇▇▇████████████
train_f1_step,▁▂▂▂▂▅▇▇▇▇▇▇▇▇▆▇▆▇▇▇▇▇██▇▇█▇▇██▇███▇███▇

0,1
epoch,19.0
test_accuracy,0.82157
test_f1,0.78129
test_loss,0.43176
test_precision,0.76373
test_recall,0.80306
train_accuracy_epoch,0.92632
train_accuracy_step,0.91667
train_f1_epoch,0.92844
train_f1_step,0.9143


In [8]:
from relation_modeling_utils import load_data, HeadDataset
from torch.utils.data import DataLoader

test_df = load_data("data/atomic2020_data-feb2021/test.tsv", multi_label=True)
test_data = HeadDataset(test_df, vocab=VOCAB)
test_dataloader = DataLoader(test_data, batch_size=len(test_data))

In [9]:
test_df.head()

Unnamed: 0,text,label
0,PersonX abuses PersonX's power,"[0, 1, 1]"
1,PersonX accepts PersonY's apology,"[0, 1, 1]"
2,PersonX accepts ___ in payment,"[0, 1, 1]"
3,PersonX accidentally kicked,"[0, 1, 1]"
4,PersonX accidentally kicked ___,"[0, 1, 1]"


In [10]:
len(test_data.texts), len(test_df)

(6569, 6569)

In [12]:
len(test_df[test_df.label.apply(lambda l: l[0]) == 0])

4668

In [13]:
import torch
model = torch.load('models/swem_multi_label_finetune_model.bin')

In [6]:
torch.save(model.state_dict(), "models/swem_multi_label_finetune_state_dict.pth")

In [14]:
X, y = next(iter(test_dataloader))
preds = model.forward(X)



In [15]:
import torchmetrics
test_accuracy = torchmetrics.Accuracy()
test_precision = torchmetrics.Precision(num_classes=3, average="weighted")
test_recall = torchmetrics.Recall(num_classes=3, average="weighted")
test_f1 = torchmetrics.F1Score(num_classes=3, average="weighted")
print(f'Test accurayc={test_accuracy(preds, y).item():.3f}, precision={test_precision(preds, y).item():.3f}, recall={test_recall(preds, y).item():.3f}, f1={test_f1(preds, y).item():.3f}')

Test accurayc=0.860, precision=0.829, recall=0.961, f1=0.878


In [16]:
test_confusion = torchmetrics.ConfusionMatrix(num_classes=3, multilabel=True)
confusion_matrix = test_confusion(preds, y)
confusion_matrix

tensor([[[4564,  104],
         [  27, 1874]],

        [[1896,  254],
         [ 230, 4189]],

        [[1935, 2061],
         [  87, 2486]]])

In [17]:
import pandas as pd
pred_df = pd.DataFrame({'texts': test_data.texts, 'labels': test_data.labels.tolist(), 'probs': preds.detach().tolist()})
pred_df['preds'] = pred_df.probs.apply(lambda p: (np.array(p) >= 0.5).astype(int).tolist())

In [18]:
pred_df.head()

Unnamed: 0,texts,labels,probs,preds
0,PersonX abuses PersonX's power,"[0, 1, 1]","[0.007279202342033386, 0.8754659295082092, 0.9...","[0, 1, 1]"
1,PersonX accepts PersonY's apology,"[0, 1, 1]","[0.004542194306850433, 0.9169185161590576, 0.9...","[0, 1, 1]"
2,PersonX accepts ___ in payment,"[0, 1, 1]","[0.006188894622027874, 0.6897541880607605, 0.9...","[0, 1, 1]"
3,PersonX accidentally kicked,"[0, 1, 1]","[0.3949142396450043, 0.4427388310432434, 0.498...","[0, 0, 0]"
4,PersonX accidentally kicked ___,"[0, 1, 1]","[0.01818978600203991, 0.6005459427833557, 0.95...","[0, 1, 1]"


In [19]:
pred_df['matches'] = pred_df.apply(lambda row: (np.array(row.labels) * np.array(row.preds)).sum().tolist(), axis=1)

In [23]:
pred_df['label_0'] = pred_df.labels.apply(lambda l: l[0])
pred_df['label_1'] = pred_df.labels.apply(lambda l: l[1])
pred_df['label_2'] = pred_df.labels.apply(lambda l: l[2])
pred_df['pred_0'] = pred_df.preds.apply(lambda p: p[0])
pred_df['pred_1'] = pred_df.preds.apply(lambda p: p[1])
pred_df['pred_2'] = pred_df.preds.apply(lambda p: p[2])

In [24]:
pred_df.head()

Unnamed: 0,texts,labels,probs,preds,matches,label_0,label_1,label_2,pred_0,pred_1,pred_2
0,PersonX abuses PersonX's power,"[0, 1, 1]","[0.007279202342033386, 0.8754659295082092, 0.9...","[0, 1, 1]",2,0,1,1,0,1,1
1,PersonX accepts PersonY's apology,"[0, 1, 1]","[0.004542194306850433, 0.9169185161590576, 0.9...","[0, 1, 1]",2,0,1,1,0,1,1
2,PersonX accepts ___ in payment,"[0, 1, 1]","[0.006188894622027874, 0.6897541880607605, 0.9...","[0, 1, 1]",2,0,1,1,0,1,1
3,PersonX accidentally kicked,"[0, 1, 1]","[0.3949142396450043, 0.4427388310432434, 0.498...","[0, 0, 0]",0,0,1,1,0,0,0
4,PersonX accidentally kicked ___,"[0, 1, 1]","[0.01818978600203991, 0.6005459427833557, 0.95...","[0, 1, 1]",2,0,1,1,0,1,1


In [27]:
# Percentage of confusions between class 2 and class 3
len(pred_df[(pred_df.label_2 == 0) & (pred_df.pred_2 == 1) & (pred_df.label_1 == 1)]) / len(pred_df[(pred_df.label_2 == 0) & (pred_df.pred_2 == 1)])

0.9936923823386705

In [28]:
# Percentage of confusions between class 1 and class 3
len(pred_df[(pred_df.label_2 == 0) & (pred_df.pred_2 == 1) & (pred_df.label_0 == 1)]) / len(pred_df[(pred_df.label_2 == 0) & (pred_df.pred_2 == 1)])

0.006307617661329452

In [29]:
# Percentage of confusions between class 1 and class 2
len(pred_df[(pred_df.label_1 == 0) & (pred_df.pred_1 == 1) & (pred_df.label_0 == 1)]) / len(pred_df[(pred_df.label_1 == 0) & (pred_df.pred_1 == 1)])

0.01968503937007874

In [None]:
# Percentage of confusions between class 1 and class 2
len(pred_df[(pred_df.label_1 == 1) & (pred_df.pred_1 == 0) & (pred_df.label_0 == 0)]) / len(pred_df[(pred_df.label_1 == 0) & (pred_df.pred_1 == 1)])

In [34]:
# Percentage of confusions between classes

for class_x in [0, 1, 2]:
    for class_y in [0, 1, 2]:
        for label_x in [0, 1]:
            for label_y in [0, 1]:
                for pred_x in [0, 1]:
                    for pred_y in [0, 1]:
                        if class_x != class_y and label_x != pred_x and label_y != pred_y:
                            label = ['x', 'x', 'x']
                            pred = ['x', 'x', 'x']
                            label[class_x] = label_x
                            label[class_y] = label_y
                            pred[class_x] = pred_x
                            pred[class_y] = pred_y
                            sub = pred_df[(pred_df[f'label_{class_x}'] == label_x) & (pred_df[f'label_{class_y}'] == label_y) & (pred_df[f'pred_{class_x}'] == pred_x) & (pred_df[f'pred_{class_y}'] == pred_y)]
                            whole = pred_df[(pred_df[f'label_{class_x}'] == label_x) & (pred_df[f'label_{class_y}'] == label_y)]
                            print(f"label={label}, pred={pred}, percentage={len(sub)/len(whole)}")

Label=[0, 0, 'x'], pred=[0, 0, 'x'], percentage=0.03597122302158273
Label=[0, 0, 'x'], pred=[0, 1, 'x'], percentage=0.89568345323741
Label=[0, 0, 'x'], pred=[1, 0, 'x'], percentage=0.0683453237410072
Label=[0, 0, 'x'], pred=[1, 1, 'x'], percentage=0.0
Label=[0, 1, 'x'], pred=[0, 0, 'x'], percentage=0.02642369020501139
Label=[0, 1, 'x'], pred=[0, 1, 'x'], percentage=0.9542141230068337
Label=[0, 1, 'x'], pred=[1, 0, 'x'], percentage=0.0193621867881549
Label=[0, 1, 'x'], pred=[1, 1, 'x'], percentage=0.0
Label=[1, 0, 'x'], pred=[0, 0, 'x'], percentage=0.011752136752136752
Label=[1, 0, 'x'], pred=[0, 1, 'x'], percentage=0.002670940170940171
Label=[1, 0, 'x'], pred=[1, 0, 'x'], percentage=0.9855769230769231
Label=[1, 0, 'x'], pred=[1, 1, 'x'], percentage=0.0
Label=[1, 1, 'x'], pred=[0, 0, 'x'], percentage=0.0
Label=[1, 1, 'x'], pred=[0, 1, 'x'], percentage=0.0
Label=[1, 1, 'x'], pred=[1, 0, 'x'], percentage=1.0
Label=[1, 1, 'x'], pred=[1, 1, 'x'], percentage=0.0
Label=[0, 'x', 0], pred=[0, '

In [30]:
len(pred_df[(pred_df.label_1 == 1) & (pred_df.label_2 == 1)]) / len(pred_df)

0.34830263358197594

In [31]:
len(pred_df[(pred_df.label_0 == 1) & (pred_df.label_2 == 1)]) / len(pred_df)

0.0021312224082813214

In [32]:
len(pred_df[(pred_df.label_0 == 1) & (pred_df.label_1 == 1)]) / len(pred_df)

0.004414674988582737

In [33]:
pred_df[(pred_df.label_0 == 0) & (pred_df.label_1 == 1) & (pred_df.label_2 == 0)]

Unnamed: 0,texts,labels,probs,preds,matches,label_0,label_1,label_2,pred_0,pred_1,pred_2
3200,airplane engine,"[0, 1, 0]","[0.9113984704017639, 0.10558881610631943, 0.06...","[1, 0, 0]",0,0,1,0,1,0,0
3201,alcoholism,"[0, 1, 0]","[0.9269378781318665, 0.10149876028299332, 0.06...","[1, 0, 0]",0,0,1,0,1,0,0
3202,argument,"[0, 1, 0]","[0.8820464015007019, 0.1297539323568344, 0.100...","[1, 0, 0]",0,0,1,0,1,0,0
3203,cancerous tumor,"[0, 1, 0]","[0.941831111907959, 0.07380104064941406, 0.046...","[1, 0, 0]",0,0,1,0,1,0,0
3204,car chase,"[0, 1, 0]","[0.8344407677650452, 0.15622617304325104, 0.12...","[1, 0, 0]",0,0,1,0,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...
6503,PersonX provides PersonY basis,"[0, 1, 0]","[0.040776558220386505, 0.7636806964874268, 0.9...","[0, 1, 1]",1,0,1,0,0,1,1
6504,PersonX makes PersonY changes,"[0, 1, 0]","[0.007372543681412935, 0.8869695067405701, 0.9...","[0, 1, 1]",1,0,1,0,0,1,1
6505,PersonX gets PersonY idea,"[0, 1, 0]","[0.011564677581191063, 0.8723880052566528, 0.9...","[0, 1, 1]",1,0,1,0,0,1,1
6506,PersonX uses PersonX's eyes,"[0, 1, 0]","[0.0021263044327497482, 0.9422444105148315, 0....","[0, 1, 1]",1,0,1,0,0,1,1


In [15]:
# Number of total mistakes
len(pred_df[pred_df['matches'] < 1])

201

In [16]:
pred_df[pred_df['matches'] < 1].head()

Unnamed: 0,texts,labels,probs,preds,matches
3,PersonX accidentally kicked,"[0, 1, 1]","[0.3949142396450043, 0.4427388310432434, 0.498...","[0, 0, 0]",0
5,PersonX accidentally poured,"[0, 1, 1]","[0.40726250410079956, 0.43522584438323975, 0.4...","[0, 0, 0]",0
126,PersonX blows bubbles,"[0, 1, 1]","[0.4730721414089203, 0.4032168984413147, 0.456...","[0, 0, 0]",0
419,PersonX donates plasma,"[0, 1, 1]","[0.5411490201950073, 0.3266826868057251, 0.371...","[1, 0, 0]",0
953,PersonX hats cats,"[0, 1, 1]","[0.42976856231689453, 0.3377617299556732, 0.45...","[0, 0, 0]",0


In [19]:
# Number of no predictions
len(pred_df[pred_df['preds'].apply(lambda p: np.sum(p)) == 0])

44

In [27]:
# Percentage of cases with 2 labels but less than 2 predictions
len(pred_df[(pred_df.labels.apply(lambda l: np.sum(l)) == 2) & (pred_df.matches < 2)]) / len(pred_df[pred_df.labels.apply(lambda l: np.sum(l)) == 2])

0.05584415584415584

In [28]:
# Percentage of cases with 2 labels but 1 prediction
len(pred_df[(pred_df.labels.apply(lambda l: np.sum(l)) == 2) & (pred_df.matches == 1)]) / len(pred_df[pred_df.labels.apply(lambda l: np.sum(l)) == 2])

0.03506493506493506

In [26]:
len(pred_df[pred_df.labels.apply(lambda l: np.sum(l)) == 3])

7