<a href="https://colab.research.google.com/github/graviraja/100-Days-of-NLP/blob/applications%2Fclassification/applications/classification/natural_language_inference/NLI%20with%20Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Natural Language Inference

The goal of natural language inference (NLI), a widely-studied natural language processing task, is to determine if one given statement (a premise) semantically entails another given statement (a hypothesis).


# NLI with Distillation

**`Distillation`**: A technique you can use to compress a large model, called the `teacher`, into a smaller model, called the `student`.

- [Medium blog on Distillation by Victor Sanh (Must Read)](https://medium.com/huggingface/distilbert-8cf3380435b5)

## Imports

In [1]:
import time
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext import data, datasets, vocab

In [2]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## Glove Embeddings

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [4]:
!unzip "./drive/My Drive/glove.6B.zip"

Archive:  ./drive/My Drive/glove.6B.zip
  inflating: glove.6B.100d.txt       
  inflating: glove.6B.200d.txt       
  inflating: glove.6B.300d.txt       
  inflating: glove.6B.50d.txt        


In [5]:
glove_file_path = "./glove.6B.100d.txt"

In [6]:
vectors = vocab.Vectors(glove_file_path, unk_init = torch.Tensor.normal_)

100%|█████████▉| 398301/400001 [00:24<00:00, 17882.25it/s]

## Fields

In [18]:
TEXT = data.Field(tokenize = 'spacy', lower = True)
LABEL = data.LabelField()

## SNLI (Stanford Natural Language Inference) Dataset

In [19]:
train_data, valid_data, test_data = datasets.SNLI.splits(TEXT, LABEL)

In [20]:
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")
print(f"Number of testing examples: {len(test_data)}")

Number of training examples: 549367
Number of validation examples: 9842
Number of testing examples: 9824


In [21]:
print(vars(train_data.examples[0]))

{'premise': ['a', 'person', 'on', 'a', 'horse', 'jumps', 'over', 'a', 'broken', 'down', 'airplane', '.'], 'hypothesis': ['a', 'person', 'is', 'training', 'his', 'horse', 'for', 'a', 'competition', '.'], 'label': 'neutral'}


## Building Vocabulary

In [22]:
MIN_FREQ = 10

TEXT.build_vocab(train_data, min_freq = MIN_FREQ, vectors=vectors)

LABEL.build_vocab(train_data)

In [23]:
print(f"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}")

Unique tokens in TEXT vocabulary: 12193


In [24]:
print(LABEL.vocab.itos)

['entailment', 'contradiction', 'neutral']


In [25]:
print(LABEL.vocab.freqs.most_common())

[('entailment', 183416), ('contradiction', 183187), ('neutral', 182764)]


## Data Iterators

In [27]:
BATCH_SIZE = 128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    sort_key=lambda x: len(x.premise),
    device = device)

In [29]:
# sample check
sample = next(iter(valid_iterator))
sample.premise.shape, sample.hypothesis.shape

(torch.Size([6, 128]), torch.Size([14, 128]))

# Base Model

In [30]:
class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim, emb_dim, output_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.linear = nn.Linear(emb_dim, output_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, premise, hypothesis):
        prem_embedded = self.dropout(self.embedding(premise))
        hypo_embedded = self.dropout(self.embedding(hypothesis))
        combined = torch.cat((prem_embedded, hypo_embedded), dim=0)
        final = torch.sum(combined, dim=0)
        outputs = self.linear(self.dropout(final))
        return outputs

In [31]:
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
OUTPUT_DIM = len(LABEL.vocab)
DROPOUT = 0.3
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

base_model = LogisticRegressionModel(
    INPUT_DIM,
    EMBEDDING_DIM,
    OUTPUT_DIM,
    DROPOUT).to(device)

In [32]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(base_model):,} trainable parameters')

The model has 1,219,603 trainable parameters


In [33]:
pretrained_embeddings = TEXT.vocab.vectors
print(pretrained_embeddings.shape)

torch.Size([12193, 100])


In [34]:
base_model.embedding.weight.data.copy_(pretrained_embeddings)

tensor([[ 0.2837, -0.6263, -0.4435,  ...,  0.4368, -0.8261, -0.1570],
        [-0.7534,  0.2218,  0.4468,  ...,  0.9045, -1.6214, -0.4485],
        [-0.2709,  0.0440, -0.0203,  ..., -0.4923,  0.6369,  0.2364],
        ...,
        [-0.3962, -0.0070,  0.4369,  ..., -0.3295,  0.1764,  0.0092],
        [ 0.0882, -0.3188,  0.4663,  ...,  0.8881,  0.5180, -0.1170],
        [-0.6644, -0.3045,  0.6151,  ...,  0.1404,  0.5788, -0.0333]],
       device='cuda:0')

In [35]:
base_model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

In [36]:
print(base_model.embedding.weight.data)

tensor([[ 0.2837, -0.6263, -0.4435,  ...,  0.4368, -0.8261, -0.1570],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2709,  0.0440, -0.0203,  ..., -0.4923,  0.6369,  0.2364],
        ...,
        [-0.3962, -0.0070,  0.4369,  ..., -0.3295,  0.1764,  0.0092],
        [ 0.0882, -0.3188,  0.4663,  ...,  0.8881,  0.5180, -0.1170],
        [-0.6644, -0.3045,  0.6151,  ...,  0.1404,  0.5788, -0.0333]],
       device='cuda:0')


In [37]:
base_model

LogisticRegressionModel(
  (embedding): Embedding(12193, 100)
  (linear): Linear(in_features=100, out_features=3, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)

### Optimizer & Loss Criterion

In [39]:
optimizer = optim.Adam(base_model.parameters())
criterion = nn.CrossEntropyLoss()

### Accuracy

In [40]:
def categorical_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    correct = max_preds.squeeze(1).eq(y)
    return correct.sum() / torch.FloatTensor([y.shape[0]])

### Training Method

In [41]:
def train(model, iterator, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        prem = batch.premise
        hypo = batch.hypothesis
        labels = batch.label
        
        optimizer.zero_grad()
        
        predictions = model(prem, hypo)
        
        # predictions => [batch size, output dim]
        # labels => [batch size]
    
        loss = criterion(predictions, labels)            
        acc = categorical_accuracy(predictions, labels)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

### Validation Method

In [42]:
def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            prem = batch.premise
            hypo = batch.hypothesis
            labels = batch.label
                        
            predictions = model(prem, hypo)
            
            loss = criterion(predictions, labels)
                
            acc = categorical_accuracy(predictions, labels)
            
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [43]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

### Base Model Training

In [45]:
N_EPOCHS = 10

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(base_model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(base_model, valid_iterator, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(base_model.state_dict(), 'base_model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 0m 28s
	Train Loss: 1.023 | Train Acc: 49.20%
	 Val. Loss: 0.979 |  Val. Acc: 53.27%
Epoch: 02 | Epoch Time: 0m 27s
	Train Loss: 0.992 | Train Acc: 52.16%
	 Val. Loss: 0.972 |  Val. Acc: 54.57%
Epoch: 03 | Epoch Time: 0m 27s
	Train Loss: 0.977 | Train Acc: 53.45%
	 Val. Loss: 0.964 |  Val. Acc: 55.23%
Epoch: 04 | Epoch Time: 0m 27s
	Train Loss: 0.970 | Train Acc: 54.15%
	 Val. Loss: 0.960 |  Val. Acc: 55.46%
Epoch: 05 | Epoch Time: 0m 27s
	Train Loss: 0.964 | Train Acc: 54.59%
	 Val. Loss: 0.963 |  Val. Acc: 54.98%
Epoch: 06 | Epoch Time: 0m 27s
	Train Loss: 0.959 | Train Acc: 54.91%
	 Val. Loss: 0.959 |  Val. Acc: 55.63%
Epoch: 07 | Epoch Time: 0m 27s
	Train Loss: 0.957 | Train Acc: 55.08%
	 Val. Loss: 0.959 |  Val. Acc: 55.54%
Epoch: 08 | Epoch Time: 0m 28s
	Train Loss: 0.954 | Train Acc: 55.32%
	 Val. Loss: 0.960 |  Val. Acc: 55.71%
Epoch: 09 | Epoch Time: 0m 28s
	Train Loss: 0.952 | Train Acc: 55.46%
	 Val. Loss: 0.958 |  Val. Acc: 55.49%
Epoch: 10 | Epoch T

### Base Model - Test Performance

In [46]:
base_model.load_state_dict(torch.load('base_model.pt'))

test_loss, test_acc = evaluate(base_model, test_iterator, criterion)

print(f'Test Loss: {test_loss:.3f} |  Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.962 |  Test Acc: 55.14%


# Teacher Model

In [47]:
class BiLSTMWithAttention(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, n_linear_layers, output_dim, dropout, pad_idx):
        super().__init__()

        self.pad_idx = pad_idx
    
        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
        self.rnn = nn.LSTM(emb_dim, hid_dim, num_layers=n_layers, bidirectional=True, dropout=dropout)

        self.attn = nn.Linear(hid_dim * 2, hid_dim * 2)
        self.v = nn.Linear(hid_dim * 2, 1, bias=False)

        d_model = hid_dim * 4
        self.fcs = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(n_linear_layers)])
        self.layer_norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(n_linear_layers)])
        
        self.out = nn.Linear(d_model, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def create_mask(self, seq):
        # seq => [seq_len, batch_size]
        
        mask = (seq != self.pad_idx)
        mask = mask.permute(1, 0)
        # mask => [batch_size, seq_len]

    def forward(self, premise, hypothesis):
        # premise => [prem_seq_len, batch_size]
        # hypothesis => [hypo_seq_len, batch_size]

        # create input masks
        prem_mask = self.create_mask(premise)
        # prem_mask => [batch_size, prem_seq_len]
        hypo_mask = self.create_mask(hypothesis)
        # hypo_mask => [batch_size, hypo_seq_len]

        embedded_prem = self.dropout(self.embedding(premise))
        # embedded_prem => [prem_seq_len, batch_size, emb_dim]

        embedded_hypo = self.dropout(self.embedding(hypothesis))
        # embedded_hypo => [hypo_seq_len, batch_size, emb_dim]
        
        outputs_prem, (hidden_prem, cell_prem) = self.rnn(embedded_prem)
        # outputs_prem => [prem_seq_len, batch_size, hid_dim * 2]
        # hidden_prem, cell_prem => [n_layers * num_dir, batch_size, hidden_dim]

        outputs_hypo, (hidden_hypo, cell_hypo) = self.rnn(embedded_hypo)
        # outputs_hypo => [hypo_seq_len, batch_size, hid_dim * 2]
        # hidden_hypo, cell_hypo => [n_layers * num_dir, batch_size, hidden_dim]
        
        # weighted representation through attention
        weighted_prem = self.dropout(self.attention(outputs_prem, prem_mask))
        weighted_hypo = self.dropout(self.attention(outputs_hypo, hypo_mask))
        # weighted => [batch_size, hid_dim * 2]

        hidden = torch.cat((weighted_prem, weighted_hypo), dim=-1)
        # hidden => [batch_size, hidden_dim * 4]
        #        => [batch_size, d_model]

        for fc, norm in zip(self.fcs, self.layer_norms):
            hidden_ = fc(hidden)
            hidden_ = self.dropout(hidden_)
            # residual connection
            hidden = hidden + F.relu(hidden_)
            # layer normalization
            hidden = norm(hidden)
        
        logits = self.out(hidden)
        # logits => [batch_size, output_dim]

        return logits
    
    def attention(self, outputs, mask=None):
        # outputs => [seq_len, batch_size, hidden_dim * 2]
        # mask => [batch_size, seq_len]

        batch_size, seq_len, _ = outputs.shape

        outputs_ = outputs.permute(1, 0, 2)
        # outputs_ => [batch_size, seq_len, hidden_dim * 2]

        energy = torch.tanh(self.attn(outputs_))
        # energy => [batch_size, seq_len, hidden_dim * 2]

        attention_energy = self.v(energy).squeeze(2)
        # attention_energy => [batch_size, seq_len]

        if mask is not None:
            attention_energy = attention_energy.masked_fill(mask == 0, -1e10)
            # attention_energy => [batch_size, seq_len]

        scores = F.softmax(attention_energy, dim=-1)
        # scores => [batch_size, seq_len]

        scores = scores.unsqueeze(1)
        # scores => [batch_size, 1, seq_len]

        outputs = outputs.permute(1, 0, 2)
        # outputs => [batch_size, seq_len, hidden_dim * 2]

        weighted = torch.bmm(scores, outputs)
        # weighted => [batch_size, 1, hidden_dim * 2]

        weighted = weighted.squeeze(1)
        # weighted => [batch_size, hidden_dim * 2]

        return weighted

In [48]:
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 200
N_LSTM_LAYERS = 2
N_FC_LAYERS = 3
OUTPUT_DIM = len(LABEL.vocab)
DROPOUT = 0.3
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

teacher_model = BiLSTMWithAttention(
    INPUT_DIM,
    EMBEDDING_DIM,
    HIDDEN_DIM,
    N_LSTM_LAYERS,
    N_FC_LAYERS,
    OUTPUT_DIM,
    DROPOUT,
    PAD_IDX).to(device)

In [49]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(teacher_model):,} trainable parameters')

The model has 4,756,103 trainable parameters


In [50]:
pretrained_embeddings = TEXT.vocab.vectors
print(pretrained_embeddings.shape)

torch.Size([12193, 100])


In [51]:
teacher_model.embedding.weight.data.copy_(pretrained_embeddings)

tensor([[ 0.2837, -0.6263, -0.4435,  ...,  0.4368, -0.8261, -0.1570],
        [-0.7534,  0.2218,  0.4468,  ...,  0.9045, -1.6214, -0.4485],
        [-0.2709,  0.0440, -0.0203,  ..., -0.4923,  0.6369,  0.2364],
        ...,
        [-0.3962, -0.0070,  0.4369,  ..., -0.3295,  0.1764,  0.0092],
        [ 0.0882, -0.3188,  0.4663,  ...,  0.8881,  0.5180, -0.1170],
        [-0.6644, -0.3045,  0.6151,  ...,  0.1404,  0.5788, -0.0333]],
       device='cuda:0')

In [52]:
teacher_model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

In [53]:
print(teacher_model.embedding.weight.data)

tensor([[ 0.2837, -0.6263, -0.4435,  ...,  0.4368, -0.8261, -0.1570],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2709,  0.0440, -0.0203,  ..., -0.4923,  0.6369,  0.2364],
        ...,
        [-0.3962, -0.0070,  0.4369,  ..., -0.3295,  0.1764,  0.0092],
        [ 0.0882, -0.3188,  0.4663,  ...,  0.8881,  0.5180, -0.1170],
        [-0.6644, -0.3045,  0.6151,  ...,  0.1404,  0.5788, -0.0333]],
       device='cuda:0')


In [54]:
teacher_model

BiLSTMWithAttention(
  (embedding): Embedding(12193, 100, padding_idx=1)
  (rnn): LSTM(100, 200, num_layers=2, dropout=0.3, bidirectional=True)
  (attn): Linear(in_features=400, out_features=400, bias=True)
  (v): Linear(in_features=400, out_features=1, bias=False)
  (fcs): ModuleList(
    (0): Linear(in_features=800, out_features=800, bias=True)
    (1): Linear(in_features=800, out_features=800, bias=True)
    (2): Linear(in_features=800, out_features=800, bias=True)
  )
  (layer_norms): ModuleList(
    (0): LayerNorm((800,), eps=1e-05, elementwise_affine=True)
    (1): LayerNorm((800,), eps=1e-05, elementwise_affine=True)
    (2): LayerNorm((800,), eps=1e-05, elementwise_affine=True)
  )
  (out): Linear(in_features=800, out_features=3, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)

### Optimizer & Loss Criterion

In [55]:
optimizer = optim.Adam(teacher_model.parameters())
criterion = nn.CrossEntropyLoss()

### Accuracy

In [56]:
def categorical_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    correct = max_preds.squeeze(1).eq(y)
    return correct.sum() / torch.FloatTensor([y.shape[0]])

### Training Method

In [57]:
def train(model, iterator, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        prem = batch.premise
        hypo = batch.hypothesis
        labels = batch.label
        
        optimizer.zero_grad()
        
        predictions = model(prem, hypo)
        
        # predictions => [batch size, output dim]
        # labels => [batch size]
    
        loss = criterion(predictions, labels)            
        acc = categorical_accuracy(predictions, labels)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

### Validation Method

In [58]:
def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            prem = batch.premise
            hypo = batch.hypothesis
            labels = batch.label
                        
            predictions = model(prem, hypo)
            
            loss = criterion(predictions, labels)
                
            acc = categorical_accuracy(predictions, labels)
            
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [59]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

### Teacher Model Training

In [63]:
N_EPOCHS = 5

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(teacher_model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(teacher_model, valid_iterator, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(teacher_model.state_dict(), 'teacher_model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')


Epoch: 01 | Epoch Time: 7m 24s
	Train Loss: 0.573 | Train Acc: 76.66%
	 Val. Loss: 0.552 |  Val. Acc: 78.64%
Epoch: 02 | Epoch Time: 7m 25s
	Train Loss: 0.547 | Train Acc: 77.90%
	 Val. Loss: 0.541 |  Val. Acc: 79.03%
Epoch: 03 | Epoch Time: 7m 24s
	Train Loss: 0.527 | Train Acc: 78.86%
	 Val. Loss: 0.523 |  Val. Acc: 79.66%
Epoch: 04 | Epoch Time: 7m 24s
	Train Loss: 0.509 | Train Acc: 79.61%
	 Val. Loss: 0.531 |  Val. Acc: 79.71%
Epoch: 05 | Epoch Time: 7m 25s
	Train Loss: 0.495 | Train Acc: 80.29%
	 Val. Loss: 0.544 |  Val. Acc: 79.54%


### Teacher Model - Test Performance

In [64]:
teacher_model.load_state_dict(torch.load('teacher_model.pt'))

test_loss, test_acc = evaluate(teacher_model, test_iterator, criterion)

print(f'Test Loss: {test_loss:.3f} |  Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.540 |  Test Acc: 78.94%


# Distillation

In [65]:
class NLIWithDistillation(nn.Module):
    def __init__(self, input_dim, emb_dim, output_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.linear = nn.Linear(emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, premise, hypothesis):
        prem_embedded = self.dropout(self.embedding(premise))
        hypo_embedded = self.dropout(self.embedding(hypothesis))
        combined = torch.cat((prem_embedded, hypo_embedded), dim=0)
        final = torch.sum(combined, dim=0)
        logits = self.linear(self.dropout(final))
        return logits

In [66]:
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
OUTPUT_DIM = len(LABEL.vocab)
DROPOUT = 0.3
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

distil_model = NLIWithDistillation(
    INPUT_DIM,
    EMBEDDING_DIM,
    OUTPUT_DIM,
    DROPOUT).to(device)

In [67]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(distil_model):,} trainable parameters')

The model has 1,219,603 trainable parameters


In [68]:
pretrained_embeddings = TEXT.vocab.vectors
print(pretrained_embeddings.shape)

torch.Size([12193, 100])


In [69]:
distil_model.embedding.weight.data.copy_(pretrained_embeddings)

tensor([[ 0.2837, -0.6263, -0.4435,  ...,  0.4368, -0.8261, -0.1570],
        [-0.7534,  0.2218,  0.4468,  ...,  0.9045, -1.6214, -0.4485],
        [-0.2709,  0.0440, -0.0203,  ..., -0.4923,  0.6369,  0.2364],
        ...,
        [-0.3962, -0.0070,  0.4369,  ..., -0.3295,  0.1764,  0.0092],
        [ 0.0882, -0.3188,  0.4663,  ...,  0.8881,  0.5180, -0.1170],
        [-0.6644, -0.3045,  0.6151,  ...,  0.1404,  0.5788, -0.0333]],
       device='cuda:0')

In [70]:
distil_model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)

In [71]:
print(distil_model.embedding.weight.data)

tensor([[ 0.2837, -0.6263, -0.4435,  ...,  0.4368, -0.8261, -0.1570],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2709,  0.0440, -0.0203,  ..., -0.4923,  0.6369,  0.2364],
        ...,
        [-0.3962, -0.0070,  0.4369,  ..., -0.3295,  0.1764,  0.0092],
        [ 0.0882, -0.3188,  0.4663,  ...,  0.8881,  0.5180, -0.1170],
        [-0.6644, -0.3045,  0.6151,  ...,  0.1404,  0.5788, -0.0333]],
       device='cuda:0')


In [72]:
distil_model

NLIWithDistillation(
  (embedding): Embedding(12193, 100)
  (linear): Linear(in_features=100, out_features=3, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)

### Optimizer & Loss Criterion

In [73]:
optimizer = optim.Adam(distil_model.parameters())
criterion = nn.CrossEntropyLoss()
KD_loss = nn.KLDivLoss(reduction='batchmean')

### Accuracy

In [74]:
def categorical_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    correct = max_preds.squeeze(1).eq(y)
    return correct.sum() / torch.FloatTensor([y.shape[0]])

### Training Method

In [75]:
def train(model, teacher_model, iterator, optimizer, criterion, temperature=2.0, alpha_ce=0.5, alpha_teacher=0.5):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        prem = batch.premise
        hypo = batch.hypothesis
        labels = batch.label
        
        optimizer.zero_grad()
        
        student_logits = model(prem, hypo)
        
        with torch.no_grad():
            teacher_logits = teacher_model(prem, hypo)
        
        # student_logits => [batch size, output dim]
        # teacher_logits => [batch size, output dim]
        # labels => [batch size]
    
        teacher_loss = KD_loss(input=F.log_softmax(student_logits/temperature, dim=-1),
                       target=F.softmax(teacher_logits/temperature, dim=-1))
        
        prediction_loss = criterion(student_logits, labels)            

        loss = alpha_ce * prediction_loss + alpha_teacher * teacher_loss
        
        acc = categorical_accuracy(student_logits, labels)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

### Validation Method

In [80]:
def evaluate(model, teacher_model, iterator, criterion, temperature=2.0, alpha_ce=0.5, alpha_teacher=0.5):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            prem = batch.premise
            hypo = batch.hypothesis
            labels = batch.label
            
            student_logits = model(prem, hypo)
        
            with torch.no_grad():
                teacher_logits = teacher_model(prem, hypo)
            
            # student_logits => [batch size, output dim]
            # teacher_logits => [batch size, output dim]
            # labels => [batch size]
        
            teacher_loss = KD_loss(input=F.log_softmax(student_logits/temperature, dim=-1),
                        target=F.softmax(teacher_logits/temperature, dim=-1)) * (temperature ** 2)
            
            prediction_loss = criterion(student_logits, labels)            

            loss = alpha_ce * prediction_loss + alpha_teacher * teacher_loss
            
            acc = categorical_accuracy(student_logits, labels)
            
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [81]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

#### Load the pre-trained teacher model

In [82]:
teacher_model.load_state_dict(torch.load('teacher_model.pt'))

<All keys matched successfully>

### Distillation - Training

In [83]:
N_EPOCHS = 10

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(distil_model, teacher_model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(distil_model, teacher_model, valid_iterator, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(distil_model.state_dict(), 'distil_model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 2m 16s
	Train Loss: 0.638 | Train Acc: 55.26%
	 Val. Loss: 1.103 |  Val. Acc: 55.75%
Epoch: 02 | Epoch Time: 2m 16s
	Train Loss: 0.636 | Train Acc: 55.41%
	 Val. Loss: 1.108 |  Val. Acc: 55.63%
Epoch: 03 | Epoch Time: 2m 16s
	Train Loss: 0.634 | Train Acc: 55.53%
	 Val. Loss: 1.108 |  Val. Acc: 55.91%
Epoch: 04 | Epoch Time: 2m 16s
	Train Loss: 0.632 | Train Acc: 55.73%
	 Val. Loss: 1.104 |  Val. Acc: 55.77%
Epoch: 05 | Epoch Time: 2m 16s
	Train Loss: 0.631 | Train Acc: 55.81%
	 Val. Loss: 1.097 |  Val. Acc: 55.79%
Epoch: 06 | Epoch Time: 2m 16s
	Train Loss: 0.630 | Train Acc: 55.91%
	 Val. Loss: 1.105 |  Val. Acc: 55.74%
Epoch: 07 | Epoch Time: 2m 17s
	Train Loss: 0.630 | Train Acc: 55.93%
	 Val. Loss: 1.101 |  Val. Acc: 56.05%
Epoch: 08 | Epoch Time: 2m 16s
	Train Loss: 0.629 | Train Acc: 56.02%
	 Val. Loss: 1.107 |  Val. Acc: 55.81%
Epoch: 09 | Epoch Time: 2m 17s
	Train Loss: 0.628 | Train Acc: 56.13%
	 Val. Loss: 1.099 |  Val. Acc: 55.89%
Epoch: 10 | Epoch T

### Distillation - Test Performance

In [84]:
distil_model.load_state_dict(torch.load('distil_model.pt'))

test_loss, test_acc = evaluate(distil_model, teacher_model, test_iterator, criterion)

print(f'Test Loss: {test_loss:.3f} |  Test Acc: {test_acc*100:.2f}%')

Test Loss: 1.105 |  Test Acc: 55.70%


## Inference

In [85]:
def inference(premise, hypothesis, text_field, label_field, model, device):
    
    model.eval()
    
    if isinstance(premise, str):
        premise = text_field.tokenize(premise)
    
    if isinstance(hypothesis, str):
        hypothesis = text_field.tokenize(hypothesis)
    
    if text_field.lower:
        premise = [t.lower() for t in premise]
        hypothesis = [t.lower() for t in hypothesis]

    # numericalize  
    premise = [text_field.vocab.stoi[t] for t in premise]
    hypothesis = [text_field.vocab.stoi[t] for t in hypothesis]
    
    # convert into tensors
    premise = torch.LongTensor(premise).unsqueeze(1).to(device)
    # premise => [prem_len, 1]
    hypothesis = torch.LongTensor(hypothesis).unsqueeze(1).to(device)
    # hypothesis => [hypo_len, 1]

    prediction = model(premise, hypothesis)
    prediction = prediction.argmax(dim=-1).item()

    return label_field.vocab.itos[prediction]

In [86]:
premise = 'A woman selling bamboo sticks talking to two men on a loading dock.'
hypothesis = 'There are at least three people on a loading dock.'

inference(premise, hypothesis, TEXT, LABEL, distil_model, device)

'entailment'

In [87]:
premise = 'A woman selling bamboo sticks talking to two men on a loading dock.'
hypothesis = 'A woman is selling bamboo sticks to help provide for her family.'

inference(premise, hypothesis, TEXT, LABEL, distil_model, device)

'neutral'

In [88]:
premise = 'A woman selling bamboo sticks talking to two men on a loading dock.'
hypothesis = ' A woman is not taking money for any of her sticks.'

inference(premise, hypothesis, TEXT, LABEL, distil_model, device)

'contradiction'