<a href="https://colab.research.google.com/github/graviraja/100-Days-of-NLP/blob/applications%2Fclassification/applications/classification/natural_language_inference/NLI%20with%20BiLSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Natural Language Inference

The goal of natural language inference (NLI), a widely-studied natural language processing task, is to determine if one given statement (a premise) semantically entails another given statement (a hypothesis).


## Imports

In [None]:
import time
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext import data, datasets

In [None]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## Fields

In [None]:
TEXT = data.Field(tokenize = 'spacy', lower = True)
LABEL = data.LabelField()

## SNLI (Stanford Natural Language Inference) Dataset

In [None]:
train_data, valid_data, test_data = datasets.SNLI.splits(TEXT, LABEL)

downloading snli_1.0.zip


snli_1.0.zip: 100%|██████████| 94.6M/94.6M [00:44<00:00, 2.11MB/s]


extracting


In [None]:
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")
print(f"Number of testing examples: {len(test_data)}")

Number of training examples: 549367
Number of validation examples: 9842
Number of testing examples: 9824


In [None]:
print(vars(train_data.examples[0]))

{'premise': ['a', 'person', 'on', 'a', 'horse', 'jumps', 'over', 'a', 'broken', 'down', 'airplane', '.'], 'hypothesis': ['a', 'person', 'is', 'training', 'his', 'horse', 'for', 'a', 'competition', '.'], 'label': 'neutral'}


## Building Vocabulary

In [58]:
MIN_FREQ = 10

TEXT.build_vocab(train_data, min_freq = MIN_FREQ)

LABEL.build_vocab(train_data)

In [59]:
print(f"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}")

Unique tokens in TEXT vocabulary: 12193


In [60]:
print(LABEL.vocab.itos)

['entailment', 'contradiction', 'neutral']


In [61]:
print(LABEL.vocab.freqs.most_common())

[('entailment', 183416), ('contradiction', 183187), ('neutral', 182764)]


## Data Iterators

In [62]:
BATCH_SIZE = 128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    device = device)

## Model
![](https://drive.google.com/uc?id=1DfMfzLXSpbTRITVt3yuOPeIhHBo4-lxg)

In [None]:
class BiLSTM(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, n_linear_layers, output_dim, dropout, pad_idx):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
        self.rnn = nn.LSTM(emb_dim, hid_dim, num_layers=n_layers, bidirectional=True, dropout=dropout)

        d_model = hid_dim * 4
        self.fcs = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(n_linear_layers)])
        self.layer_norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(n_linear_layers)])
        self.out = nn.Linear(d_model, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, premise, hypothesis):
        # premise => [prem_seq_len, batch_size]
        # hypothesis => [hypo_seq_len, batch_size]

        embedded_prem = self.dropout(self.embedding(premise))
        # embedded_prem => [prem_seq_len, batch_size, emb_dim]

        embedded_hypo = self.dropout(self.embedding(hypothesis))
        # embedded_hypo => [hypo_seq_len, batch_size, emb_dim]

        outputs_prem, (hidden_prem, cell_prem) = self.rnn(embedded_prem)
        # outputs_prem => [prem_seq_len, batch_size, hid_dim * 2]
        # hidden_prem, cell_prem => [n_layers * num_dir, batch_size, hidden_dim]

        outputs_hypo, (hidden_hypo, cell_hypo) = self.rnn(embedded_hypo)
        # outputs_hypo => [hypo_seq_len, batch_size, hid_dim * 2]
        # hidden_hypo, cell_hypo => [n_layers * num_dir, batch_size, hidden_dim]
        
        # combine the final hidden states
        prem_representation = torch.cat((hidden_prem[-1], hidden_prem[-2]), dim=-1)
        hypo_representation = torch.cat((hidden_hypo[-1], hidden_hypo[-2]), dim=-1)
        # representation => [batch_size, hid_dim * 2]

        hidden = torch.cat((prem_representation, hypo_representation), dim=-1)
        # hidden => [batch_size, hidden_dim * 4]
        #        => [batch_size, d_model]

        for fc, norm in zip(self.fcs, self.layer_norms):
            hidden_ = fc(hidden)
            hidden_ = self.dropout(hidden_)
            # residual connection
            hidden = hidden + F.relu(hidden_)
            # layer normalization
            hidden = norm(hidden)
        
        logits = self.out(hidden)
        # logits => [batch_size, output_dim]

        return logits

In [63]:
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 50
HIDDEN_DIM = 100
N_LSTM_LAYERS = 2
N_FC_LAYERS = 3
OUTPUT_DIM = len(LABEL.vocab)
DROPOUT = 0.3
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

model = BiLSTM(
    INPUT_DIM,
    EMBEDDING_DIM,
    HIDDEN_DIM,
    N_LSTM_LAYERS,
    N_FC_LAYERS,
    OUTPUT_DIM,
    DROPOUT,
    PAD_IDX).to(device)

In [64]:
def init_weights(model):
    for name, param in model.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)

model.apply(init_weights)

BiLSTM(
  (embedding): Embedding(12193, 50, padding_idx=1)
  (rnn): LSTM(50, 100, num_layers=2, dropout=0.3, bidirectional=True)
  (fcs): ModuleList(
    (0): Linear(in_features=400, out_features=400, bias=True)
    (1): Linear(in_features=400, out_features=400, bias=True)
    (2): Linear(in_features=400, out_features=400, bias=True)
  )
  (layer_norms): ModuleList(
    (0): LayerNorm((400,), eps=1e-05, elementwise_affine=True)
    (1): LayerNorm((400,), eps=1e-05, elementwise_affine=True)
    (2): LayerNorm((400,), eps=1e-05, elementwise_affine=True)
  )
  (out): Linear(in_features=400, out_features=3, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)

In [65]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 1,457,653 trainable parameters


## Optimizer & Loss Criterion

In [66]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

## Accuracy

In [67]:
def categorical_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    correct = max_preds.squeeze(1).eq(y)
    return correct.sum() / torch.FloatTensor([y.shape[0]])

## Train Loop

In [68]:
def train(model, iterator, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        prem = batch.premise
        hypo = batch.hypothesis
        labels = batch.label
        
        optimizer.zero_grad()
        
        predictions = model(prem, hypo)
        
        # predictions => [batch size, output dim]
        # labels => [batch size]
    
        loss = criterion(predictions, labels)            
        acc = categorical_accuracy(predictions, labels)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

## Validation Loop

In [69]:
def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            prem = batch.premise
            hypo = batch.hypothesis
            labels = batch.label
                        
            predictions = model(prem, hypo)
            
            loss = criterion(predictions, labels)
                
            acc = categorical_accuracy(predictions, labels)
            
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [70]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

## Training

In [71]:
N_EPOCHS = 10

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')


Epoch: 01 | Epoch Time: 4m 19s
	Train Loss: 0.812 | Train Acc: 62.66%
	 Val. Loss: 0.734 |  Val. Acc: 68.38%
Epoch: 02 | Epoch Time: 4m 20s
	Train Loss: 0.701 | Train Acc: 70.09%
	 Val. Loss: 0.662 |  Val. Acc: 72.34%
Epoch: 03 | Epoch Time: 4m 20s
	Train Loss: 0.651 | Train Acc: 72.70%
	 Val. Loss: 0.630 |  Val. Acc: 74.13%
Epoch: 04 | Epoch Time: 4m 20s
	Train Loss: 0.620 | Train Acc: 74.31%
	 Val. Loss: 0.632 |  Val. Acc: 74.74%
Epoch: 05 | Epoch Time: 4m 21s
	Train Loss: 0.598 | Train Acc: 75.30%
	 Val. Loss: 0.616 |  Val. Acc: 75.72%
Epoch: 06 | Epoch Time: 4m 21s
	Train Loss: 0.582 | Train Acc: 76.12%
	 Val. Loss: 0.602 |  Val. Acc: 75.94%
Epoch: 07 | Epoch Time: 4m 20s
	Train Loss: 0.567 | Train Acc: 76.84%
	 Val. Loss: 0.600 |  Val. Acc: 76.40%
Epoch: 08 | Epoch Time: 4m 19s
	Train Loss: 0.555 | Train Acc: 77.39%
	 Val. Loss: 0.601 |  Val. Acc: 76.82%
Epoch: 09 | Epoch Time: 4m 20s
	Train Loss: 0.544 | Train Acc: 77.90%
	 Val. Loss: 0.589 |  Val. Acc: 76.99%
Epoch: 10 | Epoch T

## Testing

In [72]:
model.load_state_dict(torch.load('model.pt'))

test_loss, test_acc = evaluate(model, test_iterator, criterion)

print(f'Test Loss: {test_loss:.3f} |  Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.584 |  Test Acc: 76.84%


## Inference

In [73]:
def inference(premise, hypothesis, text_field, label_field, model, device):
    
    model.eval()
    
    if isinstance(premise, str):
        premise = text_field.tokenize(premise)
    
    if isinstance(hypothesis, str):
        hypothesis = text_field.tokenize(hypothesis)
    
    if text_field.lower:
        premise = [t.lower() for t in premise]
        hypothesis = [t.lower() for t in hypothesis]

    # numericalize  
    premise = [text_field.vocab.stoi[t] for t in premise]
    hypothesis = [text_field.vocab.stoi[t] for t in hypothesis]
    
    # convert into tensors
    premise = torch.LongTensor(premise).unsqueeze(1).to(device)
    # premise => [prem_len, 1]
    hypothesis = torch.LongTensor(hypothesis).unsqueeze(1).to(device)
    # hypothesis => [hypo_len, 1]

    prediction = model(premise, hypothesis)
    prediction = prediction.argmax(dim=-1).item()

    return label_field.vocab.itos[prediction]

In [74]:
premise = 'A woman selling bamboo sticks talking to two men on a loading dock.'
hypothesis = 'There are at least three people on a loading dock.'

inference(premise, hypothesis, TEXT, LABEL, model, device)

'entailment'

In [75]:
premise = 'A woman selling bamboo sticks talking to two men on a loading dock.'
hypothesis = 'A woman is selling bamboo sticks to help provide for her family.'

inference(premise, hypothesis, TEXT, LABEL, model, device)

'neutral'

In [76]:
premise = 'A woman selling bamboo sticks talking to two men on a loading dock.'
hypothesis = ' A woman is not taking money for any of her sticks.'

inference(premise, hypothesis, TEXT, LABEL, model, device)

'contradiction'