In [1]:
import numpy as np
import collections
from datetime import datetime

from scipy.stats import pearsonr
from scipy.stats import spearmanr
from sklearn.metrics.pairwise import cosine_similarity
from scipy import spatial
from scipy import stats

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
import torch.nn.functional as F
torch.manual_seed(1)
import torch.distributions as distb


import time
import datetime

from collections import defaultdict
from collections import Counter

from random import randint
import pickle



In [2]:
class BayesianSG(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super(BayesianSG, self).__init__()

        self.w_embeddings = nn.Embedding(vocab_size, embedding_size)
        self.dense_weights = nn.Linear(embedding_size*2, embedding_size)
        self.muLinear = nn.Linear(embedding_size, embedding_size)
        self.sigmaLinear = nn.Linear(embedding_size, embedding_size)

        self.Location = nn.Embedding(vocab_size, embedding_size)
        self.Scale = nn.Embedding(vocab_size, embedding_size)
        self.fLinear = nn.Linear(embedding_size, vocab_size)

        self.std_mean = torch.zeros(embedding_size)
        self.std_cov = torch.diag(torch.ones(embedding_size))

     

    def forward(self, center_words, context_words):
    
        context_words = context_words.view(1, -1)
        context_size = context_words.size(1)

        # Inference model
        x_stacked = center_words.repeat(context_size, 1).transpose(0, 1)
        center_embedding = self.w_embeddings(x_stacked)
        context_embedding = self.w_embeddings(context_words)
        out = self.dense_weights(torch.cat([center_embedding, context_embedding], -1))
        emb = F.relu(out)
        emb_sum = emb.sum(1)

        mu_posterior = self.muLinear(emb_sum)
        sigma_posterior = F.softplus(self.sigmaLinear(emb_sum))

        dims = mu_posterior.shape
        #epsilon = torch.distributions.multivariate_normal.MultivariateNormal \
        #(torch.zeros(dims),torch.diag(torch.ones(dims)))
       
   
     
  
        epsilon = distb.MultivariateNormal(self.std_mean, self.std_cov)
        epsilon = epsilon.sample()

        z = mu_posterior +  epsilon * sigma_posterior


        logprobs = F.log_softmax(self.fLinear(z), dim=-1).squeeze(0)

        mu = self.Location(center_words)
        sigma = F.softplus(self.Scale(center_words))

        loss_probs = torch.zeros_like(context_words).type(torch.FloatTensor)
 
        sum_loss = 
        for i, context_word in enumerate(context_words):
            loss_probs[i] = logprobs[context_word]

        neg_loss = loss_probs.sum(-1)
        loss1 = torch.log(sigma/sigma_posterior)
        numerator = (sigma_posterior.pow(2) + (mu_posterior - sigma).pow(2))
        total_loss = (loss1 + numerator / (2*sigma.pow(2)) - 0.5).sum()


        final_loss = (total_loss - neg_loss).mean()
        return final_loss

In [3]:
import string

def stop_words(file_name, punctuation=True):
    with open(file_name) as f:
        stop_word_list = [line.strip() for line in f.readlines()]
    
    for p in list(string.punctuation):
        stop_word_list.append(p)
    
    return stop_word_list

stop_word_list = stop_words("data/en_stopwords.txt")

In [4]:
def sentences_reader(dataset_path, stop_word_list):
    sentence_list = []
    with open(dataset_path) as f:
        for line in f:
            line = line.split()
            line = [word.lower() for word in line]
            line = [word for word in line if word not in stop_word_list]
            sentence_list.append(line)
    
    return sentence_list

sentences = sentences_reader("data/wa/dev.en", stop_word_list )

In [5]:
from collections import defaultdict


def UnigramTable(sentences, max_size):
    table = {}
    frequency = collections.defaultdict(int)
    n = 0
    
    
    for sentence in sentences:
        for word in sentence:
            table[n] = word
            frequency[word] += 1
            n+= 1
    
    return table, frequency

In [6]:

def vocabulary_creation(sentences, max_size = 10000):
    special_tokens = {"$UNK$", "$EOS$", "$SOS$", "$PAD"}
    index  = {}

    sentence_count = len(sentences)

    table, frequency = UnigramTable(sentences, max_size)
    
    num_tokens =  sum(frequency.values())

    counts = list(frequency.items())
    
    counts.sort(key=lambda _: -_[1])
    
    most_freq = [w[0] for w in counts[: max_size - len(special_tokens)]]


    index = dict([ (w, i) for i, w in enumerate(most_freq)])
    
    for special_token in special_tokens:
        assert special_token not in index
        index[special_token] = len(index)

    inverse_index = dict([(v, k) for (k, v) in index.items()])

    N = len(index)
    
    return index, inverse_index, N

index, inverse_index, N = vocabulary_creation(sentences, max_size = 10000)
        

def one_hot_vector(word, index, N):
    if word not in index:
        word = "$UNK$"
    vector = np.zeros(N)
    vector[index[word]] = 1
    
    return vector
    
def unknown_check(sentence, index):
    output = []
    for word in sentence:
        if word not in index:
            word = "$UNK$"
        output.append(word)
    return output
    

def retrieve_contexts(sentence, index, context_window):
    n_ = context_window // 2
    context = set()
    for i in range(index - n_, index + n_ + 1):
        if i == index or i < 0 or i >= len(sentence):
            continue
        context.add(sentence[i])
    return context

In [7]:
embedding_dim = 100
vocab_size = 10000
window_size = 5
epochs = 10
dataset_path = "data/wa/dev.en"
model_name = "bayesian_skipgram"

sentences = sentences_reader(dataset_path, stop_word_list)

index, inverse_index, N = vocabulary_creation(sentences, max_size = vocab_size)


bsm = BayesianSG(N, embedding_dim)

optimizer = optim.Adam(bsm.parameters(), lr=1e-4)

epoch_losses = []
for epoch in range(1, epochs + 1):
    print("Running epoch: ", epoch)
    then = time.time()

    epoch_loss = 0
    count = 0
    for sentence_num, sentence in enumerate(sentences):            
        for center_idx, center_word in enumerate(sentence):
            if center_word not in index:
                continue
            center_vec = one_hot_vector(center_word, index, N)
            context_words = []
            context_idx = []
            for word in retrieve_contexts(unknown_check(sentence, index), center_idx, window_size):
                if word not in index:
                    continue
                context_idx.append(index[word])
                context_words.append(one_hot_vector(word, index, N))
            if len(context_words) == 0:
                continue
            optimizer.zero_grad()
            loss = bsm(torch.LongTensor(np.array([index[center_word]])), torch.LongTensor(context_idx))
         
            epoch_loss = epoch_loss + loss.item()
            count = count+1

            loss.backward()
            optimizer.step()
    
    now = time.time()
    epoch_loss_avg = epoch_loss / count
    epoch_losses.append(epoch_loss_avg)
    print("average loss: ", epoch_loss_avg, "time: ",now-then)

#TODO save model 


Running epoch:  1
average loss:  397.7042032006669 time:  1.8488249778747559
Running epoch:  2
average loss:  261.43158559929833 time:  1.1740000247955322
Running epoch:  3
average loss:  207.5242339095024 time:  1.139603853225708
Running epoch:  4


KeyboardInterrupt: 

In [None]:
def get_embeddings(model):
    
    return model.w_embeddings.weight.data, model.lin1.weight.data

def save_embeddings(embeds, file_name):
    
    with open(file_name, 'wb') as file:
        pickle.dump(embeds.numpy(), file)

In [None]:
wm, cm = get_embeddings(model)
save_embeddings(wm, 'wordvecs_bayesian_skipgram.pickle')


In [6]:
#### Evaluation Reading #####

class Sentence:
    def __init__(self, target,complete,sent_id,position,tokens):
        self.target = target
        self.complete = complete
        self.sent_id = sent_id
        self.position = position
        self.tokens = tokens

        
        
with open('data/lst/lst_test.preprocessed', 'r') as myfile:
    data = myfile.readlines()

    
sentence_list = []
    
for line in data:
    line = line.split()
    target = line[0].split(".")[0]
    complete = line[0]
    sent_id = line[1]
    position = line[2]
    tokens = line[3:]
    sent_example = Sentence(target,complete,sent_id,position,tokens)
    
    sentence_list.append(sent_example)



In [2]:
with open('data/lst/lst.gold.candidates', 'r') as myfile:
    data = myfile.readlines()

    
word_candidates = {}
    
for line in data:
    
    word = line.split(":")[0][:-2]
    #print(word[:-2])
    candidates = line[:-1].split(":")[2].split(";")
    #print(candidates)
    
    word_candidates[word] = candidates


In [3]:
word_candidates

{'about': ['here and there',
  'regarding',
  'around',
  'of',
  'concerning',
  'arise',
  'discussed',
  'dealing with',
  'approximately',
  'roughly',
  'cope with',
  'nearly',
  'somewhat',
  'more or less',
  'occur',
  'happen',
  'consider',
  'round',
  'concerned with'],
 'account': ['access',
  'balance',
  'description',
  'chronicle',
  'facility',
  'bank balance',
  'explanation',
  'ledger',
  'finance',
  'banking facility',
  'subscriber',
  'fund',
  'synopsis',
  'asset',
  'statement',
  'narrative',
  'report',
  'consideration',
  'subscription',
  'logon',
  'banking arrangement'],
 'acquire': ['amass',
  'purchase',
  'buy',
  'secure',
  'get',
  'receive',
  'gather',
  'procure',
  'obtain',
  'collect',
  'bring in',
  'gain',
  'learn',
  'find',
  'achieve'],
 'acute': ['heightened',
  'emergency',
  'sensitive',
  'sudden',
  'urgent',
  'severe',
  'critical',
  'serious',
  'sharp',
  'keen',
  'pn',
  'intense',
  'grave'],
 'apparently': ['supposed