In [1]:
import csv
import math
import string
import itertools
from io import open
from conllu import parse_incr
import torch
import torch.nn as nn
from nltk.corpus import wordnet as wn
import numpy as np

In [2]:
from torch.nn import CosineEmbeddingLoss
l = torch.nn.CosineEmbeddingLoss()
cs = torch.nn.CosineSimilarity(dim = 1)
v1 = torch.randn(1, 5)
v2 = torch.randn(1, 5)

print(v1)
print(v1.view(5, 1))

def ang(a, b):
    return torch.dot(a, b) / (torch.norm(a) * torch.norm(b))

print(1 - ang(v1[0], v2[0]))
y = torch.ones(1)
print(l(v1, v2, y).item())
print(1 - cs(v1, v2))

tensor([[-1.3061, -0.5907,  0.9686,  1.8659, -1.0710]])
tensor([[-1.3061],
        [-0.5907],
        [ 0.9686],
        [ 1.8659],
        [-1.0710]])
tensor(0.5320)
0.5319682359695435
tensor([0.5320])


In [3]:
# parse the WSD dataset first
# and retrieve all sentences from the EUD

'''
Copyright@
White, A. S., D. Reisinger, K. Sakaguchi, T. Vieira, S. Zhang, R. Rudinger, K. Rawlins, & B. Van Durme. 2016. 
[Universal decompositional semantics on universal dependencies]
(http://aswhite.net/media/papers/white_universal_2016.pdf). 
To appear in *Proceedings of the Conference on Empirical Methods in Natural Language Processing 2016*.
'''

def parse_data():

    # parse the EUD-EWT conllu files and retrieve the sentences
    # remove all punctuation?
    train_file = open("data/UD_English-EWT/en_ewt-ud-train.conllu", "r", encoding="utf-8")
    train_data = list(parse_incr(train_file))
    # train_data = [[''.join(c for c in word.get('lemma') if c not in string.punctuation) for word in token_list] for token_list in train_data]
    # train_data = [[word for word in s if word] for s in train_data]
    print('Parsed {} training data from UD_English-EWT/en_ewt-ud-train.conllu.'.format(len(train_data)))

    test_file = open("data/UD_English-EWT/en_ewt-ud-test.conllu", "r", encoding="utf-8")
    test_data = list(parse_incr(test_file))
    # test_data = [[''.join(c for c in word.get('lemma') if c not in string.punctuation) for word in token_list] for token_list in test_data]
    # test_data = [[word for word in s if word] for s in test_data]
    print('Parsed {} testing data from UD_English-EWT/en_ewt-ud-test.conllu.'.format(len(test_data)))

    dev_file = open("data/UD_English-EWT/en_ewt-ud-dev.conllu", "r", encoding="utf-8")
    dev_data = list(parse_incr(dev_file))
    # dev_data = [[''.join(c for c in word.get('lemma') if c not in string.punctuation) for word in token_list] for token_list in dev_data]
    # dev_data = [[word for word in s if word] for s in dev_data]
    print('Parsed {} dev data from UD_English-EWT/en_ewt-ud-dev.conllu.'.format(len(dev_data)))

    # parse the WSD dataset
    wsd_data = []

    # read in tsv by White et. al., 2016
    with open('data/wsd/wsd_eng_ud1.2_10262016.tsv', mode = 'r') as wsd_file:

        tsv_reader = csv.DictReader(wsd_file, delimiter = '\t')      

        # store the data: ordered dict row
        for row in tsv_reader:                                
                            
            # each data vector
            wsd_data.append(row)

        # make sure all data are parsed
        print('Parsed {} word sense data from White et. al., 2016.'.format(len(wsd_data)))

    return wsd_data, train_data, test_data, dev_data

In [4]:
# get the raw wsd data
wsd_data, train_data, test_data, dev_data = parse_wsd_data()

Parsed 439312 word sense data from White et. al., 2016.


In [5]:
'''
return: 
all senses for each word 
all definitions for each word
all supersenses
from the EUD for train, test, and dev dataset
index provided by WSD dataset by White et. al.
'''
# get all the senses and definitions for each word from WSD dataset
# order of senses and definitions are in order
def get_all_senses_and_definitions(wsd_data, train_data, test_data, dev_data):

    # all senses for each word in train and dev
    # supersense is shared 
    all_senses = {}
    all_definitions = {}
    all_supersenses = {}

    # all senses for each word in test
    all_test_senses = {}
    all_test_definitions = {}
    
    # only get the senses for train and dev set
    for i in range(len(wsd_data)):
        
        # get the original sentence from EUD
        sentence_id = wsd_data[i].get('Sentence.ID')
        
        # get the definitions for the target word from EUD
        definition = wsd_data[i].get('Sense.Definition').split(' ')
        
        # the index in EUD is 1-based!!!
        sentence_number = int(sentence_id.split(' ')[-1]) - 1
        word_index = int(wsd_data[i].get('Arg.Token')) - 1
        
        word_lemma = wsd_data[i].get('Arg.Lemma')
        word_sense = wsd_data[i].get('Synset')
        response = wsd_data[i].get('Sense.Response')
        
        # add a under score to avoid name conflict with pytorch build-in attributes
        # get the original word
        # in case of errors in the dataset
        # correct it to the original word the annotator saw
        # add a under score to avoid name conflict with pytorch build-in attributes
        if wsd_data[i].get('Split') == 'train':
            sentence = train_data[sentence_number]
            word_lemma = '____' + [word.get('lemma') for word in sentence][word_index]
        elif wsd_data[i].get('Split') == 'test':
            sentence = test_data[sentence_number]
            word_lemma = '____' + [word.get('lemma') for word in sentence][word_index]
        else:
            sentence = dev_data[sentence_number]
            word_lemma = '____' + [word.get('lemma') for word in sentence][word_index]

        # senses for train and dev
        # preserve unknown words
        if wsd_data[i].get('Split') != 'test':
        
            # supersense-> (word_lemma, word_sense) dictionary
            super_s = wn.synset(word_sense).lexname().replace('.', '_')
            if all_supersenses.get(super_s, 'not_exist') != 'not_exist':
                all_supersenses[super_s].add((word_lemma, word_sense))
            else:
                all_supersenses[super_s] = {(word_lemma, word_sense)}

            # if the word already exits: add the new sense to the list
            # else: creata a new list for the word
            if all_senses.get(word_lemma, 'not_exist') != 'not_exist':
                if word_sense not in all_senses[word_lemma]:
                    all_senses[word_lemma].append(word_sense)
            else:
                all_senses[word_lemma] = []
                all_senses[word_lemma].append(word_sense)            
            
            if all_definitions.get(word_lemma,'not_exist') != 'not_exist':
                if definition not in all_definitions[word_lemma]: 
                    all_definitions[word_lemma].append(definition)
            else:
                all_definitions[word_lemma] = []
                all_definitions[word_lemma].append(definition)
                
        else:

            # all the senses and definitions for test words
            if all_test_senses.get(word_lemma, 'not_exist') != 'not_exist':
                if word_sense not in all_test_senses[word_lemma]:
                    all_test_senses[word_lemma].append(word_sense)
            else:
                all_test_senses[word_lemma] = []
                all_test_senses[word_lemma].append(word_sense)            
            
            if all_test_definitions.get(word_lemma,'not_exist') != 'not_exist':
                if definition not in all_test_definitions[word_lemma]: 
                    all_test_definitions[word_lemma].append(definition)
            else:
                all_test_definitions[word_lemma] = []
                all_test_definitions[word_lemma].append(definition)            
        
    return all_senses, all_definitions, all_supersenses, all_test_senses, all_test_definitions

In [6]:
# get all the senses and definitions
all_senses, all_definitions, all_supersenses, all_test_senses, all_test_definitions = get_all_senses_and_definitions(wsd_data, train_data, test_data, dev_data)


In [7]:
# print(all_test_senses['data'])
# print(all_senses['data'])

In [8]:
# test for the WordNet NLTK API
'''
The specific Synset method is lexname, e.g. wn.synsets('spring')[0].lexname(). 
That should make it really easy to get the suspersenses.
And if you have the synset name–e.g. 'spring.n.01'
you can access the supersense directly: wn.synset('spring.n.01').lexname().
Which returns 'noun.time'.
And wn.synset('spring.n.02').lexname() returns 'noun.artifact'

for idx, d in enumerate(all_definitions['spring']):
    print(d)
    print(wn.synset(all_senses['spring'][idx]).lexname())

for _ in wn.synsets('spring'):
    print(_.lexname())
'''

"\nThe specific Synset method is lexname, e.g. wn.synsets('spring')[0].lexname(). \nThat should make it really easy to get the suspersenses.\nAnd if you have the synset name–e.g. 'spring.n.01'\nyou can access the supersense directly: wn.synset('spring.n.01').lexname().\nWhich returns 'noun.time'.\nAnd wn.synset('spring.n.02').lexname() returns 'noun.artifact'\n\nfor idx, d in enumerate(all_definitions['spring']):\n    print(d)\n    print(wn.synset(all_senses['spring'][idx]).lexname())\n\nfor _ in wn.synsets('spring'):\n    print(_.lexname())\n"

In [9]:
# read the train, dev, test datasets from processed files
# check the 'data_loader.ipynb' for details
def read_file():
    
    train_X = []
    train_X_num = 0
    train_Y = []
    train_Y_num = 0
    test_X = []
    test_X_num = 0
    test_Y = []
    test_Y_num = 0
    dev_X = []
    dev_X_num = 0
    dev_Y = []
    dev_Y_num = 0
    
    train_word_idx = []
    test_word_idx = []
    dev_word_idx = []
    
    # read in csv
    with open('data/train_X.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file, delimiter = '\t')

        # store the data
        for row in csv_reader:

            train_X.append(row)
            train_X_num += 1

        # make sure all data are parsed
        print(f'Parsed {train_X_num} data points for train_X.')

    with open('data/train_Y.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            train_Y.append(row)
            train_Y_num += 1

        # make sure all data are parsed
        print(f'Parsed {train_Y_num} data points for train_Y.')
        
    with open('data/train_word_idx.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            train_word_idx = (row)

        # make sure all data are parsed
        print(f'Parsed {len(train_word_idx)} data points for train_word_idx.')

    with open('data/dev_X.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file, delimiter = '\t')

        # store the data
        for row in csv_reader:

            dev_X.append(row)
            dev_X_num += 1

        # make sure all data are parsed
        print(f'Parsed {dev_X_num} data points for dev_X.')

    with open('data/dev_Y.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            dev_Y.append(row)
            dev_Y_num += 1

        # make sure all data are parsed
        print(f'Parsed {dev_Y_num} data points for dev_Y.')
        
    with open('data/dev_word_idx.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            dev_word_idx = (row)

        # make sure all data are parsed
        print(f'Parsed {len(dev_word_idx)} data points for dev_word_idx.')
        
    with open('data/test_X.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file, delimiter = '\t')

        # store the data
        for row in csv_reader:

            test_X.append(row)
            test_X_num += 1

        # make sure all data are parsed
        print(f'Parsed {test_X_num} data points for test_X.')

    with open('data/test_Y.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            test_Y.append(row)
            test_Y_num += 1

        # make sure all data are parsed
        print(f'Parsed {test_Y_num} data points for test_Y.')
        
    with open('data/test_word_idx.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            test_word_idx = (row)

        # make sure all data are parsed
        print(f'Parsed {len(test_word_idx)} data points for test_word_idx.')    
        
    return train_X, train_Y, test_X, test_Y, dev_X, dev_Y, train_word_idx, test_word_idx, dev_word_idx

In [10]:
# get all the structured data
train_X, train_Y, test_X, test_Y, dev_X, dev_Y, train_word_idx, test_word_idx, dev_word_idx = read_file()

# test on one word
'''
word_choice = 'level'

new_train_X = []
new_train_Y = []
new_train_idx = []
distri_train = np.zeros(len(all_test_senses[word_choice]))
# stst = 0
for index, sen in enumerate(train_X):
    
    if sen[train_word_idx[index]] == word_choice:
        new_train_idx.append(train_word_idx[index])
        new_train_X.append(sen)
        new_train_Y.append(train_Y[index])
        distri_train += np.asarray(train_Y[index])
        # summ = train_Y[index][0] + train_Y[index][1]
        # if summ != 2:
            # stst += 1
# print('stst: {}'.format(stst))        
print('distri of train: {}'.format(distri_train))
        
new_test_X = []
new_test_Y = []
new_test_idx = []
distri_test = np.zeros(len(all_test_senses[word_choice]))
for index, sen in enumerate(test_X):
    
    if sen[test_word_idx[index]] == word_choice:
        new_test_idx.append(test_word_idx[index])
        new_test_X.append(sen)
        new_test_Y.append(test_Y[index])
        # print(test_Y[index])
        distri_test += np.asarray(test_Y[index])
print('distri of test: {}'.format(distri_test))

new_dev_X = []
new_dev_Y = []
new_dev_idx = []
distri_dev = np.zeros(len(all_test_senses[word_choice]))
for index, sen in enumerate(dev_X):
        
    if sen[dev_word_idx[index]] == word_choice:
        new_dev_idx.append(dev_word_idx[index])
        new_dev_X.append(sen)
        new_dev_Y.append(dev_Y[index])
        distri_dev += np.asarray(dev_Y[index])
print('distri of dev: {}'.format(distri_dev))

target_senses = all_senses[word_choice]
new_all_senses = {word_choice : target_senses}
target_def = all_definitions[word_choice]
new_all_def = {word_choice : target_def}

# limit the supersense to only the test word
new_all_supersenses = {}
for supersense in all_supersenses.keys():
    for tuples in all_supersenses[supersense]:
        
        if tuples[0] == word_choice:
            if new_all_supersenses.get(supersense, 'e') != 'e':
                new_all_supersenses[supersense].add((word_choice, tuples[1]))
            else:
                new_all_supersenses[supersense] = {(word_choice, tuples[1])}
'''

Parsed 67398 data points for train_X.
Parsed 67398 data points for train_Y.
Parsed 67398 data points for train_word_idx.
Parsed 7332 data points for dev_X.
Parsed 7332 data points for dev_Y.
Parsed 7332 data points for dev_word_idx.
Parsed 7122 data points for test_X.
Parsed 7122 data points for test_Y.
Parsed 7122 data points for test_word_idx.


"\nword_choice = 'level'\n\nnew_train_X = []\nnew_train_Y = []\nnew_train_idx = []\ndistri_train = np.zeros(len(all_test_senses[word_choice]))\n# stst = 0\nfor index, sen in enumerate(train_X):\n    \n    if sen[train_word_idx[index]] == word_choice:\n        new_train_idx.append(train_word_idx[index])\n        new_train_X.append(sen)\n        new_train_Y.append(train_Y[index])\n        distri_train += np.asarray(train_Y[index])\n        # summ = train_Y[index][0] + train_Y[index][1]\n        # if summ != 2:\n            # stst += 1\n# print('stst: {}'.format(stst))        \nprint('distri of train: {}'.format(distri_train))\n        \nnew_test_X = []\nnew_test_Y = []\nnew_test_idx = []\ndistri_test = np.zeros(len(all_test_senses[word_choice]))\nfor index, sen in enumerate(test_X):\n    \n    if sen[test_word_idx[index]] == word_choice:\n        new_test_idx.append(test_word_idx[index])\n        new_test_X.append(sen)\n        new_test_Y.append(test_Y[index])\n        # print(test_Y[ind

In [11]:
from model import *
from trainer import *

from allennlp.commands.elmo import ElmoEmbedder
elmo = ElmoEmbedder()

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.
Device: cpu


In [12]:
# trainer
epochs = 0

# test on one word
trainer = Trainer(epochs = epochs, elmo_class = elmo, all_senses = all_senses, all_supersenses = all_supersenses)
# trainer = Trainer(epochs = epochs, elmo_class = elmo, all_senses = new_all_senses, all_supersenses = new_all_supersenses)



In [13]:
# train the model
train_losses, dev_losses, dev_rs = trainer.train(train_X, train_Y, train_word_idx, dev_X, dev_Y, dev_word_idx)

# small test on only one word
# train_losses, dev_losses, dev_rs = trainer.train(new_train_X, new_train_Y, new_train_idx, new_dev_X, new_dev_Y, new_dev_idx)


#############   Model Parameters   ##############
layers.word_sense.0.weight torch.Size([512, 512])
layers.word_sense.0.bias torch.Size([512])
layers.word_sense.2.weight torch.Size([300, 512])
layers.word_sense.2.bias torch.Size([300])
dimension_reduction_MLP.weight torch.Size([256, 3072])
dimension_reduction_MLP.bias torch.Size([256])
wsd_lstm.weight_ih_l0 torch.Size([1024, 256])
wsd_lstm.weight_hh_l0 torch.Size([1024, 256])
wsd_lstm.bias_ih_l0 torch.Size([1024])
wsd_lstm.bias_hh_l0 torch.Size([1024])
wsd_lstm.weight_ih_l0_reverse torch.Size([1024, 256])
wsd_lstm.weight_hh_l0_reverse torch.Size([1024, 256])
wsd_lstm.bias_ih_l0_reverse torch.Size([1024])
wsd_lstm.bias_hh_l0_reverse torch.Size([1024])
wsd_lstm.weight_ih_l1 torch.Size([1024, 512])
wsd_lstm.weight_hh_l1 torch.Size([1024, 256])
wsd_lstm.bias_ih_l1 torch.Size([1024])
wsd_lstm.bias_hh_l1 torch.Size([1024])
wsd_lstm.weight_ih_l1_reverse torch.Size([1024, 512])
wsd_lstm.weight_hh_l1_reverse torch.Size([1024, 256])
wsd_lstm.bia

definition_embeddings.____failure torch.Size([300, 7])
definition_embeddings.____fair torch.Size([300, 4])
definition_embeddings.____faith torch.Size([300, 4])
definition_embeddings.____fall torch.Size([300, 12])
definition_embeddings.____fallacy torch.Size([300, 1])
definition_embeddings.____fallibility torch.Size([300, 1])
definition_embeddings.____fallout torch.Size([300, 2])
definition_embeddings.____fame torch.Size([300, 2])
definition_embeddings.____family torch.Size([300, 8])
definition_embeddings.____fan torch.Size([300, 3])
definition_embeddings.____fancy torch.Size([300, 3])
definition_embeddings.____fantasy torch.Size([300, 3])
definition_embeddings.____faq torch.Size([300, 1])
definition_embeddings.____far torch.Size([300, 1])
definition_embeddings.____farce torch.Size([300, 2])
definition_embeddings.____fare torch.Size([300, 4])
definition_embeddings.____farm torch.Size([300, 1])
definition_embeddings.____farmer torch.Size([300, 3])
definition_embeddings.____farrier torch.

definition_embeddings.____posting torch.Size([300, 3])
definition_embeddings.____posture torch.Size([300, 4])
definition_embeddings.____pot torch.Size([300, 9])
definition_embeddings.____potato torch.Size([300, 2])
definition_embeddings.____potential torch.Size([300, 2])
definition_embeddings.____potentiality torch.Size([300, 2])
definition_embeddings.____potty torch.Size([300, 2])
definition_embeddings.____pound torch.Size([300, 14])
definition_embeddings.____poverty torch.Size([300, 1])
definition_embeddings.____powder torch.Size([300, 3])
definition_embeddings.____power torch.Size([300, 9])
definition_embeddings.____pr torch.Size([300, 3])
definition_embeddings.____practice torch.Size([300, 5])
definition_embeddings.____praise torch.Size([300, 2])
definition_embeddings.____prayer torch.Size([300, 5])
definition_embeddings.____prc torch.Size([300, 1])
definition_embeddings.____preacher torch.Size([300, 1])
definition_embeddings.____precaution torch.Size([300, 3])
definition_embedding

definition_embeddings.____volcano torch.Size([300, 2])
definition_embeddings.____volume torch.Size([300, 6])
definition_embeddings.____vote torch.Size([300, 5])
definition_embeddings.____voter torch.Size([300, 1])
definition_embeddings.____vulnerability torch.Size([300, 2])
definition_embeddings.____wagon torch.Size([300, 5])
definition_embeddings.____wait torch.Size([300, 2])
definition_embeddings.____waiter torch.Size([300, 2])
definition_embeddings.____waiting torch.Size([300, 1])
definition_embeddings.____waitress torch.Size([300, 1])
definition_embeddings.____waiver torch.Size([300, 1])
definition_embeddings.____wake torch.Size([300, 4])
definition_embeddings.____walk torch.Size([300, 7])
definition_embeddings.____wall torch.Size([300, 8])
definition_embeddings.____wallet torch.Size([300, 1])
definition_embeddings.____wanton torch.Size([300, 1])
definition_embeddings.____war torch.Size([300, 4])
definition_embeddings.____warehouse torch.Size([300, 1])
definition_embeddings.____war

In [14]:
# plot the learning curve
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc

with open('train_loss.tsv', mode = 'w') as loss_file:
        
    csv_writer = csv.writer(loss_file)
    csv_writer.writerow(train_losses)

    
with open('dev_loss.tsv', mode = 'w') as loss_file:
        
    csv_writer = csv.writer(loss_file)
    csv_writer.writerow(dev_losses)

In [15]:
plt.figure(1)
# rc('text', usetex = True)
rc('font', family='serif')
plt.grid(True, ls = '-.',alpha = 0.4)
plt.plot(train_losses, ms = 4, marker = 's', label = "Train Loss")
plt.legend(loc = "best")
title = "Cosine Similarity Loss (number of examples: " + str(len(train_X)) + ")"
plt.title(title)
plt.ylabel('Loss')
plt.xlabel('Number of Iteration')
plt.tight_layout()
plt.savefig('train_loss.png')

In [16]:
plt.figure(2)
# rc('text', usetex = True)
rc('font', family='serif')
plt.grid(True, ls = '-.',alpha = 0.4)
plt.plot(dev_losses, ms = 4, marker = 'o', label = "Dev Loss")
plt.legend(loc = "best")
title = "Cosine Similarity Loss (number of examples: " + str(len(dev_X)) + ")"
plt.title(title)
plt.ylabel('Loss')
plt.xlabel('Number of Iteration')
plt.tight_layout()
plt.savefig('dev_loss.png')

In [19]:
# clean up the test set
# may contains plural
# only 2 examples and a wired phone number

new_test_X = []
new_test_Y = []
new_test_idx = []
print('train')
for test_idx, test_sen in enumerate(train_X):
    test_lemma = '____' + test_sen[train_word_idx[test_idx]]
    if all_senses.get(test_lemma, 'e') == 'e':
        print(test_lemma)
        print(test_sen)
        
print('test')
for test_idx, test_sen in enumerate(test_X):
    
    test_lemma = '____' + test_sen[test_word_idx[test_idx]]
    
    if all_test_senses.get(test_lemma, 'e') != 'e':
        new_test_idx.append(test_word_idx[test_idx])
        new_test_X.append(test_sen)
        new_test_Y.append(test_Y[test_idx])
        
    else:
        print(test_lemma)
        print(test_sen)


train
____stamina
['it', 'be', 'a', 'psycho-spiritual', 'exercise', 'and', 'benefit', 'the', 'mind', 'by', 'develop', 'stamina', 'and', 'strength', '.']
____stamina
['it', 'be', 'a', 'psycho-spiritual', 'exercise', 'and', 'benefit', 'the', 'mind', 'by', 'develop', 'stamina', 'and', 'strength', '.']
____due
['unfortunately', ',', 'due', 'to', 'Mr.', 'Lay', "'s", 'schedule', 'he', 'will', 'not', 'be', 'able', 'to', 'participate', '.']
____due
['unfortunately', ',', 'due', 'to', 'Mr.', 'Lay', "'s", 'schedule', 'he', 'will', 'not', 'be', 'able', 'to', 'participate', '.']
____reviews
['we', 'read', 'the', 'good', 'reviews', 'before', 'go', 'and', 'have', 'high', 'hope', '..', 'but', 'to', 'we', 'dismay', 'it', 'do', 'not', 'turn', 'out', 'that', 'way', '!']
____reviews
['we', 'read', 'the', 'good', 'reviews', 'before', 'go', 'and', 'have', 'high', 'hope', '..', 'but', 'to', 'we', 'dismay', 'it', 'do', 'not', 'turn', 'out', 'that', 'way', '!']
____,
['<<', '"', 'well', 'then', ',', '"', 'say

In [18]:
# test the model
# modify to test only one word for now
cos = nn.CosineSimilarity(dim = 1, eps = 1e-6).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
correct_count = 0
known_test_size = 0
unknown_test_size = 0
unknown_correct_count = 0

embds = []

# overall accuracy
for test_idx, test_sen in enumerate(new_test_X):
    
    test_lemma = '____' + test_sen[new_test_idx[test_idx]]
        
    # print(test_sen)
    test_emb = trainer._model.forward(test_sen, new_test_idx[test_idx]).view(1, -1).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    # print(test_emb)
    all_similarity = []
    # embds.append(test_emb)
    
    # if it is a new word
    # only test on the supersense
    if all_senses.get(test_lemma, 'e') == 'e':
        
        unknown_test_size += 1
        test_result = ''
        best_sim = -float('inf')
        
        for n, new_s in enumerate(all_test_senses[test_lemma]):
            
            new_super = wn.synset(new_s).lexname().replace('.', '_')
            super_vec = trainer._model.supersense_embeddings[new_super].view(1, -1).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
            cos_sim = cos(test_emb, super_vec)
            
            if cos_sim > best_sim:
                test_result = new_super
                best_sim = cos_sim
                
        correct_super = []
        for q, respon in enumerate(new_test_Y[test_idx]):
            if respon:
                correct_s = wn.synset(all_test_senses[test_lemma][q]).lexname().replace('.', '_')
                correct_super.append(correct_s)            
        if test_result in correct_super:
            unknown_correct_count += 1
        
    else:
            
        # if it is a known word
        known_test_size += 1
        
        for k, sense in enumerate(all_senses[test_lemma]):
            definition_vec = trainer._model.definition_embeddings[test_lemma][:, k].view(1, -1).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
            cos_sim = cos(test_emb, definition_vec)
            all_similarity.append(cos_sim)
        # print(all_similarity)
        test_result = all_similarity.index(max(all_similarity))
        # print("result index: {}".format(test_result))
        if new_test_Y[test_idx][test_result] == 1:
            correct_count += 1

print('test size for known words: {}'.format(known_test_size))
print('accuracy for known words: {}'.format(correct_count / known_test_size))

print('test size for unknown words: {}'.format(unknown_test_size))
print('accuracy for unknown words: {}'.format(unknown_correct_count / unknown_test_size))


KeyboardInterrupt: 

In [None]:
# debug

print(embds[0])
print(embds[1])

for q in range(len(embds)):
    for p in range(len(embds)):
        
        print("q: {}, p: {}\ncos: {}".format(q, p, cos(embds[q], embds[p])))