In [1]:
import csv
import math
import string
import itertools
from io import open
import nltk
from nltk.corpus import wordnet as wn
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import Iterable, defaultdict
import random

In [2]:
# set determinstic results
SEED = 1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
from allennlp.commands.elmo import ElmoEmbedder
elmo = ElmoEmbedder()

from encoder import *
from decoder import *
from emb2seq_model import *

# get the decoder vocab
with open('./data/vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)
    print("Size of vocab: {}".format(vocab.idx))

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.
Size of vocab: 49036


In [4]:
decoder = Decoder(vocab_size = vocab.idx)
encoder = Encoder(elmo_class = elmo)
emb2seq_model = Emb2Seq_Model(encoder, decoder, vocab = vocab)

# randomly initialize the weights
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)       
emb2seq_model.apply(init_weights)

Emb2Seq_Model(
  (encoder): Encoder(
    (dimension_reduction): Linear(in_features=3072, out_features=512, bias=True)
    (lstm): LSTM(512, 256, num_layers=2, bidirectional=True)
    (mlp): Sequential(
      (0): Linear(in_features=512, out_features=300, bias=True)
      (1): ReLU()
      (2): Dropout(p=0)
      (3): Linear(in_features=300, out_features=256, bias=True)
      (4): Dropout(p=0)
    )
  )
  (decoder): Decoder(
    (lstm_cell): LSTMCell(512, 256)
    (linear): Linear(in_features=256, out_features=49036, bias=True)
  )
  (embed): Embedding(49036, 256, padding_idx=0)
  (dropout): Dropout(p=0)
)

In [5]:
#cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: {}'.format(device))
print(torch.cuda.device_count())
emb2seq_model.to(device)

# training hyperparameters
optimizer = optim.Adam(emb2seq_model.parameters(), lr = 0.1)
PAD_IDX = vocab('<pad>')
print('PAD_IDX: {}'.format(PAD_IDX))
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX).to(device)

Device: cpu
0
PAD_IDX: 0


In [6]:
# utility function
# turn the given definition into its index list form
def def2idx(definition, max_length, vocab):
    
    # definition is given by the WN NLTK API in a string
    def_tokens = nltk.tokenize.word_tokenize(definition.lower())
    
    # limit the length if too long, trim
    if len(def_tokens) > (max_length - 2):
        def_tokens = def_tokens[0:(max_length - 2)]
        
        # add the start and end symbol
        def_tokens = ['<start>'] + def_tokens + ['<end>']
    
    # if the length is too short, pad
    elif len(def_tokens) < (max_length - 2):
        
        # add the start and end symbol
        def_tokens = ['<start>'] + def_tokens + ['<end>']
        
        pad = ['<pad>'] * (max_length - len(def_tokens))
        def_tokens = def_tokens + pad
        
    else:
        def_tokens = ['<start>'] + def_tokens + ['<end>']
            
    # get the index for each element in the token list
    def_idx_list = [vocab(token) for token in def_tokens]
    
    return def_idx_list
  

In [7]:
# parse the SemCor training data
import xml.etree.ElementTree as ET
tree = ET.parse('../WSD_Evaluation_Framework/Training_Corpora/SemCor/semcor.data.xml')
corpus = tree.getroot()

# parse the target sense tag 
target_file = open("../WSD_Evaluation_Framework/Training_Corpora/SemCor/semcor.gold.key.txt", "r")

# small sets of SemCor
small_train_size = 1
small_dev_size = 2

In [12]:
# the training function
def train(model, optimizer, corpus, criterion, clip):
    
    model.train()
    epoch_loss = 0
    sentence_num = 0
    
    for sub_corpus in corpus[0:small_train_size]:
    
        for sent in sub_corpus[0:20]:

            optimizer.zero_grad()
            
            # get the plain text sentence
            sentence = [word.text for word in sent]
            
            # get the tagged ambiguous words
            tagged_sent = [instance for instance in sent if instance.tag == 'instance']
            # print(sentence)
            # print(tagged_sent)
            
            # only use sentence with at least one tagged word
            if len(tagged_sent) > 0:
                
                sentence_num += 1
                
                # get all-word definitions, batch_size is the sentence length
                # [batch_size, self.max_length]
                definitions = []
                for instance in tagged_sent:
                    
                    # get the sense from the WN
                    # senses are in-order already
                    key = target_file.readline().replace('\n', '').split(' ')[-1]
                    definition = wn.lemma_from_key(key).synset().definition()
                    # print(definition)
                    def_idx_list = def2idx(definition, model.max_length, vocab)
                    # print(def_idx_list)
                    definitions.append(def_idx_list)

                # get the encoder-decoder result
                # (self.max_length, batch_size, vocab_size)
                output, result = model(sentence, tagged_sent, definitions, teacher_forcing_ratio = 0.4)

                # adjust dimension for loss calculation
                # (self.max_length * batch_size, vocab_size)
                output = output.view(-1, output.shape[-1])
                target = torch.tensor(definitions, dtype = torch.long).to(device)
                # (self.max_length * batch_size)
                target = torch.transpose(target, 0, 1).contiguous().view(-1)

                loss = criterion(output, target)
                loss.backward()

                # add clip for gradient boost
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

                optimizer.step()
                epoch_loss += loss.item()
        
    return epoch_loss / sentence_num, result

In [13]:
# evaluate the model
def evaluate(model, corpus, criterion):
    
    model.eval()
    epoch_loss = 0
    sentence_num = 0
        
    with torch.no_grad():
    
        for sub_corpus in corpus[small_train_size:small_dev_size]:
    
            for sent in sub_corpus[0:5]:
            
                sentence = [word.text for word in sent]
                
                # get the tagged ambiguous words
                tagged_sent = [instance for instance in sent if instance.tag == 'instance']
                # print(sentence)
                # print(tagged_sent)

                # only use sentence with at least one tagged word
                if len(tagged_sent) > 0:
                    sentence_num += 1

                    # get all-word definitions, batch_size is the sentence length
                    # [batch_size, self.max_length]
                    definitions = []
                    for instance in tagged_sent:

                        # get the sense from the WN
                        # senses are in-order already
                        key = target_file.readline().replace('\n', '').split(' ')[-1]
                        definition = wn.lemma_from_key(key).synset().definition()                 
                        def_idx_list = def2idx(definition, model.max_length, vocab)
                        definitions.append(def_idx_list)

                    # get the encoder-decoder result
                    # (self.max_length, batch_size, vocab_size)
                    output, result = model(sentence, tagged_sent, definitions, teacher_forcing_ratio = 0.4)

                    # adjust dimension for loss calculation
                    # (self.max_length * batch_size, vocab_size)
                    output = output.view(-1, output.shape[-1])
                    target = torch.tensor(definitions, dtype = torch.long).to(device)
                    # (self.max_length * batch_size)
                    target = torch.transpose(target, 0, 1).contiguous().view(-1)

                    loss = criterion(output, target).to(device)         
                    epoch_loss += loss.item()
                            
    return epoch_loss / sentence_num , result

In [14]:
# time used by each epoch
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [15]:
# train and evaluate
import time

N_EPOCHS = 50
CLIP = 1
best_valid_loss = float('inf')
train_losses = []
dev_losses = []

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss, _ = train(emb2seq_model, optimizer, corpus, criterion, CLIP)
    train_losses.append(train_loss)
    
    valid_loss, result = evaluate(emb2seq_model, corpus, criterion)
    dev_losses.append(valid_loss)
        
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    # visualize the results
    all_results = []
    for n in range(len(result[0])):
        sense = ''
        for m in range(len(result)):
            w = ' '+ vocab.idx2word.get(int(result[m][n]))
            sense += w
        all_results.append(sense)
    with open('result.txt', 'w') as f:
        for item in all_results:
            f.write("%s\n" % item)
            
    # save the best model based on the dev set
    '''
    if valid_loss <= best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(seq2seq_model.state_dict(), 'best_model.pth')
    '''
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Epoch: 01 | Time: 1m 0s
	Train Loss: 23.222 | Train PPL: 12162902411.537
	 Val. Loss: 27.805 |  Val. PPL: 1190340441616.401
Epoch: 02 | Time: 1m 39s
	Train Loss: 31.529 | Train PPL: 49326004956399.141
	 Val. Loss: 32.466 |  Val. PPL: 125835492093673.000
Epoch: 03 | Time: 2m 1s
	Train Loss: 30.688 | Train PPL: 21266155056851.684
	 Val. Loss: 34.732 |  Val. PPL: 1212907356108927.000
Epoch: 04 | Time: 1m 33s
	Train Loss: 36.324 | Train PPL: 5961098447404419.000
	 Val. Loss: 33.669 |  Val. PPL: 418935217405265.812
Epoch: 05 | Time: 1m 9s
	Train Loss: 39.864 | Train PPL: 205376922670022624.000
	 Val. Loss: 39.097 |  Val. PPL: 95421091567842448.000
Epoch: 06 | Time: 1m 0s
	Train Loss: 44.386 | Train PPL: 18901590758859714560.000
	 Val. Loss: 55.753 |  Val. PPL: 1633886382668839169032192.000
Epoch: 07 | Time: 1m 0s
	Train Loss: 55.741 | Train PPL: 1615112241973982652268544.000
	 Val. Loss: 58.669 |  Val. PPL: 30183425303361889945780224.000
Epoch: 08 | Time: 0m 59s
	Train Loss: 55.588 | Train 

KeyboardInterrupt: 

In [None]:
vocab('<pad>')

In [None]:
# plot the learning curve
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc

with open('train_loss.tsv', mode = 'w') as loss_file:
    csv_writer = csv.writer(loss_file)
    csv_writer.writerow(train_losses)

with open('dev_loss.tsv', mode = 'w') as loss_file: 
    csv_writer = csv.writer(loss_file)
    csv_writer.writerow(dev_losses)

In [None]:
plt.figure(1)
# rc('text', usetex = True)
rc('font', family='serif')
plt.grid(True, ls = '-.',alpha = 0.4)
plt.plot(train_losses, ms = 4, marker = 's', label = "Train Loss")
plt.legend(loc = "best")
title = "CrossEntropy Loss"
plt.title(title)
plt.ylabel('Loss')
plt.xlabel('Number of Iteration')
plt.tight_layout()
plt.savefig('train_loss.png')

In [None]:
plt.figure(2)
# rc('text', usetex = True)
rc('font', family='serif')
plt.grid(True, ls = '-.',alpha = 0.4)
plt.plot(dev_losses, ms = 4, marker = 'o', label = "Dev Loss")
plt.legend(loc = "best")
title = "CrossEntropy Loss"
plt.title(title)
plt.ylabel('Loss')
plt.xlabel('Number of Iteration')
plt.tight_layout()
plt.savefig('dev_loss.png')