In [1]:
import os
import torch
from torch import nn
from torch.utils.data import Dataset


import pickle
import numpy as np
import pandas as pd

from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_recall_fscore_support

from NLIModel import SiameseNLI
from NLIDataset import NLIDataset, collate_batch
from sklearn.metrics import accuracy_score

In [2]:
model = SiameseNLI(input_size=768,num_layers=4,hidden_size=512)
model.load_state_dict(torch.load('models/nli_siamese_bert_emb/nli_siamese_epoch_3.pt'))
model.eval()

SiameseNLI(
  (lstm): LSTM(768, 512, num_layers=4, bidirectional=True)
  (w1): Linear(in_features=2048, out_features=128, bias=True)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (w2): Linear(in_features=128, out_features=3, bias=True)
)

In [3]:
test_data_loc = '../ENLP_NLI/snli_1.0/snli_1.0_dev.txt'

In [4]:
dev_dataset = NLIDataset(test_data_loc,len_sample=10000,prefix='dev')
devloader = torch.utils.data.DataLoader(dev_dataset, batch_size=64, shuffle=False, num_workers=2, collate_fn=collate_batch)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
CLASS_WEIGHTS = torch.Tensor(np.ones(3))
LOSS_FN = nn.CrossEntropyLoss(weight=CLASS_WEIGHTS)
INPUT_SIZE = 768
sftmx = nn.Softmax(dim=1)

In [6]:
with torch.no_grad():
    curr_loss_dev = 0.0

    all_preds_dev = torch.tensor([])
    all_targets_dev = torch.tensor([])

    for j, data_dev in enumerate(devloader):
        sents1_dev, sents2_dev, targets_1h_dev, targets_dev  = data_dev
        #print(sents1_dev)
        #print(sents1_dev.size())

        sents1_dev = torch.reshape(sents1_dev,(len(sents1_dev), -1,INPUT_SIZE))
        sents2_dev = torch.reshape(sents2_dev,(len(sents2_dev), -1,INPUT_SIZE))


        outputs_dev = model(sents1_dev.float(),sents2_dev.float())
        loss_dev = LOSS_FN(outputs_dev, targets_1h_dev.float())

        preds_dev = torch.argmax(sftmx(outputs_dev.detach()),dim=1)

        curr_loss_dev += loss_dev.item()

        all_preds_dev = torch.cat((all_preds_dev, preds_dev))
        all_targets_dev = torch.cat((all_targets_dev, targets_dev.detach()))

    curr_loss_dev = curr_loss_dev / (j+1)
    print('Epoch dev loss:',curr_loss_dev)

    acc_dev = accuracy_score(all_targets_dev.numpy(), all_preds_dev.numpy())
    print('Dev accuracy:', acc_dev)

Epoch dev loss: 0.8413872984564228
Dev accuracy: 0.6157


In [10]:
print(len(all_preds_dev))
all_preds_dev

10000


tensor([1., 2., 0.,  ..., 2., 0., 2.])

In [12]:
print(len(all_targets_dev))
all_targets_dev

10000


tensor([1., 2., 0.,  ..., 2., 0., 1.])

In [13]:
from sklearn.metrics import confusion_matrix

In [14]:
confusion_matrix(all_targets_dev.numpy(), all_preds_dev.numpy())

array([[2006,  712,  615],
       [ 705, 2091,  537],
       [ 699,  575, 2060]])

In [15]:
from sklearn.metrics import precision_recall_fscore_support

In [20]:
precision_recall_fscore_support(all_targets_dev.numpy(), all_preds_dev.numpy(), average=None)

(array([0.58826979, 0.61900533, 0.64134496]),
 array([0.60186019, 0.62736274, 0.61787642]),
 array([0.59498739, 0.62315601, 0.629392  ]),
 array([3333, 3333, 3334]))

In [18]:
3333+ 3333+ 3334

10000

In [23]:
all_preds_dev.numpy()

array([1., 2., 0., ..., 2., 0., 2.], dtype=float32)

In [25]:
dev_data = pd.read_csv('../ENLP_NLI/snli_1.0/snli_1.0_dev.txt', sep='\t')

In [26]:
dev_data

Unnamed: 0,gold_label,sentence1_binary_parse,sentence2_binary_parse,sentence1_parse,sentence2_parse,sentence1,sentence2,captionID,pairID,label1,label2,label3,label4,label5
0,neutral,( ( Two women ) ( ( are ( embracing ( while ( ...,( ( The sisters ) ( ( are ( ( hugging goodbye ...,(ROOT (S (NP (CD Two) (NNS women)) (VP (VBP ar...,(ROOT (S (NP (DT The) (NNS sisters)) (VP (VBP ...,Two women are embracing while holding to go pa...,The sisters are hugging goodbye while holding ...,4705552913.jpg#2,4705552913.jpg#2r1n,neutral,entailment,neutral,neutral,neutral
1,entailment,( ( Two women ) ( ( are ( embracing ( while ( ...,( ( Two woman ) ( ( are ( holding packages ) )...,(ROOT (S (NP (CD Two) (NNS women)) (VP (VBP ar...,(ROOT (S (NP (CD Two) (NN woman)) (VP (VBP are...,Two women are embracing while holding to go pa...,Two woman are holding packages.,4705552913.jpg#2,4705552913.jpg#2r1e,entailment,entailment,entailment,entailment,entailment
2,contradiction,( ( Two women ) ( ( are ( embracing ( while ( ...,( ( The men ) ( ( are ( fighting ( outside ( a...,(ROOT (S (NP (CD Two) (NNS women)) (VP (VBP ar...,(ROOT (S (NP (DT The) (NNS men)) (VP (VBP are)...,Two women are embracing while holding to go pa...,The men are fighting outside a deli.,4705552913.jpg#2,4705552913.jpg#2r1c,contradiction,contradiction,contradiction,contradiction,contradiction
3,entailment,( ( ( Two ( young children ) ) ( in ( ( ( ( ( ...,( ( ( Two kids ) ( in ( numbered jerseys ) ) )...,(ROOT (S (NP (NP (CD Two) (JJ young) (NNS chil...,(ROOT (S (NP (NP (CD Two) (NNS kids)) (PP (IN ...,"Two young children in blue jerseys, one with t...",Two kids in numbered jerseys wash their hands.,2407214681.jpg#0,2407214681.jpg#0r1e,entailment,entailment,entailment,entailment,entailment
4,neutral,( ( ( Two ( young children ) ) ( in ( ( ( ( ( ...,( ( ( Two kids ) ( at ( a ballgame ) ) ) ( ( w...,(ROOT (S (NP (NP (CD Two) (JJ young) (NNS chil...,(ROOT (S (NP (NP (CD Two) (NNS kids)) (PP (IN ...,"Two young children in blue jerseys, one with t...",Two kids at a ballgame wash their hands.,2407214681.jpg#0,2407214681.jpg#0r1n,neutral,neutral,neutral,entailment,entailment
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,-,( ( ( A ( small girl ) ) ( wearing ( a ( pink ...,( ( The girl ) ( ( is ( sitting ( on ( ( a ( c...,(ROOT (S (NP (NP (DT A) (JJ small) (NN girl)) ...,(ROOT (S (NP (DT The) (NN girl)) (VP (VBZ is) ...,A small girl wearing a pink jacket is riding o...,The girl is sitting on a carved horse made of ...,77063034.jpg#3,77063034.jpg#3r1n,neutral,contradiction,entailment,contradiction,
9996,contradiction,( ( ( A ( small girl ) ) ( wearing ( a ( pink ...,( ( The girl ) ( ( is ( moving ( at ( ( the sp...,(ROOT (S (NP (NP (DT A) (JJ small) (NN girl)) ...,(ROOT (S (NP (DT The) (NN girl)) (VP (VBZ is) ...,A small girl wearing a pink jacket is riding o...,The girl is moving at the speed of light.,77063034.jpg#3,77063034.jpg#3r1c,contradiction,contradiction,contradiction,contradiction,contradiction
9997,entailment,( ( ( A ( young girl ) ) ( with ( ( ( ( ( ( bl...,( People ( in ( a ( water fountain ) ) ) ),(ROOT (NP (NP (DT A) (JJ young) (NN girl)) (PP...,(ROOT (NP (NP (NNS People)) (PP (IN in) (NP (D...,A young girl with blue and pink ribbons in her...,People in a water fountain,4805835848.jpg#0,4805835848.jpg#0r1e,entailment,entailment,entailment,entailment,entailment
9998,contradiction,( ( ( A ( young girl ) ) ( with ( ( ( ( ( ( bl...,( ( ( A ( young girl ) ) knits ) ( a sweater ) ),(ROOT (NP (NP (DT A) (JJ young) (NN girl)) (PP...,(ROOT (NP (NP (DT A) (JJ young) (NN girl)) (NP...,A young girl with blue and pink ribbons in her...,A young girl knits a sweater,4805835848.jpg#0,4805835848.jpg#0r1c,contradiction,contradiction,entailment,contradiction,neutral


In [28]:
def map_y_vals(y_val):
    y_dict = {'entailment':2, 'neutral':1, 'contradiction':0}
    return y_dict[y_val]

In [31]:
dev_data['preds'] = all_preds_dev.numpy()
dev_data['act'] = dev_data['label1'].apply(map_y_vals)

In [32]:
def get_wrong(row):
    if row['preds'] != row['act']:
        return 1
    else:
        return 0

In [34]:
dev_data['model_wrong'] = dev_data.apply(get_wrong,axis=1)

In [53]:
x = 9
print(dev_data[dev_data['model_wrong'] == 1].iloc[x]['sentence1'])
print(dev_data[dev_data['model_wrong'] == 1].iloc[x]['sentence2'])
print(dev_data[dev_data['model_wrong'] == 1].iloc[x]['label1'])
print(dev_data[dev_data['model_wrong'] == 1].iloc[x]['preds'])

Two men on bicycles competing in a race.
Men are riding bicycles on the street.
neutral
0.0


In [None]:
Epoch dev loss: 1.120858532607935
Dev accuracy: 0.3334

Epoch dev loss: 1.099487037415717
Dev accuracy: 0.3333

Epoch dev loss: 1.1003777912467907
Dev accuracy: 0.3334

In [None]:
Epoch dev loss: 0.884068538570404
Dev accuracy: 0.5948

Epoch dev loss: 0.8711921903133393
Dev accuracy: 0.6019

Epoch dev loss: 0.8515633422374725
Dev accuracy: 0.6133