In [None]:
from tqdm import tqdm

def create_emb_matrix(embedding_dim=100):
    glove = pd.read_csv(f'data/glove/glove.6B.{embedding_dim}d.txt', sep=" ", quoting=3, header=None, index_col=0)
    vocab = {'<pad>': 0, '<unk>': 1}
    embeddings = np.zeros((len(glove) + 2, embedding_dim))
    embeddings[0] = np.zeros(embedding_dim)
    embeddings[1] = np.zeros(embedding_dim)

    for index, (key, val) in tqdm(enumerate(glove.T.items())):
        vocab[key] = index + 2
        embeddings[index+2] = val.to_numpy()

    return vocab, embeddings

In [None]:
vocab, emb_matrix = create_emb_matrix()

In [63]:
from kogito.core.relation import PHYSICAL_RELATIONS, SOCIAL_RELATIONS, EVENT_RELATIONS
import pandas as pd
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.optim import Adam
from tqdm import tqdm
from torch import nn
from torchtext.vocab import GloVe
import wandb
import spacy
from torch.nn.utils.rnn import pad_sequence

def load_data(datapath):
    data = []
    head_label_set = set()

    with open(datapath) as f:
        for line in f:
            try:
                head, relation, _ = line.split('\t')

                label = 0 

                if relation in EVENT_RELATIONS:
                    label = 1
                elif relation in SOCIAL_RELATIONS:
                    label = 2

                if (head, label) not in head_label_set:
                    data.append((head, label))
                    head_label_set.add((head, label))
            except:
                pass

    return pd.DataFrame(data, columns=['text', 'label'])
    

class HeadDataset(Dataset):
    def __init__(self, df, vocab):
        nlp = spacy.load("en_core_web_sm")
        self.labels = df['label'].to_numpy()
        self.texts = pad_sequence([torch.tensor([vocab.get(token.text, 1) for token in nlp(text)], dtype=torch.int) for text in df['text']], batch_first=True)

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.texts[idx], self.labels[idx]



class MaxPool(nn.Module):
    def forward(self, X):
        values, _ = torch.max(X, dim=1)
        return values


class AvgPool(nn.Module):
    def forward(self, X):
        return torch.mean(X, dim=1)


class SWEMClassifier(nn.Module):

    def __init__(self, hidden_dim=100, num_classes=3, pooling="max", embedding_matrix=None):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=embedding_matrix.shape[0],
                                      embedding_dim=embedding_matrix.shape[1]).from_pretrained(torch.tensor(embedding_matrix, dtype=torch.float32), freeze=True)
        self.pool = MaxPool() if pooling == "max" else AvgPool()
        self.linear = nn.Linear(hidden_dim, num_classes)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, X):
        outputs = self.embedding(X)
        outputs = self.pool(outputs)
        outputs = self.linear(outputs)
        outputs = self.softmax(outputs)

        return outputs
    
    def save_pretrained(self, path):
        torch.save(self, path)


def train(model, train_dataset, val_dataset, learning_rate=1e-3, epochs=10, batch_size=8):
    # wandb.init(project="kogito-relation-matcher", config={"learning_rate": learning_rate, "epochs": epochs, "batch_size": batch_size})

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    criterion = nn.NLLLoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)

    if use_cuda:
        print("Using CUDA")
        model = model.to(device)
        criterion = criterion.to(device)

    for epoch_num in range(epochs):

        total_acc_train = 0
        total_loss_train = 0

        for train_input, train_label in tqdm(train_dataloader):
            model.zero_grad()

            train_label = train_label.to(device)
            X = train_input.to(device)

            output = model(X)
            
            batch_loss = criterion(output, train_label)
            total_loss_train += batch_loss.item()
            
            acc = (output.argmax(dim=1) == train_label).sum().item()
            total_acc_train += acc

            batch_loss.backward()
            optimizer.step()
        
        total_acc_val = 0
        total_loss_val = 0

        with torch.no_grad():

            for val_input, val_label in val_dataloader:

                val_label = val_label.to(device)
                X = val_input.to(device)

                output = model(X)

                batch_loss = criterion(output, val_label)
                total_loss_val += batch_loss.item()
                
                acc = (output.argmax(dim=1) == val_label).sum().item()
                total_acc_val += acc
        
        train_loss = total_loss_train / len(train_data)
        train_acc = total_acc_train / len(train_data)
        val_loss = total_loss_val / len(val_data)
        val_acc = total_acc_val / len(val_data)

        print(
            f'Epochs: {epoch_num + 1} | Train Loss: {train_loss: .3f} \
            | Train Accuracy: {train_acc: .3f} \
            | Val Loss: {val_loss: .3f} \
            | Val Accuracy: {val_acc: .3f}')
        
        # wandb.log({"train_loss": train_loss, "train_accuracy": train_acc, "val_loss": val_loss, "val_accuracy": val_acc})
        # model.save_pretrained(f"./models/checkpoint_{epoch_num}.pth")

In [65]:
train_df = load_data("data/atomic2020_data-feb2021/train.tsv")
dev_df = load_data("data/atomic2020_data-feb2021/dev.tsv")
train_data = HeadDataset(train_df, vocab=vocab)
val_data = HeadDataset(dev_df, vocab=vocab)
# model.save_pretrained("./models/final_model.pth")

100%|██████████| 843/843 [00:01<00:00, 580.67it/s]


Epochs: 1 | Train Loss:  0.014             | Train Accuracy:  0.563             | Val Loss:  0.014             | Val Accuracy:  0.502


100%|██████████| 843/843 [00:01<00:00, 574.14it/s]


Epochs: 2 | Train Loss:  0.013             | Train Accuracy:  0.594             | Val Loss:  0.014             | Val Accuracy:  0.502


In [66]:
model = SWEMClassifier(embedding_matrix=emb_matrix)
train(model=model, train_dataset=train_data, val_dataset=val_data, epochs=20, batch_size=64)

100%|██████████| 843/843 [00:01<00:00, 442.92it/s]


Epochs: 1 | Train Loss:  0.014             | Train Accuracy:  0.565             | Val Loss:  0.014             | Val Accuracy:  0.498


100%|██████████| 843/843 [00:01<00:00, 659.38it/s]


Epochs: 2 | Train Loss:  0.013             | Train Accuracy:  0.592             | Val Loss:  0.014             | Val Accuracy:  0.499


100%|██████████| 843/843 [00:01<00:00, 634.24it/s]


Epochs: 3 | Train Loss:  0.012             | Train Accuracy:  0.598             | Val Loss:  0.014             | Val Accuracy:  0.504


100%|██████████| 843/843 [00:01<00:00, 604.72it/s]


Epochs: 4 | Train Loss:  0.012             | Train Accuracy:  0.598             | Val Loss:  0.013             | Val Accuracy:  0.503


100%|██████████| 843/843 [00:01<00:00, 679.39it/s]


Epochs: 5 | Train Loss:  0.012             | Train Accuracy:  0.600             | Val Loss:  0.013             | Val Accuracy:  0.510


100%|██████████| 843/843 [00:01<00:00, 695.61it/s]


Epochs: 6 | Train Loss:  0.012             | Train Accuracy:  0.602             | Val Loss:  0.013             | Val Accuracy:  0.506


100%|██████████| 843/843 [00:01<00:00, 544.75it/s]


Epochs: 7 | Train Loss:  0.012             | Train Accuracy:  0.602             | Val Loss:  0.013             | Val Accuracy:  0.505


100%|██████████| 843/843 [00:01<00:00, 692.23it/s]


Epochs: 8 | Train Loss:  0.012             | Train Accuracy:  0.603             | Val Loss:  0.013             | Val Accuracy:  0.503


100%|██████████| 843/843 [00:01<00:00, 626.27it/s]


Epochs: 9 | Train Loss:  0.012             | Train Accuracy:  0.602             | Val Loss:  0.013             | Val Accuracy:  0.507


100%|██████████| 843/843 [00:01<00:00, 583.68it/s]


Epochs: 10 | Train Loss:  0.012             | Train Accuracy:  0.603             | Val Loss:  0.013             | Val Accuracy:  0.505


100%|██████████| 843/843 [00:01<00:00, 659.08it/s]


Epochs: 11 | Train Loss:  0.012             | Train Accuracy:  0.603             | Val Loss:  0.013             | Val Accuracy:  0.499


100%|██████████| 843/843 [00:01<00:00, 664.26it/s]


Epochs: 12 | Train Loss:  0.012             | Train Accuracy:  0.604             | Val Loss:  0.013             | Val Accuracy:  0.508


100%|██████████| 843/843 [00:01<00:00, 689.57it/s]


Epochs: 13 | Train Loss:  0.012             | Train Accuracy:  0.604             | Val Loss:  0.013             | Val Accuracy:  0.511


100%|██████████| 843/843 [00:01<00:00, 667.03it/s]


Epochs: 14 | Train Loss:  0.012             | Train Accuracy:  0.604             | Val Loss:  0.013             | Val Accuracy:  0.505


100%|██████████| 843/843 [00:01<00:00, 753.66it/s]


Epochs: 15 | Train Loss:  0.012             | Train Accuracy:  0.603             | Val Loss:  0.013             | Val Accuracy:  0.510


100%|██████████| 843/843 [00:01<00:00, 701.73it/s]


Epochs: 16 | Train Loss:  0.012             | Train Accuracy:  0.603             | Val Loss:  0.013             | Val Accuracy:  0.510


100%|██████████| 843/843 [00:01<00:00, 702.34it/s]


Epochs: 17 | Train Loss:  0.012             | Train Accuracy:  0.604             | Val Loss:  0.013             | Val Accuracy:  0.503


100%|██████████| 843/843 [00:01<00:00, 643.47it/s]


Epochs: 18 | Train Loss:  0.012             | Train Accuracy:  0.605             | Val Loss:  0.013             | Val Accuracy:  0.507


100%|██████████| 843/843 [00:01<00:00, 648.25it/s]


Epochs: 19 | Train Loss:  0.012             | Train Accuracy:  0.604             | Val Loss:  0.013             | Val Accuracy:  0.513


100%|██████████| 843/843 [00:01<00:00, 585.39it/s]


Epochs: 20 | Train Loss:  0.012             | Train Accuracy:  0.603             | Val Loss:  0.013             | Val Accuracy:  0.509
