In [23]:
import csv
import math
import string
import itertools
from io import open
from conllu import parse_incr
from nltk.corpus import wordnet as wn
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import Iterable, defaultdict
import random

In [2]:
SEED = 1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
# parse the WSD dataset first
# and retrieve all sentences from the EUD

'''
Copyright@
White, A. S., D. Reisinger, K. Sakaguchi, T. Vieira, S. Zhang, R. Rudinger, K. Rawlins, & B. Van Durme. 2016. 
[Universal decompositional semantics on universal dependencies]
(http://aswhite.net/media/papers/white_universal_2016.pdf). 
To appear in *Proceedings of the Conference on Empirical Methods in Natural Language Processing 2016*.
'''

def parse_wsd_data():

    # parse the EUD-EWT conllu files and retrieve the sentences
    # remove all punctuation?
    train_file = open("data/UD_English-EWT/en_ewt-ud-train.conllu", "r", encoding="utf-8")
    train_data = list(parse_incr(train_file))
    # train_data = [[''.join(c for c in word.get('lemma') if c not in string.punctuation) for word in token_list] for token_list in train_data]
    # train_data = [[word for word in s if word] for s in train_data]
    print('Parsed {} training data from UD_English-EWT/en_ewt-ud-train.conllu.'.format(len(train_data)))

    test_file = open("data/UD_English-EWT/en_ewt-ud-test.conllu", "r", encoding="utf-8")
    test_data = list(parse_incr(test_file))
    # test_data = [[''.join(c for c in word.get('lemma') if c not in string.punctuation) for word in token_list] for token_list in test_data]
    # test_data = [[word for word in s if word] for s in test_data]
    print('Parsed {} testing data from UD_English-EWT/en_ewt-ud-test.conllu.'.format(len(test_data)))

    dev_file = open("data/UD_English-EWT/en_ewt-ud-dev.conllu", "r", encoding="utf-8")
    dev_data = list(parse_incr(dev_file))
    # dev_data = [[''.join(c for c in word.get('lemma') if c not in string.punctuation) for word in token_list] for token_list in dev_data]
    # dev_data = [[word for word in s if word] for s in dev_data]
    print('Parsed {} dev data from UD_English-EWT/en_ewt-ud-dev.conllu.'.format(len(dev_data)))

    # parse the WSD dataset
    wsd_data = []

    # read in tsv by White et. al., 2016
    with open('data/wsd/wsd_eng_ud1.2_10262016.tsv', mode = 'r') as wsd_file:

        tsv_reader = csv.DictReader(wsd_file, delimiter = '\t')      

        # store the data: ordered dict row
        for row in tsv_reader:                                
                            
            # each data vector
            wsd_data.append(row)

        # make sure all data are parsed
        print('Parsed {} word sense data from White et. al., 2016.'.format(len(wsd_data)))

    return wsd_data, train_data, test_data, dev_data

In [4]:
# get the raw wsd data
wsd_data, train_data, test_data, dev_data = parse_wsd_data()

Parsed 12543 training data from UD_English-EWT/en_ewt-ud-train.conllu.
Parsed 2077 testing data from UD_English-EWT/en_ewt-ud-test.conllu.
Parsed 2002 dev data from UD_English-EWT/en_ewt-ud-dev.conllu.
Parsed 439312 word sense data from White et. al., 2016.


In [5]:
'''
return: 
all senses for each word 
all definitions for each word
from the EUD for train, test, and dev dataset
index provided by WSD dataset by White et. al.
'''
# get all the senses and definitions for each word from WSD dataset
# order of senses and definitions are in order
def get_all_senses_and_definitions(wsd_data, train_data, test_data, dev_data):

    # all senses for each word in train and dev
    all_senses = {}
    all_definitions = {}

    # all senses for each word in test
    all_test_senses = {}
    all_test_definitions = {}
    
    # only get the senses for train and dev set
    for i in range(len(wsd_data)):
        
        # get the original sentence from EUD
        sentence_id = wsd_data[i].get('Sentence.ID')
        
        # get the definitions for the target word from EUD
        definition = wsd_data[i].get('Sense.Definition').split(' ')
        
        # the index in EUD is 1-based!!!
        sentence_number = int(sentence_id.split(' ')[-1]) - 1
        word_index = int(wsd_data[i].get('Arg.Token')) - 1
        
        word_lemma = wsd_data[i].get('Arg.Lemma')
        word_sense = wsd_data[i].get('Synset')
        response = wsd_data[i].get('Sense.Response')
        
        # add a under score to avoid name conflict with pytorch build-in attributes
        # get the original word
        # in case of errors in the dataset
        # correct it to the original word the annotator saw
        # add a under score to avoid name conflict with pytorch build-in attributes
        old = '____' + word_lemma
        if wsd_data[i].get('Split') == 'train':
            sentence = train_data[sentence_number]
            word_lemma = '____' + [word.get('lemma') for word in sentence][word_index]
        elif wsd_data[i].get('Split') == 'test':
            sentence = test_data[sentence_number]
            word_lemma = '____' + [word.get('lemma') for word in sentence][word_index]
        else:
            sentence = dev_data[sentence_number]
            word_lemma = '____' + [word.get('lemma') for word in sentence][word_index]
        
        # index error in UD: some sentences start with '<<'
        # have wrong index
        if [word.get('lemma') for word in sentence][0] == '<<' and [word.get('lemma') for word in sentence][-1] != '>>':
            if '____' + [word.get('lemma') for word in sentence][word_index] != old:
                word_lemma = old
                
        # senses for train and dev
        # preserve unknown words
        if wsd_data[i].get('Split') != 'test':

            # if the word already exits: add the new sense to the list
            # else: creata a new list for the word
            if all_senses.get(word_lemma, 'not_exist') != 'not_exist':
                if word_sense not in all_senses[word_lemma]:
                    all_senses[word_lemma].append(word_sense)
            else:
                all_senses[word_lemma] = []
                all_senses[word_lemma].append(word_sense)            
            
            if all_definitions.get(word_lemma,'not_exist') != 'not_exist':
                if definition not in all_definitions[word_lemma]: 
                    all_definitions[word_lemma].append(definition)
            else:
                all_definitions[word_lemma] = []
                all_definitions[word_lemma].append(definition)
                
        else:

            # all the senses and definitions for test words
            if all_test_senses.get(word_lemma, 'not_exist') != 'not_exist':
                if word_sense not in all_test_senses[word_lemma]:
                    all_test_senses[word_lemma].append(word_sense)
            else:
                all_test_senses[word_lemma] = []
                all_test_senses[word_lemma].append(word_sense)            
            
            if all_test_definitions.get(word_lemma,'not_exist') != 'not_exist':
                if definition not in all_test_definitions[word_lemma]: 
                    all_test_definitions[word_lemma].append(definition)
            else:
                all_test_definitions[word_lemma] = []
                all_test_definitions[word_lemma].append(definition)            
    
    return all_senses, all_definitions, all_test_senses, all_test_definitions

In [15]:
# get all the senses and definitions
all_senses, all_definitions, all_test_senses, all_test_definitions = get_all_senses_and_definitions(wsd_data, train_data, test_data, dev_data)

# debug: make sure every word is in the WN
error_set = set()
for se in all_senses.keys():
    se = se.split('____')[-1]
    if len(wn.synsets(se)) == 0:
        error_set.add(se)

In [16]:
# read the train, dev, test datasets from processed files
# check the 'data_loader.ipynb' for details
def read_file():
    
    train_X = []
    train_X_num = 0
    train_Y = []
    train_Y_num = 0
    test_X = []
    test_X_num = 0
    test_Y = []
    test_Y_num = 0
    dev_X = []
    dev_X_num = 0
    dev_Y = []
    dev_Y_num = 0
    
    train_word_idx = []
    test_word_idx = []
    dev_word_idx = []
    
    # debug: make sure every word is in the WordNet
    WN_set = set(wn.all_synsets())
    
    # read in csv
    with open('data/train_X.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file, delimiter = '\t')

        # store the data
        for row in csv_reader:

            train_X.append(row)
            train_X_num += 1

        # make sure all data are parsed
        print(f'Parsed {train_X_num} data points for train_X.')

    with open('data/train_Y.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            train_Y.append(row)
            train_Y_num += 1

        # make sure all data are parsed
        print(f'Parsed {train_Y_num} data points for train_Y.')
        
    with open('data/train_word_idx.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            train_word_idx = (row)

        # make sure all data are parsed
        print(f'Parsed {len(train_word_idx)} data points for train_word_idx.')

    with open('data/dev_X.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file, delimiter = '\t')

        # store the data
        for row in csv_reader:

            dev_X.append(row)
            dev_X_num += 1

        # make sure all data are parsed
        print(f'Parsed {dev_X_num} data points for dev_X.')

    with open('data/dev_Y.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            dev_Y.append(row)
            dev_Y_num += 1

        # make sure all data are parsed
        print(f'Parsed {dev_Y_num} data points for dev_Y.')
        
    with open('data/dev_word_idx.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            dev_word_idx = (row)

        # make sure all data are parsed
        print(f'Parsed {len(dev_word_idx)} data points for dev_word_idx.')
        
    with open('data/test_X.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file, delimiter = '\t')

        # store the data
        for row in csv_reader:

            test_X.append(row)
            test_X_num += 1

        # make sure all data are parsed
        print(f'Parsed {test_X_num} data points for test_X.')

    with open('data/test_Y.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            test_Y.append(row)
            test_Y_num += 1

        # make sure all data are parsed
        print(f'Parsed {test_Y_num} data points for test_Y.')
        
    with open('data/test_word_idx.tsv', mode = 'r') as data_file:
        
        csv_reader = csv.reader(data_file)

        # store the data
        for row in csv_reader:

            row = list(map(int, row))
            test_word_idx = (row)

        # make sure all data are parsed
        print(f'Parsed {len(test_word_idx)} data points for test_word_idx.')    
        
    return train_X, train_Y, test_X, test_Y, dev_X, dev_Y, train_word_idx, test_word_idx, dev_word_idx

In [None]:
# get all the structured data
train_X, train_Y, test_X, test_Y, dev_X, dev_Y, train_word_idx, test_word_idx, dev_word_idx = read_file()

# debug: make sure every word is in the WN
for i, j in enumerate(train_X):
    if j[train_word_idx[i]] in error_set:
        print(j[train_word_idx[i]])
        del train_X[i]
        del train_Y[i]
        del train_word_idx[i]
# debug: make sure every word is in the WN
for i, j in enumerate(test_X):
    if j[test_word_idx[i]] in error_set:
        print(j[test_word_idx[i]])
        del test_X[i]
        del test_Y[i]
        del test_word_idx[i]
# debug: make sure every word is in the WN
for i, j in enumerate(dev_X):
    if j[dev_word_idx[i]] in error_set:
        print(j[dev_word_idx[i]])
        del dev_X[i]
        del dev_Y[i]
        del dev_word_idx[i]

In [18]:
# test on small subset
train_X = train_X[:10000]
train_Y = train_Y[:10000]
train_word_idx = train_word_idx[:10000]

# small vocab from the first 10000 sentences
vocab = set()
for index, sen in enumerate(train_X):
    word = sen[train_word_idx[index]]
    if word not in vocab and ('____' + word) in all_senses.keys():
        vocab.add(word)
print('number of known words: {}'.format(len(vocab)))
print('number of train data: {}'.format(len(train_X)))

# filter the dev set
new_dev_X = []
new_dev_Y = []
new_dev_idx = []
for index, sen in enumerate(dev_X):
    word = sen[dev_word_idx[index]]
    if word in vocab and ('____' + word) in all_senses.keys():
        new_dev_idx.append(dev_word_idx[index])
        new_dev_X.append(sen)
        new_dev_Y.append(dev_Y[index])
new_dev_X = new_dev_X[:2000]
new_dev_Y = new_dev_Y[:2000]
new_dev_idx = new_dev_idx[:2000]

print('number of dev data: {}'.format(len(new_dev_X)))

# filter the test set
new_test_idx = []
new_test_X = []
new_test_Y = []
for index, sen in enumerate(test_X):
    word = sen[test_word_idx[index]]
    if word in vocab and ('____' + word) in all_senses.keys():
        new_test_idx.append(test_word_idx[index])
        new_test_X.append(sen)
        new_test_Y.append(test_Y[index]) 
    elif word not in all_senses.keys():
        new_test_idx.append(test_word_idx[index])
        new_test_X.append(sen)
        new_test_Y.append(test_Y[index])        
            
print('number of test data: {}'.format(len(new_test_X)))


Parsed 67619 data points for train_X.
Parsed 67619 data points for train_Y.
Parsed 67619 data points for train_word_idx.
Parsed 7332 data points for dev_X.
Parsed 7332 data points for dev_Y.
Parsed 7332 data points for dev_word_idx.
Parsed 7118 data points for test_X.
Parsed 7118 data points for test_Y.
Parsed 7118 data points for test_word_idx.
entrée
we
we
of
my
entrée
my
of
we
entrée
my
we
we
of
£
£
number of known words: 1872
number of train data: 10000
number of dev data: 2000
number of test data: 7118


In [20]:
from allennlp.commands.elmo import ElmoEmbedder
elmo = ElmoEmbedder()

from encoder import *
from decoder import *
from seq2seq_model import *

# get the decoder vocab
with open('./data/vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

In [21]:
decoder = Decoder(vocab_size = vocab.idx)
encoder = Encoder(all_senses = all_senses, elmo_class = elmo)
seq2seq_model = Seq2Seq_Model(encoder, decoder, all_senses = all_senses)

# randomly initialize the weights
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)       
seq2seq_model.apply(init_weights)

Seq2Seq_Model(
  (encoder): Encoder(
    (layers): ModuleDict(
      (word_sense): ModuleList(
        (0): Linear(in_features=512, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=300, bias=True)
      )
    )
    (mlp_dropout): Dropout(p=0)
    (dimension_reduction_MLP): Linear(in_features=3072, out_features=256, bias=True)
    (lstm): LSTM(256, 256, num_layers=2, bidirectional=True)
  )
  (decoder): Decoder(
    (lstm): LSTM(300, 300)
    (linear): Linear(in_features=300, out_features=49036, bias=True)
  )
  (embed): Embedding(49036, 300)
  (dropout): Dropout(p=0)
)

In [27]:
# training hyperparameters
epochs = 0
optimizer = optim.Adam(seq2seq_model.parameters())

PAD_IDX = vocab('<pad>')
print(PAD_IDX)
loss = nn.CrossEntropyLoss(ignore_index = PAD_IDX)


0


In [None]:
# train the model
# train_losses, dev_losses, dev_rs = trainer.train(train_X, train_Y, train_word_idx, dev_X, dev_Y, dev_word_idx)

# small vocab
train_losses, dev_losses, dev_rs = trainer.train(train_X, train_Y, train_word_idx, new_dev_X, new_dev_Y, new_dev_idx)


In [None]:
# plot the learning curve
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc

with open('train_loss.tsv', mode = 'w') as loss_file:
        
    csv_writer = csv.writer(loss_file)
    csv_writer.writerow(train_losses)

    
with open('dev_loss.tsv', mode = 'w') as loss_file:
        
    csv_writer = csv.writer(loss_file)
    csv_writer.writerow(dev_losses)

In [None]:
plt.figure(1)
# rc('text', usetex = True)
rc('font', family='serif')
plt.grid(True, ls = '-.',alpha = 0.4)
plt.plot(train_losses, ms = 4, marker = 's', label = "Train Loss")
plt.legend(loc = "best")
title = "Cosine Similarity Loss (number of examples: " + str(len(train_X)) + ")"
plt.title(title)
plt.ylabel('Loss')
plt.xlabel('Number of Iteration')
plt.tight_layout()
plt.savefig('train_loss.png')

In [None]:
plt.figure(2)
# rc('text', usetex = True)
rc('font', family='serif')
plt.grid(True, ls = '-.',alpha = 0.4)
plt.plot(dev_losses, ms = 4, marker = 'o', label = "Dev Loss")
plt.legend(loc = "best")
title = "Cosine Similarity Loss (number of examples: " + str(len(dev_X)) + ")"
plt.title(title)
plt.ylabel('Loss')
plt.xlabel('Number of Iteration')
plt.tight_layout()
plt.savefig('dev_loss.png')

In [None]:
# debug

print('test debug')
for test_idx, test_sen in enumerate(test_X):
    test_lemma = '____' + test_sen[test_word_idx[test_idx]]
    emb_length = len(all_test_senses.get(test_lemma))
    y = len(test_Y[test_idx])
    
    if emb_length != y:
        print('lemma: {}, y: {}, emb: {}'.format(test_lemma, y, emb_length))


In [None]:
# debug

print('dev debug')
for test_idx, test_sen in enumerate(dev_X):
    test_lemma = '____' + test_sen[dev_word_idx[test_idx]]
    emb_length = len(all_senses.get(test_lemma))
    y = len(dev_Y[test_idx])
    
    if emb_length != y:
        print('lemma: {}, y: {}, emb: {}'.format(test_lemma, y, emb_length))


In [None]:
# debug
# should print nothing 
print('train debug')
for test_idx, test_sen in enumerate(train_X):
    test_lemma = '____' + test_sen[train_word_idx[test_idx]]
    if all_senses.get(test_lemma, 'e') == 'e':
        print(test_lemma)
        print(test_sen)


In [None]:
# debug
# should print nothing 
print('test debug')
for test_idx, test_sen in enumerate(test_X):
    test_lemma = '____' + test_sen[test_word_idx[test_idx]]
    if all_test_senses.get(test_lemma, 'e') == 'e':
        print(test_lemma)
        print(test_sen)

In [None]:
# test the model
cos = nn.CosineSimilarity(dim = 1, eps = 1e-6).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
correct_count = 0
known_test_size = 0
unknown_test_size = 0
unknown_correct_count = 0

embds = []

new_test_idx = new_test_idx[:100]
new_test_X = new_test_X[:100]
new_test_Y = new_test_Y[:100]

# overall accuracy
for test_idx, test_sen in enumerate(new_test_X):
    
    test_lemma = '____' + test_sen[new_test_idx[test_idx]]
        
    # print(test_sen)
    test_emb = trainer._model.forward(test_sen, new_test_idx[test_idx]).view(1, -1).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    # print(test_emb)
    all_similarity = []
    # embds.append(test_emb)
    
    # if it is a new word
    # only test on the supersense
    if all_senses.get(test_lemma, 'e') == 'e':
        
        unknown_test_size += 1
        test_result = ''
        best_sim = -float('inf')
        
        for n, new_s in enumerate(all_test_senses[test_lemma]):
            
            new_super = wn.synset(new_s).lexname().replace('.', '_')
            super_vec = trainer._model.supersense_embeddings[new_super].view(1, -1).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
            cos_sim = cos(test_emb, super_vec)
            
            if cos_sim > best_sim:
                test_result = new_super
                best_sim = cos_sim
                
        correct_super = []
        for q, respon in enumerate(new_test_Y[test_idx]):
            if respon:
                correct_s = wn.synset(all_test_senses[test_lemma][q]).lexname().replace('.', '_')
                correct_super.append(correct_s)            
        if test_result in correct_super:
            unknown_correct_count += 1
        
    else:
            
        # if it is a known word
        known_test_size += 1
        
        for k, sense in enumerate(all_senses[test_lemma]):
            definition_vec = trainer._model.definition_embeddings[test_lemma][:, k].view(1, -1).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
            cos_sim = cos(test_emb, definition_vec)
            all_similarity.append(cos_sim)
        # print(all_similarity)
        test_result = all_similarity.index(max(all_similarity))
        # print("result index: {}".format(test_result))
        if new_test_Y[test_idx][test_result] == 1:
            correct_count += 1

print('test size for known words: {}'.format(known_test_size))
print('accuracy for known words: {}'.format(correct_count / known_test_size))

print('test size for unknown words: {}'.format(unknown_test_size))
print('accuracy for unknown words: {}'.format(unknown_correct_count / unknown_test_size))
