[toc]

# Word2vec SkipGram Pytorch 实现

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import time
from collections import Counter
from sklearn.metrics.pairwise import cosine_distances
import numpy as np
import os

In [2]:
negative_sample_size = 100 # 负采样的个数
window_size = 5 # 窗口宽度
embedding_size = 100
max_vocab_size = 30000

n_epochs = 1
batch_size = 128
learning_rate = 0.05

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
file_path = '/Users/ed/Downloads/text8/text8.train.txt'

with open(file_path) as f:
    text = f.read()

text = text.lower().split()

vocab_dict = dict(Counter(text).most_common(max_vocab_size - 1))
vocab_dict['UNK'] = len(text) - np.sum(list(vocab_dict.values()))

word2idx = dict(zip(vocab_dict.keys(), range(len(vocab_dict))))
idx2word = {idx: word for word, idx in word2idx.items()}

In [4]:
word_counts = np.array(list(vocab_dict.values()))
word_freqs = word_counts / np.sum(word_counts)
word_freqs = word_freqs ** (3 / 4)

In [None]:
class WordEmbeddingDataSet(Dataset):
    def __init__(self, text, word2idx, idx2word, word_freqs, word_counts, negative_sample_size=5, window_size=2):
        super(Dataset, self).__init__()
        self.text_encode = [word2idx.get(word, word2idx['UNK']) for word in text]
        self.text_encode = torch.LongTensor(self.text_encode)

        self.window_size = window_size
        self.negative_sample_size = negative_sample_size
        self.word2idx = word2idx
        self.idx2word = idx2word
        self.word_freqs = word_freqs
        self.word_counts = torch.Tensor(word_counts)

    def __len__(self):
        return len(self.text_encode)

    def __getitem__(self, idx):
        center_words = self.text_encode[idx]
        pos_indices = list(range(idx - self.window_size, idx)) + list(range(idx + 1, idx + 1 + self.window_size))
        pos_indices = list(filter(lambda x: x < len(self.text_encode), pos_indices)) # 防止下标超出边界
        pos_words = self.text_encode[pos_indices] 
        neg_words = torch.multinomial(self.word_counts, self.negative_sample_size * pos_words.shape[0],
                                      replacement=True)
        # 注意这三个都是 longTensor 类型的
        # 以词为单位输出样本
        return center_words, pos_words, neg_words 

In [None]:
dataset = WordEmbeddingDataSet(text, word2idx, idx2word, word_freqs, word_counts, negative_sample_size, window_size)

dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)

In [None]:
class SkipGram(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super(SkipGram, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size

        # 每个单词有两个 embedding
        self.u_embedding = nn.Embedding(self.vocab_size, self.embedding_size)
        self.v_embedding = nn.Embedding(self.vocab_size, self.embedding_size)

    def forward(self, input_labels, pos_labels, neg_labels):
        center_embedding = self.u_embedding(input_labels)
        pos_embedding = self.v_embedding(pos_labels)
        neg_embedding = self.v_embedding(neg_labels)

        center_embedding = center_embedding.unsqueeze(2)
        pos_dot = torch.bmm(pos_embedding, center_embedding)
        pos_dot = pos_dot.squeeze(2)

        neg_dot = torch.bmm(neg_embedding, -center_embedding)
        neg_dot = neg_dot.squeeze(2)

        log_pos = F.logsigmoid(pos_dot).sum(axis=1)
        log_neg = F.logsigmoid(neg_dot).sum(axis=1)
        loss = (log_pos + log_neg).mean()

        return -loss

    def input_embedding(self):
        return self.u_embedding.weight.detach().numpy()

In [None]:
model = SkipGram(vocab_size=max_vocab_size, embedding_size=embedding_size)
model = model.to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

In [None]:
epoch = 0
while epoch < n_epochs:
    time_start = time.time()
    for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader, start=1):
        input_labels = input_labels.to(device)
        pos_labels = pos_labels.to(device)
        neg_labels = neg_labels.to(device)

        optimizer.zero_grad()
        loss = model(input_labels, pos_labels, neg_labels)
        loss.backward()
        optimizer.step()

        if i % 500 == 0:
            print("Epoch: {}/{} Step: {} Loss: {} time_used: {}".format(epoch, n_epochs, i, loss.item(),
                                                                        time.time() - time_start))
            time_start = time.time()
            for file in os.listdir():
                if file.endswith("pth"):
                    os.remove(file)
            torch.save(model.state_dict(), "embedding-{}.pth".format(i))
            print("Model saved successfully!")
    epoch += 1
    

Epoch: 0/1 Step: 500 Loss: 152.97000122070312 time_used: 240.0837619304657
Model saved successfully!
Epoch: 0/1 Step: 1000 Loss: 181.14987182617188 time_used: 228.9447259902954
Model saved successfully!
Epoch: 0/1 Step: 1500 Loss: 152.96710205078125 time_used: 236.4238359928131
Model saved successfully!
Epoch: 0/1 Step: 2000 Loss: 105.79800415039062 time_used: 231.19964408874512
Model saved successfully!
Epoch: 0/1 Step: 2500 Loss: 163.25140380859375 time_used: 246.37053894996643
Model saved successfully!
Epoch: 0/1 Step: 3000 Loss: 98.78641510009766 time_used: 263.79660415649414
Model saved successfully!


In [None]:
model.load_state_dict(torch.load("embedding-14500.pth"))
embedding_weights = model.input_embedding()

def find_nearest(word):
    index = word2idx[word]
    embedding = embedding_weights[index]
    cos_dis = cosine_distances(embedding.reshape(1, -1), embedding_weights).squeeze()
    return [idx2word[i] for i in cos_dis.argsort()[:10]]


for word in ["two", "america", "computer"]:
    print(word, find_nearest(word))

# References
1. [PyTorch实现Word2Vec - mathor](https://wmathor.com/index.php/archives/1435/)