In [None]:
import torch
import torch.nn as nn
import torchtext
from torchtext import data
from torchtext.data import Field, Iterator
from torchtext import datasets
import torch.nn.functional as F
import random
import re
import time
import numpy as np
import spacy
from spacy.tokenizer import Tokenizer

In [None]:
!python -m spacy download en

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
nlp = spacy.load("en")
tokenizer = Tokenizer(nlp.vocab)

In [None]:
def spacy_tokenize(x):
    x = re.sub(
        r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", 
        str(x))
    x = re.sub(r"[ ]+", " ", x)
    x = re.sub(r"\!+", "!", x)
    x = re.sub(r"\,+", ",", x)
    x = re.sub(r"\?+", "?", x)
    return [tok.text for tok in tokenizer(x) if tok.text != " "]

In [None]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
path = ('/content/drive/My Drive/Colab Notebooks/NLI/NLI_Datasets/')

TEXT = data.Field(lower=True, tokenize = spacy_tokenize, batch_first = True)
LABEL = data.LabelField(sequential=False, is_target = True)

fields = {'sentence1': ('premise', TEXT),
          'sentence2': ('hypothesis', TEXT),
          'gold_label': ('label', LABEL)}

train, dev, test = data.TabularDataset.splits(
    path=path, 
    train=('mnli_train.jsonl'),
    validation=('mnli_dev.jsonl'),
    test=('mnli_test.jsonl'),
    format='json', 
    fields=fields
)

In [None]:
print(f'Number of training examples: {len(train)}')
print(f'Number of valid examples: {len(dev)}')
print(f'Number of testing examples: {len(test)}')

In [None]:
print(vars(train.examples[0]))

In [None]:
print(vars(test.examples[0]))

In [None]:
TEXT.build_vocab(train, dev, min_freq=2, vectors=torchtext.vocab.Vectors('/content/drive/My Drive/Colab Notebooks/NLI/Glove/glove.6B.300d.txt', unk_init=torch.Tensor.normal_))

In [None]:
LABEL.build_vocab(train)

In [None]:
print(f"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}")
print(f"Unique tokens in LABEL vocabulary: {len(LABEL.vocab)}")

In [None]:
print(TEXT.vocab.freqs.most_common(10))

In [None]:
print(TEXT.vocab.itos[:10])

In [None]:
print(LABEL.vocab.itos)

In [None]:
print(LABEL.vocab.freqs.most_common())

In [None]:
BATCH_SIZE = 64

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iter, dev_iter, test_iter = data.BucketIterator.splits(
    (train, dev, test), batch_size=BATCH_SIZE, sort=False, device=device)

In [None]:
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 300
HIDDEN_DIM = 200
OUTPUT_DIM = len(LABEL.vocab)
DP_RATIO = 0.2
LEARN_RATE = 0.001

In [None]:
class BiLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed_dim = EMBEDDING_DIM
        self.hidden_size = HIDDEN_DIM
        self.directions = 2
        self.num_layers = 2
        self.concat = 4
        self.device = device
        self.embedding = nn.Embedding(INPUT_DIM, EMBEDDING_DIM)
        self.projection = nn.Linear(self.embed_dim, self.hidden_size)
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size, self.num_layers,
                                    bidirectional = True, batch_first = True, dropout = DP_RATIO)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p = DP_RATIO)

        self.lin1 = nn.Linear(self.hidden_size * self.directions * self.concat, self.hidden_size)
        self.lin2 = nn.Linear(self.hidden_size, self.hidden_size)
        self.lin3 = nn.Linear(self.hidden_size, OUTPUT_DIM)

        for lin in [self.lin1, self.lin2, self.lin3]:
            nn.init.xavier_uniform_(lin.weight)
            nn.init.zeros_(lin.bias)

        self.out = nn.Sequential(
            self.lin1,
            self.relu,
            self.dropout,
            self.lin2,
            self.relu,
            self.dropout,
            self.lin3
        ) 
        
    def forward(self, batch):
        premise_embed = self.embedding(batch.premise)
        hypothesis_embed = self.embedding(batch.hypothesis)

        premise_proj = self.relu(self.projection(premise_embed))
        hypothesis_proj = self.relu(self.projection(hypothesis_embed))

        h0 = c0 = torch.tensor([]).new_zeros((self.num_layers * self.directions, batch.batch_size, self.hidden_size)).to(self.device)

        _, (premise_ht, _) = self.lstm(premise_proj, (h0, c0))
        _, (hypothesis_ht, _) = self.lstm(hypothesis_proj, (h0, c0))
    
        premise = premise_ht[-2:].transpose(0, 1).contiguous().view(batch.batch_size, -1)
        hypothesis = hypothesis_ht[-2:].transpose(0, 1).contiguous().view(batch.batch_size, -1)

        combined = torch.cat((premise, hypothesis, torch.abs(premise - hypothesis), premise * hypothesis), 1)
        return self.out(combined)

In [None]:
model = BiLSTM()

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:
pretrained_embeddings = TEXT.vocab.vectors
print(pretrained_embeddings.shape)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = LEARN_RATE)
criterion = nn.CrossEntropyLoss(reduction = 'sum')

In [None]:
model = model.to(device)
criterion = criterion.to(device)

In [None]:
def train(model, iterator, optimizer, criterion):
        model.train(); train_iter.init_epoch()
        n_correct, n_total, n_loss = 0, 0, 0
        for batch_idx, batch in enumerate(train_iter):
            optimizer.zero_grad()
            answer = model(batch)
            loss = criterion(answer, batch.label)
            
            n_correct += (torch.max(answer, 1)[1].view(batch.label.size()) == batch.label).sum().item()
            n_total += batch.batch_size
            n_loss += loss.item()
            
            loss.backward(); optimizer.step()
        train_loss = n_loss/n_total
        train_acc = 100. * n_correct/n_total
        return train_loss, train_acc

In [None]:
def validate(model, iterator, criterion):
        model.eval(); test_iter.init_epoch()
        n_correct, n_total, n_loss = 0, 0, 0
        with torch.no_grad():
            for batch_idx, batch in enumerate(test_iter):
                answer = model(batch)
                loss = criterion(answer, batch.label)
                
                n_correct += (torch.max(answer, 1)[1].view(batch.label.size()) == batch.label).sum().item()
                n_total += batch.batch_size
                n_loss += loss.item()

            val_loss = n_loss/n_total
            val_acc = 100. * n_correct/n_total
            return val_loss, val_acc

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
N_EPOCHS = 5

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iter, optimizer, criterion)
    valid_loss, valid_acc = validate(model, test_iter, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), '/content/drive/My Drive/Colab Notebooks/NLI/Models/bilstm-mnli-model.pt')
        
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc:.2f}%')
    