In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from collections import Counter
import numpy as np
import random
import math

import pandas as pd
import scipy
import sklearn
from  sklearn.metrics.pairwise import cosine_similarity

In [2]:
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
C=3  # context window
K=100 #number of negative samples
NUM_EPOCHS=2
MAX_VOCAB_SIZE=30000
BATCH_SIZE=128
LEARNING_RATE=0.2
EMSEDDING_SIZE=100

def word_tokenize(text):
    return text.split()

In [3]:
with open("text8/text8.train.txt","r") as fin:
    text=fin.read()
text=text.split()
vocab=dict(Counter(text).most_common(MAX_VOCAB_SIZE-1))
vocab["<unk>"]=len(text)-np.sum(list(vocab.values()))

In [4]:
idx_to_word=[word for word in vocab.keys()]
word_to_idx={word:i for i,word in enumerate(idx_to_word)} #按出现次数进行了排序

In [6]:
word_counts=np.array([count for count in vocab.values()],dtype=np.float32)
word_freqs=word_counts/np.sum(word_counts)
word_freqs=word_freqs**(3./4.)
word_freqs=word_counts/np.sum(word_counts)
VOCAB_SIZE=len(idx_to_word)
VOCAB_SIZE

30000

In [5]:
list(word_to_idx.items())[:100]

[('the', 0),
 ('of', 1),
 ('and', 2),
 ('one', 3),
 ('in', 4),
 ('a', 5),
 ('to', 6),
 ('zero', 7),
 ('nine', 8),
 ('two', 9),
 ('is', 10),
 ('as', 11),
 ('eight', 12),
 ('for', 13),
 ('s', 14),
 ('five', 15),
 ('three', 16),
 ('was', 17),
 ('by', 18),
 ('that', 19),
 ('four', 20),
 ('six', 21),
 ('seven', 22),
 ('with', 23),
 ('on', 24),
 ('are', 25),
 ('it', 26),
 ('from', 27),
 ('or', 28),
 ('his', 29),
 ('an', 30),
 ('be', 31),
 ('this', 32),
 ('he', 33),
 ('at', 34),
 ('which', 35),
 ('not', 36),
 ('also', 37),
 ('have', 38),
 ('were', 39),
 ('has', 40),
 ('but', 41),
 ('other', 42),
 ('their', 43),
 ('its', 44),
 ('first', 45),
 ('they', 46),
 ('had', 47),
 ('some', 48),
 ('more', 49),
 ('all', 50),
 ('can', 51),
 ('most', 52),
 ('been', 53),
 ('such', 54),
 ('who', 55),
 ('many', 56),
 ('new', 57),
 ('there', 58),
 ('used', 59),
 ('after', 60),
 ('american', 61),
 ('when', 62),
 ('time', 63),
 ('into', 64),
 ('these', 65),
 ('only', 66),
 ('see', 67),
 ('may', 68),
 ('than', 69)

In [18]:
vocab["<unk>"]

617111

In [30]:
word_to_idx

{'the': 0,
 'of': 1,
 'and': 2,
 'one': 3,
 'in': 4,
 'a': 5,
 'to': 6,
 'zero': 7,
 'nine': 8,
 'two': 9,
 'is': 10,
 'as': 11,
 'eight': 12,
 'for': 13,
 's': 14,
 'five': 15,
 'three': 16,
 'was': 17,
 'by': 18,
 'that': 19,
 'four': 20,
 'six': 21,
 'seven': 22,
 'with': 23,
 'on': 24,
 'are': 25,
 'it': 26,
 'from': 27,
 'or': 28,
 'his': 29,
 'an': 30,
 'be': 31,
 'this': 32,
 'he': 33,
 'at': 34,
 'which': 35,
 'not': 36,
 'also': 37,
 'have': 38,
 'were': 39,
 'has': 40,
 'but': 41,
 'other': 42,
 'their': 43,
 'its': 44,
 'first': 45,
 'they': 46,
 'had': 47,
 'some': 48,
 'more': 49,
 'all': 50,
 'can': 51,
 'most': 52,
 'been': 53,
 'such': 54,
 'who': 55,
 'many': 56,
 'new': 57,
 'there': 58,
 'used': 59,
 'after': 60,
 'american': 61,
 'when': 62,
 'time': 63,
 'into': 64,
 'these': 65,
 'only': 66,
 'see': 67,
 'may': 68,
 'than': 69,
 'i': 70,
 'world': 71,
 'b': 72,
 'd': 73,
 'would': 74,
 'no': 75,
 'however': 76,
 'between': 77,
 'about': 78,
 'over': 79,
 'states':

In [42]:
class WordEmbeddingDataset(tud.Dataset):
    def __init__(self,text,word_to_idx,idx_to_word,word_freqs,word_counts):
        super(WordEmbeddingDataset,self).__init__()
        self.text_encoded=[word_to_idx.get(word,word_to_idx["<unk>"]) for word in text]
        self.text_encoded=torch.LongTensor(self.text_encoded)
        self.word_to_idx=word_to_idx
        self.idx_to_word=idx_to_word
        self.word_feqs=torch.Tensor(word_freqs)
        self.word_counts=torch.Tensor(word_counts)
    def __len__(self):
        return len(self.text_encoded)
    def __getitem__(self, idx):
        center_word=self.text_encoded[idx]
        pos_indices=list(range(idx-C,idx))+list(range(idx+1,idx+C+1))  #window内单词的index
        pos_indices={i % len(self.text_encoded) for i in pos_indices}  #取余，放置超出text长度
        pos_words=self.text_encoded[pos_indices] #周围单词
        neg_words=torch.multinomial(self.word_feqs,K*pos_indices.shape[0],True)  #负例采样
        return center_word,pos_words,neg_words

In [43]:
dataset=WordEmbeddingDataset(text,word_to_idx,idx_to_word,word_freqs,word_counts)
dataLoader=tud.DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)

In [44]:
dataset.text_encoded[:100]

tensor([ 5266,  3106,    11,     5,   194,     1,  3079,    45,    59,   154,
          127,   744,   457, 10524,   135,     0, 25748,     1,     0,   100,
          866,     2,     0, 16271, 29999,     1,     0,   153,   866,  3539,
            0,   194,    10,   186,    59,     4,     5, 10679,   215,     6,
         1354,   104,   429,    19,    59,  2846,   359,     6,  3658,     0,
          725,     1,   366,    26,    40,    37,    53,   527,    97,    11,
            5,  1425,  2980,    18,   565,   712,  7095,     0,   252,  5266,
           10,  1039,    27,     0,   312,   247, 29999,  2947,   780,   179,
         5266,    11,     5,   199,   575,    10,     0,  1105,    19,  2558,
           25,  8777,     2,   275,    31,  3989,   142,    58,    25,  6525])

In [53]:
class EmbeddingModel(nn.Module):
    def __init__(self):
        super(EmbeddingModel,self).__init__()
        self.vocab_size=vocab_size
        self.embed_size=embed_size
        
        self.in_embed=nn.Embedding(self.vocab_size,self.embed_size)
        self.out_embed=nn.Embedding(self.vocab_size,self.embed_size)
    def forward(self,input_labels,pos_labels,neg_labels):
        input_embedding=self.in_embed(input_labels)
        pos_emdedding=self.in_embed(pos_labels)
        neg_emdedding=self.in_embed(input_labels)
        
        input_embedding=input_embedding.unsquuze(2)
        pos_dot=torch.bmm(pos_emdedding,input_embedding).squeeze(2)
        neg_dot=torch.bmm(neg_emdedding,-input_embedding).squeeze(2)
        
        log_pos=F.logsigmoid(pos_dot)
        log_neg = torch.bmm(neg_embedding, -input_embedding.unsqueeze(2)).squeeze() # B * (2*C*K)

        log_pos = F.logsigmoid(log_pos).sum(1)
        log_neg = F.logsigmoid(log_neg).sum(1) # batch_size
        loss = log_pos + log_neg
        
        return -loss
    def input_embeddings(self):
        return self.in_embed.weight.data.cpu().numpy()

tensor([[[-1.6091, -1.1660, -1.0748],
         [ 1.3168, -0.6818, -0.2223],
         [-0.5710,  0.0135,  0.1578],
         [-0.7735,  0.1991,  0.0457],
         [-0.5710,  0.0135,  0.1578],
         [ 1.6871,  0.2284,  0.4676],
         [ 1.3168, -0.6818, -0.2223],
         [-0.6298,  2.4070,  0.2786]]], grad_fn=<EmbeddingBackward>)