In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud

from collections import Counter
import numpy as np
import random
import math

import pandas as pd
import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity
from torch import optim

In [3]:
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)

C = 3  # context window
K = 15  # number of negative samples
epochs = 2
MAX_VOCAB_SIZE = 10000
EMBEDDING_SIZE = 100
batch_size = 32
lr = 0.2

In [4]:
with open('宋_4.txt', encoding='utf-8') as f:
    text = f.read()  # 得到文本内容

In [5]:
text = text.replace('，', '')
text = text.replace('。', '')

In [6]:
text = list(text)

In [7]:
vocab_dict = dict(Counter(text).most_common(MAX_VOCAB_SIZE - 1))  # 得到单词字典表，key是单词，value是次数
vocab_dict['<UNK>'] = len(text) - np.sum(list(vocab_dict.values()))  # 把不常用的单词都编码为"<UNK>"
idx2word = [word for word in vocab_dict.keys()]
word2idx = {word: i for i, word in enumerate(idx2word)}
word_counts = np.array([count for count in vocab_dict.values()], dtype=np.float32)
word_freqs = word_counts / np.sum(word_counts)
word_freqs = word_freqs ** (3. / 4.)

In [8]:
class WordEmbeddingDataSet(tud.Dataset):
    def __init__(self, text, word2idx, idx2word, word_freqs, word_counts):
        '''
        text: 文本数据
        word2idx: 字到id的映射
        idx2word: id到字的映射
        word_freq: 单词出现的频率
        '''
        super(WordEmbeddingDataSet, self).__init__()
        self.text_encoded = [word2idx.get(word, word2idx['<UNK>']) for word in text]
        self.text_encoded = torch.LongTensor(self.text_encoded)
        self.word2idx = word2idx
        self.idx2word = idx2word
        self.word_freqs = torch.Tensor(word_freqs)
        self.word_counts = torch.Tensor(word_counts)

    def __len__(self):
        return len(self.text_encoded)

    def __getitem__(self, idx):
        center_word = self.text_encoded[idx]
        pos_indices = list(range(idx - C, idx)) + list(range(idx, idx + C))
        pos_indices = [i % len(self.text_encoded) for i in pos_indices]
        pos_words = self.text_encoded[pos_indices]

        neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], True)
        return center_word, pos_words, neg_words

In [9]:
dataset = WordEmbeddingDataSet(text, word2idx, idx2word, word_freqs, word_counts)
dataloader = tud.DataLoader(dataset, batch_size, shuffle=True)

In [None]:
class EmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(EmbeddingModel, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.in_embed = nn.Embedding(self.vocab_size, self.embed_size)
        self.out_embed = nn.Embedding(self.vocab_size, self.embed_size)

    def forward(self, input_labels, pos_labels, neg_labels):
        input_embedding = self.in_embed(input_labels)
        pos_embedding = self.in_embed(pos_labels)
        neg_embedding = self.in_embed(neg_labels)

        input_embedding = input_embedding.unsqueeze(2)
        pos_dot = torch.bmm(pos_embedding, input_embedding)
        pos_dot = pos_dot.squeeze(2)

        neg_dot = torch.bmm(neg_embedding, -input_embedding)
        neg_dot = neg_dot.squeeze(2)

        log_pos = F.logsigmoid(pos_dot).sum(1)
        log_neg = F.logsigmoid(neg_dot).sum(1)

        loss = log_pos + log_neg
        return -loss

In [None]:
def train():
    criterion = nn.CrossEntropyLoss()
    model = EmbeddingModel(1000, 100)
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    for e in range(1):
        for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
            input_labels = input_labels.long()
            pos_labels = pos_labels.long()
            neg_labels = neg_labels.long()

            optimizer.zero_grad()
            loss = model(input_labels, pos_labels, neg_labels)
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                print('epoch', e, 'iteration', i, loss.item())
    embedding_weights = model.input_embeddings()
    torch.save(model.state_dict(), "embedding-{}.th".format(EMBEDDING_SIZE))