In [None]:
import os 
import re
import numpy as np
import string
import time
import matplotlib.pyplot as plt

import torch
from torch import distributions
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data
from torch.utils.data import DataLoader


from nltk.corpus import stopwords
from collections import Counter
from collections import defaultdict


In [None]:
WINDOW_SIZE = 2
EMBEDDING_DIM = 128
BATCH_SIZE = 10000
NUM_EPOCHS = 100
vocab_size_limit = 10000

filename = "wa/dev.en"

In [None]:
with open("english") as stopwords:
    english_stop_words = stopwords.read().split()

In [None]:
# read the data, clean them, make a vocabulary and create dictionaries
with open("english") as stopwords:
    english_stop_words = stopwords.read().split()

corpus = []
translator = str.maketrans('', '', string.punctuation)

with open(filename) as f:
    for sentence in f:
        clean_sentence = sentence.lower()
        clean_sentence = sentence.translate(translator)
        clean_word = clean_sentence.split()
        clean_word = [w.lower() for w in clean_word]
        corpus.append(clean_word)

print(len(corpus))
flat_list = [item for sublist in corpus for item in sublist]
corpus_set = set(flat_list)
vocabulary = defaultdict(lambda: 0)
text = []

for sentence in corpus:
    for word in sentence:
        if word not in english_stop_words:
            vocabulary[word] +=1
            

word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}

word2idx["UNK"] = len(word2idx)
idx2word[len(word2idx)] = "UNK"

for sentence in corpus:
    text.append([word2idx[word]if word in word2idx.keys() else word2idx["UNK"] for word in sentence])

vocab_size = len(vocabulary)    
# make the word pairs and return dictionary
pairs = {}
center = []
context= []
for sentence in corpus:
    for central_word in range(len(sentence)):
        pairs[sentence[central_word]] = []
        for current_window in range(-WINDOW_SIZE, WINDOW_SIZE +1):
            context_word = central_word + current_window            
            if (context_word <=0 or context_word >= len(sentence) or context_word==central_word):
                continue
            if sentence[context_word] not in vocabulary:
                context.append(word2idx["UNK"])
            else:
                context.append(word2idx[sentence[context_word]])
            if sentence[central_word] not in vocabulary:
                center.append(word2idx["UNK"])
            else:
                center.append(word2idx[sentence[central_word]])
                
list_of_center = []
list_of_context = []
center_update = []
context_update = []
for center, context in zip(center,context):
    if center != 10001:
        list_of_center.append(center_update)
        list_of_context.append(context_update)
        center = 10001
        center_update = []
        context_update = []
    
    center_update.append(center)
    context_update.append(context)
list_of_center.pop(0)
list_of_context.pop(0)    

In [None]:
class bayesian_skipgram(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(bayesian_skipgram, self).__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.mu_prior = nn.Embedding(vocab_size, embed_dim)
        self.sigma_prior = nn.Embedding(vocab_size, embed_dim)
        self.M = nn.Linear(embed_dim*2, embed_dim)
        self.lambda_mu = nn.Linear(embed_dim, embed_dim)
        self.lambda_sigma = nn.Linear(embed_dim, embed_dim)
        self.softplus = nn.Softplus()
        self.sofmax = nn.Softmax()
        self.relu = nn.ReLU()
        
    def forward(self, word_idx, context_idx):
        batch_size = BATCH_SIZE
        #forward step 
        word_emb = self.embedding(word_idx)               
        context_emb = self.embedding(context_idx)        
        WC = torch.cat((word_emb, context_emb), dim=2)   
        
        m = self.relu(self.M(WC))
        h = torch.sum(m, dim=1)
        
        mu = self.lambda_mu(h)
        sigma = self.softplus(self.lambda_sigma(h))
        
        eps = distributions.MultivariateNormal(torch.zeros(self.embed_dim), torch.eye(self.embed_dim)).sample()
        z = mu + sigma * eps 
        
        out = self.out(z)
        
        likelihood_terms = torch.zeros(BATCH_SIZE)
        KL_div_terms = torch.zeros(BATCH_SIZE)

        for i, contexts in enumerate(context_idx):  
            likelihood = 0
            for idx in contexts:
                likelihood += torch.log(f_i[i, idx] +1e-8)
            likelihood_terms[i] = likelihood

            KL =  self.KL_div(mu_prior[i], sigma_prior[i],  mu[i],  sigma[i] )
            KL_div_terms[i] = KL


        total_loss = torch.mean(KL_div_terms) - torch.mean(likelihood_terms)

        return total_loss

        

In [None]:
def KL_div(self,  mu_p, sigma_p, mu, sigma):
    div = torch.log(sigma_p + 1e-8) - torch.log(sigma+1e-8) + (sigma**2 + (mu - mu_p)**2) / (2*sigma_p**2) - 0.5
    return div.sum()    

In [None]:
epoch_losses= []
epoch_loss = 0

start = time.time()

iterations = len(list_of_context)



batch = 0


for center, context in zip(list_of_center,list_of_context):
    batch_start = time.time()

    
    center = torch.LongTensor(center).cuda()
    context = torch.LongTensor(context).cuda()
    center = Variable(center).cuda()
    context = Variable(context).cuda()
    
    
    optimizer.zero_grad()
    
    
    out, mu, sigma = bsg_model(center,context)
    cat_out = out.repeat(len(context),1) 

    CE_loss = loss_function(cat_out, context)
    prior = torch.distributions.multivariate_normal.MultivariateNormal(torch.ones(embedding_dims).cuda(), torch.eye(embedding_dims).cuda())
    posterior = torch.distributions.multivariate_normal.MultivariateNormal(u, torch.diag(s))
    KL_div_terms = torch.distributions.kl.kl_divergence(posterior, prior).sum()
    loss = -CE_loss + kl 
    epoch_loss += loss.data.item()
    loss.backward()
    optimizer.step()

    epoch_time = time.time() - batch_start 
    batch += 1

epoch_loss /= iterations

epoch_losses.append(epoch_loss)