In [1]:
import csv
import math
import string
import itertools
from io import open
from conllu import parse_incr
import torch
import torch.nn as nn
import torch.nn.functional as F
from nltk.corpus import wordnet as wn

In [2]:
# tests for pytorch computation graph
from torch.nn import CosineEmbeddingLoss
l = CosineEmbeddingLoss()
import torch.optim as optim

sense_vec = torch.ones(3, 1)
sense_vec.requires_grad = True
def_vec = torch.randn(3, 1)
def_vec.requires_grad = True

matrix = torch.ones(3, 3)
matrix.requires_grad = True
vec = matrix[2, :].view(3, 1)
print(vec.grad_fn)
print(vec.requires_grad)
label = torch.randn(3, 1)

opt = optim.Adam([matrix])
test_vec = torch.ones(def_vec.size())

for i in range(1):
    
    opt.zero_grad()
    loss = l(vec, label, test_vec)
    # print(loss)
    loss.backward()
    print(matrix.grad)
    # print(def_vec.grad)
    
    opt.step()

print(matrix)

<ViewBackward object at 0x1096156d8>
True
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 2.9802e-08]])
tensor([[1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 0.9999]], requires_grad=True)


In [3]:
# parse the WSD dataset first

'''
Copyright@
White, A. S., D. Reisinger, K. Sakaguchi, T. Vieira, S. Zhang, R. Rudinger, K. Rawlins, & B. Van Durme. 2016. 
[Universal decompositional semantics on universal dependencies]
(http://aswhite.net/media/papers/white_universal_2016.pdf). 
To appear in *Proceedings of the Conference on Empirical Methods in Natural Language Processing 2016*.
'''

def parse_wsd_data():

    # parse the WSD dataset
    wsd_data = []

    # read in tsv by White et. al., 2016
    with open('data/wsd/wsd_eng_ud1.2_10262016.tsv', mode = 'r') as wsd_file:

        tsv_reader = csv.DictReader(wsd_file, delimiter = '\t')      

        # store the data: ordered dict row
        for row in tsv_reader:                                

            # each data vector
            wsd_data.append(row)

        # make sure all data are parsed
        print('Parsed {} word sense data from White et. al., 2016.'.format(len(wsd_data)))

    return wsd_data

In [4]:
# get the raw wsd data
wsd_data = parse_wsd_data()

Parsed 439312 word sense data from White et. al., 2016.


In [5]:
'''
return: 
all senses for each word 
all definitions for each word
all supersenses
from the EUD for train, test, and dev dataset
index provided by WSD dataset by White et. al.
'''
# get all the senses and definitions for each word from WSD dataset
# order of senses and definitions are in order
def get_all_senses_and_definitions(wsd_data):

    # all senses for each 
    all_senses = {}
    all_definitions = {}
    all_supersenses = {}
    
    # for test purpose: only load specific amount of data
    for i in range(len(wsd_data)):

        # get the original sentence from EUD
        sentence_id = wsd_data[i].get('Sentence.ID')
        
        # get the definitions for the target word from EUD
        definition = wsd_data[i].get('Sense.Definition').split(' ')
        
        # the index in EUD is 1-based!!!
        sentence_number = int(sentence_id.split(' ')[-1]) - 1
        word_index = int(wsd_data[i].get('Arg.Token')) - 1
        word_lemma = wsd_data[i].get('Arg.Lemma')
        word_sense = wsd_data[i].get('Synset')
        response = wsd_data[i].get('Sense.Response')
        
        # supersense-> (word_lemma, word_sense) dictionary
        super_s = wn.synset(word_sense).lexname().replace('.', '_')
        if all_supersenses.get(super_s, 'not_exist') != 'not_exist':
            all_supersenses[super_s].add((word_lemma, word_sense))
        else:
            all_supersenses[super_s] = {(word_lemma, word_sense)}

        # if the word already exits: add the new sense to the list
        # else: creata a new list for the word
        if all_senses.get(word_lemma, 'not_exist') != 'not_exist':
            if word_sense not in all_senses[word_lemma]:
                all_senses[word_lemma].append(word_sense)
        else:
            all_senses[word_lemma] = []
            all_senses[word_lemma].append(word_sense)            
            
        if all_definitions.get(word_lemma,'not_exist') != 'not_exist':
            if definition not in all_definitions[word_lemma]: 
                all_definitions[word_lemma].append(definition)
        else:
            all_definitions[word_lemma] = []
            all_definitions[word_lemma].append(definition)
        
    return all_senses, all_definitions, all_supersenses

In [6]:
# get all the senses and definitions
all_senses, all_definitions, all_supersenses = get_all_senses_and_definitions(wsd_data)


In [7]:
# test for the WordNet NLTK API
'''
The specific Synset method is lexname, e.g. wn.synsets('spring')[0].lexname(). 
That should make it really easy to get the suspersenses.
And if you have the synset name–e.g. 'spring.n.01'
you can access the supersense directly: wn.synset('spring.n.01').lexname().
Which returns 'noun.time'.
And wn.synset('spring.n.02').lexname() returns 'noun.artifact'

for idx, d in enumerate(all_definitions['spring']):
    print(d)
    print(wn.synset(all_senses['spring'][idx]).lexname())

for _ in wn.synsets('spring'):
    print(_.lexname())
'''

"\nThe specific Synset method is lexname, e.g. wn.synsets('spring')[0].lexname(). \nThat should make it really easy to get the suspersenses.\nAnd if you have the synset name–e.g. 'spring.n.01'\nyou can access the supersense directly: wn.synset('spring.n.01').lexname().\nWhich returns 'noun.time'.\nAnd wn.synset('spring.n.02').lexname() returns 'noun.artifact'\n\nfor idx, d in enumerate(all_definitions['spring']):\n    print(d)\n    print(wn.synset(all_senses['spring'][idx]).lexname())\n\nfor _ in wn.synsets('spring'):\n    print(_.lexname())\n"

In [8]:
# read the train, dev, test datasets from processed files
# check the 'data_loader.ipynb' for details
def read_file():
    
    train_X = []
    train_X_num = 0
    train_Y = []
    train_Y_num = 0
    test_X = []
    test_X_num = 0
    test_Y = []
    test_Y_num = 0
    dev_X = []
    dev_X_num = 0
    dev_Y = []
    dev_Y_num = 0
    
    train_word_idx = []
    test_word_idx = []
    dev_word_idx = []
    
    # read in csv
    with open('data/train_X.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file, delimiter = '\t')

        # store the data
        for row in csv_reader:

            train_X.append(row)
            train_X_num += 1

        # make sure all data are parsed
        print(f'Parsed {train_X_num} data points for train_X.')

    with open('data/train_Y.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            train_Y.append(row)
            train_Y_num += 1

        # make sure all data are parsed
        print(f'Parsed {train_Y_num} data points for train_Y.')
        
    with open('data/train_word_idx.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            train_word_idx = (row)

        # make sure all data are parsed
        print(f'Parsed {len(train_word_idx)} data points for train_word_idx.')

    with open('data/dev_X.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file, delimiter = '\t')

        # store the data
        for row in csv_reader:

            dev_X.append(row)
            dev_X_num += 1

        # make sure all data are parsed
        print(f'Parsed {dev_X_num} data points for dev_X.')

    with open('data/dev_Y.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            dev_Y.append(row)
            dev_Y_num += 1

        # make sure all data are parsed
        print(f'Parsed {dev_Y_num} data points for dev_Y.')
        
    with open('data/dev_word_idx.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            dev_word_idx = (row)

        # make sure all data are parsed
        print(f'Parsed {len(dev_word_idx)} data points for dev_word_idx.')
        
    with open('data/test_X.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file, delimiter = '\t')

        # store the data
        for row in csv_reader:

            test_X.append(row)
            test_X_num += 1

        # make sure all data are parsed
        print(f'Parsed {test_X_num} data points for test_X.')

    with open('data/test_Y.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            test_Y.append(row)
            test_Y_num += 1

        # make sure all data are parsed
        print(f'Parsed {test_Y_num} data points for test_Y.')
        
    with open('data/test_word_idx.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            test_word_idx = (row)

        # make sure all data are parsed
        print(f'Parsed {len(test_word_idx)} data points for test_word_idx.')    
        
    return train_X, train_Y, test_X, test_Y, dev_X, dev_Y, train_word_idx, test_word_idx, dev_word_idx

In [9]:
train_X, train_Y, test_X, test_Y, dev_X, dev_Y, train_word_idx, test_word_idx, dev_word_idx = read_file()

# test on one word
word_choice = 'level'

new_train_X = []
new_train_Y = []
new_train_idx = []
distri_train = [0 for _ in range(len(all_senses[word_choice]))]
for index, sen in enumerate(train_X):
    
    if sen[train_word_idx[index]] == word_choice:
        new_train_idx.append(train_word_idx[index])
        new_train_X.append(sen)
        new_train_Y.append(train_Y[index])
        for i, response in enumerate(train_Y[index]):
            if response:
                distri_train[i] += 1
print('distri of train: {}'.format(distri_train))
        
new_test_X = []
new_test_Y = []
new_test_idx = []
distri_test = [0 for _ in range(len(all_senses[word_choice]))]
for index, sen in enumerate(test_X):
    
    if sen[test_word_idx[index]] == word_choice:
        new_test_idx.append(test_word_idx[index])
        new_test_X.append(sen)
        new_test_Y.append(test_Y[index])
        for i, response in enumerate(test_Y[index]):
            if response:
                distri_test[i] += 1
print('distri of test: {}'.format(distri_test))

new_dev_X = []
new_dev_Y = []
new_dev_idx = []
distri_dev = [0 for _ in range(len(all_senses[word_choice]))]
for index, sen in enumerate(dev_X):
        
    if sen[dev_word_idx[index]] == word_choice:
        new_dev_idx.append(dev_word_idx[index])
        new_dev_X.append(sen)
        new_dev_Y.append(dev_Y[index])
        for i, response in enumerate(dev_Y[index]):
            if response:
                distri_dev[i] += 1
print('distri of dev: {}'.format(distri_dev))

target_senses = all_senses[word_choice]
new_all_senses = {word_choice : target_senses}
# print(new_all_senses)
target_def = all_definitions[word_choice]
new_all_def = {word_choice : target_def}

new_all_supersenses = {}
for supersense in all_supersenses.keys():
    for tuples in all_supersenses[supersense]:
        
        if tuples[0] == word_choice:
            if new_all_supersenses.get(supersense, 'e') != 'e':
                new_all_supersenses[supersense].add((word_choice, tuples[1]))
            else:
                new_all_supersenses[supersense] = {(word_choice, tuples[1])}


Parsed 67398 data points for train_X.
Parsed 67398 data points for train_Y.
Parsed 67398 data points for train_word_idx.
Parsed 7333 data points for dev_X.
Parsed 7333 data points for dev_Y.
Parsed 7333 data points for dev_word_idx.
Parsed 7123 data points for test_X.
Parsed 7123 data points for test_Y.
Parsed 7123 data points for test_word_idx.
distri of train: [35, 43, 27, 11, 6, 3, 7, 5]
distri of test: [5, 5, 4, 1, 1, 1, 2, 1]
distri of dev: [4, 3, 0, 2, 0, 0, 0, 0]


In [10]:
from model import *
from trainer import *

from allennlp.commands.elmo import ElmoEmbedder
elmo = ElmoEmbedder()


Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.
Device: cpu


In [11]:
# trainer
epochs = 50
trainer = Trainer(epochs = epochs, elmo_class = elmo, all_senses = new_all_senses, all_supersenses = new_all_supersenses)

In [12]:
# get the results
# train_losses, dev_losses, dev_rs = trainer.train(train_X, train_Y, train_word_idx, dev_X, dev_Y, dev_word_idx, development = False)
# small test
train_losses, dev_losses, dev_rs = trainer.train(new_train_X, new_train_Y, new_train_idx, new_dev_X, new_dev_Y, new_dev_idx)



#############   Model Parameters   ##############
layers.word_sense.0.weight torch.Size([512, 512])
layers.word_sense.0.bias torch.Size([512])
layers.word_sense.2.weight torch.Size([300, 512])
layers.word_sense.2.bias torch.Size([300])
dimension_reduction_MLP.weight torch.Size([256, 3072])
dimension_reduction_MLP.bias torch.Size([256])
wsd_lstm.weight_ih_l0 torch.Size([1024, 256])
wsd_lstm.weight_hh_l0 torch.Size([1024, 256])
wsd_lstm.bias_ih_l0 torch.Size([1024])
wsd_lstm.bias_hh_l0 torch.Size([1024])
wsd_lstm.weight_ih_l0_reverse torch.Size([1024, 256])
wsd_lstm.weight_hh_l0_reverse torch.Size([1024, 256])
wsd_lstm.bias_ih_l0_reverse torch.Size([1024])
wsd_lstm.bias_hh_l0_reverse torch.Size([1024])
wsd_lstm.weight_ih_l1 torch.Size([1024, 512])
wsd_lstm.weight_hh_l1 torch.Size([1024, 256])
wsd_lstm.bias_ih_l1 torch.Size([1024])
wsd_lstm.bias_hh_l1 torch.Size([1024])
wsd_lstm.weight_ih_l1_reverse torch.Size([1024, 512])
wsd_lstm.weight_hh_l1_reverse torch.Size([1024, 256])
wsd_lstm.bia

HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 1, Mean Training Loss: 9.340226578986508
Epoch: 1, Mean Dev Loss: 9.08190359388079
[Epoch: 2/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 2, Mean Training Loss: 9.299961024317248
Epoch: 2, Mean Dev Loss: 9.06714221409389
[Epoch: 3/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 3, Mean Training Loss: 9.29961591479422
Epoch: 3, Mean Dev Loss: 9.03428500039237
[Epoch: 4/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 4, Mean Training Loss: 9.276551158948877
Epoch: 4, Mean Dev Loss: 8.997142791748047
[Epoch: 5/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 5, Mean Training Loss: 9.26019111720995
Epoch: 5, Mean Dev Loss: 8.961427552359444
[Epoch: 6/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 6, Mean Training Loss: 9.240497583630441
Epoch: 6, Mean Dev Loss: 8.902856418064662
[Epoch: 7/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 7, Mean Training Loss: 9.233983938721405
Epoch: 7, Mean Dev Loss: 8.845237186976842
[Epoch: 8/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 8, Mean Training Loss: 9.217202219469794
Epoch: 8, Mean Dev Loss: 8.873808452061244
[Epoch: 9/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 9, Mean Training Loss: 9.204098970040508
Epoch: 9, Mean Dev Loss: 8.818570681980678
[Epoch: 10/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 10, Mean Training Loss: 9.169807784858792
Epoch: 10, Mean Dev Loss: 8.789998871939522
[Epoch: 11/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 11, Mean Training Loss: 9.166167873075638
Epoch: 11, Mean Dev Loss: 8.766666412353516
[Epoch: 12/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 12, Mean Training Loss: 9.15402227708663
Epoch: 12, Mean Dev Loss: 8.79857063293457
[Epoch: 13/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 13, Mean Training Loss: 9.185478248815427
Epoch: 13, Mean Dev Loss: 8.829522814069476
[Epoch: 14/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 14, Mean Training Loss: 9.202604776141287
Epoch: 14, Mean Dev Loss: 8.866665158952985
[Epoch: 15/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 15, Mean Training Loss: 9.247470543302338
Epoch: 15, Mean Dev Loss: 8.935236930847168
[Epoch: 16/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 16, Mean Training Loss: 9.271608517087738
Epoch: 16, Mean Dev Loss: 8.932856151035853
[Epoch: 17/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 17, Mean Training Loss: 9.314903555245236
Epoch: 17, Mean Dev Loss: 9.011904444013323
[Epoch: 18/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 18, Mean Training Loss: 9.333102752422464
Epoch: 18, Mean Dev Loss: 9.028570856366839
[Epoch: 19/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 19, Mean Training Loss: 9.342451588860873
Epoch: 19, Mean Dev Loss: 9.00666550227574
[Epoch: 20/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 20, Mean Training Loss: 9.342375020871216
Epoch: 20, Mean Dev Loss: 9.025713239397321
[Epoch: 21/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 21, Mean Training Loss: 9.348160140815823
Epoch: 21, Mean Dev Loss: 9.004285539899554
[Epoch: 22/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 22, Mean Training Loss: 9.370152736532278
Epoch: 22, Mean Dev Loss: 9.010952404567174
[Epoch: 23/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 23, Mean Training Loss: 9.352413374802161
Epoch: 23, Mean Dev Loss: 8.979046957833427
[Epoch: 24/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 24, Mean Training Loss: 9.339271534448383
Epoch: 24, Mean Dev Loss: 8.956665992736816
[Epoch: 25/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 25, Mean Training Loss: 9.308351615379596
Epoch: 25, Mean Dev Loss: 8.956665992736816
[Epoch: 26/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 26, Mean Training Loss: 9.290880126514654
Epoch: 26, Mean Dev Loss: 8.964760644095284
[Epoch: 27/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 27, Mean Training Loss: 9.29333280146807
Epoch: 27, Mean Dev Loss: 8.95428534916469
[Epoch: 28/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 28, Mean Training Loss: 9.296781211063779
Epoch: 28, Mean Dev Loss: 8.95428534916469
[Epoch: 29/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 29, Mean Training Loss: 9.296857691359246
Epoch: 29, Mean Dev Loss: 8.955713953290667
[Epoch: 30/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 30, Mean Training Loss: 9.267164153614264
Epoch: 30, Mean Dev Loss: 8.89618968963623
[Epoch: 31/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 31, Mean Training Loss: 9.271838242980255
Epoch: 31, Mean Dev Loss: 8.907618113926478
[Epoch: 32/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 32, Mean Training Loss: 9.276972430875931
Epoch: 32, Mean Dev Loss: 8.910476003374372
[Epoch: 33/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 33, Mean Training Loss: 9.254367674904309
Epoch: 33, Mean Dev Loss: 8.898571286882673
[Epoch: 34/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 34, Mean Training Loss: 9.263907690157836
Epoch: 34, Mean Dev Loss: 8.92095238821847
[Epoch: 35/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 35, Mean Training Loss: 9.270688879078833
Epoch: 35, Mean Dev Loss: 8.933332306998116
[Epoch: 36/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 36, Mean Training Loss: 9.281800171424603
Epoch: 36, Mean Dev Loss: 8.972855704171318
[Epoch: 37/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 37, Mean Training Loss: 9.293486003218026
Epoch: 37, Mean Dev Loss: 8.9880952835083
[Epoch: 38/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 38, Mean Training Loss: 9.312068347273202
Epoch: 38, Mean Dev Loss: 9.00237968989781
[Epoch: 39/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 39, Mean Training Loss: 9.31678067130604
Epoch: 39, Mean Dev Loss: 9.00237968989781
[Epoch: 40/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 40, Mean Training Loss: 9.314787996226343
Epoch: 40, Mean Dev Loss: 9.001903261457171
[Epoch: 41/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 41, Mean Training Loss: 9.323754376378552
Epoch: 41, Mean Dev Loss: 9.032380921500069
[Epoch: 42/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 42, Mean Training Loss: 9.308390485829321
Epoch: 42, Mean Dev Loss: 8.999523026602608
[Epoch: 43/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 43, Mean Training Loss: 9.303830864785732
Epoch: 43, Mean Dev Loss: 8.999046870640345
[Epoch: 44/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 44, Mean Training Loss: 9.297815498264356
Epoch: 44, Mean Dev Loss: 8.993332999093193
[Epoch: 45/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 45, Mean Training Loss: 9.292605049308689
Epoch: 45, Mean Dev Loss: 8.992857115609306
[Epoch: 46/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 46, Mean Training Loss: 9.284904326515637
Epoch: 46, Mean Dev Loss: 8.982857295445033
[Epoch: 47/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 47, Mean Training Loss: 9.293026233541555
Epoch: 47, Mean Dev Loss: 8.981902939932686
[Epoch: 48/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 48, Mean Training Loss: 9.284137396976865
Epoch: 48, Mean Dev Loss: 8.994285583496094
[Epoch: 49/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 49, Mean Training Loss: 9.2847121118129
Epoch: 49, Mean Dev Loss: 9.010476112365723
[Epoch: 50/50]


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Epoch: 50, Mean Training Loss: 9.297356265714798
Epoch: 50, Mean Dev Loss: 9.019047055925641


In [13]:
# plot the learning curve
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc

with open('train_loss.tsv', mode = 'w') as loss_file:
        
    csv_writer = csv.writer(loss_file)
    csv_writer.writerow(train_losses)

    
with open('dev_loss.tsv', mode = 'w') as loss_file:
        
    csv_writer = csv.writer(loss_file)
    csv_writer.writerow(dev_losses)

In [14]:
plt.figure(1)
rc('text', usetex = True)
rc('font', family='serif')
plt.grid(True, ls = '-.',alpha = 0.4)
plt.plot(train_losses, ms = 4, marker = 's', label = "Train Loss")
plt.legend(loc = "best")
title = "Cosine Similarity Loss (number of examples: " + str(len(new_train_X)) + ")"
plt.title(title)
plt.ylabel('Loss')
plt.xlabel('Number of Iteration')
plt.tight_layout()
plt.savefig('train_loss.png')

In [15]:
plt.figure(2)
rc('text', usetex = True)
rc('font', family='serif')
plt.grid(True, ls = '-.',alpha = 0.4)
plt.plot(dev_losses, ms = 4, marker = 'o', label = "Dev Loss")
plt.legend(loc = "best")
title = "Cosine Similarity Loss (number of examples: " + str(len(new_dev_X)) + ")"
plt.title(title)
plt.ylabel('Loss')
plt.xlabel('Number of Iteration')
plt.tight_layout()
plt.savefig('dev_loss.png')

In [16]:
# test the model
# modify the test set smaller for now
cos = nn.CosineSimilarity(dim = 0, eps = 1e-6)
correct_count = 0
        
# overall accuracy
for test_idx, test_sen in enumerate(new_test_X):
    
    test_lemma = test_sen[new_test_idx[test_idx]]
    test_emb = trainer._model.forward(test_sen, new_test_idx[test_idx])
    all_similarity = []
    
    for k, sense in enumerate(new_all_senses[test_lemma]):
        definition_vec = trainer._model.definition_embeddings[test_lemma][:, k].view(trainer._model.output_size, -1)
        cos_sim = cos(test_emb, definition_vec)
        all_similarity.append(cos_sim)
    
    test_result = all_similarity.index(max(all_similarity))
    if new_test_Y[test_idx][test_result] == 1:
        correct_count += 1

print('test size: {}'.format(len(new_test_X)))
print('overall accuracy: {}'.format(correct_count / len(new_test_X)))


test size: 9
overall accuracy: 0.5555555555555556
