In [3]:
import numpy as np
import torch
from torch import nn, optim
import random
from collections import Counter
import matplotlib.pyplot as plt
import re

In [38]:
# training data is from http://mattmahoney.net/dc/text8.zip
EMBEDDING_DIM = embedding_dim = 128 
PRINT_EVERY = 100 # print frequency
EPOCHES = 1000 
BATCH_SIZE = 5 
N_SAMPLES = n_samples = 3 
WINDOW_SIZE = 5 
FREQ = 5 #threshold for word's frequency
DELETE_WORDS = False #delete words with high frequent
VOCABULARY_SIZE = 50000
re_pats = [r'[?|!|\'|"#]',r'[.|,|)|(|\|/]']
t=1e-5 #used in remove high frequent words

In [1]:
with open("text8.train.txt", "r") as f:
    data = f.read()

In [51]:
def preprocess(text):
    text = text.lower()
    for re_pattern in re_pats:
        text = re.sub(re_pattern, " ", text)
    words = text.split()
    words_counts = Counter(words)
    trimmed_words = [word for word in words if words_counts[word] > FREQ]
    return trimmed_words

In [52]:
words = preprocess(data)

In [57]:
words[0:5]


['anarchism', 'originated', 'as', 'a', 'term']

prepare dictionary, frequency

In [69]:
vocab = set(words)
vocab2int = {w: c for c, w in enumerate(vocab)}
int2vocab = {c: w for c, w in enumerate(vocab)}
int_words = [vocab2int[w] for w in words]
int_word_counts = Counter(int_words)
total_count = len(int_words)
word_freqs = {w: c/total_count for w, c in int_word_counts.items() }
prob_drop = {w: 1-np.sqrt(t/word_freqs[w]) for w in int_word_counts}
train_words = [w for w in int_words if random.random()<(1-prob_drop[w])]

collect neighbors

In [98]:
def get_target(words, idx):
    target_window = np.random.randint(1, WINDOW_SIZE + 1)
    start_point = idx-target_window if (idx - target_window) > 0 else 0
    end_point = idx + target_window
    targets = set(words[start_point:idx] + words[idx+1:end_point+1])
    return list(targets)

generate batch

In [111]:
def get_batch(words):
    n_batches = len(words)//BATCH_SIZE
    words = words[:n_batches*BATCH_SIZE]
    for idx in range(0, len(words), BATCH_SIZE):
        batch_x, batch_y = [], []
        batch = words[idx:idx+BATCH_SIZE]
        for i in range(len(batch)):
            x = batch[i]
            y = get_target(batch, i)
            batch_x.extend([x] * len(y))
            batch_y.extend(y)
            yield batch_x, batch_y

network and train

In [109]:
class SkipGramNeg(nn.Module):
    def __init__(self, n_vocab, n_embed, noise_dist = None):
        super().__init__()
        self.n_vocab = n_vocab
        self.n_embed = n_embed
        self.noise_dist = noise_dist
        
        # in for center words, out for neighbors words
        self.in_embed = nn.Embedding(n_vocab, n_embed)
        self.out_embed = nn.Embedding(n_vocab, n_embed)
        self.in_embed.weight.data.uniform_(-1,1)
        self.out_embed.weight.data.uniform_(-1,1)
    def forward_input(self, input_words):
        input_vector = self.in_embed(input_words)
        return input_vector
    def forward_output(self, output_words):
        output_vector = self.out_embed(output_words)
        return output_vector
    # do negative sampling based on given noise_dist
    def forward_noise(self,var_batch_size):
        if self.noise_dist is None:
            noise_dist = torch.ones(self.n_vocab)
        else:
            noise_dist = self.noise_dist
        noise_words = torch.multinomial(noise_dist, var_batch_size * n_samples, replacement=True)
        noise_vectors = self.out_embed(noise_words).view(var_batch_size, n_samples, self.n_embed)
        return noise_vectors

In [73]:
# calculate noise dist used in negative sampling
word_freqs_array = np.array(list(word_freqs.values()))
unigram_dist = word_freqs_array / word_freqs_array.sum()
noise_dist = torch.from_numpy(unigram_dist ** (0.75) / np.sum(unigram_dist ** (0.75)))

In [74]:
noise_dist

tensor([3.5665e-05, 5.2315e-05, 3.1533e-03,  ..., 1.8826e-06, 2.3360e-06,
        2.3360e-06], dtype=torch.float64)

make loss function

In [113]:
class NegativeSamplingLoss(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, input_vectors, output_vectors, noise_vectors):
        batch_size, embed_size = input_vectors.shape
        input_vectors = input_vectors.view(batch_size, embed_size, 1)
        output_vectors = output_vectors.view(batch_size, 1, embed_size)
        out_loss = torch.bmm(output_vectors, input_vectors).sigmoid().log()
        out_loss = out_loss.squeeze()
        
        noise_loss = torch.bmm(noise_vectors.neg(), input_vectors).sigmoid().log()
        noise_loss = noise_loss.squeeze().sum(1)
        return -(out_loss + noise_loss).mean()
        

training model

In [114]:
model = SkipGramNeg(len(vocab2int), EMBEDDING_DIM, noise_dist=noise_dist)
criterion = NegativeSamplingLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.003)
steps = 0
for e in range(EPOCHES):
    for input_words, target_words in get_batch(train_words):
        steps +=1
        inputs, targets = torch.LongTensor(input_words), torch.LongTensor(target_words)
        input_vectors = model.forward_input(inputs)
        output_vectors = model.forward_output(targets)
        noise_vectors = model.forward_noise(len(input_vectors))
        print(input_vectors.size(), output_vectors.size(), noise_vectors.size())
        
        loss = criterion(input_vectors, output_vectors, noise_vectors)
        if steps // PRINT_EVERY == 0:
            print("{} is current loss".format(loss))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        

torch.Size([3, 128]) torch.Size([3, 128]) torch.Size([3, 3, 128])
6.597312927246094 is current loss
torch.Size([7, 128]) torch.Size([7, 128]) torch.Size([7, 3, 128])
5.9126410484313965 is current loss
torch.Size([11, 128]) torch.Size([11, 128]) torch.Size([11, 3, 128])
6.1255412101745605 is current loss
torch.Size([15, 128]) torch.Size([15, 128]) torch.Size([15, 3, 128])
6.536654949188232 is current loss
torch.Size([17, 128]) torch.Size([17, 128]) torch.Size([17, 3, 128])
6.674886703491211 is current loss
torch.Size([2, 128]) torch.Size([2, 128]) torch.Size([2, 3, 128])
7.731934547424316 is current loss
torch.Size([6, 128]) torch.Size([6, 128]) torch.Size([6, 3, 128])
4.529255390167236 is current loss
torch.Size([10, 128]) torch.Size([10, 128]) torch.Size([10, 3, 128])
6.140489101409912 is current loss
torch.Size([14, 128]) torch.Size([14, 128]) torch.Size([14, 3, 128])
5.356362342834473 is current loss
torch.Size([16, 128]) torch.Size([16, 128]) torch.Size([16, 3, 128])
6.346285343170

KeyboardInterrupt: 