In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
import torch.multiprocessing as mp

In [2]:
class SkipGram(nn.Module):
    def __init__(self, vocab_size, word_dim, use_gpu=False):
        super(SkipGram, self).__init__()
        self.word_dim = word_dim
        self.u_embeddings = nn.Embedding(vocab_size, word_dim, sparse=True)
        self.v_embeddings = nn.Embedding(vocab_size, word_dim, sparse=True)
        self.init_emb()

    def init_emb(self):
        initrange = 0.5 / self.word_dim
        self.u_embeddings.weight.data.uniform_(-initrange, initrange)
        self.v_embeddings.weight.data.zero_()

    def forward(self, pos_u, pos_v, neg_v):
        emb_u = self.u_embeddings(pos_u)
        emb_v = self.v_embeddings(pos_v)
        score = torch.mul(emb_u, emb_v).squeeze()
        score = torch.sum(score, dim=1)
        score = F.logsigmoid(score)
        neg_emb_v = self.v_embeddings(neg_v)
        neg_score = torch.bmm(neg_emb_v, emb_u.unsqueeze(2)).squeeze()
        neg_score = F.logsigmoid(-1 * neg_score)
        return -1 * (torch.sum(score)+torch.sum(neg_score))

    def save_embedding(self, word2id, file_name, use_gpu=False):
        if use_gpu:
            embedding = self.u_embeddings.weight.cpu().data.numpy()
        else:
            embedding = self.u_embeddings.weight.data.numpy()
            
        fout = open(file_name, 'w')
        fout.write('%d %d\n' % (len(word2id), self.word_dim))
        for word, word_id in word2id.items():
            e = embedding[word_id]
            e = ' '.join(map(lambda x: str(x), e))
            fout.write('%s %s\n' % (word, e))


In [57]:
class PreProcessText():
    def __init__(self, file_path, min_count):
        from collections import deque
        self.file_path = file_path
        self.sentence_length = 0
        self.build_vocab(min_count)
        self.word_pair_catch = deque()
        self.init_neg_sample_table()
        print('Vocab size: %d' % self.vocab_size)
        print('Sentence Length: %d' % self.sentence_length)

    def get_clean_word(self, file, init=0):
        """
        sentence_count
        sentence_length
        """
        import re
        lines = file.readlines()
        lines = [line for line in lines if len(line) > 1]
        if init:
            self.sentence_count = len(lines)
        r = re.compile("[!-/:-@[-`{-~]")
        for line in map(lambda x: re.sub(r, '', x.lower().strip()).split(),
                        lines):
            if init:
                self.sentence_length += len(line)
            for word in line:
                yield word

    def build_vocab(self, min_count):
        """
        word2id
        word_list
        word_count
        """
        from collections import Counter
        vocab = Counter()
        for word in self.get_clean_word(open(self.file_path), init=1):
            vocab[word] += 1

        self.freq = {k: v for k, v in vocab.items() if v >= min_count}
        self.word_count = sum(self.freq.values())
        word_list = sorted(self.freq, key=self.freq.get, reverse=True)
        self.word2id = {w: i for i, w in enumerate(word_list)}
        self.vocab_size = len(self.word2id)

    def init_neg_sample_table(self):
        self.neg_sample_table = []
        neg_sample_table_size = 1e8
        pow_frequency = np.array(list(self.freq.values()))**0.75
        words_pow = sum(pow_frequency)
        ratio = pow_frequency / words_pow
        count = np.round(ratio * neg_sample_table_size)
        for idx, c in enumerate(count):
            self.neg_sample_table += [idx] * int(c)
        self.neg_sample_table = np.array(self.neg_sample_table)

    def get_batch_pairs(self, batch_size, window_size):
        while len(self.word_pair_catch) < batch_size:
            word_ids = []
            for word in self.get_clean_word(open(self.file_path)):
                try:
                    word_ids.append(self.word2id[word])
                except:
                    continue

            for i, u in enumerate(word_ids):
                for j, v in enumerate(
                        word_ids[max(i - window_size, 0):i + window_size]):
                    assert u < self.vocab_size
                    assert v < self.vocab_size
                    if i == j:
                        continue
                    self.word_pair_catch.append((u, v))
        batch_pairs = []
        for _ in range(batch_size):
            batch_pairs.append(self.word_pair_catch.popleft())
        return batch_pairs

    # @profile
    def get_neg_v(self, batch_size, negative_sample_size):
        neg_v = np.random.choice(
            self.neg_sample_table, size=(batch_size,
                                         negative_sample_size)).tolist()
        return neg_v

    def evaluate_pair_count(self, window_size):
        return self.sentence_length * (2 * window_size - 1) - (
            self.sentence_count - 1) * (1 + window_size) * window_size

In [58]:
data = PreProcessText('./data/alice.txt', 10)
data

Vocab size: 410
Sentence Length: 29384


<__main__.PreProcessText at 0x7ffb50acab70>

In [95]:
emb_dim = 128
batch_size = 50
window_size = 5
negative_sample_size = 5
iteration = 5
initial_lr = 0.0025
emb_size = len(data.word2id)

model = SkipGram(emb_size, emb_dim)
optimizer = optim.SparseAdam(model.parameters(), lr=initial_lr)

pair_count = data.evaluate_pair_count(window_size)
batch_count = iteration * pair_count / batch_size

In [96]:
from tqdm import tnrange

In [97]:
process_bar = tnrange(int(batch_count))
for i in process_bar:
    pos_pairs = data.get_batch_pairs(batch_size, window_size)
    neg_v = data.get_neg_v(batch_size, negative_sample_size)
    pos_u = [pair[0] for pair in pos_pairs]
    pos_v = [pair[1] for pair in pos_pairs]

    pos_u = Variable(torch.LongTensor(pos_u))
    pos_v = Variable(torch.LongTensor(pos_v))
    neg_v = Variable(torch.LongTensor(neg_v))

    optimizer.zero_grad()
    loss = model.forward(pos_u, pos_v, neg_v)
    loss.backward()
    optimizer.step()

#     process_bar.set_description("Loss: %0.8f, lr: %0.6f" %
#                                 (loss.data[0],
#                                  optimizer.param_groups[0]['lr']))
    if i * batch_size % 10000 == 0:
        process_bar.set_description("Loss: %0.8f, lr: %0.6f" %
                            (loss.data[0],
                             optimizer.param_groups[0]['lr']))
        lr = initial_lr * (1.0 - 1.0 * i / batch_count)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr





In [98]:
model.save_embedding(data.word2id, './data/skipgram.w')

In [99]:
from gensim.models import KeyedVectors

In [100]:
word_vectors = KeyedVectors.load_word2vec_format('./data/skipgram.w', binary=False)
word_vectors.most_similar(positive=['queen'])

[('dormouse', 0.8835446834564209),
 ('his', 0.8777092099189758),
 ('king', 0.8710049986839294),
 ('who', 0.861481249332428),
 ('turtle', 0.8438349366188049),
 ('hatter', 0.8383009433746338),
 ('gryphon', 0.8233548402786255),
 ('gloves', 0.8195096254348755),
 ('mock', 0.8153071999549866),
 ('white', 0.8116865158081055)]