In [1]:
import csv
import math
import string
import itertools
from io import open
from conllu import parse_incr
import torch
import torch.nn as nn
import torch.nn.functional as F
from nltk.corpus import wordnet as wn

In [2]:
# tests for pytorch computation graph
from torch.nn import CosineEmbeddingLoss
l = CosineEmbeddingLoss()
import torch.optim as optim

sense_vec = torch.ones(3, 1)
sense_vec.requires_grad = True
def_vec = torch.randn(3, 1)
def_vec.requires_grad = True

matrix = torch.ones(3, 3)
matrix.requires_grad = True
vec = matrix[2, :].view(3, 1)
print(vec.grad_fn)
print(vec.requires_grad)
label = torch.randn(3, 1)

opt = optim.Adam([matrix])
test_vec = torch.ones(def_vec.size())

for i in range(1):
    
    opt.zero_grad()
    loss = l(vec, label, test_vec)
    # print(loss)
    loss.backward()
    print(matrix.grad)
    # print(def_vec.grad)
    
    opt.step()

print(matrix)

<ViewBackward object at 0x12413fa90>
True
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00, -2.9802e-08]])
tensor([[1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0001]], requires_grad=True)


In [3]:
'''
return: 
all senses for each word 
all definitions for each word
all supersenses
from the EUD for train, test, and dev dataset
index provided by WSD dataset by White et. al.
'''
# get all the senses and definitions for each word from WSD dataset
# order of senses and definitions are in order
def get_all_senses_and_definitions(wsd_data):

    # all senses for each 
    all_senses = {}
    all_definitions = {}
    all_supersenses = {}
    
    # for test purpose: only load specific amount of data
    for i in range(100):

        # get the original sentence from EUD
        sentence_id = wsd_data[i].get('Sentence.ID')
        
        # get the definitions for the target word from EUD
        definition = wsd_data[i].get('Sense.Definition').split(' ')
        
        # the index in EUD is 1-based!!!
        sentence_number = int(sentence_id.split(' ')[-1]) - 1
        word_index = int(wsd_data[i].get('Arg.Token')) - 1
        word_lemma = wsd_data[i].get('Arg.Lemma')
        word_sense = wsd_data[i].get('Synset')
        response = wsd_data[i].get('Sense.Response')
        
        # supersense-> (word_lemma, word_sense) dictionary
        super_s = wn.synset(word_sense).lexname().replace('.', '_')
        if all_supersenses.get(super_s, 'not_exist') != 'not_exist':
            all_supersenses[super_s].add((word_lemma, word_sense))
        else:
            all_supersenses[super_s] = {(word_lemma, word_sense)}

        # if the word already exits: add the new sense to the list
        # else: creata a new list for the word
        if all_senses.get(word_lemma, 'not_exist') != 'not_exist':
            if word_sense not in all_senses[word_lemma]:
                all_senses[word_lemma].append(word_sense)
        else:
            all_senses[word_lemma] = []
            all_senses[word_lemma].append(word_sense)            
            
        if all_definitions.get(word_lemma,'not_exist') != 'not_exist':
            if definition not in all_definitions[word_lemma]: 
                all_definitions[word_lemma].append(definition)
        else:
            all_definitions[word_lemma] = []
            all_definitions[word_lemma].append(definition)
        
    return all_senses, all_definitions, all_supersenses

In [4]:
# parse the WSD dataset first
# and retrieve all sentences from the EUD

'''
Copyright@
White, A. S., D. Reisinger, K. Sakaguchi, T. Vieira, S. Zhang, R. Rudinger, K. Rawlins, & B. Van Durme. 2016. 
[Universal decompositional semantics on universal dependencies]
(http://aswhite.net/media/papers/white_universal_2016.pdf). 
To appear in *Proceedings of the Conference on Empirical Methods in Natural Language Processing 2016*.
'''

# parse the WSD dataset and construct X_Y tensors
def parse_data():

    # parse the EUD-EWT conllu files and retrieve the sentences
    # remove all punctuation?
    train_file = open("data/UD_English-EWT/en_ewt-ud-train.conllu", "r", encoding="utf-8")
    train_data = list(parse_incr(train_file))
    # train_data = [[''.join(c for c in word.get('lemma') if c not in string.punctuation) for word in token_list] for token_list in train_data]
    # train_data = [[word for word in s if word] for s in train_data]
    print('Parsed {} training data from UD_English-EWT/en_ewt-ud-train.conllu.'.format(len(train_data)))

    test_file = open("data/UD_English-EWT/en_ewt-ud-test.conllu", "r", encoding="utf-8")
    test_data = list(parse_incr(test_file))
    # test_data = [[''.join(c for c in word.get('lemma') if c not in string.punctuation) for word in token_list] for token_list in test_data]
    # test_data = [[word for word in s if word] for s in test_data]
    print('Parsed {} testing data from UD_English-EWT/en_ewt-ud-test.conllu.'.format(len(test_data)))

    dev_file = open("data/UD_English-EWT/en_ewt-ud-dev.conllu", "r", encoding="utf-8")
    dev_data = list(parse_incr(dev_file))
    # dev_data = [[''.join(c for c in word.get('lemma') if c not in string.punctuation) for word in token_list] for token_list in dev_data]
    # dev_data = [[word for word in s if word] for s in dev_data]
    print('Parsed {} dev data from UD_English-EWT/en_ewt-ud-dev.conllu.'.format(len(dev_data)))

    # parse the WSD dataset
    wsd_data = []

    # read in tsv by White et. al., 2016
    with open('data/wsd/wsd_eng_ud1.2_10262016.tsv', mode = 'r') as wsd_file:

        tsv_reader = csv.DictReader(wsd_file, delimiter = '\t')      

        # store the data: ordered dict row
        for row in tsv_reader:                                

            # each data vector
            wsd_data.append(row)

        # make sure all data are parsed
        print('Parsed {} word sense data from White et. al., 2016.'.format(len(wsd_data)))

    return wsd_data, train_data, test_data, dev_data

In [5]:
# parse the data
wsd_data, train_data, test_data, dev_data = parse_data()

# return the raw sentences from the EUD for train, test, and dev
# test small amount of sentences
all_senses, all_definitions, all_supersenses = get_all_senses_and_definitions(wsd_data)

for key in all_supersenses.keys():
    print('{} : {}'.format(key, all_supersenses[key]))

Parsed 12543 training data from UD_English-EWT/en_ewt-ud-train.conllu.
Parsed 2077 testing data from UD_English-EWT/en_ewt-ud-test.conllu.
Parsed 2002 dev data from UD_English-EWT/en_ewt-ud-dev.conllu.
Parsed 439312 word sense data from White et. al., 2016.
noun_time : {('rate', 'rate.n.01'), ('spring', 'spring.n.01')}
noun_artifact : {('puppet', 'puppet.n.03'), ('level', 'level.n.05'), ('level', 'horizontal_surface.n.01'), ('puppet', 'puppet.n.01'), ('spring', 'spring.n.02'), ('house', 'house.n.12'), ('level', 'floor.n.02'), ('house', 'theater.n.01'), ('house', 'house.n.01')}
noun_object : {('spring', 'spring.n.03')}
noun_location : {('place', 'place.n.04'), ('place', 'place.n.02'), ('place', 'position.n.01'), ('place', 'seat.n.01'), ('place', 'plaza.n.01'), ('spring', 'spring.n.04'), ('place', 'home.n.01'), ('place', 'topographic_point.n.01'), ('house', 'sign_of_the_zodiac.n.01')}
noun_attribute : {('spring', 'give.n.01'), ('rate', 'pace.n.03'), ('level', 'level.n.04'), ('advantage',

In [6]:
# test for the WordNet NLTK API
'''
The specific Synset method is lexname, e.g. wn.synsets('spring')[0].lexname(). 
That should make it really easy to get the suspersenses.
And if you have the synset name–e.g. 'spring.n.01'
you can access the supersense directly: wn.synset('spring.n.01').lexname().
Which returns 'noun.time'.
And wn.synset('spring.n.02').lexname() returns 'noun.artifact'
'''
for idx, d in enumerate(all_definitions['spring']):
    print(d)
    print(wn.synset(all_senses['spring'][idx]).lexname())

for _ in wn.synsets('spring'):
    print(_.lexname())

['the', 'season', 'of', 'growth']
noun.time
['a', 'metal', 'elastic', 'device', 'that', 'returns', 'to', 'its', 'shape', 'or', 'position', 'when', 'pushed', 'or', 'pulled', 'or', 'pressed']
noun.artifact
['a', 'natural', 'flow', 'of', 'ground', 'water']
noun.object
['a', 'point', 'at', 'which', 'water', 'issues', 'forth']
noun.location
['the', 'elasticity', 'of', 'something', 'that', 'can', 'be', 'stretched', 'and', 'returns', 'to', 'its', 'original', 'length']
noun.attribute
['a', 'light,', 'self-propelled', 'movement', 'upwards', 'or', 'forwards']
noun.act
noun.time
noun.artifact
noun.object
noun.location
noun.attribute
noun.act
verb.motion
verb.stative
verb.motion
verb.body
verb.communication


In [7]:
'''
Construct the X and Y for train, dev, and test from White et. al., 2016
For each anonator and each word, on pair of data and label will be created
Warning: code here is hard to read LMAO
'''
def construct_X_Y(all_senses, all_definitions, train_data, dev_data, test_data):
    
    wsd_data = []
    
    with open('data/wsd/wsd_eng_ud1.2_10262016.tsv', mode = 'r') as wsd_file:
        
        tsv_reader = csv.DictReader(wsd_file, delimiter = '\t')
        
        # same annotator and same sentence number will generate on pair of X_Y
        # manually set the first sentence from White et. al., 2016
        current_annotator = '0'
        current_sentence_num = '1364'
        current_Y = [0 for _ in range(len(all_senses['spring']))]           
        sentence = train_data[1363]
        
        # word from the EUD is a ordered dict for word properties
        # use key 'lemma' to get the literal representations
        current_X = [word.get('lemma') for word in sentence]
        current_idx = 12
        
        # lists X and Y
        train_X, test_X, dev_X = ([] for i in range(3))
        train_Y, test_Y, dev_Y = ([] for i in range(3))
        train_word_idx, test_word_idx, dev_word_idx = ([] for i in range(3))

        for idx, row in enumerate(tsv_reader):
                        
            # training set; only test first 30 training sentences for now
            if idx < 100 and row['Split'] == 'train':
                
                # if still is the same annotatior, word index, target word
                # modify Y with the sense reponse
                if current_annotator == row['Annotator.ID'] and current_idx == int(row['Arg.Token']) - 1 and current_sentence_num == row['Sentence.ID'].split(' ')[-1]:
                    
                    sense_idx = all_senses[row['Arg.Lemma']].index(row['Synset'])
                    if row['Sense.Response'] == 'True':
                        current_Y[sense_idx] = 1
                    else:
                        current_Y[sense_idx] = 0
                
                # if switch annotator or target word
                # append the Y and X from the last annotator and word
                # start a new Y and X for the current annotator and target
                else:
                    # print('h2: {}'.format(idx))
                    # sentence
                    train_X.append(current_X)
                    # annotator responses, e.g., [1, 0, ...]
                    train_Y.append(current_Y)
                    train_word_idx.append(current_idx)
                    
                    current_annotator = row['Annotator.ID']
                    current_sentence_num = row['Sentence.ID'].split(' ')[-1]
                    
                    current_idx = int(row['Arg.Token']) - 1
                    current_Y = [0 for _ in range(len(all_senses[row['Arg.Lemma']]))]
                    sense_idx = all_senses[row['Arg.Lemma']].index(row['Synset'])
                    if row['Sense.Response'] == 'True':
                        current_Y[sense_idx] = 1
                    else:
                        current_Y[sense_idx] = 0
                    
                    sentence_id = row['Sentence.ID']
                    sentence_number = int(sentence_id.split(' ')[-1]) - 1
                    sentence = train_data[sentence_number]
                    current_X = [word.get('lemma') for word in sentence]
                    
            # testing set
            elif idx < 100 and row['Split'] == 'test':
                
                if current_annotator == row['Annotator.ID'] and current_idx == int(row['Arg.Token']) - 1 and current_sentence_num == row['Sentence.ID'].split(' ')[-1]:
                    
                    sense_idx = all_senses[row['Arg.Lemma']].index(row['Synset'])
                    if row['Sense.Response'] == 'True':
                        current_Y[sense_idx] = 1
                    else:
                        current_Y[sense_idx] = 0
                else:
                    test_X.append(current_X)
                    test_Y.append(current_Y)
                    test_word_idx.append(current_idx)

                    current_annotator = row['Annotator.ID']
                    current_sentence_num = row['Sentence.ID'].split(' ')[-1]

                    current_idx = int(row['Arg.Token']) - 1
                    current_Y = [0 for _ in range(len(all_senses[row['Arg.Lemma']]))]
                    sense_idx = all_senses[row['Arg.Lemma']].index(row['Synset'])
                    if row['Sense.Response'] == 'True':
                        current_Y[sense_idx] = 1
                    else:
                        current_Y[sense_idx] = 0
                    
                    sentence_id = row['Sentence.ID']
                    sentence_number = int(sentence_id.split(' ')[-1]) - 1
                    sentence = test_data[sentence_number]
                    current_X = [word.get('lemma') for word in sentence]
                    
            # dev set       
            elif idx < 100:
                if current_annotator == row['Annotator.ID'] and current_idx == int(row['Arg.Token']) - 1 and current_sentence_num == row['Sentence.ID'].split(' ')[-1]:
                    
                    sense_idx = all_senses[row['Arg.Lemma']].index(row['Synset'])
                    if row['Sense.Response'] == 'True':
                        current_Y[sense_idx] = 1
                    else:
                        current_Y[sense_idx] = 0
                else:
                    dev_X.append(current_X)
                    dev_Y.append(current_Y)
                    dev_word_idx.append(current_idx)

                    current_annotator = row['Annotator.ID']
                    current_sentence_num = row['Sentence.ID'].split(' ')[-1]

                    current_idx = int(row['Arg.Token']) - 1
                    current_Y = [0 for _ in range(len(all_senses[row['Arg.Lemma']]))]
                    sense_idx = all_senses[row['Arg.Lemma']].index(row['Synset'])
                    if row['Sense.Response'] == 'True':
                        current_Y[sense_idx] = 1
                    else:
                        current_Y[sense_idx] = 0
                    
                    sentence_id = row['Sentence.ID']
                    sentence_number = int(sentence_id.split(' ')[-1]) - 1
                    sentence = dev_data[sentence_number]
                    current_X = [word.get('lemma') for word in sentence]
        
        print('\n******************* Data Example ***********************')
        print('Sentence: {}'.format(train_X[0]))
        print('Annotator Response, i.e., true label: {}'.format(train_Y[0]))
        print('Target Word Index: {}'.format(train_word_idx[0]))
        print('All senses for the target word: {}'.format(all_senses[train_X[0][train_word_idx[0]]]))
        print('All definitions (in order of its senses from WordNet): {}'.format(all_definitions[train_X[0][train_word_idx[0]]]))
        print('********************************************************')
        
        return train_X, test_X, dev_X, train_Y, test_Y, dev_Y, train_word_idx, test_word_idx, dev_word_idx
                    

In [8]:
# get the training, dev, and testing data
train_X, test_X, dev_X, train_Y, test_Y, dev_Y, train_word_idx, test_word_idx, dev_word_idx = construct_X_Y(all_senses, all_definitions, train_data, dev_data, test_data)



******************* Data Example ***********************
Sentence: ['on', 'August', '9', ',', '2004', ',', 'it', 'be', 'announce', 'that', 'in', 'the', 'spring', 'of', '2001', ',', 'a', 'man', 'name', 'El', '-', 'Shukrijumah', ',', 'also', 'know', 'as', 'Jafar', 'the', 'Pilot', ',', 'who', 'be', 'part', 'of', 'a', '"', 'second', 'wave', ',', '"', 'have', 'be', 'case', 'New', 'York', 'City', 'helicopter', '.']
Annotator Response, i.e., true label: [1, 0, 0, 0, 0, 0]
Target Word Index: 12
All senses for the target word: ['spring.n.01', 'spring.n.02', 'spring.n.03', 'spring.n.04', 'give.n.01', 'leap.n.01']
All definitions (in order of its senses from WordNet): [['the', 'season', 'of', 'growth'], ['a', 'metal', 'elastic', 'device', 'that', 'returns', 'to', 'its', 'shape', 'or', 'position', 'when', 'pushed', 'or', 'pulled', 'or', 'pressed'], ['a', 'natural', 'flow', 'of', 'ground', 'water'], ['a', 'point', 'at', 'which', 'water', 'issues', 'forth'], ['the', 'elasticity', 'of', 'something',

In [9]:
from model import *
from trainer import *

from allennlp.commands.elmo import ElmoEmbedder
elmo = ElmoEmbedder()

'''
from torchviz import make_dot

model = Model(all_senses = all_senses, elmo_class = elmo)
sense_embedding = model.forward(train_X[0], train_word_idx[0])
'''

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.
Device: cpu


'\nfrom torchviz import make_dot\n\nmodel = Model(all_senses = all_senses, elmo_class = elmo)\nsense_embedding = model.forward(train_X[0], train_word_idx[0])\n'

In [12]:
# trainer
epochs = 100
trainer = Trainer(epochs = epochs, elmo_class = elmo, all_senses = all_senses, all_supersenses = all_supersenses)

In [13]:
# get the results
train_losses, dev_losses, dev_rs = trainer.train(train_X, train_Y, train_word_idx, dev_X, dev_Y, dev_word_idx, development = False)



#############   Model Parameters   ##############
layers.word_sense.0.weight torch.Size([512, 512])
layers.word_sense.0.bias torch.Size([512])
layers.word_sense.2.weight torch.Size([300, 512])
layers.word_sense.2.bias torch.Size([300])
dimension_reduction_MLP.weight torch.Size([256, 3072])
dimension_reduction_MLP.bias torch.Size([256])
wsd_lstm.weight_ih_l0 torch.Size([1024, 256])
wsd_lstm.weight_hh_l0 torch.Size([1024, 256])
wsd_lstm.bias_ih_l0 torch.Size([1024])
wsd_lstm.bias_hh_l0 torch.Size([1024])
wsd_lstm.weight_ih_l0_reverse torch.Size([1024, 256])
wsd_lstm.weight_hh_l0_reverse torch.Size([1024, 256])
wsd_lstm.bias_ih_l0_reverse torch.Size([1024])
wsd_lstm.bias_hh_l0_reverse torch.Size([1024])
wsd_lstm.weight_ih_l1 torch.Size([1024, 512])
wsd_lstm.weight_hh_l1 torch.Size([1024, 256])
wsd_lstm.bias_ih_l1 torch.Size([1024])
wsd_lstm.bias_hh_l1 torch.Size([1024])
wsd_lstm.weight_ih_l1_reverse torch.Size([1024, 512])
wsd_lstm.weight_hh_l1_reverse torch.Size([1024, 256])
wsd_lstm.bia

HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 1, Mean Training Loss: 5.5710016300803735
[Epoch: 2/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 2, Mean Training Loss: 5.546315544529965
[Epoch: 3/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 3, Mean Training Loss: 5.529647977728593
[Epoch: 4/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 4, Mean Training Loss: 5.514912228835256
[Epoch: 5/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 5, Mean Training Loss: 5.513333094747443
[Epoch: 6/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 6, Mean Training Loss: 5.504736912877936
[Epoch: 7/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 7, Mean Training Loss: 5.506315959127326
[Epoch: 8/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 8, Mean Training Loss: 5.512280539462441
[Epoch: 9/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 9, Mean Training Loss: 5.505964781108656
[Epoch: 10/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 10, Mean Training Loss: 5.504912250920346
[Epoch: 11/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 11, Mean Training Loss: 5.507017298748619
[Epoch: 12/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 12, Mean Training Loss: 5.504385747407612
[Epoch: 13/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 13, Mean Training Loss: 5.505087614059448
[Epoch: 14/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 14, Mean Training Loss: 5.504385935632806
[Epoch: 15/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 15, Mean Training Loss: 5.50105253018831
[Epoch: 16/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 16, Mean Training Loss: 5.500350707455685
[Epoch: 17/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 17, Mean Training Loss: 5.494035206342998
[Epoch: 18/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 18, Mean Training Loss: 5.494736633802715
[Epoch: 19/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 19, Mean Training Loss: 5.488947328768279
[Epoch: 20/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 20, Mean Training Loss: 5.487192969573171
[Epoch: 21/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 21, Mean Training Loss: 5.4901752346440365
[Epoch: 22/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 22, Mean Training Loss: 5.487192844089709
[Epoch: 23/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 23, Mean Training Loss: 5.489122955422652
[Epoch: 24/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 24, Mean Training Loss: 5.487543720948069
[Epoch: 25/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 25, Mean Training Loss: 5.481929553182502
[Epoch: 26/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 26, Mean Training Loss: 5.4819298166977735
[Epoch: 27/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 27, Mean Training Loss: 5.484034964912816
[Epoch: 28/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 28, Mean Training Loss: 5.486140326449745
[Epoch: 29/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 29, Mean Training Loss: 5.485438522539641
[Epoch: 30/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 30, Mean Training Loss: 5.483157785315263
[Epoch: 31/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 31, Mean Training Loss: 5.485438547636333
[Epoch: 32/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 32, Mean Training Loss: 5.4917542307000415
[Epoch: 33/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 33, Mean Training Loss: 5.493859717720433
[Epoch: 34/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 34, Mean Training Loss: 5.490701675415039
[Epoch: 35/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 35, Mean Training Loss: 5.492456047158492
[Epoch: 36/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 36, Mean Training Loss: 5.506666622663799
[Epoch: 37/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 37, Mean Training Loss: 5.502631400760851
[Epoch: 38/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 38, Mean Training Loss: 5.507719253238879
[Epoch: 39/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 39, Mean Training Loss: 5.5075438022613525
[Epoch: 40/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 40, Mean Training Loss: 5.506140282279567
[Epoch: 41/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 41, Mean Training Loss: 5.505788225876658
[Epoch: 42/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 42, Mean Training Loss: 5.506666660308838
[Epoch: 43/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 43, Mean Training Loss: 5.503684407786319
[Epoch: 44/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 44, Mean Training Loss: 5.495964803193745
[Epoch: 45/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 45, Mean Training Loss: 5.499473571777344
[Epoch: 46/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 46, Mean Training Loss: 5.5031577285967375
[Epoch: 47/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 47, Mean Training Loss: 5.508771971652382
[Epoch: 48/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 48, Mean Training Loss: 5.514912128448486
[Epoch: 49/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 49, Mean Training Loss: 5.510877182609157
[Epoch: 50/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 50, Mean Training Loss: 5.518772012309024
[Epoch: 51/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 51, Mean Training Loss: 5.5175440311431885
[Epoch: 52/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 52, Mean Training Loss: 5.5143858508059855
[Epoch: 53/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 53, Mean Training Loss: 5.512631491610878
[Epoch: 54/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 54, Mean Training Loss: 5.505613828960218
[Epoch: 55/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 55, Mean Training Loss: 5.505613828960218
[Epoch: 56/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 56, Mean Training Loss: 5.502104985086541
[Epoch: 57/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 57, Mean Training Loss: 5.500701690974989
[Epoch: 58/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 58, Mean Training Loss: 5.499824172572086
[Epoch: 59/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 59, Mean Training Loss: 5.497894713753148
[Epoch: 60/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 60, Mean Training Loss: 5.497894713753148
[Epoch: 61/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 61, Mean Training Loss: 5.497543811798096
[Epoch: 62/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 62, Mean Training Loss: 5.501754321550068
[Epoch: 63/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 63, Mean Training Loss: 5.4999998494198445
[Epoch: 64/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 64, Mean Training Loss: 5.5007016156849105
[Epoch: 65/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 65, Mean Training Loss: 5.50210522350512
[Epoch: 66/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 66, Mean Training Loss: 5.500175363139102
[Epoch: 67/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 67, Mean Training Loss: 5.499999937258269
[Epoch: 68/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 68, Mean Training Loss: 5.499999937258269
[Epoch: 69/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 69, Mean Training Loss: 5.500000213321886
[Epoch: 70/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 70, Mean Training Loss: 5.502455949783325
[Epoch: 71/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 71, Mean Training Loss: 5.496666694942274
[Epoch: 72/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 72, Mean Training Loss: 5.497543824346442
[Epoch: 73/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 73, Mean Training Loss: 5.505789355227821
[Epoch: 74/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 74, Mean Training Loss: 5.499298108251471
[Epoch: 75/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 75, Mean Training Loss: 5.498596354534752
[Epoch: 76/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 76, Mean Training Loss: 5.498596354534752
[Epoch: 77/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 77, Mean Training Loss: 5.498596354534752
[Epoch: 78/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 78, Mean Training Loss: 5.491403441680105
[Epoch: 79/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 79, Mean Training Loss: 5.484561330393741
[Epoch: 80/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 80, Mean Training Loss: 5.48403495236447
[Epoch: 81/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 81, Mean Training Loss: 5.4838594762902515
[Epoch: 82/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 82, Mean Training Loss: 5.483508599431891
[Epoch: 83/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 83, Mean Training Loss: 5.483157722573531
[Epoch: 84/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 84, Mean Training Loss: 5.481052561810142
[Epoch: 85/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 85, Mean Training Loss: 5.4778947830200195
[Epoch: 86/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 86, Mean Training Loss: 5.475263056002166
[Epoch: 87/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 87, Mean Training Loss: 5.4789471500798275
[Epoch: 88/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 88, Mean Training Loss: 5.4712280850661426
[Epoch: 89/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 89, Mean Training Loss: 5.470877208207783
[Epoch: 90/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 90, Mean Training Loss: 5.470877208207783
[Epoch: 91/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 91, Mean Training Loss: 5.470877208207783
[Epoch: 92/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 92, Mean Training Loss: 5.473157895238776
[Epoch: 93/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 93, Mean Training Loss: 5.478245609684994
[Epoch: 94/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 94, Mean Training Loss: 5.477543855968275
[Epoch: 95/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 95, Mean Training Loss: 5.477368430087441
[Epoch: 96/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 96, Mean Training Loss: 5.480526196329217
[Epoch: 97/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 97, Mean Training Loss: 5.477017590874119
[Epoch: 98/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 98, Mean Training Loss: 5.480000069266872
[Epoch: 99/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 99, Mean Training Loss: 5.4743858638562655
[Epoch: 100/100]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Epoch: 100, Mean Training Loss: 5.477719131268953


In [14]:
# plot the learning curve
import matplotlib
import matplotlib.pyplot as plt
ite = [e for e in range(epochs)]
plt.plot(train_losses, label = "Cosine Similarity Loss")
plt.legend(loc = "best")
title = "Learning Curve (# of training examples " + str(len(train_X)) + ")"
plt.title(title)
plt.ylabel('Loss')
plt.xlabel('Number of Iteration')
plt.savefig('loss.png')
