In [1]:
import os
import time
import glob

import torch
import torch.optim as O
import torch.nn as nn

from torchtext import data
from torchtext import datasets

In [2]:
# ref: https://github.com/pytorch/text/blob/master/torchtext/datasets/snli.py
batch_size = 128
inputs = data.Field(lower=False)
answers = data.Field(sequential=False)
train, dev, test = datasets.SNLI.splits(inputs, answers)
train_iter, dev_iter, test_iter = data.BucketIterator.splits(
            (train, dev, test), batch_size=batch_size, device=0)

In [3]:
print(train.shape)
print(dev.shape)
print(test.shape)
print(len(train))
print(len(dev))
print(len(test))

<generator object Dataset.__getattr__ at 0x7f7600b08200>
<generator object Dataset.__getattr__ at 0x7f7600aecdb0>
<generator object Dataset.__getattr__ at 0x7f7600aecdb0>
549367
9842
9824


In [4]:
from IPython.display import Markdown, display

print(train[0])
# ref: https://stackoverflow.com/questions/2675028/list-attributes-of-an-object
print(train[0].__dict__)

def printmd(string):
    # ref: https://discuss.analyticsvidhya.com/t/how-to-make-a-text-bold-within-print-statement-in-ipython-notebook/14552/2
    display(Markdown(string))

def display_sentence(tokens):
    """
    ref: https://stackoverflow.com/questions/493386/how-to-print-without-newline-or-space
    """
    for i, token in enumerate(tokens):
        if i == len(tokens)-1:
            print(token, end='\n')
        else:
            print(token, end=' ')

for i in range(1, 10):
    print('='*80)
    display_sentence(train[i].premise)
    display_sentence(train[i].hypothesis)
    printmd('label: **%s**'%train[i].label)

<torchtext.data.example.Example object at 0x7f7621259a58>
{'premise': ['A', 'person', 'on', 'a', 'horse', 'jumps', 'over', 'a', 'broken', 'down', 'airplane.'], 'hypothesis': ['A', 'person', 'is', 'training', 'his', 'horse', 'for', 'a', 'competition.'], 'label': 'neutral'}
A person on a horse jumps over a broken down airplane.
A person is at a diner, ordering an omelette.


label: **contradiction**

A person on a horse jumps over a broken down airplane.
A person is outdoors, on a horse.


label: **entailment**

Children smiling and waving at camera
They are smiling at their parents


label: **neutral**

Children smiling and waving at camera
There are children present


label: **entailment**

Children smiling and waving at camera
The kids are frowning


label: **contradiction**

A boy is jumping on skateboard in the middle of a red bridge.
The boy skates down the sidewalk.


label: **contradiction**

A boy is jumping on skateboard in the middle of a red bridge.
The boy does a skateboarding trick.


label: **entailment**

A boy is jumping on skateboard in the middle of a red bridge.
The boy is wearing safety equipment.


label: **neutral**

An older man sits with his orange juice at a small table in a coffee shop while employees in bright colored shirts smile in the background.
An older man drinks his juice as he waits for his daughter to get off work.


label: **neutral**

In [5]:
from collections import defaultdict
word_count = defaultdict(float)
for i in train:
    #print(i.premise)
    #print(i.hypothesis)
    for j in i.premise:
        word_count[j] += 1
    for j in i.hypothesis:
        word_count[j] += 1
        
print( len(list(word_count.keys())) )

62996


In [6]:
def build_word_vocab(self, sentences):

        # Build vocabulary
        word_counts = collections.Counter(sentences)

        # Mapping from index to word
        idx_to_word = [x[0] for x in word_counts.most_common()]
        idx_to_word = list(sorted(idx_to_word)) + [self.pad_token, self.go_token, self.end_token]

        words_vocab_size = len(idx_to_word)

        # Mapping from word to index
        word_to_idx = {x: i for i, x in enumerate(idx_to_word)}

        return words_vocab_size, idx_to_word, word_to_idx

In [47]:
# Create a simplified dataset containing the entailment label only
import os, json

root_path = '.data/snli/snli_1.0_entail'
train = 'snli_1.0_train.jsonl'
dev = 'snli_1.0_dev.jsonl'
test = 'snli_1.0_test.jsonl'
train_new = 'snli_1.0_train_entail.jsonl'
dev_new = 'snli_1.0_dev_entail.jsonl'
test_new = 'snli_1.0_test_entail.jsonl'

org_files = [train, dev, test]
new_files = [train_new, dev_new, test_new]

for index, org_file in enumerate(org_files):
    print('#'*50)
    with open(os.path.join(root_path, org_file)) as f:
        lines = f.readlines()
        print(lines[0])

        lines_new = []
        for i, line in enumerate(lines):
            if i % 100000 == 0:
                print(i)
            data = json.loads(line)
            if data['gold_label'] == 'entailment':
                lines_new.append(line)

        print(lines_new[100])
        print(len(lines_new))
        with open(os.path.join(root_path, new_files[index]), 'w') as f_new:
            #f_new.write("\n".join(lines_new))
            f_new.write("".join(lines_new))
            
print('DONE.')

##################################################
{"annotator_labels": ["neutral"], "captionID": "3416050480.jpg#4", "gold_label": "neutral", "pairID": "3416050480.jpg#4r1n", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is training his horse for a competition.", "sentence2_binary_parse": "( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN horse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .)))"}

0
100000
200000
300000
400000
500000
{"annotator_labels": ["entailment"], "captionID": "37

In [9]:
import re, collections, operator, pickle, json

def tokenize(sent):
    '''
    data_reader.tokenize('a#b')
    ['a', '#', 'b']
    '''
    return [x.strip().lower() for x in re.split('(\W+)?', sent) if x.strip()]

root_path = '.data/snli/snli_1.0_entail'
train_new = 'snli_1.0_train_entail.jsonl'
dev_new = 'snli_1.0_dev_entail.jsonl'
test_new = 'snli_1.0_test_entail.jsonl'
new_files = [train_new, dev_new, test_new]

word_counts = collections.defaultdict(float)

for index, new_file in enumerate(new_files):
    print('#'*50)
    with open(os.path.join(root_path, new_file)) as f:
        lines = f.readlines()
        print(lines[1])
    
        lines_new = []
        for i, line in enumerate(lines):
            if i % 100000 == 0:
                print(i)
            data = json.loads(line)
            if data['gold_label'] == 'entailment':
                tokens1 = tokenize(data['sentence1'])
                tokens2 = tokenize(data['sentence2'])
                for token in tokens1:
                    word_counts[token] += 1
                for token in tokens2:
                    word_counts[token] += 1
                #lines_new.append(line)
        

print( len(list(word_counts.keys())) )
sorted_counts = sorted(word_counts.items(), key=operator.itemgetter(1), reverse=True)
print(sorted_counts[:100])
with open(os.path.join(root_path, 'word_counts.dat'), 'wb') as f:
    pickle.dump(sorted_counts, f)

print('DONE.')

##################################################
{"annotator_labels": ["entailment"], "captionID": "2267923837.jpg#2", "gold_label": "entailment", "pairID": "2267923837.jpg#2r1e", "sentence1": "Children smiling and waving at camera", "sentence1_binary_parse": "( Children ( ( ( smiling and ) waving ) ( at camera ) ) )", "sentence1_parse": "(ROOT (NP (S (NP (NNP Children)) (VP (VBG smiling) (CC and) (VBG waving) (PP (IN at) (NP (NN camera)))))))", "sentence2": "There are children present", "sentence2_binary_parse": "( There ( ( are children ) present ) )", "sentence2_parse": "(ROOT (S (NP (EX There)) (VP (VBP are) (NP (NNS children)) (ADVP (RB present)))))"}

0


  return _compile(pattern, flags).split(string, maxsplit)


100000
##################################################
{"annotator_labels": ["entailment", "entailment", "entailment", "entailment", "entailment"], "captionID": "2407214681.jpg#0", "gold_label": "entailment", "pairID": "2407214681.jpg#0r1e", "sentence1": "Two young children in blue jerseys, one with the number 9 and one with the number 2 are standing on wooden steps in a bathroom and washing their hands in a sink.", "sentence1_binary_parse": "( ( ( Two ( young children ) ) ( in ( ( ( ( ( blue jerseys ) , ) ( one ( with ( the ( number 9 ) ) ) ) ) and ) ( one ( with ( the ( number 2 ) ) ) ) ) ) ) ( ( are ( ( ( standing ( on ( ( wooden steps ) ( in ( a bathroom ) ) ) ) ) and ) ( ( washing ( their hands ) ) ( in ( a sink ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (CD Two) (JJ young) (NNS children)) (PP (IN in) (NP (NP (JJ blue) (NNS jerseys)) (, ,) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 9)))) (CC and) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN num

In [69]:
import fileinput

with open(os.path.join(root_path, 'word_counts.dat'), 'rb') as f:
    word_count = pickle.load(f)
print( len(list(word_counts.keys())) )

sorted_counts = sorted(word_counts.items(), key=operator.itemgetter(1), reverse=False)
print(sorted_counts[:50])


UNK_TOKEN = 'UNK'
# ref: https://stackoverflow.com/questions/17140886/how-to-search-and-replace-text-in-a-file-using-python
for index, new_file in enumerate(new_files):
    print('#'*50)
    
    with open(os.path.join(root_path, new_file), 'r') as file :
        #filedata = file.read()
        lines = file.readlines()
        
    lines_reduced = []
    for i, line in enumerate(lines):
        if i % 100 == 0:
            print(i)
            
        data = json.loads(line)
        
        #print(type(json.dumps(data)))
        
        for j, token in enumerate(sorted_counts[:10000]):
            tokens1 = tokenize(data['sentence1'])
            tokens2 = tokenize(data['sentence2'])
            #if token[0] in tokens1 or token[0] in tokens2:
            #    data['sentence1'] = data['sentence1'].replace(token[0], UNK_TOKEN)
            #    data['sentence2'] = data['sentence2'].replace(token[0], UNK_TOKEN)
            
                
            replaced1 = [UNK_TOKEN if x!=UNK_TOKEN and token[0] == x else x for x in tokens1]
            replaced2 = [UNK_TOKEN if x!=UNK_TOKEN and token[0] == x else x for x in tokens2]
            data['sentence1'] = " ".join(replaced1)
            data['sentence2'] = " ".join(replaced2)
            #if token[0] in tokens1 or token[0] in tokens2:
            #    print(replaced1)
            #    print(replaced2)
            #    print(data['sentence1'])
            #    print(data['sentence2'])
                
                    
        lines_reduced.append(json.dumps(data))
            
    with open(os.path.join(root_path, '%s_reduced'%new_file), 'w') as f:
        f.write("\n".join(lines_reduced))
                         
        # Replace the target string
    #   filedata = filedata.replace(token[0], UNK_TOKEN)
    #   Write the file out again
    #with open(os.path.join(root_path, '%s_reduced'%new_file), 'w') as file:
    #    file.write(filedata)

21953
[('carryout', 1.0), ('cultures', 1.0), ('dealt', 1.0), ('fertilizing', 1.0), ('fertilizering', 1.0), ('laws', 1.0), ('humanitarian', 1.0), ('possing', 1.0), ('scultupres', 1.0), ('powell', 1.0), ('grafffiti', 1.0), ('foils', 1.0), ('tease', 1.0), ('shucker', 1.0), ('pwople', 1.0), ('airlifted', 1.0), ('mik', 1.0), ('ruckus', 1.0), ('88', 1.0), ('tokyo', 1.0), ('foundtain', 1.0), ('charts', 1.0), ('parasailed', 1.0), ('expectant', 1.0), ('pregnancy', 1.0), ('performnef', 1.0), ('costruction', 1.0), ('doord', 1.0), ('brige', 1.0), ('listener', 1.0), ('aquestrian', 1.0), ('unfunny', 1.0), ('bodily', 1.0), ('cleft', 1.0), ('paddleboarding', 1.0), ('functions', 1.0), ('regard', 1.0), ('accurately', 1.0), ('fail', 1.0), ('tangle', 1.0), ('readily', 1.0), ('sipped', 1.0), ('watersking', 1.0), ('teeshirt', 1.0), ('performance3', 1.0), ('crucial', 1.0), ('tripoli', 1.0), ('libya', 1.0), ('smilingly', 1.0), ('sterring', 1.0)]
##################################################
0


  return _compile(pattern, flags).split(string, maxsplit)


100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
10500
10600
10700
10800
10900
11000
11100
11200
11300
11400
11500
11600
11700
11800
11900
12000
12100
12200
12300
12400
12500
12600
12700
12800
12900
13000
13100
13200
13300
13400
13500
13600
13700
13800
13900
14000
14100
14200
14300
14400
14500
14600
14700
14800
14900
15000
15100
15200
15300
15400
15500
15600
15700
15800
15900
16000
16100
16200
16300
16400
16500
16600
16700
16800
16900
17000
17100
17200
17300
17400
17500
17600
17700
17800
17900
18000
18100
18200
18300
18400
1850

133000
133100
133200
133300
133400
133500
133600
133700
133800
133900
134000
134100
134200
134300
134400
134500
134600
134700
134800
134900
135000
135100
135200
135300
135400
135500
135600
135700
135800
135900
136000
136100
136200
136300
136400
136500
136600
136700
136800
136900
137000
137100
137200
137300
137400
137500
137600
137700
137800
137900
138000
138100
138200
138300
138400
138500
138600
138700
138800
138900
139000
139100
139200
139300
139400
139500
139600
139700
139800
139900
140000
140100
140200
140300
140400
140500
140600
140700
140800
140900
141000
141100
141200
141300
141400
141500
141600
141700
141800
141900
142000
142100
142200
142300
142400
142500
142600
142700
142800
142900
143000
143100
143200
143300
143400
143500
143600
143700
143800
143900
144000
144100
144200
144300
144400
144500
144600
144700
144800
144900
145000
145100
145200
145300
145400
145500
145600
145700
145800
145900
146000
146100
146200
146300
146400
146500
146600
146700
146800
146900
147000
147100
147200