<a href="https://colab.research.google.com/github/graviraja/100-Days-of-NLP/blob/applications%2Fclassification/applications/classification/natural_language_inference/NLI%20with%20Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Natural Language Inference

The goal of natural language inference (NLI), a widely-studied natural language processing task, is to determine if one given statement (a premise) semantically entails another given statement (a hypothesis).


## Imports

In [81]:
import time
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext import data, datasets, vocab

In [78]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## Glove Embeddings

In [79]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [80]:
!unzip "./drive/My Drive/glove.6B.zip"

Archive:  ./drive/My Drive/glove.6B.zip
  inflating: glove.6B.100d.txt       
  inflating: glove.6B.200d.txt       
  inflating: glove.6B.300d.txt       
  inflating: glove.6B.50d.txt        


In [112]:
glove_file_path = "./glove.6B.100d.txt"

In [113]:
vectors = vocab.Vectors(glove_file_path, unk_init = torch.Tensor.normal_)

 99%|█████████▉| 397436/400001 [00:15<00:00, 24832.19it/s]

## Fields

In [114]:
TEXT = data.Field(tokenize = 'spacy', lower = True, include_lengths=True)
LABEL = data.LabelField()

## SNLI (Stanford Natural Language Inference) Dataset

In [115]:
train_data, valid_data, test_data = datasets.SNLI.splits(TEXT, LABEL)

In [116]:
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")
print(f"Number of testing examples: {len(test_data)}")

Number of training examples: 549367
Number of validation examples: 9842
Number of testing examples: 9824


In [117]:
print(vars(train_data.examples[0]))

{'premise': ['a', 'person', 'on', 'a', 'horse', 'jumps', 'over', 'a', 'broken', 'down', 'airplane', '.'], 'hypothesis': ['a', 'person', 'is', 'training', 'his', 'horse', 'for', 'a', 'competition', '.'], 'label': 'neutral'}


## Building Vocabulary

In [118]:
MIN_FREQ = 10

TEXT.build_vocab(train_data, min_freq = MIN_FREQ, vectors=vectors)

LABEL.build_vocab(train_data)

In [119]:
print(f"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}")

Unique tokens in TEXT vocabulary: 12193


In [120]:
print(LABEL.vocab.itos)

['entailment', 'contradiction', 'neutral']


In [121]:
print(LABEL.vocab.freqs.most_common())

[('entailment', 183416), ('contradiction', 183187), ('neutral', 182764)]


## Data Iterators

In [122]:
BATCH_SIZE = 128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    sort_key=lambda x: len(x.premise),
    sort_within_batch=True,
    device = device)

In [123]:
# sample check
sample = next(iter(valid_iterator))
prem, prem_lengths = sample.premise
hypo, hypo_lengths = sample.hypothesis
print(prem.shape, prem_lengths.shape)
print(hypo.shape, hypo_lengths.shape)
print(sample.label.shape)

torch.Size([6, 128]) torch.Size([128])
torch.Size([14, 128]) torch.Size([128])
torch.Size([128])


In [124]:
print(prem_lengths[:100])
print(hypo_lengths[:100])

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4], device='cuda:0')
tensor([10, 10,  8,  6,  4, 11,  7, 14,  9,  8,  8, 10,  7,  7,  7,  5,  8, 12,
         5,  7,  5,  4,  5,  8, 11,  8,  8,  9,  6,  7,  8,  7,  5,  6,  6,  5,
         9,  8,  7, 10, 12,  4,  7,  7,  5,  5, 11,  6,  7,  7,  7,  7, 10,  4,
         6,  4,  5,  5, 10,  7,  8,  6,  6,  5,  5,  6,  3,  3,  7, 10,  7,  8,
         7,  5,  7,  7, 13,  7,  7,  6,  8,  9, 13,  6,  5,  8,  9,  9,  8,  5,
         8,  9,  5,  4,  2, 11,  5, 11,  6,  5], device='cuda:0')


## Model
![](https://drive.google.com/uc?id=1vc_Bg0WSMEBZhdNx7JdJXbhYTbaRDbke)


In [125]:
class BiLSTMWithAttention(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, n_linear_layers, output_dim, dropout, pad_idx):
        super().__init__()

        self.pad_idx = pad_idx
    
        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
        self.rnn = nn.LSTM(emb_dim, hid_dim, num_layers=n_layers, bidirectional=True, dropout=dropout)

        self.attn = nn.Linear(hid_dim * 2, hid_dim * 2)
        self.v = nn.Linear(hid_dim * 2, 1, bias=False)

        d_model = hid_dim * 4
        self.fcs = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(n_linear_layers)])
        self.layer_norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(n_linear_layers)])
        
        self.out = nn.Linear(d_model, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def create_mask(self, seq):
        # seq => [seq_len, batch_size]
        
        mask = (seq != self.pad_idx)
        mask = mask.permute(1, 0)
        # mask => [batch_size, seq_len]

    def forward(self, premise, hypothesis, premise_lengths, hypothesis_lengths):
        # premise => [prem_seq_len, batch_size]
        # hypothesis => [hypo_seq_len, batch_size]
        # premise_lengths => [batch_size]
        # hypothesis_lengths => [batch_size]

        # create input masks
        prem_mask = self.create_mask(premise)
        # prem_mask => [batch_size, prem_seq_len]
        hypo_mask = self.create_mask(hypothesis)
        # hypo_mask => [batch_size, hypo_seq_len]

        embedded_prem = self.dropout(self.embedding(premise))
        # embedded_prem => [prem_seq_len, batch_size, emb_dim]

        embedded_hypo = self.dropout(self.embedding(hypothesis))
        # embedded_hypo => [hypo_seq_len, batch_size, emb_dim]
        
        packed_prem = nn.utils.rnn.pack_padded_sequence(embedded_prem, premise_lengths)
        packed_out_prem, (hidden_prem, cell_prem) = self.rnn(packed_prem)
        outputs_prem, _ = nn.utils.rnn.pad_packed_sequence(packed_out_prem)
        # outputs_prem => [prem_seq_len, batch_size, hid_dim * 2]
        # hidden_prem, cell_prem => [n_layers * num_dir, batch_size, hidden_dim]

        outputs_hypo, (hidden_hypo, cell_hypo) = self.rnn(embedded_hypo)
        # outputs_hypo => [hypo_seq_len, batch_size, hid_dim * 2]
        # hidden_hypo, cell_hypo => [n_layers * num_dir, batch_size, hidden_dim]
        
        # weighted representation through attention
        weighted_prem = self.dropout(self.attention(outputs_prem, prem_mask))
        weighted_hypo = self.dropout(self.attention(outputs_hypo, hypo_mask))
        # weighted => [batch_size, hid_dim * 2]

        hidden = torch.cat((weighted_prem, weighted_hypo), dim=-1)
        # hidden => [batch_size, hidden_dim * 4]
        #        => [batch_size, d_model]

        for fc, norm in zip(self.fcs, self.layer_norms):
            hidden_ = fc(hidden)
            hidden_ = self.dropout(hidden_)
            # residual connection
            hidden = hidden + F.relu(hidden_)
            # layer normalization
            hidden = norm(hidden)
        
        logits = self.out(hidden)
        # logits => [batch_size, output_dim]

        return logits
    
    def attention(self, outputs, mask=None):
        # outputs => [seq_len, batch_size, hidden_dim * 2]
        # mask => [batch_size, seq_len]

        batch_size, seq_len, _ = outputs.shape

        outputs_ = outputs.permute(1, 0, 2)
        # outputs_ => [batch_size, seq_len, hidden_dim * 2]

        energy = torch.tanh(self.attn(outputs_))
        # energy => [batch_size, seq_len, hidden_dim * 2]

        attention_energy = self.v(energy).squeeze(2)
        # attention_energy => [batch_size, seq_len]

        if mask is not None:
            attention_energy = attention_energy.masked_fill(mask == 0, -1e10)
            # attention_energy => [batch_size, seq_len]

        scores = F.softmax(attention_energy, dim=-1)
        # scores => [batch_size, seq_len]

        scores = scores.unsqueeze(1)
        # scores => [batch_size, 1, seq_len]

        outputs = outputs.permute(1, 0, 2)
        # outputs => [batch_size, seq_len, hidden_dim * 2]

        weighted = torch.bmm(scores, outputs)
        # weighted => [batch_size, 1, hidden_dim * 2]

        weighted = weighted.squeeze(1)
        # weighted => [batch_size, hidden_dim * 2]

        return weighted

In [150]:

INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 200
N_LSTM_LAYERS = 2
N_FC_LAYERS = 3
OUTPUT_DIM = len(LABEL.vocab)
DROPOUT = 0.3
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

model = BiLSTMWithAttention(
    INPUT_DIM,
    EMBEDDING_DIM,
    HIDDEN_DIM,
    N_LSTM_LAYERS,
    N_FC_LAYERS,
    OUTPUT_DIM,
    DROPOUT,
    PAD_IDX).to(device)

In [151]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 4,756,103 trainable parameters


In [152]:
pretrained_embeddings = TEXT.vocab.vectors
print(pretrained_embeddings.shape)

torch.Size([12193, 100])


In [153]:
model.embedding.weight.data.copy_(pretrained_embeddings)

tensor([[ 0.2837, -0.6263, -0.4435,  ...,  0.4368, -0.8261, -0.1570],
        [ 1.4773,  1.2373, -0.3034,  ..., -1.5434,  0.0221, -0.3314],
        [-0.2709,  0.0440, -0.0203,  ..., -0.4923,  0.6369,  0.2364],
        ...,
        [-0.3962, -0.0070,  0.4369,  ..., -0.3295,  0.1764,  0.0092],
        [ 0.0882, -0.3188,  0.4663,  ...,  0.8881,  0.5180, -0.1170],
        [-0.6644, -0.3045,  0.6151,  ...,  0.1404,  0.5788, -0.0333]],
       device='cuda:0')

In [154]:
model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

In [155]:
print(model.embedding.weight.data)

tensor([[ 0.2837, -0.6263, -0.4435,  ...,  0.4368, -0.8261, -0.1570],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2709,  0.0440, -0.0203,  ..., -0.4923,  0.6369,  0.2364],
        ...,
        [-0.3962, -0.0070,  0.4369,  ..., -0.3295,  0.1764,  0.0092],
        [ 0.0882, -0.3188,  0.4663,  ...,  0.8881,  0.5180, -0.1170],
        [-0.6644, -0.3045,  0.6151,  ...,  0.1404,  0.5788, -0.0333]],
       device='cuda:0')


In [162]:
model

BiLSTMWithAttention(
  (embedding): Embedding(12193, 100, padding_idx=1)
  (rnn): LSTM(100, 200, num_layers=2, dropout=0.3, bidirectional=True)
  (attn): Linear(in_features=400, out_features=400, bias=True)
  (v): Linear(in_features=400, out_features=1, bias=False)
  (fcs): ModuleList(
    (0): Linear(in_features=800, out_features=800, bias=True)
    (1): Linear(in_features=800, out_features=800, bias=True)
    (2): Linear(in_features=800, out_features=800, bias=True)
  )
  (layer_norms): ModuleList(
    (0): LayerNorm((800,), eps=1e-05, elementwise_affine=True)
    (1): LayerNorm((800,), eps=1e-05, elementwise_affine=True)
    (2): LayerNorm((800,), eps=1e-05, elementwise_affine=True)
  )
  (out): Linear(in_features=800, out_features=3, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)

## Optimizer & Loss Criterion

In [156]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

## Accuracy

In [157]:
def categorical_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    correct = max_preds.squeeze(1).eq(y)
    return correct.sum() / torch.FloatTensor([y.shape[0]])

## Train Loop

In [158]:
def train(model, iterator, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        prem, prem_lengths = batch.premise
        hypo, hypo_lengths = batch.hypothesis
        labels = batch.label
        
        optimizer.zero_grad()
        
        predictions = model(prem, hypo, prem_lengths, hypo_lengths)
        
        # predictions => [batch size, output dim]
        # labels => [batch size]
    
        loss = criterion(predictions, labels)            
        acc = categorical_accuracy(predictions, labels)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

## Validation Loop

In [159]:
def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            prem, prem_lengths = batch.premise
            hypo, hypo_lengths = batch.hypothesis
            labels = batch.label
                        
            predictions = model(prem, hypo, prem_lengths, hypo_lengths)
            
            loss = criterion(predictions, labels)
                
            acc = categorical_accuracy(predictions, labels)
            
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [160]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

## Training

In [161]:
N_EPOCHS = 10

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')


Epoch: 01 | Epoch Time: 1m 55s
	Train Loss: 0.804 | Train Acc: 63.64%
	 Val. Loss: 0.666 |  Val. Acc: 72.31%
Epoch: 02 | Epoch Time: 1m 55s
	Train Loss: 0.682 | Train Acc: 71.00%
	 Val. Loss: 0.596 |  Val. Acc: 75.56%
Epoch: 03 | Epoch Time: 1m 55s
	Train Loss: 0.618 | Train Acc: 74.49%
	 Val. Loss: 0.553 |  Val. Acc: 77.66%
Epoch: 04 | Epoch Time: 1m 55s
	Train Loss: 0.577 | Train Acc: 76.48%
	 Val. Loss: 0.546 |  Val. Acc: 78.58%
Epoch: 05 | Epoch Time: 1m 55s
	Train Loss: 0.549 | Train Acc: 77.75%
	 Val. Loss: 0.542 |  Val. Acc: 79.12%
Epoch: 06 | Epoch Time: 1m 55s
	Train Loss: 0.528 | Train Acc: 78.77%
	 Val. Loss: 0.524 |  Val. Acc: 79.77%
Epoch: 07 | Epoch Time: 1m 55s
	Train Loss: 0.511 | Train Acc: 79.57%
	 Val. Loss: 0.530 |  Val. Acc: 80.44%
Epoch: 08 | Epoch Time: 1m 55s
	Train Loss: 0.497 | Train Acc: 80.22%
	 Val. Loss: 0.531 |  Val. Acc: 80.42%
Epoch: 09 | Epoch Time: 1m 55s
	Train Loss: 0.485 | Train Acc: 80.74%
	 Val. Loss: 0.525 |  Val. Acc: 80.73%
Epoch: 10 | Epoch T

## Testing

In [163]:
model.load_state_dict(torch.load('model.pt'))

test_loss, test_acc = evaluate(model, test_iterator, criterion)

print(f'Test Loss: {test_loss:.3f} |  Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.532 |  Test Acc: 79.51%


## Inference

In [164]:
def inference(premise, hypothesis, text_field, label_field, model, device):
    
    model.eval()
    
    if isinstance(premise, str):
        premise = text_field.tokenize(premise)
    
    if isinstance(hypothesis, str):
        hypothesis = text_field.tokenize(hypothesis)
    
    if text_field.lower:
        premise = [t.lower() for t in premise]
        hypothesis = [t.lower() for t in hypothesis]

    # numericalize  
    premise = [text_field.vocab.stoi[t] for t in premise]
    premise_lengths = [len(premise)]
    hypothesis = [text_field.vocab.stoi[t] for t in hypothesis]
    hypothesis_lengths = [len(hypothesis)]
    
    # convert into tensors
    premise = torch.LongTensor(premise).unsqueeze(1).to(device)
    # premise => [prem_len, 1]
    premise_lengths = torch.LongTensor(premise_lengths).to(device)
    # premise_lengths => [1]
    hypothesis = torch.LongTensor(hypothesis).unsqueeze(1).to(device)
    # hypothesis => [hypo_len, 1]
    hypothesis_lengths = torch.LongTensor(hypothesis_lengths).to(device)
    # hypothesis_lengths => [1]

    prediction = model(premise, hypothesis, premise_lengths, hypothesis_lengths)
    prediction = prediction.argmax(dim=-1).item()

    return label_field.vocab.itos[prediction]

In [165]:
premise = 'A woman selling bamboo sticks talking to two men on a loading dock.'
hypothesis = 'There are at least three people on a loading dock.'

inference(premise, hypothesis, TEXT, LABEL, model, device)

'entailment'

In [166]:
premise = 'A woman selling bamboo sticks talking to two men on a loading dock.'
hypothesis = 'A woman is selling bamboo sticks to help provide for her family.'

inference(premise, hypothesis, TEXT, LABEL, model, device)

'neutral'

In [167]:
premise = 'A woman selling bamboo sticks talking to two men on a loading dock.'
hypothesis = ' A woman is not taking money for any of her sticks.'

inference(premise, hypothesis, TEXT, LABEL, model, device)

'contradiction'