<a href="https://colab.research.google.com/github/graviraja/100-Days-of-NLP/blob/applications%2Fclassification/applications/classification/natural_language_inference/NLI%20with%20Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Natural Language Inference

The goal of natural language inference (NLI), a widely-studied natural language processing task, is to determine if one given statement (a premise) semantically entails another given statement (a hypothesis).


## Imports

In [2]:
import time
import math
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from torchtext import data, datasets, vocab

In [3]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## Fields

In [4]:
TEXT = data.Field(tokenize = 'spacy', lower = True)
LABEL = data.LabelField()

## SNLI (Stanford Natural Language Inference) Dataset

In [5]:
train_data, valid_data, test_data = datasets.SNLI.splits(TEXT, LABEL)

In [6]:
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")
print(f"Number of testing examples: {len(test_data)}")

Number of training examples: 549367
Number of validation examples: 9842
Number of testing examples: 9824


In [7]:
print(vars(train_data.examples[0]))

{'premise': ['a', 'person', 'on', 'a', 'horse', 'jumps', 'over', 'a', 'broken', 'down', 'airplane', '.'], 'hypothesis': ['a', 'person', 'is', 'training', 'his', 'horse', 'for', 'a', 'competition', '.'], 'label': 'neutral'}


## Building Vocabulary

In [8]:
MIN_FREQ = 10

TEXT.build_vocab(train_data, min_freq = MIN_FREQ)

LABEL.build_vocab(train_data)

In [9]:
print(f"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}")

Unique tokens in TEXT vocabulary: 12193


In [10]:
print(LABEL.vocab.itos)

['entailment', 'contradiction', 'neutral']


In [11]:
print(LABEL.vocab.freqs.most_common())

[('entailment', 183416), ('contradiction', 183187), ('neutral', 182764)]


## Data Iterators

In [12]:
BATCH_SIZE = 128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    device = device)

In [13]:
# sample check
sample = next(iter(valid_iterator))
sample.premise.shape, sample.hypothesis.shape

(torch.Size([7, 128]), torch.Size([7, 128]))

## Model
![](https://drive.google.com/uc?id=1vc_Bg0WSMEBZhdNx7JdJXbhYTbaRDbke)


In [14]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=512):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [15]:
class TransformerModel(nn.Module):
    def __init__(self, input_dim, d_model, n_head, hid_dim, n_layers, n_linear_layers, output_dim, dropout, pad_idx):
        super().__init__()

        self.pad_idx = pad_idx
        self.d_model = d_model
        self.pos_encoder = PositionalEncoding(d_model, dropout)

        self.embedding = nn.Embedding(input_dim, d_model, padding_idx=pad_idx)
        
        encoder_layers = TransformerEncoderLayer(d_model, n_head, hid_dim, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, n_layers)

        self.fcs = nn.ModuleList([nn.Linear(d_model * 2, d_model * 2) for _ in range(n_linear_layers)])
        self.layer_norms = nn.ModuleList([nn.LayerNorm(d_model * 2) for _ in range(n_linear_layers)])
        
        self.out = nn.Linear(d_model * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def create_mask(self, seq):
        # seq => [seq_len, batch_size]
        
        mask = (seq == self.pad_idx)
        mask = mask.permute(1, 0)
        # mask => [batch_size, seq_len]

    def forward(self, premise, hypothesis):
        # premise => [prem_seq_len, batch_size]
        # hypothesis => [hypo_seq_len, batch_size]

        # create input masks
        prem_mask = self.create_mask(premise)
        # prem_mask => [batch_size, prem_seq_len]
        hypo_mask = self.create_mask(hypothesis)
        # hypo_mask => [batch_size, hypo_seq_len]

        embedded_prem = self.dropout(self.embedding(premise)) * math.sqrt(self.d_model)
        # embedded_prem => [prem_seq_len, batch_size, emb_dim]

        embedded_hypo = self.dropout(self.embedding(hypothesis)) * math.sqrt(self.d_model)
        # embedded_hypo => [hypo_seq_len, batch_size, emb_dim]
        
        embedded_prem = self.pos_encoder(embedded_prem)
        embedded_hypo = self.pos_encoder(embedded_hypo)

        outputs_prem = self.transformer_encoder(embedded_prem, src_key_padding_mask=prem_mask)
        # outputs_prem => [prem_seq_len, batch_size, d_model]

        outputs_hypo = self.transformer_encoder(embedded_hypo, src_key_padding_mask=hypo_mask)
        # outputs_hypo => [hypo_seq_len, batch_size, d_model]
        
        # add the representation through attention
        prem_representation = self.dropout(torch.sum(outputs_prem, dim=0))
        hypo_representation = self.dropout(torch.sum(outputs_hypo, dim=0))
        # representation => [batch_size, d_model]

        hidden = torch.cat((prem_representation, hypo_representation), dim=-1)
        # hidden => [batch_size, d_model * 2]

        for fc, norm in zip(self.fcs, self.layer_norms):
            hidden_ = fc(hidden)
            hidden_ = self.dropout(hidden_)
            # residual connection
            hidden = hidden + F.relu(hidden_)
            # layer normalization
            hidden = norm(hidden)
        
        logits = self.out(hidden)
        # logits => [batch_size, output_dim]

        return logits


In [35]:
INPUT_DIM = len(TEXT.vocab)
D_MODEL = 128
N_HEAD = 8
HIDDEN_DIM = 200
N_LAYERS = 3
N_FC_LAYERS = 3
OUTPUT_DIM = len(LABEL.vocab)
DROPOUT = 0.3
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

model = TransformerModel(
    INPUT_DIM,
    D_MODEL,
    N_HEAD,
    HIDDEN_DIM,
    N_LAYERS,
    N_FC_LAYERS,
    OUTPUT_DIM,
    DROPOUT,
    PAD_IDX).to(device)

In [36]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 2,114,651 trainable parameters


In [37]:
def init_weights(model):
    for name, param in model.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)

model.apply(init_weights)

# def init_weights(m):
#     for name, param in m.named_parameters():
#         nn.init.normal_(param.data, mean = 0, std = 0.1)
# model.apply(init_weights)

TransformerModel(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (embedding): Embedding(12193, 128, padding_idx=1)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=200, bias=True)
        (dropout): Dropout(p=0.3, inplace=False)
        (linear2): Linear(in_features=200, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.3, inplace=False)
        (dropout2): Dropout(p=0.3, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=128, out_features=128, bias=True)
        )
      

## Optimizer & Loss Criterion

In [38]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

## Accuracy

In [39]:
def categorical_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    correct = max_preds.squeeze(1).eq(y)
    return correct.sum() / torch.FloatTensor([y.shape[0]])

## Train Loop

In [40]:
def train(model, iterator, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        prem = batch.premise
        hypo = batch.hypothesis
        labels = batch.label
        
        optimizer.zero_grad()
        
        predictions = model(prem, hypo)
        
        # predictions => [batch size, output dim]
        # labels => [batch size]
    
        loss = criterion(predictions, labels)            
        acc = categorical_accuracy(predictions, labels)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

## Validation Loop

In [41]:
def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            prem = batch.premise
            hypo = batch.hypothesis
            labels = batch.label
                        
            predictions = model(prem, hypo)
            
            loss = criterion(predictions, labels)
                
            acc = categorical_accuracy(predictions, labels)
            
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [42]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

## Training

In [43]:
N_EPOCHS = 20

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')


Epoch: 01 | Epoch Time: 4m 32s
	Train Loss: 0.905 | Train Acc: 57.45%
	 Val. Loss: 0.802 |  Val. Acc: 64.25%
Epoch: 02 | Epoch Time: 4m 33s
	Train Loss: 0.797 | Train Acc: 64.58%
	 Val. Loss: 0.771 |  Val. Acc: 66.86%
Epoch: 03 | Epoch Time: 4m 33s
	Train Loss: 0.772 | Train Acc: 66.05%
	 Val. Loss: 0.776 |  Val. Acc: 66.96%
Epoch: 04 | Epoch Time: 4m 32s
	Train Loss: 0.759 | Train Acc: 66.82%
	 Val. Loss: 0.771 |  Val. Acc: 68.07%
Epoch: 05 | Epoch Time: 4m 32s
	Train Loss: 0.749 | Train Acc: 67.39%
	 Val. Loss: 0.766 |  Val. Acc: 68.45%
Epoch: 06 | Epoch Time: 4m 33s
	Train Loss: 0.741 | Train Acc: 67.84%
	 Val. Loss: 0.766 |  Val. Acc: 68.75%
Epoch: 07 | Epoch Time: 4m 33s
	Train Loss: 0.734 | Train Acc: 68.15%
	 Val. Loss: 0.775 |  Val. Acc: 69.00%
Epoch: 08 | Epoch Time: 4m 33s
	Train Loss: 0.729 | Train Acc: 68.49%
	 Val. Loss: 0.780 |  Val. Acc: 68.98%
Epoch: 09 | Epoch Time: 4m 33s
	Train Loss: 0.724 | Train Acc: 68.74%
	 Val. Loss: 0.765 |  Val. Acc: 69.01%
Epoch: 10 | Epoch T

## Testing

In [44]:
model.load_state_dict(torch.load('model.pt'))

test_loss, test_acc = evaluate(model, test_iterator, criterion)

print(f'Test Loss: {test_loss:.3f} |  Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.746 |  Test Acc: 69.61%


## Inference

In [45]:
def inference(premise, hypothesis, text_field, label_field, model, device):
    
    model.eval()
    
    if isinstance(premise, str):
        premise = text_field.tokenize(premise)
    
    if isinstance(hypothesis, str):
        hypothesis = text_field.tokenize(hypothesis)
    
    if text_field.lower:
        premise = [t.lower() for t in premise]
        hypothesis = [t.lower() for t in hypothesis]

    # numericalize  
    premise = [text_field.vocab.stoi[t] for t in premise]
    hypothesis = [text_field.vocab.stoi[t] for t in hypothesis]
    
    # convert into tensors
    premise = torch.LongTensor(premise).unsqueeze(1).to(device)
    # premise => [prem_len, 1]
    hypothesis = torch.LongTensor(hypothesis).unsqueeze(1).to(device)
    # hypothesis => [hypo_len, 1]

    prediction = model(premise, hypothesis)
    prediction = prediction.argmax(dim=-1).item()

    return label_field.vocab.itos[prediction]

In [46]:
premise = 'A woman selling bamboo sticks talking to two men on a loading dock.'
hypothesis = 'There are at least three people on a loading dock.'

inference(premise, hypothesis, TEXT, LABEL, model, device)

'entailment'

In [47]:
premise = 'A woman selling bamboo sticks talking to two men on a loading dock.'
hypothesis = 'A woman is selling bamboo sticks to help provide for her family.'

inference(premise, hypothesis, TEXT, LABEL, model, device)

'neutral'

In [48]:
premise = 'A woman selling bamboo sticks talking to two men on a loading dock.'
hypothesis = ' A woman is not taking money for any of her sticks.'

inference(premise, hypothesis, TEXT, LABEL, model, device)

'contradiction'