In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from pprint import pprint
from functools import partial

from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from transformers import AutoModel, AlbertTokenizerFast, ElectraTokenizerFast

from data import SNLIDataset, SNLIBiEncoderBucketingDataset
from data.utils import get_data, collate_fn_buckets, collate_fn_biencoder

from model import NLIBiEncoder

# SNLI: 3-way classification

In [2]:
train, test, val = get_data('snli_1.0')

In [3]:
LABELS = ['entailment', 'contradiction', 'neutral']
NUM_LABELS = len(LABELS)

test = test[test.target.isin(set(LABELS))]

In [4]:
test.head()

Unnamed: 0,sentence1,sentence2,target
0,Two women are embracing while holding to go pa...,The sisters are hugging goodbye while holding ...,neutral
1,Two women are embracing while holding to go pa...,Two woman are holding packages.,entailment
2,Two women are embracing while holding to go pa...,The men are fighting outside a deli.,contradiction
3,"Two young children in blue jerseys, one with t...",Two kids in numbered jerseys wash their hands.,entailment
4,"Two young children in blue jerseys, one with t...",Two kids at a ballgame wash their hands.,neutral


In [5]:
target2idx = {l: i for i, l in enumerate(LABELS)}
test.target = test.target.map(target2idx)

In [6]:
def evaluate(model, test_loader):
    model.eval()

    all_preds, all_true = [], []

    i = 0
    for batch in tqdm(test_loader):
        x = batch[:-1]
        y = batch[-1]
        all_true.append(y.numpy())
        
        with torch.no_grad():
            logits = model(*x).detach()
        
        probs = F.softmax(logits, dim=-1)
        preds = probs.argmax(dim=-1)
        all_preds.append(preds.numpy())
    
    all_true = np.concatenate(all_true)
    all_preds = np.concatenate(all_preds)
    
    pprint({
        'accuracy': accuracy_score(all_true, all_preds),
        'f1_macro': f1_score(all_true, all_preds, average='macro'),
        'f1_micro': f1_score(all_true, all_preds, average='micro'),
    })

## ALBERT

In [7]:
albert_tokenizer = AlbertTokenizerFast.from_pretrained('albert-base-v2')
test_dataset = SNLIBiEncoderBucketingDataset(albert_tokenizer, test.sentence1.tolist(), test.sentence2.tolist(), test.target.tolist(), batch_size=1)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_buckets)

100%|██████████| 19684/19684 [00:01<00:00, 11096.44it/s]


In [8]:
albert_v1 = AutoModel.from_pretrained('albert-base-v2')
model_albert_v1 = NLIBiEncoder(albert_v1, NUM_LABELS, lambda x: x.pooler_output)
model_albert_v1.load_state_dict(torch.load('albert/v1/albert_snli_best.pt', map_location='cpu'))

<All keys matched successfully>

In [9]:
evaluate(model_albert_v1, test_loader)

100%|██████████| 9842/9842 [25:43<00:00,  6.38it/s]

{'accuracy': 0.8138589717537086,
 'f1_macro': 0.813392961971716,
 'f1_micro': 0.8138589717537086}





In [10]:
test_dataset = SNLIDataset(test.sentence1.tolist(), test.sentence2.tolist(), test.target.tolist())
collate_fn_biencoder = partial(collate_fn_biencoder, tokenizer=albert_tokenizer)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_biencoder)

In [11]:
albert_v2 = AutoModel.from_pretrained('albert-base-v2')
model_albert_v2 = NLIBiEncoder(albert_v2, NUM_LABELS, lambda x: x.pooler_output)
model_albert_v2.load_state_dict(torch.load('albert/v2/albert_snli_best.pt', map_location='cpu'))

<All keys matched successfully>

In [12]:
evaluate(model_albert_v2, test_loader)

100%|██████████| 9842/9842 [24:57<00:00,  6.57it/s]

{'accuracy': 0.8065433854907539,
 'f1_macro': 0.8063974009722755,
 'f1_micro': 0.8065433854907539}





## ELECTRA

In [13]:
electra_tokenizer = ElectraTokenizerFast.from_pretrained('google/electra-base-discriminator')
test_dataset = SNLIBiEncoderBucketingDataset(electra_tokenizer, test.sentence1.tolist(), test.sentence2.tolist(), test.target.tolist(), batch_size=1)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_buckets)

100%|██████████| 19684/19684 [00:01<00:00, 12552.16it/s]


In [14]:
electra_v1 = AutoModel.from_pretrained('google/electra-base-discriminator')
model_electra_v1 = NLIBiEncoder(electra_v1, NUM_LABELS, lambda x: x.last_hidden_state[:, 0, :])
model_electra_v1.load_state_dict(torch.load('electra/v1/electra_snli_best.pt', map_location='cpu'))

<All keys matched successfully>

In [15]:
evaluate(model_electra_v1, test_loader)

100%|██████████| 9842/9842 [27:07<00:00,  6.05it/s]

{'accuracy': 0.8194472668156879,
 'f1_macro': 0.8193417460422833,
 'f1_micro': 0.8194472668156879}





In [16]:
test_dataset = SNLIDataset(test.sentence1.tolist(), test.sentence2.tolist(), test.target.tolist())
collate_fn_biencoder = partial(collate_fn_biencoder, tokenizer=electra_tokenizer)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_biencoder)

In [17]:
electra_v2 = AutoModel.from_pretrained('google/electra-base-discriminator')
model_electra_v2 = NLIBiEncoder(electra_v2, NUM_LABELS, lambda x: x.last_hidden_state[:, 0, :])
model_electra_v2.load_state_dict(torch.load('electra/v2/electra_snli_best.pt', map_location='cpu'))

<All keys matched successfully>

In [18]:
evaluate(model_electra_v2, test_loader)

100%|██████████| 9842/9842 [23:58<00:00,  6.84it/s]

{'accuracy': 0.826153220890063,
 'f1_macro': 0.8253713781247237,
 'f1_micro': 0.826153220890063}





### Outline

The last model (ELECTRA fine-tuned on the SNLI dataset) clearly produced the best results. It is not comparable to SOTA models (~92% accuracy), but that result would probably be achieved if we constructed our feature vector by doing more than just concatenating the encoder's output vectors (as I have later found out, it is common to concatenate not just those vectors, but also their elementwise absolute difference as well as elementwise product: $[u$, $v$, $|u-v|$, $u*v]$) and if we used a larger batch size. One trick to be able to afford a larger batch size is to cache the model's encoder outputs in case of identical premises (most of the examples in the SNLI dataset go in triplets, where the premise is the same).

# Quora Question Pairs

Let's use our model to try and detect paraphrases in the Quora Question Pairs dataset. Hopefully, it is able to produce decent sentence embeddings after being fine-tuned for the NLI task.

In [19]:
quora = pd.read_csv('quora/train.csv')[:10000]
quora.head()

Unnamed: 0,id,qid1,qid2,question1,question2,is_duplicate
0,0,1,2,What is the step by step guide to invest in sh...,What is the step by step guide to invest in sh...,0
1,1,3,4,What is the story of Kohinoor (Koh-i-Noor) Dia...,What would happen if the Indian government sto...,0
2,2,5,6,How can I increase the speed of my internet co...,How can Internet speed be increased by hacking...,0
3,3,7,8,Why am I mentally very lonely? How can I solve...,Find the remainder when [math]23^{24}[/math] i...,0
4,4,9,10,"Which one dissolve in water quikly sugar, salt...",Which fish would survive in salt water?,0


In [20]:
questions = set(quora.question1) | set(quora.question2)
len(questions)

19413

In [21]:
def get_sentence_embedding(sentence, tokenizer, model):
    token_ids = electra_tokenizer([sentence], truncation=True, return_tensors='pt').input_ids
    with torch.no_grad():
        sent_emb = model.encoder(token_ids).last_hidden_state[:, 0, :].detach().squeeze(0).numpy()
        
    return sent_emb

In [22]:
embeddings = {s: get_sentence_embedding(s, electra_tokenizer, model_electra_v2)
             for s in tqdm(questions)}

100%|██████████| 19413/19413 [23:44<00:00, 13.63it/s]


In [23]:
embeddings['What would happen if the Indian government stole the Kohinoor (Koh-i-Noor) diamond back?'].shape

(768,)

In [24]:
def get_score(row):
    emb1 = embeddings[row.question1]
    emb2 = embeddings[row.question2]
    
    score = np.dot(emb1 / np.linalg.norm(emb1), emb2 / np.linalg.norm(emb2))
    
    return score

In [25]:
quora['score'] = quora.apply(get_score, axis='columns')

In [26]:
quora.head()

Unnamed: 0,id,qid1,qid2,question1,question2,is_duplicate,score
0,0,1,2,What is the step by step guide to invest in sh...,What is the step by step guide to invest in sh...,0,0.972247
1,1,3,4,What is the story of Kohinoor (Koh-i-Noor) Dia...,What would happen if the Indian government sto...,0,0.714898
2,2,5,6,How can I increase the speed of my internet co...,How can Internet speed be increased by hacking...,0,0.663342
3,3,7,8,Why am I mentally very lonely? How can I solve...,Find the remainder when [math]23^{24}[/math] i...,0,0.578619
4,4,9,10,"Which one dissolve in water quikly sugar, salt...",Which fish would survive in salt water?,0,0.511743


In [27]:
roc_auc_score(quora.is_duplicate, quora.score)

0.7058667147931963

Without any fine-tuning on this dataset, our model is able to provide a decent ROC-AUC score.