In [1]:
import csv
import math
import string
import itertools
from io import open
from conllu import parse_incr
import torch
import torch.nn as nn
import torch.nn.functional as F
from nltk.corpus import wordnet as wn

In [2]:
# tests for pytorch computation graph
from torch.nn import CosineEmbeddingLoss
l = CosineEmbeddingLoss()
import torch.optim as optim

sense_vec = torch.ones(3, 1)
sense_vec.requires_grad = True
def_vec = torch.randn(3, 1)
def_vec.requires_grad = True

matrix = torch.ones(3, 3)
matrix.requires_grad = True
vec = matrix[2, :].view(3, 1)
print(vec.grad_fn)
print(vec.requires_grad)
label = torch.randn(3, 1)

opt = optim.Adam([matrix])
test_vec = torch.ones(def_vec.size())

for i in range(1):
    
    opt.zero_grad()
    loss = l(vec, label, test_vec)
    # print(loss)
    loss.backward()
    print(matrix.grad)
    # print(def_vec.grad)
    
    opt.step()

print(matrix)

<ViewBackward object at 0x120fea0b8>
True
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 2.9802e-08, 0.0000e+00]])
tensor([[1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000],
        [1.0000, 0.9999, 1.0000]], requires_grad=True)


In [3]:
# parse the WSD dataset first

'''
Copyright@
White, A. S., D. Reisinger, K. Sakaguchi, T. Vieira, S. Zhang, R. Rudinger, K. Rawlins, & B. Van Durme. 2016. 
[Universal decompositional semantics on universal dependencies]
(http://aswhite.net/media/papers/white_universal_2016.pdf). 
To appear in *Proceedings of the Conference on Empirical Methods in Natural Language Processing 2016*.
'''

def parse_wsd_data():

    # parse the WSD dataset
    wsd_data = []

    # read in tsv by White et. al., 2016
    with open('data/wsd/wsd_eng_ud1.2_10262016.tsv', mode = 'r') as wsd_file:

        tsv_reader = csv.DictReader(wsd_file, delimiter = '\t')      

        # store the data: ordered dict row
        for row in tsv_reader:                                

            # each data vector
            wsd_data.append(row)

        # make sure all data are parsed
        print('Parsed {} word sense data from White et. al., 2016.'.format(len(wsd_data)))

    return wsd_data

In [4]:
# get the raw wsd data
wsd_data = parse_wsd_data()

Parsed 439312 word sense data from White et. al., 2016.


In [5]:
'''
return: 
all senses for each word 
all definitions for each word
all supersenses
from the EUD for train, test, and dev dataset
index provided by WSD dataset by White et. al.
'''
# get all the senses and definitions for each word from WSD dataset
# order of senses and definitions are in order
def get_all_senses_and_definitions(wsd_data):

    # all senses for each 
    all_senses = {}
    all_definitions = {}
    all_supersenses = {}
    
    # for test purpose: only load specific amount of data
    for i in range(100):

        # get the original sentence from EUD
        sentence_id = wsd_data[i].get('Sentence.ID')
        
        # get the definitions for the target word from EUD
        definition = wsd_data[i].get('Sense.Definition').split(' ')
        
        # the index in EUD is 1-based!!!
        sentence_number = int(sentence_id.split(' ')[-1]) - 1
        word_index = int(wsd_data[i].get('Arg.Token')) - 1
        word_lemma = wsd_data[i].get('Arg.Lemma')
        word_sense = wsd_data[i].get('Synset')
        response = wsd_data[i].get('Sense.Response')
        
        # supersense-> (word_lemma, word_sense) dictionary
        super_s = wn.synset(word_sense).lexname().replace('.', '_')
        if all_supersenses.get(super_s, 'not_exist') != 'not_exist':
            all_supersenses[super_s].add((word_lemma, word_sense))
        else:
            all_supersenses[super_s] = {(word_lemma, word_sense)}

        # if the word already exits: add the new sense to the list
        # else: creata a new list for the word
        if all_senses.get(word_lemma, 'not_exist') != 'not_exist':
            if word_sense not in all_senses[word_lemma]:
                all_senses[word_lemma].append(word_sense)
        else:
            all_senses[word_lemma] = []
            all_senses[word_lemma].append(word_sense)            
            
        if all_definitions.get(word_lemma,'not_exist') != 'not_exist':
            if definition not in all_definitions[word_lemma]: 
                all_definitions[word_lemma].append(definition)
        else:
            all_definitions[word_lemma] = []
            all_definitions[word_lemma].append(definition)
        
    return all_senses, all_definitions, all_supersenses

In [6]:
# get all the senses and definitions
all_senses, all_definitions, all_supersenses = get_all_senses_and_definitions(wsd_data)

for key in all_supersenses.keys():
    print('{} : {}'.format(key, all_supersenses[key]))

noun_time : {('rate', 'rate.n.01'), ('spring', 'spring.n.01')}
noun_artifact : {('house', 'house.n.12'), ('level', 'level.n.05'), ('level', 'floor.n.02'), ('puppet', 'puppet.n.01'), ('house', 'house.n.01'), ('house', 'theater.n.01'), ('spring', 'spring.n.02'), ('level', 'horizontal_surface.n.01'), ('puppet', 'puppet.n.03')}
noun_object : {('spring', 'spring.n.03')}
noun_location : {('place', 'place.n.04'), ('place', 'place.n.02'), ('place', 'position.n.01'), ('place', 'seat.n.01'), ('place', 'plaza.n.01'), ('spring', 'spring.n.04'), ('house', 'sign_of_the_zodiac.n.01'), ('place', 'topographic_point.n.01'), ('place', 'home.n.01')}
noun_attribute : {('advantage', 'advantage.n.01'), ('rate', 'pace.n.03'), ('advantage', 'advantage.n.03'), ('spring', 'give.n.01'), ('level', 'degree.n.01'), ('ambition', 'ambition.n.02'), ('level', 'level.n.04')}
noun_act : {('overthrow', 'overthrow.n.01'), ('spring', 'leap.n.01'), ('place', 'stead.n.01'), ('overthrow', 'upset.n.02'), ('place', 'position.n.06

In [7]:
# test for the WordNet NLTK API
'''
The specific Synset method is lexname, e.g. wn.synsets('spring')[0].lexname(). 
That should make it really easy to get the suspersenses.
And if you have the synset name–e.g. 'spring.n.01'
you can access the supersense directly: wn.synset('spring.n.01').lexname().
Which returns 'noun.time'.
And wn.synset('spring.n.02').lexname() returns 'noun.artifact'

for idx, d in enumerate(all_definitions['spring']):
    print(d)
    print(wn.synset(all_senses['spring'][idx]).lexname())

for _ in wn.synsets('spring'):
    print(_.lexname())
'''

"\nThe specific Synset method is lexname, e.g. wn.synsets('spring')[0].lexname(). \nThat should make it really easy to get the suspersenses.\nAnd if you have the synset name–e.g. 'spring.n.01'\nyou can access the supersense directly: wn.synset('spring.n.01').lexname().\nWhich returns 'noun.time'.\nAnd wn.synset('spring.n.02').lexname() returns 'noun.artifact'\n\nfor idx, d in enumerate(all_definitions['spring']):\n    print(d)\n    print(wn.synset(all_senses['spring'][idx]).lexname())\n\nfor _ in wn.synsets('spring'):\n    print(_.lexname())\n"

In [8]:
# read the train, dev, test datasets from processed files
# check the 'data_loader.ipynb' for details
def read_file():
    
    train_X = []
    train_X_num = 0
    train_Y = []
    train_Y_num = 0
    test_X = []
    test_X_num = 0
    test_Y = []
    test_Y_num = 0
    dev_X = []
    dev_X_num = 0
    dev_Y = []
    dev_Y_num = 0
    
    train_word_idx = []
    test_word_idx = []
    dev_word_idx = []
    
    # read in csv
    with open('data/train_X.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file, delimiter = '\t')

        # store the data
        for row in csv_reader:

            train_X.append(row)
            train_X_num += 1

        # make sure all data are parsed
        print(f'Parsed {train_X_num} data points for train_X.')

    with open('data/train_Y.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            train_Y.append(row)
            train_Y_num += 1

        # make sure all data are parsed
        print(f'Parsed {train_Y_num} data points for train_Y.')
        
    with open('data/train_word_idx.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            train_word_idx = (row)

        # make sure all data are parsed
        print(f'Parsed {len(train_word_idx)} data points for train_word_idx.')

    with open('data/dev_X.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file, delimiter = '\t')

        # store the data
        for row in csv_reader:

            dev_X.append(row)
            dev_X_num += 1

        # make sure all data are parsed
        print(f'Parsed {dev_X_num} data points for dev_X.')

    with open('data/dev_Y.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            dev_Y.append(row)
            dev_Y_num += 1

        # make sure all data are parsed
        print(f'Parsed {dev_Y_num} data points for dev_Y.')
        
    with open('data/dev_word_idx.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            dev_word_idx = (row)

        # make sure all data are parsed
        print(f'Parsed {len(dev_word_idx)} data points for dev_word_idx.')
        
    with open('data/test_X.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file, delimiter = '\t')

        # store the data
        for row in csv_reader:

            test_X.append(row)
            test_X_num += 1

        # make sure all data are parsed
        print(f'Parsed {test_X_num} data points for test_X.')

    with open('data/test_Y.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            test_Y.append(row)
            test_Y_num += 1

        # make sure all data are parsed
        print(f'Parsed {test_Y_num} data points for test_Y.')
        
    with open('data/test_word_idx.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            test_word_idx = (row)

        # make sure all data are parsed
        print(f'Parsed {len(test_word_idx)} data points for test_word_idx.')    
        
    return train_X, train_Y, test_X, test_Y, dev_X, dev_Y, train_word_idx, test_word_idx, dev_word_idx

In [9]:
train_X, train_Y, test_X, test_Y, dev_X, dev_Y, train_word_idx, test_word_idx, dev_word_idx = read_file()

# cut small sets for for my laptop
train_X = train_X[:19]
train_Y = train_Y[:19]
train_word_idx = train_word_idx[:19]

# modify the dev set smaller for now
new_dev_X = []
new_dev_y = []
new_dev_idx = []

for index, sen in enumerate(dev_X):
        
    if sen[dev_word_idx[index]] in all_senses.keys():
        new_dev_idx.append(dev_word_idx[index])
        new_dev_X.append(sen)
        new_dev_y.append(dev_Y[index])
        
new_dev_X = new_dev_X[:3]
new_dev_y = new_dev_y[:3]
new_dev_idx = new_dev_idx[:3]


Parsed 67398 data points for train_X.
Parsed 67398 data points for train_Y.
Parsed 67398 data points for train_word_idx.
Parsed 7333 data points for dev_X.
Parsed 7333 data points for dev_Y.
Parsed 7333 data points for dev_word_idx.
Parsed 7123 data points for test_X.
Parsed 7123 data points for test_Y.
Parsed 7123 data points for test_word_idx.


In [10]:
from model import *
from trainer import *

from allennlp.commands.elmo import ElmoEmbedder
elmo = ElmoEmbedder()

'''
from torchviz import make_dot

model = Model(all_senses = all_senses, elmo_class = elmo)
sense_embedding = model.forward(train_X[0], train_word_idx[0])
'''

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.
Device: cpu


'\nfrom torchviz import make_dot\n\nmodel = Model(all_senses = all_senses, elmo_class = elmo)\nsense_embedding = model.forward(train_X[0], train_word_idx[0])\n'

In [11]:
# trainer
epochs = 50
trainer = Trainer(epochs = epochs, elmo_class = elmo, all_senses = all_senses, all_supersenses = all_supersenses)

In [12]:
# get the results
# train_losses, dev_losses, dev_rs = trainer.train(train_X, train_Y, train_word_idx, dev_X, dev_Y, dev_word_idx, development = False)
# small test
train_losses, dev_losses, dev_rs = trainer.train(train_X, train_Y, train_word_idx, new_dev_X, new_dev_y, new_dev_idx, development = False)



#############   Model Parameters   ##############
layers.word_sense.0.weight torch.Size([512, 512])
layers.word_sense.0.bias torch.Size([512])
layers.word_sense.2.weight torch.Size([300, 512])
layers.word_sense.2.bias torch.Size([300])
dimension_reduction_MLP.weight torch.Size([256, 3072])
dimension_reduction_MLP.bias torch.Size([256])
wsd_lstm.weight_ih_l0 torch.Size([1024, 256])
wsd_lstm.weight_hh_l0 torch.Size([1024, 256])
wsd_lstm.bias_ih_l0 torch.Size([1024])
wsd_lstm.bias_hh_l0 torch.Size([1024])
wsd_lstm.weight_ih_l0_reverse torch.Size([1024, 256])
wsd_lstm.weight_hh_l0_reverse torch.Size([1024, 256])
wsd_lstm.bias_ih_l0_reverse torch.Size([1024])
wsd_lstm.bias_hh_l0_reverse torch.Size([1024])
wsd_lstm.weight_ih_l1 torch.Size([1024, 512])
wsd_lstm.weight_hh_l1 torch.Size([1024, 256])
wsd_lstm.bias_ih_l1 torch.Size([1024])
wsd_lstm.bias_hh_l1 torch.Size([1024])
wsd_lstm.weight_ih_l1_reverse torch.Size([1024, 512])
wsd_lstm.weight_hh_l1_reverse torch.Size([1024, 256])
wsd_lstm.bia

HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 1, Mean Training Loss: 5.69368428932993
Epoch: 1, Mean Dev Loss: 12.573333263397217
[Epoch: 2/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 2, Mean Training Loss: 5.635060046848498
Epoch: 2, Mean Dev Loss: 12.546665986378988
[Epoch: 3/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 3, Mean Training Loss: 5.616140415793971
Epoch: 3, Mean Dev Loss: 12.483330567677816
[Epoch: 4/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 4, Mean Training Loss: 5.607368193174663
Epoch: 4, Mean Dev Loss: 12.441109895706177
[Epoch: 5/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 5, Mean Training Loss: 5.612280669965242
Epoch: 5, Mean Dev Loss: 12.438888947168985
[Epoch: 6/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 6, Mean Training Loss: 5.615789400903802
Epoch: 6, Mean Dev Loss: 12.438888947168985
[Epoch: 7/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 7, Mean Training Loss: 5.605438533582185
Epoch: 7, Mean Dev Loss: 12.48222271601359
[Epoch: 8/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 8, Mean Training Loss: 5.599298050529079
Epoch: 8, Mean Dev Loss: 12.484444936116537
[Epoch: 9/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 9, Mean Training Loss: 5.609122903723466
Epoch: 9, Mean Dev Loss: 12.437777042388916
[Epoch: 10/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 10, Mean Training Loss: 5.607193030809102
Epoch: 10, Mean Dev Loss: 12.41333293914795
[Epoch: 11/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 11, Mean Training Loss: 5.609122577466462
Epoch: 11, Mean Dev Loss: 12.3811088403066
[Epoch: 12/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 12, Mean Training Loss: 5.612806696640818
Epoch: 12, Mean Dev Loss: 12.38110907872518
[Epoch: 13/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 13, Mean Training Loss: 5.6152628973910685
Epoch: 13, Mean Dev Loss: 12.344444354375204
[Epoch: 14/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 14, Mean Training Loss: 5.60964896804408
Epoch: 14, Mean Dev Loss: 12.377775112787882
[Epoch: 15/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 15, Mean Training Loss: 5.612105344471178
Epoch: 15, Mean Dev Loss: 12.357778787612915
[Epoch: 16/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 16, Mean Training Loss: 5.6159648016879435
Epoch: 16, Mean Dev Loss: 12.344444354375204
[Epoch: 17/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 17, Mean Training Loss: 5.601929739901894
Epoch: 17, Mean Dev Loss: 12.342222134272257
[Epoch: 18/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 18, Mean Training Loss: 5.603333272432026
Epoch: 18, Mean Dev Loss: 12.35555593172709
[Epoch: 19/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 19, Mean Training Loss: 5.600350768942582
Epoch: 19, Mean Dev Loss: 12.337778647740683
[Epoch: 20/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 20, Mean Training Loss: 5.595789507815712
Epoch: 20, Mean Dev Loss: 12.346666892369589
[Epoch: 21/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 21, Mean Training Loss: 5.587719340073435
Epoch: 21, Mean Dev Loss: 12.344443718592325
[Epoch: 22/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 22, Mean Training Loss: 5.58596500597502
Epoch: 22, Mean Dev Loss: 12.344443718592325
[Epoch: 23/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 23, Mean Training Loss: 5.584736899325722
Epoch: 23, Mean Dev Loss: 12.345554828643799
[Epoch: 24/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 24, Mean Training Loss: 5.583157940914757
Epoch: 24, Mean Dev Loss: 12.345554828643799
[Epoch: 25/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 25, Mean Training Loss: 5.583508842869809
Epoch: 25, Mean Dev Loss: 12.354444662729898
[Epoch: 26/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 26, Mean Training Loss: 5.580526063316746
Epoch: 26, Mean Dev Loss: 12.357778151830038
[Epoch: 27/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 27, Mean Training Loss: 5.575789250825581
Epoch: 27, Mean Dev Loss: 12.337777217229208
[Epoch: 28/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 28, Mean Training Loss: 5.577543597472341
Epoch: 28, Mean Dev Loss: 12.337777217229208
[Epoch: 29/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 29, Mean Training Loss: 5.575438574740761
Epoch: 29, Mean Dev Loss: 12.339999119440714
[Epoch: 30/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 30, Mean Training Loss: 5.570877225775468
Epoch: 30, Mean Dev Loss: 12.386665105819702
[Epoch: 31/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 31, Mean Training Loss: 5.569824645393773
Epoch: 31, Mean Dev Loss: 12.386665105819702
[Epoch: 32/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 32, Mean Training Loss: 5.567368281515021
Epoch: 32, Mean Dev Loss: 12.37333369255066
[Epoch: 33/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 33, Mean Training Loss: 5.564561229003103
Epoch: 33, Mean Dev Loss: 12.37000036239624
[Epoch: 34/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 34, Mean Training Loss: 5.56157858748185
Epoch: 34, Mean Dev Loss: 12.374444802602133
[Epoch: 35/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 35, Mean Training Loss: 5.568069997586702
Epoch: 35, Mean Dev Loss: 12.352221330006918
[Epoch: 36/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 36, Mean Training Loss: 5.567192717602379
Epoch: 36, Mean Dev Loss: 12.301111459732056
[Epoch: 37/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 37, Mean Training Loss: 5.567894596802561
Epoch: 37, Mean Dev Loss: 12.287776390711466
[Epoch: 38/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 38, Mean Training Loss: 5.567719083083303
Epoch: 38, Mean Dev Loss: 12.276667515436808
[Epoch: 39/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 39, Mean Training Loss: 5.565789228991458
Epoch: 39, Mean Dev Loss: 12.258888483047485
[Epoch: 40/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 40, Mean Training Loss: 5.562631431378816
Epoch: 40, Mean Dev Loss: 12.258888483047485
[Epoch: 41/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 41, Mean Training Loss: 5.5717541293094035
Epoch: 41, Mean Dev Loss: 12.257776578267416
[Epoch: 42/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 42, Mean Training Loss: 5.572455995961239
Epoch: 42, Mean Dev Loss: 12.265555302302042
[Epoch: 43/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 43, Mean Training Loss: 5.5710526265596085
Epoch: 43, Mean Dev Loss: 12.263332287470499
[Epoch: 44/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 44, Mean Training Loss: 5.57561373083215
Epoch: 44, Mean Dev Loss: 12.243332624435425
[Epoch: 45/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 45, Mean Training Loss: 5.562280529423764
Epoch: 45, Mean Dev Loss: 12.2433340549469
[Epoch: 46/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 46, Mean Training Loss: 5.56175442745811
Epoch: 46, Mean Dev Loss: 12.252220233281454
[Epoch: 47/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 47, Mean Training Loss: 5.554034885607268
Epoch: 47, Mean Dev Loss: 12.252220233281454
[Epoch: 48/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 48, Mean Training Loss: 5.547719227640252
Epoch: 48, Mean Dev Loss: 12.236666599909464
[Epoch: 49/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 49, Mean Training Loss: 5.549473737415514
Epoch: 49, Mean Dev Loss: 12.238887707392374
[Epoch: 50/50]


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


Epoch: 50, Mean Training Loss: 5.546315531981619
Epoch: 50, Mean Dev Loss: 12.216665665308634


In [13]:
# plot the learning curve
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc

with open('train_loss.tsv', mode = 'w') as loss_file:
        
    csv_writer = csv.writer(loss_file)
    csv_writer.writerow(train_losses)

    
with open('dev_loss.tsv', mode = 'w') as loss_file:
        
    csv_writer = csv.writer(loss_file)
    csv_writer.writerow(dev_losses)

In [14]:
plt.figure(1)
rc('text', usetex = True)
rc('font', family='serif')
plt.grid(True, ls = '-.',alpha = 0.4)
plt.plot(train_losses, ms = 4, marker = 's', label = "Train Loss")
plt.legend(loc = "best")
title = "Cosine Similarity Loss (number of examples: " + str(len(train_X)) + ")"
plt.title(title)
plt.ylabel('Loss')
plt.xlabel('Number of Iteration')
plt.tight_layout()
plt.savefig('train_loss.png')

In [15]:
plt.figure(2)
rc('text', usetex = True)
rc('font', family='serif')
plt.grid(True, ls = '-.',alpha = 0.4)
plt.plot(dev_losses, ms = 4, marker = 'o', label = "Dev Loss")
plt.legend(loc = "best")
title = "Cosine Similarity Loss (number of examples: " + str(len(train_X)) + ")"
plt.title(title)
plt.ylabel('Loss')
plt.xlabel('Number of Iteration')
plt.tight_layout()
plt.savefig('dev_loss.png')

In [48]:
# test the model
# modify the test set smaller for now
import numpy as np

new_test_X = []
new_test_y = []
new_test_idx = []
target_words = set()

for index, sen in enumerate(test_X):
    
    # only test known words
    if sen[test_word_idx[index]] in all_senses.keys():
        
        new_test_idx.append(test_word_idx[index])
        new_test_X.append(sen)
        new_test_y.append(test_Y[index])
        target_words.add(sen[test_word_idx[index]])

# new_test_idx = new_test_idx[120:130]
# new_test_X = new_test_X[120:130]
# new_test_y = new_test_y[120:130]

cos = nn.CosineSimilarity(dim = 0, eps = 1e-6)
correct_count = 0

temp_base = {}
temp_true = {}
for word in target_words:
    if temp_base.get(word, 'e') == 'e':
        temp_base[word] = [0 for _ in range(len(all_senses[word]))]
        temp_true[word] = [0 for _ in range(len(all_senses[word]))]
        
# overall accuracy
for test_idx, test_sen in enumerate(new_test_X):
    
    test_lemma = test_sen[new_test_idx[test_idx]]
    test_emb = trainer._model.forward(test_sen, new_test_idx[test_idx])
    all_similarity = []
    
    for k, sense in enumerate(all_senses[test_lemma]):
        definition_vec = trainer._model.definition_embeddings[test_lemma][:, k].view(trainer._model.output_size, -1)
        cos_sim = cos(test_emb, definition_vec)
        all_similarity.append(cos_sim)
        
        if new_test_y[test_idx][k] == 1:
            temp_base[test_lemma][k] += 1
    
    test_result = all_similarity.index(max(all_similarity))
    if new_test_y[test_idx][test_result] == 1:
        correct_count += 1
        temp_true[test_lemma][test_result] += 1

print('test size: {}'.format(len(new_test_X)))
print('overall accuracy: {}'.format(correct_count / len(new_test_X)))

for word in temp_true.keys():
    base = np.asarray(temp_base[word])
    true = np.asarray(temp_true[word])
    
    # true positive rates
    true_pos = true / base
    print('word: {}, rates: {}'.format(word, true_pos))
    

test size: 157
overall accuracy: 0.25477707006369427
word: level, rates: [0. 0. 0. 0. 0. 0. 1. 0.]
word: rate, rates: [nan  0.  1.  0.]
word: house, rates: [ 0. nan nan nan nan nan nan nan nan nan nan  0.]
word: management, rates: [nan  1.]
word: advantage, rates: [ 1. nan nan]
word: place, rates: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


