In [1]:
import json
import numpy as np
import torch
from torch import nn
from pprint import pprint
import jsonlines
import os
import re
from tqdm.notebook import tqdm
import pickle
import string
import random


from typing import *

# Utils

In [2]:
def save_pickle(data: dict, path: str) -> None:
    with open(path, 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_pickle(path: str) -> dict:
    with open(path, 'rb') as f:
        return pickle.load(f)

In [2]:
# Saving / loading models
class Checkpoint:
    def __init__(self, path: str, resume=False):
        self.path = path
        os.makedirs(path, exist_ok=True)
        self.resume = resume

    def load(self, model, optimizer, id_path=""):
        if (not self.resume) and id_path == "":
            raise RuntimeError()
        if self.resume:
            id_path = sorted(os.listdir(self.path))[-1]
        self.checkpoint = torch.load(
            os.path.join(self.path, id_path), map_location=lambda storage, loc: storage
        )
        if self.checkpoint == None:
            raise RuntimeError("Checkpoint empty.")
        epoch = self.checkpoint["epoch"]
        model.load_state_dict(self.checkpoint["model_state_dict"])
        optimizer.load_state_dict(self.checkpoint["optimizer_state_dict"])
        losses = self.checkpoint["losses"]
        accuracies = self.checkpoint["accuracies"]
        return (model, optimizer, epoch, losses, accuracies)

    def save(self, model, optimizer, epoch, losses, accuracies):
        model_checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch,
            "losses": losses,
            "accuracies": accuracies
        }
        checkpoint_name = "{}.pth".format(str(epoch).zfill(3))
        complete_path = os.path.join(self.path, checkpoint_name)
        torch.save(model_checkpoint, complete_path)
        return

    def load_just_model(self, model, id_path=""):
        if self.resume:
            id_path = sorted(os.listdir(self.path))[-1]
        self.checkpoint = torch.load(
            os.path.join(self.path, id_path), map_location=lambda storage, loc: storage
        )
        if self.checkpoint == None:
            raise RuntimeError("Checkpoint empty.")
        model.load_state_dict(self.checkpoint["model_state_dict"])
        return model

In [3]:
def preprocess(sentence: str) -> str:
    # lowercase sentence
    sentence = sentence.lower()
    # remove punctuation
    sentence = re.sub('[^\w\s]', ' ', sentence)
    # replace multiple adjacent spaces with one single space
    sentence = re.sub(' +', ' ', sentence).strip()
    return sentence

In [4]:
def embeddings_dictionary(path: str) -> Dict[str, torch.Tensor]:
    word_vectors = dict()
    with open(path) as f:
        for i, line in tqdm(enumerate(f)):

            word, *vector = line.strip().split(' ')
            vector = torch.tensor([float(c) for c in vector])

            word_vectors[word] = vector
    return word_vectors

# Create word embedding with GloVe

In [5]:
def custom_tokenizer(sentence: str, marker: str) -> List[str]:
    tokens = sentence.split()
    for i, tk in enumerate(tokens):
        if marker in tk:
            target_position = i
            tokens[i] = tk[20:]
    return tokens, target_position

In [7]:
class TokensEmbedder(object):
    def __init__(self, word_vectors: Dict[str, torch.Tensor]) -> None:
        self.word_vectors = word_vectors
    
    def compute_embeddings(self, tokens: List[str]) -> List[torch.Tensor]:
        word_embeddings = []
        for w in tokens:
            word_embeddings.append(self.word_vectors[w] if w in self.word_vectors else self.word_vectors['<unk>'])
        
        return word_embeddings
    
    def aggregate_embeddings(self, tokens: List[str], target_position: int) -> torch.Tensor:
        pass
    
    def __call__(self, tokens: List[str], target_position: int) -> torch.Tensor:
        return self.aggregate_embeddings(tokens, target_position)

class AverageEmbedder(TokensEmbedder):
    def __init__(self, word_vectors: Dict[str, torch.Tensor]) -> None:
        self.word_vectors = word_vectors
    
    def aggregate_embeddings(self, tokens: List[str], target_position: int) -> torch.Tensor:
        embeddings = torch.stack(self.compute_embeddings(tokens))
        return torch.mean(embeddings, dim=0)

class WeightedAverageEmbedder(TokensEmbedder):
    def __init__(self, word_vectors: Dict[str, torch.Tensor], max_weight: float = 1.0, min_weight: float = 0.0) -> None:
        self.word_vectors = word_vectors
        self.max_weight = max_weight
        self.min_weight = min_weight
    
    def aggregate_embeddings(self, tokens: List[str], target_position: int) -> torch.Tensor:
        embeddings = torch.stack(self.compute_embeddings(tokens))
        # aliases for readibility
        n = len(embeddings)
        t = target_position

        # weights from 1 to 0
        weights = torch.linspace(self.max_weight, self.min_weight, n).unsqueeze(1)
        # weights = torch.exp(reversed(torch.arange(n, dtype=torch.float32))).unsqueeze(1)

        # weighted vector
        new_vectors = embeddings

        # weighted average

        # right of the target word
        new_vectors[t:] = new_vectors[t:] * weights[:n - t]
        # left of the target word
        new_vectors[:t] = new_vectors[:t] * reversed(weights[1:t + 1])

        # denominator (sum of the weights)
        weights_sum = weights[:n - t].sum() + weights[1:t + 1].sum()

        return new_vectors.sum(dim=0) / weights_sum

In [6]:
embedding_path = 'embeddings/glove.6B.300d.txt'

word_vectors = embeddings_dictionary(embedding_path)

# <unk> token
# https://stackoverflow.com/questions/49239941/what-is-unk-in-the-pretrained-glove-vector-files-e-g-glove-6b-50d-txt
# unk_embedding = '0.22418134 -0.28881392 0.13854356 0.00365387 -0.12870757 0.10243822 0.061626635 0.07318011 -0.061350107 -1.3477012 0.42037755 -0.063593924 -0.09683349 0.18086134 0.23704372 0.014126852 0.170096 -1.1491593 0.31497982 0.06622181 0.024687296 0.076693475 0.13851812 0.021302193 -0.06640582 -0.010336159 0.13523154 -0.042144544 -0.11938788 0.006948221 0.13333307 -0.18276379 0.052385733 0.008943111 -0.23957317 0.08500333 -0.006894406 0.0015864656 0.063391194 0.19177166 -0.13113557 -0.11295479 -0.14276934 0.03413971 -0.034278486 -0.051366422 0.18891625 -0.16673574 -0.057783455 0.036823478 0.08078679 0.022949161 0.033298038 0.011784158 0.05643189 -0.042776518 0.011959623 0.011552498 -0.0007971594 0.11300405 -0.031369694 -0.0061559738 -0.009043574 -0.415336 -0.18870236 0.13708843 0.005911723 -0.113035575 -0.030096142 -0.23908928 -0.05354085 -0.044904727 -0.20228513 0.0065645403 -0.09578946 -0.07391877 -0.06487607 0.111740574 -0.048649278 -0.16565254 -0.052037314 -0.078968436 0.13684988 0.0757494 -0.006275573 0.28693774 0.52017444 -0.0877165 -0.33010918 -0.1359622 0.114895485 -0.09744406 0.06269521 0.12118575 -0.08026362 0.35256687 -0.060017522 -0.04889904 -0.06828978 0.088740796 0.003964443 -0.0766291 0.1263925 0.07809314 -0.023164088 -0.5680669 -0.037892066 -0.1350967 -0.11351585 -0.111434504 -0.0905027 0.25174105 -0.14841858 0.034635577 -0.07334565 0.06320108 -0.038343467 -0.05413284 0.042197507 -0.090380974 -0.070528865 -0.009174437 0.009069661 0.1405178 0.02958134 -0.036431845 -0.08625681 0.042951006 0.08230793 0.0903314 -0.12279937 -0.013899368 0.048119213 0.08678239 -0.14450377 -0.04424887 0.018319942 0.015026873 -0.100526 0.06021201 0.74059093 -0.0016333034 -0.24960588 -0.023739101 0.016396184 0.11928964 0.13950661 -0.031624354 -0.01645025 0.14079992 -0.0002824564 -0.08052984 -0.0021310581 -0.025350995 0.086938225 0.14308536 0.17146006 -0.13943303 0.048792403 0.09274929 -0.053167373 0.031103406 0.012354865 0.21057427 0.32618305 0.18015954 -0.15881181 0.15322933 -0.22558987 -0.04200665 0.0084689725 0.038156632 0.15188617 0.13274793 0.113756925 -0.095273495 -0.049490947 -0.10265804 -0.27064866 -0.034567792 -0.018810693 -0.0010360252 0.10340131 0.13883452 0.21131058 -0.01981019 0.1833468 -0.10751636 -0.03128868 0.02518242 0.23232952 0.042052146 0.11731903 -0.15506615 0.0063580726 -0.15429358 0.1511722 0.12745973 0.2576985 -0.25486213 -0.0709463 0.17983761 0.054027 -0.09884228 -0.24595179 -0.093028545 -0.028203879 0.094398156 0.09233813 0.029291354 0.13110267 0.15682974 -0.016919162 0.23927948 -0.1343307 -0.22422817 0.14634751 -0.064993896 0.4703685 -0.027190214 0.06224946 -0.091360025 0.21490277 -0.19562101 -0.10032754 -0.09056772 -0.06203493 -0.18876675 -0.10963594 -0.27734384 0.12616494 -0.02217992 -0.16058226 -0.080475815 0.026953284 0.110732645 0.014894041 0.09416802 0.14299914 -0.1594008 -0.066080004 -0.007995227 -0.11668856 -0.13081996 -0.09237365 0.14741232 0.09180138 0.081735 0.3211204 -0.0036552632 -0.047030564 -0.02311798 0.048961394 0.08669574 -0.06766279 -0.50028914 -0.048515294 0.14144728 -0.032994404 -0.11954345 -0.14929578 -0.2388355 -0.019883996 -0.15917352 -0.052084364 0.2801028 -0.0029121689 -0.054581646 -0.47385484 0.17112483 -0.12066923 -0.042173345 0.1395337 0.26115036 0.012869649 0.009291686 -0.0026459037 -0.075331464 0.017840583 -0.26869613 -0.21820338 -0.17084768 -0.1022808 -0.055290595 0.13513643 0.12362477 -0.10980586 0.13980341 -0.20233242 0.08813751 0.3849736 -0.10653763 -0.06199595 0.028849555 0.03230154 0.023856193 0.069950655 0.19310954 -0.077677034 -0.144811'
# unk_embedding = unk_embedding.strip().split()
# word_vectors['<unk>'] = torch.tensor([float(c) for c in unk_embedding])
# word_vectors['<unk>'] = torch.rand(300)

0it [00:00, ?it/s]

In [16]:
save_pickle(word_vectors, 'embeddings/vocabulary_tensors.pkl')

# Dataset class using GloVe

In [18]:
class WiCDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path: str, marker: str, embedder: TokensEmbedder):
        self.data = []
        self.marker = marker
        self.embedder = embedder

        self.create_dataset(dataset_path)
        
    
    def create_dataset(self, dataset_path: str) -> None:
        with jsonlines.open(dataset_path, 'r') as f:
            for i, line in enumerate(f.iter()):
                # load sentences
                start1 = int(line['start1'])
                start2 = int(line['start2'])
                s1 = line['sentence1']
                s2 = line['sentence2']
                # insert special characters to locate target word after preprocessing
                s1 = s1[:start1] + self.marker + s1[start1:]
                s2 = s2[:start2] + self.marker + s2[start2:]
                
                # preprocessing
                s1 = preprocess(s1)
                s2 = preprocess(s2)
                
                # tokenization
                t1, target_position1 = custom_tokenizer(s1, self.marker)
                t2, target_position2 = custom_tokenizer(s2, self.marker)
                
                # convert tokens to embeddings and aggregate
                v1 = self.embedder(t1, target_position1)
                v2 = self.embedder(t2, target_position2)
                
                # concatenate vectors
                sentence_vector = torch.cat((v1, v2))
                
                label = torch.tensor(1, dtype=torch.float32) if line['label'] == 'True' else torch.tensor(0, dtype=torch.float32)
                self.data.append((sentence_vector, label))


    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.data[idx]

# Model Class

In [19]:
class MLP(nn.Module):
    def __init__(
        self,
        n_features: int,
        num_layers: int,
        hidden_dim: int,
        activation: Callable[[torch.Tensor], torch.Tensor],
    ) -> None:
        super().__init__()

        self.first_layer = nn.Linear(in_features=n_features, out_features=hidden_dim)

        self.layers = (
            nn.ModuleList()
        )

        for i in range(num_layers):
            self.layers.append(
                nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
            )
            
        self.activation = activation
        
        self.last_layer = nn.Linear(in_features=hidden_dim, out_features=1)
        
        self.sigmoid = nn.Sigmoid()
        
        self.dropout = nn.Dropout(0.3)

    def forward(self, meshgrid: torch.Tensor) -> torch.Tensor:
        out = meshgrid

        out = self.first_layer(
            out
        )  # First linear layer, transforms the hidden dimensions from `n_features` (embedding dimension) to `hidden_dim`
        for layer in self.layers:  # Apply `k` (linear, activation) layer
            out = layer(out)
            out = self.activation(out)
            out = self.dropout(out)
        out = self.last_layer(
            out
        )  # Last linear layer to bring the `hiddem_dim` features to a binary space (`True`/`False`)
        
        out = self.sigmoid(out)
        return out.squeeze(-1)


# Training process

In [20]:
def fit(epochs: int,
        model: nn.Module,
        criterion: nn.Module,
        opt: torch.optim.Optimizer,
        train_dl: torch.utils.data.DataLoader,
        valid_dl: torch.utils.data.DataLoader,
        checkpoint: Checkpoint = None
       ) -> None:

    losses = {'train': [], 'val': []}
    accuracies = {'train': [], 'val': []}

    for epoch in tqdm(range(epochs)):
        losses_train = 0
        d_train, n_train = 0, 0

        model.train()
        for xb, yb in train_dl:
            xb = xb.to(device)
            yb = yb.to(device)
            
            pred = model(xb)
            
            loss_train = criterion(pred, yb)
            loss_train.backward()
            opt.step()
            opt.zero_grad()
            losses_train += loss_train.item()
            
            pred = torch.round(pred)
            # number of predictions
            d_train += pred.shape[0]
            # number of correct predictions
            n_train += (yb == pred).int().sum()
            
            
            
        model.eval()
        with torch.no_grad():
            losses_val = 0
            d_val, n_val = 0, 0
            for xb, yb in valid_dl:
                xb = xb.to(device)
                yb = yb.to(device)
                
                pred = model(xb)

                loss_val = criterion(pred, yb)
                losses_val += loss_val.item()
                

                pred = torch.round(pred)
                # number of predictions
                d_val += pred.shape[0]
                # number of correct predictions
                n_val += (yb == pred).int().sum().item()
                
        loss_train = losses_train / d_train
        loss_val = losses_val / d_val
        
        acc_train = n_train / d_train
        acc_val = n_val / d_val
        
        losses['train'].append(loss_train)
        losses['val'].append(loss_val)
        
        accuracies['train'].append(acc_train)
        accuracies['val'].append(acc_val)

        if checkpoint:
            checkpoint.save(model, opt, epoch, losses, accuracies)

        print(
            f"Epoch {epoch} \t T. Loss = {loss_train:.4f}, V. Loss = {loss_val:.4f}, T. Accuracy {acc_train:.2f}, V. Accuracy {acc_val:.2f}."
        )

In [21]:
train_path = 'data/train.jsonl'
dev_path = 'data/dev.jsonl'

random.seed(42)
marker = ''.join(random.choices(string.ascii_lowercase, k=20))

embedder = WeightedAverageEmbedder(word_vectors, 1, 0)

train_dataset = WiCDataset(train_path, marker, embedder)
val_dataset = WiCDataset(dev_path, marker, embedder)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.BCELoss()
model = MLP(n_features=600,
            num_layers=5, 
            hidden_dim=150, 
            activation=torch.nn.functional.relu).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.00001)

checkpoint = Checkpoint(path='checkpoints')

fit(50, model, criterion, optimizer, train_loader, val_loader, checkpoint)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 0 	 T. Loss = 0.0055, V. Loss = 0.0055, T. Accuracy 0.50, V. Accuracy 0.50.
Epoch 1 	 T. Loss = 0.0055, V. Loss = 0.0055, T. Accuracy 0.50, V. Accuracy 0.50.
Epoch 2 	 T. Loss = 0.0055, V. Loss = 0.0055, T. Accuracy 0.51, V. Accuracy 0.52.
Epoch 3 	 T. Loss = 0.0055, V. Loss = 0.0055, T. Accuracy 0.50, V. Accuracy 0.57.
Epoch 4 	 T. Loss = 0.0055, V. Loss = 0.0055, T. Accuracy 0.50, V. Accuracy 0.53.
Epoch 5 	 T. Loss = 0.0055, V. Loss = 0.0055, T. Accuracy 0.51, V. Accuracy 0.58.
Epoch 6 	 T. Loss = 0.0054, V. Loss = 0.0055, T. Accuracy 0.53, V. Accuracy 0.57.
Epoch 7 	 T. Loss = 0.0054, V. Loss = 0.0055, T. Accuracy 0.56, V. Accuracy 0.57.
Epoch 8 	 T. Loss = 0.0054, V. Loss = 0.0054, T. Accuracy 0.56, V. Accuracy 0.57.
Epoch 9 	 T. Loss = 0.0053, V. Loss = 0.0054, T. Accuracy 0.58, V. Accuracy 0.57.
Epoch 10 	 T. Loss = 0.0053, V. Loss = 0.0054, T. Accuracy 0.59, V. Accuracy 0.57.
Epoch 11 	 T. Loss = 0.0052, V. Loss = 0.0054, T. Accuracy 0.61, V. Accuracy 0.59.
Epoch 12 	 T. 

# Reccurent Neural Networks

In [7]:
from collections import defaultdict

word_index = dict()
vectors_store = []

# pad token, index = 0
vectors_store.append(torch.rand(300))

# unk token, index = 1
vectors_store.append(torch.rand(300))

# save index for each word
for word, vector in word_vectors.items():
    word_index[word] = len(vectors_store)
    vectors_store.append(vector)

word_index = defaultdict(lambda: 1, word_index)  # default dict returns 1 (unk token) when unknown word
vectors_store = torch.stack(vectors_store)

In [8]:
def review2indices(review: str) -> torch.Tensor:
    return torch.tensor([word_index[word] for word in review.split(' ')], dtype=torch.long)

In [9]:
def rnn_collate_fn(
    data_elements: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] # list of (x, y) pairs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    X1 = [el[0] for el in data_elements]  # list of index tensors
    X2 = [el[1] for el in data_elements]  # list of index tensors
    

    # sizes of the sentences
    # to implement the many-to-one strategy
    X1_lengths = torch.tensor([x.size(0) for x in X1], dtype=torch.long)
    X2_lengths = torch.tensor([x.size(0) for x in X2], dtype=torch.long)
    

    X1 = torch.nn.utils.rnn.pad_sequence(X1, batch_first=True, padding_value=0)  #  shape (batch_size x max_seq_len)
    X2 = torch.nn.utils.rnn.pad_sequence(X2, batch_first=True, padding_value=0)  #  shape (batch_size x max_seq_len)
    

    y = [el[2] for el in data_elements]
    y = torch.tensor(y)

    return X1, X1_lengths, X2, X2_lengths, y

In [120]:
strings1 = ['the cat is on the table', 'hello old friend', 'hi my name is michele what is yours', 'brother']
labels1 = [torch.tensor(0) for _ in range(len(strings))]

X1, X1_length, y1 = rnn_collate_fn(list(zip([review2indices(string) for string in strings1], labels1)))

In [126]:
strings2 = ['my father is happy', 'hello man ', 'sister and up', "cause because chicken"]
labels2 = [torch.tensor(0) for _ in range(len(strings))]

X2, X2_length, y2 = rnn_collate_fn(list(zip([review2indices(string) for string in strings2], labels2)))

In [23]:
class LSTMClassifier(nn.Module):
    def __init__(
        self,
        vectors_store: torch.Tensor,
        n_hidden: int
    ) -> None:
        super().__init__()

        # embedding layer
        self.embedding = torch.nn.Embedding.from_pretrained(vectors_store)

        # recurrent layer
        self.rnn = torch.nn.LSTM(input_size=vectors_store.size(1), hidden_size=n_hidden, num_layers=1, batch_first=True)

        # classification head
        self.lin1 = torch.nn.Linear(2 * n_hidden, 2 * n_hidden)
        self.lin2 = torch.nn.Linear(2 * n_hidden, 1)
        
        self.sigmoid = nn.Sigmoid()
        
        
    def _compute_embedding(self, X, X_length):
        # embedding words from indices
        embedding_out = self.embedding(X)

        # recurrent encoding
        recurrent_out = self.rnn(embedding_out)[0]
        # here we utilize the sequences length to retrieve the last token 
        # output for each sequence
        batch_size, seq_len, hidden_size = recurrent_out.shape

        # we flatten the recurrent output
        # now I have a long sequence of batch x seq_len vectors 
        flattened_out = recurrent_out.reshape(-1, hidden_size)
        
        # and we use a simple trick to compute a tensor of the indices 
        # of the last token in each batch element
        last_word_relative_indices = X_length - 1
        # tensor of the start offsets of each element in the batch
        sequences_offsets = torch.arange(batch_size, device='cuda:0') * seq_len
        # e.g. (0, 5, 10, 15, ) + ( 3, 2, 1, 4 ) = ( 3, 7, 11, 19 )
        summary_vectors_indices = sequences_offsets + last_word_relative_indices

        # finally we retrieve the vectors that should summarize every review.
        # (i.e. the last token in the sequence)
        summary_vectors = flattened_out[summary_vectors_indices]
        return summary_vectors
    
    
    def forward(
        self, 
        X1: torch.Tensor,
        X1_length: torch.Tensor,
        X2: torch.Tensor,
        X2_length: torch.Tensor,
    ) -> torch.Tensor:
        
        summary_vectors_1 = self._compute_embedding(X1, X1_length)
        summary_vectors_2 = self._compute_embedding(X2, X2_length)
        
        summary_vectors = torch.cat((summary_vectors_1, summary_vectors_2), dim=1)

        # now we can classify the reviews with a feedforward pass on the summary
        # vectors
        out = self.lin1(summary_vectors)
        out = torch.relu(out)
        out = self.lin2(out).squeeze(1)

        # compute logits (which are simply the out variable)
        # and the actual probability distribution (pred, as it is the predicted distribution)
        logits = out
        pred = self.sigmoid(logits)

        return pred

In [24]:
class EmbeddingDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path: str):
        self.data = []

        self.create_dataset(dataset_path)
        
    
    def create_dataset(self, dataset_path: str) -> None:
        with jsonlines.open(dataset_path, 'r') as f:
            for i, line in enumerate(f.iter()):
                # load sentences
                start1 = int(line['start1'])
                start2 = int(line['start2'])
                s1 = line['sentence1']
                s2 = line['sentence2']

                # preprocessing
                s1 = preprocess(s1)
                s2 = preprocess(s2)
                
                # sentences to indices
                i1 = review2indices(s1)
                i2 = review2indices(s2)

                label = torch.tensor(1, dtype=torch.float32) if line['label'] == 'True' else torch.tensor(0, dtype=torch.float32)
                self.data.append((i1, i2, label))


    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.data[idx]

In [25]:
def fit_rnn(epochs: int,
            model: nn.Module,
            criterion: nn.Module,
            opt: torch.optim.Optimizer,
            train_dl: torch.utils.data.DataLoader,
            valid_dl: torch.utils.data.DataLoader,
            checkpoint: Checkpoint = None
           ) -> None:

    losses = {'train': [], 'val': []}
    accuracies = {'train': [], 'val': []}

    for epoch in tqdm(range(epochs)):
        losses_train = 0
        d_train, n_train = 0, 0

        model.train()
        for x1, x1_length, x2, x2_length, yb in train_dl:
            x1 = x1.to(device)
            x2 = x2.to(device)
            x1_length = x1_length.to(device)
            x2_length = x2_length.to(device)
            
            yb = yb.to(device)

            pred = model(x1, x1_length, x2, x2_length)
            
            loss_train = criterion(pred, yb)
            loss_train.backward()
            opt.step()
            opt.zero_grad()
            losses_train += loss_train.item()
            
            pred = torch.round(pred)
            # number of predictions
            d_train += pred.shape[0]
            # number of correct predictions
            n_train += (yb == pred).int().sum()
            
            
            
        model.eval()
        with torch.no_grad():
            losses_val = 0
            d_val, n_val = 0, 0
            for x1, x1_length, x2, x2_length, yb in train_dl:
                x1 = x1.to(device)
                x2 = x2.to(device)
                x1_length = x1_length.to(device)
                x2_length = x2_length.to(device)
                yb = yb.to(device)
                
                pred = model(x1, x1_length, x2, x2_length)

                loss_val = criterion(pred, yb)
                losses_val += loss_val.item()
                

                pred = torch.round(pred)
                # number of predictions
                d_val += pred.shape[0]
                # number of correct predictions
                n_val += (yb == pred).int().sum().item()
                
        loss_train = losses_train / d_train
        loss_val = losses_val / d_val
        
        acc_train = n_train / d_train
        acc_val = n_val / d_val
        
        losses['train'].append(loss_train)
        losses['val'].append(loss_val)
        
        accuracies['train'].append(acc_train)
        accuracies['val'].append(acc_val)

        if checkpoint:
            checkpoint.save(model, opt, epoch, losses, accuracies)

        print(
            f"Epoch {epoch} \t T. Loss = {loss_train:.4f}, V. Loss = {loss_val:.4f}, T. Accuracy {acc_train:.2f}, V. Accuracy {acc_val:.2f}."
        )

In [26]:
train_path = 'data/train.jsonl'
dev_path = 'data/dev.jsonl'

train_dataset = EmbeddingDataset(train_path)
val_dataset = EmbeddingDataset(dev_path)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=rnn_collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=rnn_collate_fn)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.BCELoss()

model = LSTMClassifier(vectors_store, 100).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.00001)

checkpoint = Checkpoint(path='checkpoints/rnn')

fit_rnn(50, model, criterion, optimizer, train_loader, val_loader, checkpoint)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 0 	 T. Loss = 0.0055, V. Loss = 0.0055, T. Accuracy 0.50, V. Accuracy 0.53.
Epoch 1 	 T. Loss = 0.0055, V. Loss = 0.0054, T. Accuracy 0.53, V. Accuracy 0.56.
Epoch 2 	 T. Loss = 0.0054, V. Loss = 0.0054, T. Accuracy 0.55, V. Accuracy 0.57.
Epoch 3 	 T. Loss = 0.0054, V. Loss = 0.0054, T. Accuracy 0.56, V. Accuracy 0.57.
Epoch 4 	 T. Loss = 0.0054, V. Loss = 0.0054, T. Accuracy 0.57, V. Accuracy 0.59.
Epoch 5 	 T. Loss = 0.0054, V. Loss = 0.0054, T. Accuracy 0.58, V. Accuracy 0.59.
Epoch 6 	 T. Loss = 0.0054, V. Loss = 0.0053, T. Accuracy 0.58, V. Accuracy 0.60.
Epoch 7 	 T. Loss = 0.0053, V. Loss = 0.0053, T. Accuracy 0.59, V. Accuracy 0.59.
Epoch 8 	 T. Loss = 0.0053, V. Loss = 0.0053, T. Accuracy 0.60, V. Accuracy 0.61.
Epoch 9 	 T. Loss = 0.0052, V. Loss = 0.0052, T. Accuracy 0.60, V. Accuracy 0.60.
Epoch 10 	 T. Loss = 0.0052, V. Loss = 0.0051, T. Accuracy 0.61, V. Accuracy 0.63.
Epoch 11 	 T. Loss = 0.0051, V. Loss = 0.0050, T. Accuracy 0.63, V. Accuracy 0.64.
Epoch 12 	 T. 