# NLI Example

Based on https://talman.io/nli/pytorch/demo/2020/12/11/natural-language-inference-with-pytorch-and-transformers.html

In [1]:
!export CUDA_VISIBLE_DEVICES=""

In [2]:
import torch
from torch.utils.data import DataLoader
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, AdamW, logging
import datasets
from tqdm import tqdm
import numpy as np

In [3]:
nli_data = datasets.load_dataset("multi_nli")

train_dataset = nli_data['train'].select(range(20000)) 
# limiting the training set size to 20,000 for demo purposes
dev_dataset = nli_data['validation_matched']

Reusing dataset multi_nli (/home/jmperez/.cache/huggingface/datasets/multi_nli/plain_text/1.0.0/9969e1448f410fe7c6c688a84bfcb61312d0a3f2741d57341c26ef99f28a5451)


In [4]:
num = 0

# Clearer if it is a dict
label_name = {
    0: "entailment",
    1: "neutral", 
    2: "contradiction",
}

for num in range(170, 185):
    print("="*80, "\n")
    print("Hypothesis: ", train_dataset["hypothesis"][num])
    print("Premise   : ", train_dataset["premise"][num])
    print("Label     : ", label_name[train_dataset["label"][num]])


Hypothesis:  Workers carve sculptures and paint scrolls with great enthusiasm.
Premise   :  The individual artisans' shops are no longer here, but you can visit a silk-weaving factory, a ceramics plant, and the Foshan Folk Art Studio, where you can observe workers making Chinese lanterns, carving sculptures, painting scrolls, and cutting intricate designs in paper.
Label     :  neutral

Hypothesis:  Sir Ernest bent his head slightly, and continued.
Premise   :  Really, Sir Ernest, protested the judge, "these questions are not relevant." Sir Ernest bowed, and having shot his arrow proceeded. 
Label     :  entailment

Hypothesis:  The house is very large and boasts over ten bedrooms, a huge kitchen, and a full sized olympic pool.
Premise   :  The house is surprisingly small and simple, with one bedroom, a tiny kitchen, and a couple of social rooms.
Label     :  contradiction

Hypothesis:  Hiding things is just dirty, whereas there is glory in fiction
Premise   :  Fiction has its glories

In [5]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer.model_max_length = 256

def tokenize(batch):
    return tokenizer(batch['premise'], batch['hypothesis'], padding='max_length', truncation=True)

batch_size = 32
eval_batch_size = 8

train_dataset = train_dataset.map(tokenize, batched=True, batch_size=batch_size)
dev_dataset = dev_dataset.map(tokenize, batched=True, batch_size=eval_batch_size)



Loading cached processed dataset at /home/jmperez/.cache/huggingface/datasets/multi_nli/plain_text/1.0.0/9969e1448f410fe7c6c688a84bfcb61312d0a3f2741d57341c26ef99f28a5451/cache-02a5acd1417d35e4.arrow
Loading cached processed dataset at /home/jmperez/.cache/huggingface/datasets/multi_nli/plain_text/1.0.0/9969e1448f410fe7c6c688a84bfcb61312d0a3f2741d57341c26ef99f28a5451/cache-a38577b45d228e73.arrow


You can check the tokenizer has added a `[SEP]` token



In [6]:
lens = {sum(example["attention_mask"]) for example in train_dataset}



In [7]:
max(lens)

256

In [8]:
len([l for l in lens if l >= 256])

1

Uso 256!

In [9]:
example = train_dataset[0]

print(example.keys())
print("Premise    :", example["premise"])
print("Hypothesis :", example["hypothesis"])

tokenizer.decode(example["input_ids"])

dict_keys(['attention_mask', 'hypothesis', 'input_ids', 'label', 'premise', 'token_type_ids'])
Premise    : Conceptually cream skimming has two basic dimensions - product and geography.
Hypothesis : Product and geography are what make cream skimming work. 


'[CLS] conceptually cream skimming has two basic dimensions - product and geography. [SEP] product and geography are what make cream skimming work. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] 

In [10]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
model.to(device)
model.train()
optim = AdamW(model.parameters(), lr=5e-5)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [11]:
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

def compute_metrics(pred):
    """
    Compute metrics for Trainer
    """
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    #_, _, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
    
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        #'macro f1': macro_f1,
        'precision': precision,
        'recall': recall
    }

In [12]:
from transformers import Trainer, TrainingArguments
epochs = 5

total_steps = (epochs * len(train_dataset)) // batch_size
warmup_steps = total_steps // 10
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=eval_batch_size,
    warmup_steps=warmup_steps,
    evaluation_strategy="epoch",
    do_eval=True,
    weight_decay=0.01,
    logging_dir='./logs',
)

results = []

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
)

trainer.train()


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall,Runtime,Samples Per Second
1,0.8513,0.621105,0.745899,0.744487,0.745393,0.744387,84.7457,115.817
2,0.5101,0.602508,0.764646,0.763833,0.764221,0.763898,84.3775,116.322
3,0.322,0.795992,0.76434,0.762898,0.763188,0.763042,84.1692,116.61
4,0.0889,1.134876,0.761386,0.760541,0.761556,0.760934,83.0934,118.12
5,0.0406,1.34091,0.761793,0.760963,0.761822,0.761356,83.3127,117.809


TrainOutput(global_step=3125, training_loss=0.31832323486328123, metrics={'train_runtime': 2956.2523, 'train_samples_per_second': 1.057, 'total_flos': 16816826419200000, 'epoch': 5.0})

In [13]:
trainer.evaluate(dev_dataset)

{'eval_loss': 1.3409099578857422,
 'eval_accuracy': 0.7617931737137035,
 'eval_f1': 0.7609629841162301,
 'eval_precision': 0.7618218391035612,
 'eval_recall': 0.7613561321275197,
 'eval_runtime': 84.1848,
 'eval_samples_per_second': 116.589,
 'epoch': 5.0}