## 1.导包

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

from collections import Counter
import numpy as np
import random

import math
import scipy
from sklearn.metrics.pairwise import cosine_similarity

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# 文本窗口大小
C = 5

# 负样本的数目
K = 15

MAX_VOCAB_SIZE = 10000
EMBEDDING_SIZE = 100

## 2.读取文本数据并处理

In [8]:
with open('./data/text8.train.txt') as f:
    text = f.read()

# 分割文本为单词列表
text = text.lower().split()
# 得到单词字典表，key是单词、value是出现的次数
vocab_dict = dict(Counter(text).most_common(MAX_VOCAB_SIZE - 1))
# 将不常用单词都编码为"<UNK>"
vocab_dict['<UNK>'] = len(text) - np.sum(list(vocab_dict.values()))

word2idx = {word: idx for idx, word in enumerate(vocab_dict.keys())}
idx2word = {idx: word for idx, word in enumerate(vocab_dict.keys())}

word_counts = np.array([count for count in vocab_dict.values()], dtype=np.float32)
word_freqs = word_counts / np.sum(word_counts)
# 论文里将单词的频率变为原来的0.75次方
word_freqs = word_freqs ** (3. / 4.)

## 3. 构建数据集

In [9]:
class SkipGramDataset(Dataset):
    def __init__(self, text, word2idx, word_freqs):
        """
        :param text: 单词列表
        :param word2idx: 从单词到索引的字典
        :param word_freqs: 每个单词出现的频率
        """
        super(SkipGramDataset, self).__init__()
        # 将单词数字化表示，若不在字典中，表示为unk的数字化结果
        self.text_encoded = [word2idx.get(word, word2idx['<UNK>']) for word in text]
        self.text_encoded = torch.LongTensor(self.text_encoded)
        self.word2idx = word2idx
        self.word_freqs = torch.Tensor(word_freqs)

    def __len__(self):
        return len(self.text_encoded)

    def __getitem__(self, item):
        """
        :param item: 索引
        :return:
            - 中心词
            - 中心词附近2C个postive word
            - 随机采样的K个单词作为negative word
        """
        # 获得中心词
        center_word = self.text_encoded[item]
        # 先取得中心词左右各C个词的索引
        pos_indices = list(range(item - C, item)) + list(range(item + 1, item + C + 1))
        # 为了避免索引越界，所以进行取余处理
        pos_indices = [i % len(self.text_encoded) for i in pos_indices]
        pos_words = self.text_encoded[pos_indices]
        # 每采样一个正确的单词(positive word)，就采样K个错误的单词(negative word)
        # 取样方式采用有放回的采样，并且self.word_freqs数值越大，取样概率越大
        neg_words = torch.multinomial(self.word_freqs, num_samples=K * len(pos_words), replacement=True)
        # 保证 neg_words中不能包含positive word
        while len(set(pos_indices) & set(neg_words.numpy().tolist())) > 0:
            neg_words = torch.multinomial(self.word_freqs, K * len(pos_words), True)

        return center_word, pos_words, neg_words


dataset = SkipGramDataset(
    text=text,
    word2idx=word2idx,
    word_freqs=word_freqs
)
dataset[0]

(tensor(4813),
 tensor([  10,  419,   50, 9999,  393, 3139,   11,    5,  194,    1]),
 tensor([ 338, 1930,  608, 1466,   34,  126, 2807, 6407, 3822,   74,   13, 4649,
         9999,  425, 1385,  285,  236, 6396, 2119, 5071, 1651,   41,  274,  878,
          220,  317,   18,  177,  626,  703, 2405, 3659, 5951,   33, 5391,  139,
         1622,   18,   10, 4686,  137,   60,  664,   13, 1483, 3363,  673,   29,
         9999,   19, 4841, 9999, 4520, 9999, 2588, 6291,  122, 5034,  650,    9,
         1581, 9999,  427,  614,   77, 5074, 2753,   12,   91, 2980, 5395,  569,
         2098, 1172,  193,    0, 9603, 3606, 5405, 3304, 5036, 4687,  278, 8863,
          275, 7728, 1361,   40, 1784, 3220, 6374, 9999, 1791, 9999,  390, 4027,
         4840,   45, 9935, 1312, 4440, 5552, 1819,  133, 2358,  733, 1047, 3110,
         1822, 1553, 1211,   54,    0, 4867, 3098,  705, 1100, 6854, 5142, 9999,
         5419, 4450,  747, 3879, 9470,  254,  334, 6300, 8656,   44, 3390,  284,
          678, 9999,   

## 4. 构建Skip-Gram模型

In [20]:
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(SkipGramModel, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.in_embed = nn.Embedding(self.vocab_size, self.embed_size)
        self.out_embed = nn.Embedding(self.vocab_size, self.embed_size)

    def forward(self, input_labels, pos_labels, neg_labels):
        """
        :param input_labels: center words, [batch_size]
        :param pos_labels: positive words, [batch_size, (window_size * 2)]
        :param neg_labels: negative words, [batch_size, (window_size * 2 * K)]
        :return: loss, [batch_size]
        """
        input_embedding = self.in_embed(input_labels)
        pos_embedding = self.out_embed(pos_labels)
        neg_embedding = self.out_embed(neg_labels)

        # [batch_size, embed_size] -> [batch_size, embed_size, 1]
        input_embedding = input_embedding.unsqueeze(2)

        pos_dot = torch.bmm(pos_embedding, input_embedding)
        pos_dot = pos_dot.squeeze(2)

        neg_dot = torch.bmm(neg_embedding, -input_embedding)
        neg_dot = neg_dot.squeeze(2)

        log_pos = F.logsigmoid(pos_dot).sum(1)
        log_neg = F.logsigmoid(neg_dot).sum(1)

        loss = log_pos + log_neg
        return -loss

    def input_embedding(self):
        return self.in_embed.weight.detach().numpy()

model = SkipGramModel(
    vocab_size=MAX_VOCAB_SIZE,
    embed_size=EMBEDDING_SIZE
)

## 5. 模型训练

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = model.to(device)

epochs = 1
batch_size = 1024

dataloader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True
)

optimizer = optim.Adam(params=model.parameters(), lr=1e-4)
lf = lambda x: ((1 + math.cos(x * math.pi / epochs))/ 2) * (1 - 1e-4) + 1e-4
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

for epoch in range(epochs):
    for idx, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
        input_labels = input_labels.long().to(device)
        pos_labels = pos_labels.long().to(device)
        neg_labels = neg_labels.long().to(device)

        optimizer.zero_grad()
        loss = model(input_labels, pos_labels, neg_labels).mean()
        loss.backward()

        optimizer.step()

        if (idx+1) % 500 == 0:
            print('Epoch {}, iteration {}, loss: {}'.format(epoch+1, idx+1, loss.item()))
    scheduler.step()

torch.save(model.state_dict(), 'embedding-{}.pth'.format(EMBEDDING_SIZE))

Epoch 1, iteration 500, loss: 32.6262092590332
Epoch 1, iteration 1000, loss: 32.87359619140625
Epoch 1, iteration 1500, loss: 32.77399444580078
Epoch 1, iteration 2000, loss: 32.819480895996094
Epoch 1, iteration 2500, loss: 32.727874755859375
Epoch 1, iteration 3000, loss: 32.81598663330078
Epoch 1, iteration 3500, loss: 32.71991729736328
Epoch 1, iteration 4000, loss: 32.64341735839844
Epoch 1, iteration 4500, loss: 32.81446838378906
Epoch 1, iteration 5000, loss: 32.635986328125
Epoch 1, iteration 5500, loss: 32.823448181152344
Epoch 1, iteration 6000, loss: 32.686309814453125
Epoch 1, iteration 6500, loss: 32.77131271362305
Epoch 1, iteration 7000, loss: 32.748016357421875
Epoch 1, iteration 7500, loss: 33.0582275390625
Epoch 1, iteration 8000, loss: 32.71543502807617
Epoch 1, iteration 8500, loss: 32.72041320800781
Epoch 1, iteration 9000, loss: 32.8436279296875
Epoch 1, iteration 9500, loss: 32.858367919921875
Epoch 1, iteration 10000, loss: 32.91303253173828
Epoch 1, iteration 

## 6.词向量应用

In [22]:
model = SkipGramModel(
    vocab_size=MAX_VOCAB_SIZE,
    embed_size=EMBEDDING_SIZE
)

model.load_state_dict(torch.load('embedding-100.pth'))

embedding_weights = model.input_embedding()

def find_nearest(word):
    """找出与某个词相近的一些词"""
    index = word2idx[word]
    embedding = embedding_weights[index]
    cos_dis = np.array([scipy.spatial.distance.cosine(e, embedding) for e in embedding_weights])
    return [idx2word[i] for i in cos_dis.argsort()[:10]]

for word in ["two", "america", "computer"]:
    print(word, find_nearest(word))

two ['two', 'three', 'four', 'five', 'six', 'one', 'zero', 'seven', 'eight', 'nine']
america ['america', 'europe', 'americas', 'africa', 'caribbean', 'australia', 'atlantic', 'united', 'pacific', 'north']
computer ['computer', 'computers', 'hardware', 'graphics', 'video', 'computing', 'software', 'computation', 'console', 'digital']
