In [1]:
import os
import re
import pickle
import string
import random
import json
import jsonlines
from collections import defaultdict
from tqdm.notebook import tqdm
import numpy as np
import torch
from torch import nn

from typing import *

In [7]:
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

nltk.download('stopwords')

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/michele/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

# Utils

In [2]:
def save_pickle(data: dict, path: str) -> None:
    with open(path, 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_pickle(path: str) -> dict:
    with open(path, 'rb') as f:
        return pickle.load(f)

In [3]:
# Saving / loading models
class Checkpoint:
    def __init__(self, path: str, resume=False):
        self.path = path
        os.makedirs(path, exist_ok=True)
        self.resume = resume

    def load(self, model, optimizer, id_path=""):
        if (not self.resume) and id_path == "":
            raise RuntimeError()
        if self.resume:
            id_path = sorted(os.listdir(self.path))[-1]
        self.checkpoint = torch.load(
            os.path.join(self.path, id_path), map_location=lambda storage, loc: storage
        )
        if self.checkpoint == None:
            raise RuntimeError("Checkpoint empty.")
        epoch = self.checkpoint["epoch"]
        model.load_state_dict(self.checkpoint["model_state_dict"])
        optimizer.load_state_dict(self.checkpoint["optimizer_state_dict"])
        losses = self.checkpoint["losses"]
        accuracies = self.checkpoint["accuracies"]
        return (model, optimizer, epoch, losses, accuracies)

    def save(self, model, optimizer, epoch, losses, accuracies):
        model_checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch,
            "losses": losses,
            "accuracies": accuracies
        }
        checkpoint_name = "{}.pth".format(str(epoch).zfill(3))
        complete_path = os.path.join(self.path, checkpoint_name)
        torch.save(model_checkpoint, complete_path)
        return

    def load_just_model(self, model, id_path=""):
        if self.resume:
            id_path = sorted(os.listdir(self.path))[-1]
        self.checkpoint = torch.load(
            os.path.join(self.path, id_path), map_location=lambda storage, loc: storage
        )
        if self.checkpoint == None:
            raise RuntimeError("Checkpoint empty.")
        model.load_state_dict(self.checkpoint["model_state_dict"])
        return model

In [20]:
set_stopwords = set(stopwords.words())

def preprocess(sentence: str, target_word=None) -> str:
    # lowercase sentence
    sentence = sentence.lower()
    # remove punctuation
    sentence = re.sub('[^\w\s]', ' ', sentence)
    # replace multiple adjacent spaces with one single space
    sentence = re.sub(' +', ' ', sentence).strip()
    
    tokens = sentence.split()
    tokens_sw = [word for word in tokens if (not word in set_stopwords or word == target_word)]
    
    return ' '.join(tokens_sw)

In [21]:
def embeddings_dictionary(path: str) -> Dict[str, torch.Tensor]:
    word_vectors = dict()
    with open(path) as f:
        for i, line in tqdm(enumerate(f)):

            word, *vector = line.strip().split(' ')
            vector = torch.tensor([float(c) for c in vector])

            word_vectors[word] = vector
    return word_vectors

# Create word embedding with GloVe

In [22]:
def custom_tokenizer(sentence: str, marker: str) -> List[str]:
    tokens = sentence.split()
    for i, tk in enumerate(tokens):
        if marker in tk:
            target_position = i
            tokens[i] = tk[20:]
    return tokens, target_position

In [23]:
class TokensEmbedder(object):
    def __init__(self, word_vectors: Dict[str, torch.Tensor]) -> None:
        self.word_vectors = word_vectors
    
    def compute_embeddings(self, tokens: List[str]) -> List[torch.Tensor]:
        word_embeddings = []
        for w in tokens:
            word_embeddings.append(self.word_vectors[w] if w in self.word_vectors else self.word_vectors['<unk>'])
        
        return word_embeddings
    
    def aggregate_embeddings(self, tokens: List[str], target_position: int) -> torch.Tensor:
        pass
    
    def __call__(self, tokens: List[str], target_position: int) -> torch.Tensor:
        return self.aggregate_embeddings(tokens, target_position)

class AverageEmbedder(TokensEmbedder):
    def __init__(self, word_vectors: Dict[str, torch.Tensor]) -> None:
        self.word_vectors = word_vectors
    
    def aggregate_embeddings(self, tokens: List[str], target_position: int) -> torch.Tensor:
        embeddings = torch.stack(self.compute_embeddings(tokens))
        return torch.mean(embeddings, dim=0)

class WeightedAverageEmbedder(TokensEmbedder):
    def __init__(self, word_vectors: Dict[str, torch.Tensor], max_weight: float = 1.0, min_weight: float = 0.0) -> None:
        self.word_vectors = word_vectors
        self.max_weight = max_weight
        self.min_weight = min_weight
    
    def aggregate_embeddings(self, tokens: List[str], target_position: int) -> torch.Tensor:
        embeddings = torch.stack(self.compute_embeddings(tokens))
        # aliases for readibility
        n = len(embeddings)
        t = target_position

        # weights from 1 to 0
        weights = torch.linspace(self.max_weight, self.min_weight, n).unsqueeze(1)
        # weights = torch.exp(reversed(torch.arange(n, dtype=torch.float32))).unsqueeze(1)

        # weighted vector
        new_vectors = embeddings

        # weighted average

        # right of the target word
        new_vectors[t:] = new_vectors[t:] * weights[:n - t]
        # left of the target word
        new_vectors[:t] = new_vectors[:t] * reversed(weights[1:t + 1])

        # denominator (sum of the weights)
        weights_sum = weights[:n - t].sum() + weights[1:t + 1].sum()

        return new_vectors.sum(dim=0) / weights_sum

In [24]:
embedding_path = 'embeddings/glove.6B.300d.txt'

word_vectors = embeddings_dictionary(embedding_path)

# <unk> token
# https://stackoverflow.com/questions/49239941/what-is-unk-in-the-pretrained-glove-vector-files-e-g-glove-6b-50d-txt
# unk_embedding = '0.22418134 -0.28881392 0.13854356 0.00365387 -0.12870757 0.10243822 0.061626635 0.07318011 -0.061350107 -1.3477012 0.42037755 -0.063593924 -0.09683349 0.18086134 0.23704372 0.014126852 0.170096 -1.1491593 0.31497982 0.06622181 0.024687296 0.076693475 0.13851812 0.021302193 -0.06640582 -0.010336159 0.13523154 -0.042144544 -0.11938788 0.006948221 0.13333307 -0.18276379 0.052385733 0.008943111 -0.23957317 0.08500333 -0.006894406 0.0015864656 0.063391194 0.19177166 -0.13113557 -0.11295479 -0.14276934 0.03413971 -0.034278486 -0.051366422 0.18891625 -0.16673574 -0.057783455 0.036823478 0.08078679 0.022949161 0.033298038 0.011784158 0.05643189 -0.042776518 0.011959623 0.011552498 -0.0007971594 0.11300405 -0.031369694 -0.0061559738 -0.009043574 -0.415336 -0.18870236 0.13708843 0.005911723 -0.113035575 -0.030096142 -0.23908928 -0.05354085 -0.044904727 -0.20228513 0.0065645403 -0.09578946 -0.07391877 -0.06487607 0.111740574 -0.048649278 -0.16565254 -0.052037314 -0.078968436 0.13684988 0.0757494 -0.006275573 0.28693774 0.52017444 -0.0877165 -0.33010918 -0.1359622 0.114895485 -0.09744406 0.06269521 0.12118575 -0.08026362 0.35256687 -0.060017522 -0.04889904 -0.06828978 0.088740796 0.003964443 -0.0766291 0.1263925 0.07809314 -0.023164088 -0.5680669 -0.037892066 -0.1350967 -0.11351585 -0.111434504 -0.0905027 0.25174105 -0.14841858 0.034635577 -0.07334565 0.06320108 -0.038343467 -0.05413284 0.042197507 -0.090380974 -0.070528865 -0.009174437 0.009069661 0.1405178 0.02958134 -0.036431845 -0.08625681 0.042951006 0.08230793 0.0903314 -0.12279937 -0.013899368 0.048119213 0.08678239 -0.14450377 -0.04424887 0.018319942 0.015026873 -0.100526 0.06021201 0.74059093 -0.0016333034 -0.24960588 -0.023739101 0.016396184 0.11928964 0.13950661 -0.031624354 -0.01645025 0.14079992 -0.0002824564 -0.08052984 -0.0021310581 -0.025350995 0.086938225 0.14308536 0.17146006 -0.13943303 0.048792403 0.09274929 -0.053167373 0.031103406 0.012354865 0.21057427 0.32618305 0.18015954 -0.15881181 0.15322933 -0.22558987 -0.04200665 0.0084689725 0.038156632 0.15188617 0.13274793 0.113756925 -0.095273495 -0.049490947 -0.10265804 -0.27064866 -0.034567792 -0.018810693 -0.0010360252 0.10340131 0.13883452 0.21131058 -0.01981019 0.1833468 -0.10751636 -0.03128868 0.02518242 0.23232952 0.042052146 0.11731903 -0.15506615 0.0063580726 -0.15429358 0.1511722 0.12745973 0.2576985 -0.25486213 -0.0709463 0.17983761 0.054027 -0.09884228 -0.24595179 -0.093028545 -0.028203879 0.094398156 0.09233813 0.029291354 0.13110267 0.15682974 -0.016919162 0.23927948 -0.1343307 -0.22422817 0.14634751 -0.064993896 0.4703685 -0.027190214 0.06224946 -0.091360025 0.21490277 -0.19562101 -0.10032754 -0.09056772 -0.06203493 -0.18876675 -0.10963594 -0.27734384 0.12616494 -0.02217992 -0.16058226 -0.080475815 0.026953284 0.110732645 0.014894041 0.09416802 0.14299914 -0.1594008 -0.066080004 -0.007995227 -0.11668856 -0.13081996 -0.09237365 0.14741232 0.09180138 0.081735 0.3211204 -0.0036552632 -0.047030564 -0.02311798 0.048961394 0.08669574 -0.06766279 -0.50028914 -0.048515294 0.14144728 -0.032994404 -0.11954345 -0.14929578 -0.2388355 -0.019883996 -0.15917352 -0.052084364 0.2801028 -0.0029121689 -0.054581646 -0.47385484 0.17112483 -0.12066923 -0.042173345 0.1395337 0.26115036 0.012869649 0.009291686 -0.0026459037 -0.075331464 0.017840583 -0.26869613 -0.21820338 -0.17084768 -0.1022808 -0.055290595 0.13513643 0.12362477 -0.10980586 0.13980341 -0.20233242 0.08813751 0.3849736 -0.10653763 -0.06199595 0.028849555 0.03230154 0.023856193 0.069950655 0.19310954 -0.077677034 -0.144811'
# unk_embedding = unk_embedding.strip().split()
# word_vectors['<unk>'] = torch.tensor([float(c) for c in unk_embedding])
word_vectors['<unk>'] = torch.rand(300)

0it [00:00, ?it/s]

In [9]:
# save_pickle(word_vectors, 'embeddings/vocabulary_tensors.pkl')

# Dataset class using GloVe

In [25]:
class EmbeddedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path: str, marker: str, embedder: TokensEmbedder):
        self.data = []
        self.marker = marker
        self.embedder = embedder

        self.create_dataset(dataset_path)
        
    
    def create_dataset(self, dataset_path: str) -> None:
        with jsonlines.open(dataset_path, 'r') as f:
            for i, line in enumerate(f.iter()):
                # load sentences
                start1 = int(line['start1'])
                start2 = int(line['start2'])
                s1 = line['sentence1']
                s2 = line['sentence2']
                # insert special characters to locate target word after preprocessing
                s1 = s1[:start1] + self.marker + s1[start1:]
                s2 = s2[:start2] + self.marker + s2[start2:]
                
                # preprocessing
                s1 = preprocess(s1)
                s2 = preprocess(s2)
                
                # tokenization
                t1, target_position1 = custom_tokenizer(s1, self.marker)
                t2, target_position2 = custom_tokenizer(s2, self.marker)
                
                # convert tokens to embeddings and aggregate
                v1 = self.embedder(t1, target_position1)
                v2 = self.embedder(t2, target_position2)
                
                # concatenate vectors
                sentence_vector = torch.cat((v1, v2))
                
                label = torch.tensor(1, dtype=torch.float32) if line['label'] == 'True' else torch.tensor(0, dtype=torch.float32)
                self.data.append((sentence_vector, label))


    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.data[idx]

# Model Class

In [26]:
class MLP(nn.Module):
    def __init__(
        self,
        n_features: int,
        num_layers: int,
        hidden_dim: int,
        activation: Callable[[torch.Tensor], torch.Tensor],
    ) -> None:
        super().__init__()

        self.first_layer = nn.Linear(in_features=n_features, out_features=hidden_dim)

        self.layers = (
            nn.ModuleList()
        )

        for i in range(num_layers):
            self.layers.append(
                nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
            )
            
        self.activation = activation
        
        self.last_layer = nn.Linear(in_features=hidden_dim, out_features=1)
        
        self.sigmoid = nn.Sigmoid()
        
        self.dropout = nn.Dropout(0.3)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = x

        out = self.first_layer(
            out
        )  # First linear layer, transforms the hidden dimensions from `n_features` (embedding dimension) to `hidden_dim`
        for layer in self.layers:  # Apply `k` (linear, activation) layer
            out = layer(out)
            out = self.activation(out)
            out = self.dropout(out)
        out = self.last_layer(
            out
        )  # Last linear layer to bring the `hiddem_dim` features to a binary space (`True`/`False`)
        
        out = self.sigmoid(out)
        return out.squeeze(-1)


# Training process

In [27]:
def batch_to_device(batch: List[torch.Tensor], device: str) -> List[torch.Tensor]:
    return [x.to(device) for x in batch]

def fit(
    epochs: int,
    model: nn.Module,
    criterion: nn.Module,
    opt: torch.optim.Optimizer,
    train_dl: torch.utils.data.DataLoader,
    valid_dl: torch.utils.data.DataLoader,
    checkpoint: Checkpoint = None,
) -> None:

    # keep track of losses and accuracies
    losses = {"train": [], "val": []}
    accuracies = {"train": [], "val": []}

    # training loop
    for epoch in range(1, epochs + 1):
        # instantiate loss and accuracy each epoch
        losses_train = 0
        d_train, n_train = 0, 0

        model.train()
        train_iterator = tqdm(
            train_dl, desc=f"Epoch {epoch}/{epochs} (TRAIN)", leave=False
        )
        for batch in train_iterator:
            # send batch to device
            batch = batch_to_device(batch, device)
            batch_x = batch[:-1]
            batch_y = batch[-1]

            pred = model(*batch_x)

            # compute loss and backprop
            loss_train = criterion(pred, batch_y)
            loss_train.backward()
            opt.step()
            opt.zero_grad()
            losses_train += loss_train.item()

            pred = torch.round(pred)
            # number of predictions
            d_train += pred.shape[0]
            # number of correct predictions
            n_train += (batch_y == pred).int().sum()

        model.eval()
        with torch.no_grad():
            losses_val = 0
            d_val, n_val = 0, 0

            valid_iterator = tqdm(
                valid_dl, desc=f"Epoch {epoch}/{epochs} (VALID)", leave=False
            )
            for batch in valid_iterator:
                # send batch to device
                batch = batch_to_device(batch, device)
                batch_x = batch[:-1]
                batch_y = batch[-1]

                # compute predictions
                pred_val = model(*batch_x)

                # compute loss (validation step => no backprop)
                loss_val = criterion(pred_val, batch_y)
                losses_val += loss_val.item()

                pred_val = torch.round(pred_val)
                # number of predictions
                d_val += pred_val.shape[0]
                # number of correct predictions
                n_val += (batch_y == pred_val).int().sum().item()

        # compute accuracy (train + val)
        loss_train = losses_train / d_train
        loss_val = losses_val / d_val

        acc_train = n_train / d_train
        acc_val = n_val / d_val

        # log losses and accuracies
        losses["train"].append(loss_train)
        losses["val"].append(loss_val)

        accuracies["train"].append(acc_train)
        accuracies["val"].append(acc_val)

        # store checkpoint
        if checkpoint:
            checkpoint.save(model, opt, epoch, losses, accuracies)

        print(
            f"Epoch {epoch} \t T. Loss = {loss_train:.4f}, V. Loss = {loss_val:.4f}, T. Accuracy {acc_train:.2f}, V. Accuracy {acc_val:.2f}."
        )

In [28]:
train_path = 'data/train.jsonl'
dev_path = 'data/dev.jsonl'

random.seed(42)
marker = ''.join(random.choices(string.ascii_lowercase, k=20))

embedder = WeightedAverageEmbedder(word_vectors, 1, 0.2)

train_dataset = EmbeddedDataset(train_path, marker, embedder)
val_dataset = EmbeddedDataset(dev_path, marker, embedder)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.BCELoss()
model = MLP(n_features=600,
            num_layers=5, 
            hidden_dim=200, 
            activation=torch.nn.functional.relu).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.00001)

checkpoint = Checkpoint(path='checkpoints/mlp')

epochs = 50

fit(epochs, model, criterion, optimizer, train_loader, val_loader, checkpoint)

Epoch 1/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 1 	 T. Loss = 0.0055, V. Loss = 0.0055, T. Accuracy 0.50, V. Accuracy 0.50.


Epoch 2/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 2/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 2 	 T. Loss = 0.0055, V. Loss = 0.0055, T. Accuracy 0.50, V. Accuracy 0.50.


Epoch 3/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 3/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 3 	 T. Loss = 0.0055, V. Loss = 0.0055, T. Accuracy 0.51, V. Accuracy 0.58.


Epoch 4/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 4/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 4 	 T. Loss = 0.0055, V. Loss = 0.0055, T. Accuracy 0.52, V. Accuracy 0.58.


Epoch 5/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 5/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 5 	 T. Loss = 0.0054, V. Loss = 0.0055, T. Accuracy 0.54, V. Accuracy 0.57.


Epoch 6/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 6/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 6 	 T. Loss = 0.0054, V. Loss = 0.0054, T. Accuracy 0.57, V. Accuracy 0.59.


Epoch 7/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 7/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 7 	 T. Loss = 0.0053, V. Loss = 0.0053, T. Accuracy 0.60, V. Accuracy 0.60.


Epoch 8/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 8/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 8 	 T. Loss = 0.0051, V. Loss = 0.0053, T. Accuracy 0.62, V. Accuracy 0.61.


Epoch 9/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 9/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 9 	 T. Loss = 0.0050, V. Loss = 0.0053, T. Accuracy 0.64, V. Accuracy 0.61.


Epoch 10/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 10/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 10 	 T. Loss = 0.0049, V. Loss = 0.0053, T. Accuracy 0.66, V. Accuracy 0.62.


Epoch 11/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 11/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 11 	 T. Loss = 0.0048, V. Loss = 0.0052, T. Accuracy 0.68, V. Accuracy 0.62.


Epoch 12/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 12/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 12 	 T. Loss = 0.0047, V. Loss = 0.0052, T. Accuracy 0.69, V. Accuracy 0.63.


Epoch 13/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 13/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 13 	 T. Loss = 0.0045, V. Loss = 0.0052, T. Accuracy 0.71, V. Accuracy 0.64.


Epoch 14/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 14/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 14 	 T. Loss = 0.0043, V. Loss = 0.0052, T. Accuracy 0.72, V. Accuracy 0.63.


Epoch 15/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 15/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 15 	 T. Loss = 0.0042, V. Loss = 0.0054, T. Accuracy 0.74, V. Accuracy 0.63.


Epoch 16/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 16/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 16 	 T. Loss = 0.0040, V. Loss = 0.0054, T. Accuracy 0.75, V. Accuracy 0.64.


Epoch 17/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 17/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 17 	 T. Loss = 0.0039, V. Loss = 0.0054, T. Accuracy 0.77, V. Accuracy 0.65.


Epoch 18/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 18/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 18 	 T. Loss = 0.0037, V. Loss = 0.0055, T. Accuracy 0.79, V. Accuracy 0.65.


Epoch 19/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 19/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 19 	 T. Loss = 0.0036, V. Loss = 0.0057, T. Accuracy 0.80, V. Accuracy 0.64.


Epoch 20/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 20/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 20 	 T. Loss = 0.0034, V. Loss = 0.0059, T. Accuracy 0.81, V. Accuracy 0.64.


Epoch 21/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 21/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 21 	 T. Loss = 0.0032, V. Loss = 0.0060, T. Accuracy 0.82, V. Accuracy 0.64.


Epoch 22/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 22/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 22 	 T. Loss = 0.0031, V. Loss = 0.0062, T. Accuracy 0.83, V. Accuracy 0.64.


Epoch 23/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 23/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 23 	 T. Loss = 0.0029, V. Loss = 0.0064, T. Accuracy 0.85, V. Accuracy 0.63.


Epoch 24/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 24/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 24 	 T. Loss = 0.0028, V. Loss = 0.0063, T. Accuracy 0.85, V. Accuracy 0.64.


Epoch 25/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 25/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 25 	 T. Loss = 0.0026, V. Loss = 0.0068, T. Accuracy 0.86, V. Accuracy 0.63.


Epoch 26/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 26/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 26 	 T. Loss = 0.0026, V. Loss = 0.0067, T. Accuracy 0.87, V. Accuracy 0.64.


Epoch 27/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 27/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 27 	 T. Loss = 0.0024, V. Loss = 0.0069, T. Accuracy 0.87, V. Accuracy 0.63.


Epoch 28/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 28/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 28 	 T. Loss = 0.0022, V. Loss = 0.0076, T. Accuracy 0.90, V. Accuracy 0.62.


Epoch 29/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 29/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 29 	 T. Loss = 0.0021, V. Loss = 0.0077, T. Accuracy 0.90, V. Accuracy 0.62.


Epoch 30/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 30/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 30 	 T. Loss = 0.0021, V. Loss = 0.0077, T. Accuracy 0.90, V. Accuracy 0.63.


Epoch 31/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 31/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 31 	 T. Loss = 0.0020, V. Loss = 0.0078, T. Accuracy 0.90, V. Accuracy 0.63.


Epoch 32/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 32/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 32 	 T. Loss = 0.0018, V. Loss = 0.0081, T. Accuracy 0.92, V. Accuracy 0.63.


Epoch 33/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 33/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 33 	 T. Loss = 0.0018, V. Loss = 0.0086, T. Accuracy 0.92, V. Accuracy 0.62.


Epoch 34/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 34/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 34 	 T. Loss = 0.0017, V. Loss = 0.0088, T. Accuracy 0.92, V. Accuracy 0.63.


Epoch 35/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 35/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 35 	 T. Loss = 0.0016, V. Loss = 0.0089, T. Accuracy 0.92, V. Accuracy 0.61.


Epoch 36/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 36/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 36 	 T. Loss = 0.0015, V. Loss = 0.0096, T. Accuracy 0.92, V. Accuracy 0.61.


Epoch 37/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 37/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 37 	 T. Loss = 0.0014, V. Loss = 0.0095, T. Accuracy 0.93, V. Accuracy 0.63.


Epoch 38/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 38/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 38 	 T. Loss = 0.0014, V. Loss = 0.0093, T. Accuracy 0.93, V. Accuracy 0.61.


Epoch 39/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 39/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 39 	 T. Loss = 0.0012, V. Loss = 0.0100, T. Accuracy 0.94, V. Accuracy 0.62.


Epoch 40/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 40/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 40 	 T. Loss = 0.0012, V. Loss = 0.0101, T. Accuracy 0.94, V. Accuracy 0.63.


Epoch 41/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 41/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 41 	 T. Loss = 0.0011, V. Loss = 0.0106, T. Accuracy 0.95, V. Accuracy 0.62.


Epoch 42/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 42/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 42 	 T. Loss = 0.0011, V. Loss = 0.0107, T. Accuracy 0.95, V. Accuracy 0.61.


Epoch 43/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 43/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 43 	 T. Loss = 0.0010, V. Loss = 0.0113, T. Accuracy 0.95, V. Accuracy 0.61.


Epoch 44/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 44/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 44 	 T. Loss = 0.0010, V. Loss = 0.0109, T. Accuracy 0.96, V. Accuracy 0.62.


Epoch 45/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 45/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 45 	 T. Loss = 0.0010, V. Loss = 0.0110, T. Accuracy 0.95, V. Accuracy 0.62.


Epoch 46/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 46/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 46 	 T. Loss = 0.0010, V. Loss = 0.0109, T. Accuracy 0.96, V. Accuracy 0.62.


Epoch 47/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 47/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 47 	 T. Loss = 0.0009, V. Loss = 0.0117, T. Accuracy 0.96, V. Accuracy 0.61.


Epoch 48/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 48/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 48 	 T. Loss = 0.0008, V. Loss = 0.0120, T. Accuracy 0.96, V. Accuracy 0.61.


Epoch 49/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 49/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 49 	 T. Loss = 0.0008, V. Loss = 0.0115, T. Accuracy 0.96, V. Accuracy 0.63.


Epoch 50/50 (TRAIN):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 50/50 (VALID):   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 50 	 T. Loss = 0.0008, V. Loss = 0.0121, T. Accuracy 0.97, V. Accuracy 0.61.


# Reccurent Neural Networks

In [29]:
def index_dictionary(word_vectors):
    word_index = dict()
    vectors_store = []

    # pad token, index = 0
    vectors_store.append(torch.rand(300))

    # unk token, index = 1
    vectors_store.append(torch.rand(300))

    # save index for each word
    for word, vector in word_vectors.items():
        # skip unk token if present
        if word == '<unk>':
            continue
        word_index[word] = len(vectors_store)
        vectors_store.append(vector)

    word_index = defaultdict(lambda: 1, word_index)  # default dict returns 1 (unk token) when unknown word
    vectors_store = torch.stack(vectors_store)
    return word_index, vectors_store

In [30]:
def review2indices(review: str) -> torch.Tensor:
    return torch.tensor([word_index[word] for word in review.split(' ')], dtype=torch.long)

In [31]:
def rnn_collate_fn(
    data_elements: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] # list of (x, y) pairs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    X1 = [el[0] for el in data_elements]  # list of index tensors
    X2 = [el[1] for el in data_elements]  # list of index tensors
    

    # sizes of the sentences
    # to implement the many-to-one strategy
    X1_lengths = torch.tensor([x.size(0) for x in X1], dtype=torch.long)
    X2_lengths = torch.tensor([x.size(0) for x in X2], dtype=torch.long)
    

    X1 = torch.nn.utils.rnn.pad_sequence(X1, batch_first=True, padding_value=0)  #  shape (batch_size x max_seq_len)
    X2 = torch.nn.utils.rnn.pad_sequence(X2, batch_first=True, padding_value=0)  #  shape (batch_size x max_seq_len)
    

    y = [el[2] for el in data_elements]
    y = torch.tensor(y)

    return X1, X1_lengths, X2, X2_lengths, y

In [32]:
class LSTMClassifier(nn.Module):
    def __init__(
        self,
        vectors_store: torch.Tensor,
        n_hidden: int
    ) -> None:
        super().__init__()

        # embedding layer
        self.embedding = torch.nn.Embedding.from_pretrained(vectors_store)

        # recurrent layer
        self.rnn = torch.nn.LSTM(input_size=vectors_store.size(1), hidden_size=n_hidden, num_layers=1, batch_first=True, bidirectional=True)

        # classification head

        self.lin1 = torch.nn.Linear(2 * (2 * n_hidden), 2 * (2 * n_hidden))
        self.lin2 = torch.nn.Linear(2 * (2 * n_hidden), 1)

        self.dropout = nn.Dropout(0.3)

        self.sigmoid = nn.Sigmoid()

        
    def _compute_embedding(self, X, X_length):
        # embedding words from indices
        embedding_out = self.embedding(X)

        # recurrent encoding
        recurrent_out = self.rnn(embedding_out)[0]
        # here we utilize the sequences length to retrieve the last token 
        # output for each sequence
        batch_size, seq_len, hidden_size = recurrent_out.shape

        # we flatten the recurrent output
        # now I have a long sequence of batch x seq_len vectors 
        flattened_out = recurrent_out.reshape(-1, hidden_size)
        
        # and we use a simple trick to compute a tensor of the indices 
        # of the last token in each batch element
        last_word_relative_indices = X_length - 1
        # tensor of the start offsets of each element in the batch
        sequences_offsets = torch.arange(batch_size, device=X.device) * seq_len
        # e.g. (0, 5, 10, 15, ) + ( 3, 2, 1, 4 ) = ( 3, 7, 11, 19 )
        summary_vectors_indices = sequences_offsets + last_word_relative_indices

        # finally we retrieve the vectors that should summarize every review.
        # (i.e. the last token in the sequence)
        summary_vectors = flattened_out[summary_vectors_indices]
        return summary_vectors
    
    
    def forward(
        self, 
        X1: torch.Tensor,
        X1_length: torch.Tensor,
        X2: torch.Tensor,
        X2_length: torch.Tensor,
    ) -> torch.Tensor:
        
        summary_vectors_1 = self._compute_embedding(X1, X1_length)
        summary_vectors_2 = self._compute_embedding(X2, X2_length)
        
        summary_vectors = torch.cat((summary_vectors_1, summary_vectors_2), dim=1)
        
        # return summary_vectors
        
        # now we can classify the reviews with a feedforward pass on the summary
        # vectors
        out = self.lin1(summary_vectors)
        out = torch.relu(out)
        out = self.dropout(out)
        
        out = self.lin2(out).squeeze(1)
        logits = out
        pred = self.sigmoid(logits)

        return pred

In [43]:
batch = next(iter(train_loader))

In [44]:
batch_x = batch[:-1]
batch_y = batch[-1]

In [60]:
model = LSTMClassifier(vectors_store, 300).to(device)


In [52]:
device = 'cpu'

In [61]:
model(*batch_x).shape

torch.Size([16])

In [33]:
class IndicesDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path: str):
        self.data = []

        self.create_dataset(dataset_path)
        
    
    def create_dataset(self, dataset_path: str) -> None:
        with jsonlines.open(dataset_path, 'r') as f:
            for i, line in enumerate(f.iter()):
                # load sentences
                start1 = int(line['start1'])
                start2 = int(line['start2'])
                end1 = int(line['end1'])
                end2 = int(line['end2'])
                s1 = line['sentence1']
                s2 = line['sentence2']
                lemma1 = s1[start1:end1]
                lemma2 = s2[start2:end2]
                

                # preprocessing
                s1 = preprocess(s1, lemma1)
                s2 = preprocess(s2, lemma2)
                
                # sentences to indices
                i1 = review2indices(s1)
                i2 = review2indices(s2)

                label = torch.tensor(1, dtype=torch.float32) if line['label'] == 'True' else torch.tensor(0, dtype=torch.float32)
                self.data.append((i1, i2, label))


    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.data[idx]

In [37]:
train_path = 'data/train.jsonl'
dev_path = 'data/dev.jsonl'

word_index, vectors_store = index_dictionary(word_vectors)

train_dataset = IndicesDataset(train_path)
val_dataset = IndicesDataset(dev_path)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=rnn_collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=rnn_collate_fn)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.BCELoss()

model = LSTMClassifier(vectors_store, 150).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.000)

checkpoint = Checkpoint(path='checkpoints/rnn')

epochs = 20

fit(epochs, model, criterion, optimizer, train_loader, val_loader, None)

Epoch 1/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 1/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1 	 T. Loss = 0.0433, V. Loss = 0.0435, T. Accuracy 0.51, V. Accuracy 0.55.


Epoch 2/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 2 	 T. Loss = 0.0430, V. Loss = 0.0432, T. Accuracy 0.55, V. Accuracy 0.56.


Epoch 3/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 3/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 3 	 T. Loss = 0.0421, V. Loss = 0.0425, T. Accuracy 0.58, V. Accuracy 0.59.


Epoch 4/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 4/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 4 	 T. Loss = 0.0399, V. Loss = 0.0404, T. Accuracy 0.64, V. Accuracy 0.63.


Epoch 5/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 5/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 5 	 T. Loss = 0.0371, V. Loss = 0.0408, T. Accuracy 0.68, V. Accuracy 0.64.


Epoch 6/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 6/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 6 	 T. Loss = 0.0347, V. Loss = 0.0426, T. Accuracy 0.72, V. Accuracy 0.62.


Epoch 7/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 7/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 7 	 T. Loss = 0.0323, V. Loss = 0.0444, T. Accuracy 0.75, V. Accuracy 0.61.


Epoch 8/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 8/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 8 	 T. Loss = 0.0298, V. Loss = 0.0446, T. Accuracy 0.77, V. Accuracy 0.62.


Epoch 9/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 9/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 9 	 T. Loss = 0.0270, V. Loss = 0.0484, T. Accuracy 0.81, V. Accuracy 0.60.


Epoch 10/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 10/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 10 	 T. Loss = 0.0241, V. Loss = 0.0502, T. Accuracy 0.83, V. Accuracy 0.62.


Epoch 11/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 11/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 11 	 T. Loss = 0.0213, V. Loss = 0.0547, T. Accuracy 0.86, V. Accuracy 0.60.


Epoch 12/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 12/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 12 	 T. Loss = 0.0182, V. Loss = 0.0590, T. Accuracy 0.88, V. Accuracy 0.60.


Epoch 13/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 13/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 13 	 T. Loss = 0.0153, V. Loss = 0.0620, T. Accuracy 0.90, V. Accuracy 0.59.


Epoch 14/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 14/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 14 	 T. Loss = 0.0125, V. Loss = 0.0701, T. Accuracy 0.93, V. Accuracy 0.59.


Epoch 15/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 15/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 15 	 T. Loss = 0.0099, V. Loss = 0.0753, T. Accuracy 0.95, V. Accuracy 0.59.


Epoch 16/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 16/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 16 	 T. Loss = 0.0078, V. Loss = 0.0831, T. Accuracy 0.96, V. Accuracy 0.59.


Epoch 17/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 17/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 17 	 T. Loss = 0.0057, V. Loss = 0.0913, T. Accuracy 0.98, V. Accuracy 0.58.


Epoch 18/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 18/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 18 	 T. Loss = 0.0042, V. Loss = 0.1016, T. Accuracy 0.98, V. Accuracy 0.59.


Epoch 19/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 19/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 19 	 T. Loss = 0.0031, V. Loss = 0.1099, T. Accuracy 0.99, V. Accuracy 0.59.


Epoch 20/20 (TRAIN):   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 20/20 (VALID):   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 20 	 T. Loss = 0.0022, V. Loss = 0.1192, T. Accuracy 0.99, V. Accuracy 0.59.
