# Natural Language Inference with PyTorch and Transformers

In this notebook I'm showing how to use [PyTorch](https://pytorch.org/) and [Huggingface Transformers](https://github.com/huggingface/transformers) to fine-tune a pre-trained transformers model to do natural language inference (NLI). In NLI the aim is to model the inferential relationship between two or more given sentences. In particular, given two sentences - the premise `p` and the hypothesis `h` - the task is to determine whether `h` is entailed by `p`, whether the sentences are in contradiction with each other or whether there is no inferential relationship between the sentences (neutral).

So let's get started! First we need to install the python libraries using the following command.

In [1]:
!pip3 install pandas torch transformers datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.4-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m89.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.11.0-py3-none-any.whl (468 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 KB[0m [31m53.4 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m26.5 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m113.1 MB/s[0m eta [36m0:00:00[0m
Collecti

We will then import the needed libraries. We are using [DistilBERT](https://medium.com/huggingface/distilbert-8cf3380435b5) model for this task so we need to import the relevant DistilBERT model designed for sequence classification task and the corresponding tokeniser.

In [2]:
import pandas as pd
import re
import torch
from torch.utils.data import DataLoader
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, AdamW, logging
import datasets
from tqdm import tqdm
import numpy as np

In [3]:
logging.set_verbosity_error()

Let's load the [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) dataset using the Huggingface [Datasets](https://github.com/huggingface/datasets) library. For this demonstration we are using only the training and validation data. We are also further limiting the training data to just 20,000 sentence pairs. This will not allow us to train a good quality model, but it speeds up the demonstration. You can change the values here or use the whole dataset. However, be aware that fine tuning the model will take a lot of time.

In [4]:
nli_data = datasets.load_dataset("multi_nli")

train_data = nli_data['train']
train_labels = train_data['label']

dev_data = nli_data['validation_matched']
val_labels = dev_data['label']

mnli_noun = "https://raw.githubusercontent.com/msainio/thesis-2023/main/MNLI-NOUN-int.tsv"
mnli_noun_data = pd.read_csv(mnli_noun, sep='\t')
mnli_noun_labels = list(mnli_noun_data['gold_label'])

mnli_subset = "https://raw.githubusercontent.com/msainio/thesis-2023/main/MNLI-NOUN-int-subset.tsv"
mnli_subset_data = pd.read_csv(mnli_subset, sep='\t')
mnli_subset_labels = list(mnli_subset_data['gold_label'])

Downloading builder script:   0%|          | 0.00/5.14k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.88k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/8.67k [00:00<?, ?B/s]

Downloading and preparing dataset multi_nli/default to /root/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39...


Downloading data:   0%|          | 0.00/227M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]

Generating validation_matched split:   0%|          | 0/9815 [00:00<?, ? examples/s]

Generating validation_mismatched split:   0%|          | 0/9832 [00:00<?, ? examples/s]

Dataset multi_nli downloaded and prepared to /root/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Next we will initialise the tokeniser and tokenise our training and validation data. Notice that we are two lists of sentences to both the training and validation set. This is because in NLI we are classifying pairs of sentences: the premise and the hypothesis.

In [5]:
tokeniser = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
train_encodings = tokeniser(train_data['premise'], train_data['hypothesis'], truncation=True, padding=True)
val_encodings = tokeniser(dev_data['premise'], dev_data['hypothesis'], truncation=True, padding=True)
mnli_noun_encodings = tokeniser(list(mnli_noun_data['sentence1']), list(mnli_noun_data['sentence2']), truncation=True, padding=True)
mnli_subset_encodings = tokeniser(list(mnli_subset_data['sentence1']), list(mnli_subset_data['sentence2']), truncation=True, padding=True)

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Once the data has been tokenised we will create a `NLIDataset` object for our data. Here we are creating a subclass that inherits the `torch.utils.data.Dataset` class.

In [6]:
class NLIDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.encodings.input_ids)

Once we've defined our dataset class we can initialise the training and validation datasets with our tokenised sentence pairs and labels. We will then create `DataLoader` objects for the training and validation data. 

In [7]:
train_dataset = NLIDataset(train_encodings, train_labels)
val_dataset = NLIDataset(val_encodings, val_labels)
mnli_noun_dataset = NLIDataset(mnli_noun_encodings, mnli_noun_labels)
mnli_subset_dataset = NLIDataset(mnli_subset_encodings, mnli_subset_labels)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)
mnli_noun_loader = DataLoader(mnli_noun_dataset, batch_size=16, shuffle=True)
mnli_subset_loader = DataLoader(mnli_subset_dataset, batch_size=16, shuffle=True)

Now, before we can start training, we need to import our model and optimiser to be used in training. We first set the device and use `cuda` if GPU is available. We then get the pre-trained DistilBERT model specifying the number of classes we are classifying to.

In [8]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)
model.to(device)
model.train()
optim = AdamW(model.parameters(), lr=5e-5)

Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]



Now we are ready to train the model. In this demonstration we are fine-tuning for just three epochs, but you can change the value to something more meaningful if you like. Note that you could also use the Transformers `Trainer` class to fine-tune the model but I've chosen to use native PyTorch instead.

In [9]:
epochs = 3
for epoch in range(epochs):
    all_losses = []

    for batch in tqdm(train_loader, total=len(train_loader), desc="Epoch: {}/{}".format(epoch+1, epochs)):
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.backward()
        optim.step()
        all_losses.append(loss.item())
        
    print("\nMean loss: {:<.4f}".format(np.mean(all_losses)))

Epoch: 1/3: 100%|██████████| 24544/24544 [4:13:45<00:00,  1.61it/s]



Mean loss: 0.6197


Epoch: 2/3: 100%|██████████| 24544/24544 [4:13:46<00:00,  1.61it/s]



Mean loss: 0.4754


Epoch: 3/3: 100%|██████████| 24544/24544 [4:13:58<00:00,  1.61it/s]


Mean loss: 0.3877





Once the model has been trained we can evaluate it to get the validation accuracy for our model.

In [16]:
model.eval()
with torch.no_grad():
    eval_preds = []
    eval_labels = []
    eval_pairs = []

    for batch in tqdm(val_loader, total=len(val_loader)):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        preds = model(input_ids, attention_mask=attention_mask, labels=labels)
        preds = preds[1].argmax(dim=-1)
        eval_preds.append(preds.cpu().numpy())
        eval_labels.append(batch['labels'].cpu().numpy())

        decoded_ids = [tokeniser.decode(id) for id in input_ids]
        separated_strings = [id.split('[SEP]') for id in decoded_ids]
        normalized_pairs = [(re.sub("\[\w+\]","",p),re.sub("\[\w+\]","",h)) for
                            [p,h,end] in separated_strings]
        for pair in normalized_pairs:
            eval_pairs.append(pair)

print("\nValidation accuracy: {:6.2f}".format(round(100 * (np.concatenate(eval_labels) == np.concatenate(eval_preds)).mean()), 2))

with open("results.txt", "a") as results:
    results.write("Original MNLI accuracy{:6.2f}\n".format(round(100 * (np.concatenate(eval_labels) == np.concatenate(eval_preds)).mean()), 2))

model_performance = []
            
for i in range(len(eval_pairs)):
    model_performance.append([eval_pairs[i][0], eval_pairs[i][1],
                               str(np.concatenate(eval_labels)[i]),
                               str(np.concatenate(eval_preds)[i])])
    
with open("original_results.tsv", "a") as original_results:
    original_results.write('sentence1\tsentence2\tgold_label\tpredicted_label\n')
    for entry in model_performance:
        original_results.write('\t'.join(entry)+'\n')

100%|██████████| 614/614 [01:04<00:00,  9.52it/s]



Validation accuracy:  79.00


Now we are all done. As you can see the results are far from state of the art if you use just a fraction of the training data.

Hope you enjoyed this demo. Feel free to contact me if you have any questions.

In [17]:
model.eval()
with torch.no_grad():
    eval_preds = []
    eval_labels = []
    eval_pairs = []

    for batch in tqdm(mnli_noun_loader, total=len(mnli_noun_loader)):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        preds = model(input_ids, attention_mask=attention_mask, labels=labels)
        preds = preds[1].argmax(dim=-1)
        eval_preds.append(preds.cpu().numpy())
        eval_labels.append(batch['labels'].cpu().numpy())
        
        decoded_ids = [tokeniser.decode(id) for id in input_ids]
        separated_strings = [id.split('[SEP]') for id in decoded_ids]
        normalized_pairs = [(re.sub("\[\w+\]","",p),re.sub("\[\w+\]","",h)) for
                            [p,h,end] in separated_strings]
        for pair in normalized_pairs:
            eval_pairs.append(pair)
            
print("\nValidation accuracy: {:6.2f}".format(round(100 * (np.concatenate(eval_labels) == np.concatenate(eval_preds)).mean()), 2))
            
with open("results.txt", "a") as results:
    results.write("Noun accuracy{:6.2f}\n".format(round(100 * (np.concatenate(eval_labels) == np.concatenate(eval_preds)).mean()), 2))

model_performance = []
            
for i in range(len(eval_pairs)):
    model_performance.append([eval_pairs[i][0], eval_pairs[i][1],
                               str(np.concatenate(eval_labels)[i]),
                               str(np.concatenate(eval_preds)[i])])
    
with open("noun_results.tsv", "a") as noun_results:
    noun_results.write('sentence1\tsentence2\tgold_label\tpredicted_label\n')
    for entry in model_performance:
        noun_results.write('\t'.join(entry)+'\n')

100%|██████████| 611/611 [01:05<00:00,  9.33it/s]



Validation accuracy:  68.00


In [18]:
model.eval()
with torch.no_grad():
    eval_preds = []
    eval_labels = []
    eval_pairs = []

    for batch in tqdm(mnli_subset_loader, total=len(mnli_subset_loader)):
        input_ids = batch['input_ids'].to(device)        
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        preds = model(input_ids, attention_mask=attention_mask, labels=labels)
        preds = preds[1].argmax(dim=-1)
        eval_preds.append(preds.cpu().numpy())
        eval_labels.append(batch['labels'].cpu().numpy())
        
        decoded_ids = [tokeniser.decode(id) for id in input_ids]
        separated_strings = [id.split('[SEP]') for id in decoded_ids]
        normalized_pairs = [(re.sub("\[\w+\]","",p),re.sub("\[\w+\]","",h)) for
                            [p,h,end] in separated_strings]
        for pair in normalized_pairs:
            eval_pairs.append(pair)
            
print("\nValidation accuracy: {:6.2f}".format(round(100 * (np.concatenate(eval_labels) == np.concatenate(eval_preds)).mean()), 2))

with open("results.txt", "a") as results:
    results.write("Subset accuracy{:6.2f}\n".format(round(100 * (np.concatenate(eval_labels) == np.concatenate(eval_preds)).mean()), 2))

model_performance = []
            
for i in range(len(eval_pairs)):
    model_performance.append([eval_pairs[i][0], eval_pairs[i][1],
                               str(np.concatenate(eval_labels)[i]),
                               str(np.concatenate(eval_preds)[i])])
    
with open("subset_results.tsv", "a") as subset_results:
    subset_results.write('sentence1\tsentence2\tgold_label\tpredicted_label\n')
    for entry in model_performance:
        subset_results.write('\t'.join(entry)+'\n')

100%|██████████| 4/4 [00:00<00:00, 27.38it/s]


Validation accuracy:  70.00



