In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.stats
import seaborn as sns
import torch
import nltk
nltk.download("punkt")
import torch
import pickle
import regrFuncs as rF
import testFuncs as tF


[nltk_data] Downloading package punkt to /Users/cocolab/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
GLOVE_PATH = './Downloads/glove.840B.300d.txt'
MODEL_PATH = './Downloads/infersent.allnli.pickle'
REGR_MODEL_PATH = './models/'
EMBED_STORE = None
TEST_OUT_PATH = './regout/'
DATA_PATH = './Downloads/SNLI/true/'

outpaths = {'REGR_MODEL_PATH': REGR_MODEL_PATH, 'TEST_OUT_PATH': TEST_OUT_PATH}


id2label = {0:'CONTRADICTION', 1:'NEUTRAL', 2:'ENTAILMENT'}
label2id = {'CONTRADICTION': 0, 'NEUTRAL':1, 'ENTAILMENT':2}

In [3]:
model = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)
model.use_cuda = False
model.set_glove_path(GLOVE_PATH)
model.build_vocab_k_words(K=100000)

Vocab size : 100000


In [4]:
names = ['InferSent', 'BOW']
classifiers = [ 'LogReg']
all_regs = {}
for name in names:
    for classifier in classifiers:
        all_regs[name+classifier] = pickle.load(open('{0}{1}'.format(outpaths['REGR_MODEL_PATH'], name+classifier), 'rb'))



In [52]:
def print_preds(sents_a, sents_b, verbose = True):
    vals = {}
    for name in names:
        for classifier in classifiers:
            A, B = rF.embed(model, sent_a, 1, name), rF.embed(model, sent_b, 1, name)
            pred, conf = tF.predict(A, B, all_regs[name+classifier])
            if verbose:
                print('*'*20)
                print(name, classifier)
                print('*'*20, '\n')
            vals[name + classifier] = {}
            vals[name + classifier]['pred'] = []
            vals[name + classifier]['conf'] = []
            for i in range(len(A)):
                if verbose:
                    print('A: ', sent_a[i], '\t B: ', sent_b[i])
                    print(id2label[pred[i]], conf[i][pred[i]]*100)
                    print('\n')
                vals[name + classifier]['pred'].append(id2label[pred[i]])
                vals[name + classifier]['conf'].append(conf[i][pred[i]]*100)
                
            if verbose:
                print('\n\n')
    return vals

In [53]:
sent_a = ['the boy is less cheerful than the fat man .', 
          'the tall boy is less cheerful than the fat man .', 
          'the tall pale boy is less cheerful than the fat man .',
         'the tall pale young boy is less cheerful than the fat man .',
          'the tall pale young blonde boy is less cheerful than the fat man .']

sent_b = ['the boy is more cheerful than the fat man .', 
          'the tall boy is more cheerful than the fat man .', 
          'the tall pale boy is more cheerful than the fat man .',
         'the tall pale young boy is more cheerful than the fat man .',
          'the tall pale young blonde boy is more cheerful than the fat man .']
vals = print_preds(sent_a, sent_b)

********************
InferSent LogReg
******************** 

A:  the boy is less cheerful than the fat man . 	 B:  the boy is more cheerful than the fat man .
CONTRADICTION 95.73731422424316


A:  the tall boy is less cheerful than the fat man . 	 B:  the tall boy is more cheerful than the fat man .
CONTRADICTION 66.7090117931366


A:  the tall pale boy is less cheerful than the fat man . 	 B:  the tall pale boy is more cheerful than the fat man .
CONTRADICTION 68.93193125724792


A:  the tall pale young boy is less cheerful than the fat man . 	 B:  the tall pale young boy is more cheerful than the fat man .
ENTAILMENT 51.049256324768066


A:  the tall pale young blonde boy is less cheerful than the fat man . 	 B:  the tall pale young blonde boy is more cheerful than the fat man .
ENTAILMENT 77.26410031318665





********************
BOW LogReg
******************** 

A:  the boy is less cheerful than the fat man . 	 B:  the boy is more cheerful than the fat man .
NEUTRAL 92.1831071376

In [54]:
sent_a = ['the boy is less cheerful than the fat man.',
          'the boy is less cheerful than the fat man with the hat.',
          'the boy is less cheerful than the fat man with the big hat.',
          'the boy with the hat is less cheerful than the fat man.',
          'the boy with the big hat is less cheerful than the fat man.']

sent_b = ['the boy is more cheerful than the fat man.',
          'the boy is more cheerful than the fat man with the hat.',
          'the boy is more cheerful than the fat man with the big hat.',
          'the boy with the hat is more cheerful than the fat man.',
          'the boy with the big hat is more cheerful than the fat man.']
vals = print_preds(sent_a, sent_b)

********************
InferSent LogReg
******************** 

A:  the boy is less cheerful than the fat man. 	 B:  the boy is more cheerful than the fat man.
CONTRADICTION 93.67058277130127


A:  the boy is less cheerful than the fat man with the hat. 	 B:  the boy is more cheerful than the fat man with the hat.
CONTRADICTION 95.35847306251526


A:  the boy is less cheerful than the fat man with the big hat. 	 B:  the boy is more cheerful than the fat man with the big hat.
CONTRADICTION 81.98242783546448


A:  the boy with the hat is less cheerful than the fat man. 	 B:  the boy with the hat is more cheerful than the fat man.
CONTRADICTION 65.6223475933075


A:  the boy with the big hat is less cheerful than the fat man. 	 B:  the boy with the big hat is more cheerful than the fat man.
ENTAILMENT 65.0149405002594





********************
BOW LogReg
******************** 

A:  the boy is less cheerful than the fat man. 	 B:  the boy is more cheerful than the fat man.
NEUTRAL 92.183107137

In [55]:
sent_a = ['the girl does shout loudly , however the boy does not shout loudly .',
         'the girl does shout loudly , however the tall boy does not shout loudly .',
         'the girl does shout loudly , however the tall pale boy does not shout loudly .',
         'the girl does shout loudly , however the tall pale young boy does not shout loudly .']

sent_b = ['the boy does shout loudly .',
         'the tall boy does shout loudly .',
         'the tall pale boy does shout loudly .',
         'the tall pale young boy does shout loudly .']


vals = print_preds(sent_a, sent_b)

********************
InferSent LogReg
******************** 

A:  the girl does shout loudly , however the boy does not shout loudly . 	 B:  the boy does shout loudly .
CONTRADICTION 99.96054768562317


A:  the girl does shout loudly , however the tall boy does not shout loudly . 	 B:  the tall boy does shout loudly .
CONTRADICTION 98.7440824508667


A:  the girl does shout loudly , however the tall pale boy does not shout loudly . 	 B:  the tall pale boy does shout loudly .
CONTRADICTION 82.65823721885681


A:  the girl does shout loudly , however the tall pale young boy does not shout loudly . 	 B:  the tall pale young boy does shout loudly .
ENTAILMENT 73.01568388938904





********************
BOW LogReg
******************** 

A:  the girl does shout loudly , however the boy does not shout loudly . 	 B:  the boy does shout loudly .
ENTAILMENT 99.84913468360901


A:  the girl does shout loudly , however the tall boy does not shout loudly . 	 B:  the tall boy does shout loudly .
ENTA

In [56]:
sent_a = ['the girl does shout loudly , however the boy does not shout loudly .',
         'the tall girl does shout loudly , however the boy does not shout loudly .',
         'the tall pale girl does shout loudly , however the boy does not shout loudly .',
         'the tall pale young girl does shout loudly , however the boy does not shout loudly .',]

sent_b = ['the boy does shout loudly .',
          'the boy does shout loudly .',
          'the boy does shout loudly .',
          'the boy does shout loudly .']

vals = print_preds(sent_a, sent_b)

********************
InferSent LogReg
******************** 

A:  the girl does shout loudly , however the boy does not shout loudly . 	 B:  the boy does shout loudly .
CONTRADICTION 99.96054768562317


A:  the tall girl does shout loudly , however the boy does not shout loudly . 	 B:  the boy does shout loudly .
CONTRADICTION 99.45414662361145


A:  the tall pale girl does shout loudly , however the boy does not shout loudly . 	 B:  the boy does shout loudly .
CONTRADICTION 96.6947615146637


A:  the tall pale young girl does shout loudly , however the boy does not shout loudly . 	 B:  the boy does shout loudly .
CONTRADICTION 89.04892802238464





********************
BOW LogReg
******************** 

A:  the girl does shout loudly , however the boy does not shout loudly . 	 B:  the boy does shout loudly .
ENTAILMENT 99.84913468360901


A:  the tall girl does shout loudly , however the boy does not shout loudly . 	 B:  the boy does shout loudly .
ENTAILMENT 99.61853623390198


A:  th

In [57]:
sent_a = ['the boy overtakes the old woman .',
           'the tall boy overtakes the old woman .',
           'the tall thin boy overtakes the old woman .',
           'the tall thin pale boy overtakes the old woman .']

sent_b = ['the old woman overtakes the boy .',
           'the old woman overtakes the tall boy .',
           'the old woman overtakes the tall thin boy .',
           'the old woman overtakes the tall thin pale boy .']

vals = print_preds(sent_a, sent_b)         

********************
InferSent LogReg
******************** 

A:  the boy overtakes the old woman . 	 B:  the old woman overtakes the boy .
CONTRADICTION 92.18960404396057


A:  the tall boy overtakes the old woman . 	 B:  the old woman overtakes the tall boy .
CONTRADICTION 90.90285897254944


A:  the tall thin boy overtakes the old woman . 	 B:  the old woman overtakes the tall thin boy .
ENTAILMENT 57.74509906768799


A:  the tall thin pale boy overtakes the old woman . 	 B:  the old woman overtakes the tall thin pale boy .
ENTAILMENT 82.39078521728516





********************
BOW LogReg
******************** 

A:  the boy overtakes the old woman . 	 B:  the old woman overtakes the boy .
ENTAILMENT 92.98434853553772


A:  the tall boy overtakes the old woman . 	 B:  the old woman overtakes the tall boy .
ENTAILMENT 89.67492580413818


A:  the tall thin boy overtakes the old woman . 	 B:  the old woman overtakes the tall thin boy .
ENTAILMENT 70.89486122131348


A:  the tall thin pale

In [62]:
adjs = ['cheerful', 'happy', 'tired', 'unhappy', 'big', 'excited']

sent_a0 = ['', 'more ', 'less ', 'not ']

sent_b0 = ['not ', 'less ', 'more ', '',]

adj_vals = {}
for adj in adjs:
    sent_a = [x + adj for x in sent_a0]
    sent_b = [x + adj for x in sent_b0]
    adj_vals[adj] = print_preds(sent_a, sent_b, verbose = False)         

In [63]:
adj_vals

{'big': {'BOWLogReg': {'conf': [57.25385546684265,
    99.49661493301392,
    89.60018157958984,
    94.13952231407166],
   'pred': ['CONTRADICTION', 'CONTRADICTION', 'NEUTRAL', 'CONTRADICTION']},
  'InferSentLogReg': {'conf': [50.8463978767395,
    99.69869256019592,
    99.8491644859314,
    96.91717028617859],
   'pred': ['ENTAILMENT', 'CONTRADICTION', 'CONTRADICTION', 'ENTAILMENT']}},
 'cheerful': {'BOWLogReg': {'conf': [91.12583994865417,
    40.27191996574402,
    89.72070217132568,
    51.44846439361572],
   'pred': ['ENTAILMENT', 'NEUTRAL', 'NEUTRAL', 'ENTAILMENT']},
  'InferSentLogReg': {'conf': [91.288822889328,
    98.98456931114197,
    97.59145379066467,
    95.07464170455933],
   'pred': ['CONTRADICTION', 'CONTRADICTION', 'CONTRADICTION', 'ENTAILMENT']}},
 'excited': {'BOWLogReg': {'conf': [99.94516372680664,
    87.2475266456604,
    58.1783652305603,
    99.92314577102661],
   'pred': ['ENTAILMENT', 'ENTAILMENT', 'ENTAILMENT', 'ENTAILMENT']},
  'InferSentLogReg': {'conf