In [None]:
from collections import defaultdict
import torch
from torch import nn
import random
import numpy as np
import torch.nn.functional as F

class WordEmbSkip(nn.Module):
    def __init__(self,nwords,emb_size):
        super(WordEmbSkip,self).__init__()
        self.word_emb=nn.Embedding(nwords,emb_size,sparse=True)
        nn.init.xavier_uniform_(self.word_emb.weight)
        self.context_emb=nn.Embedding(nwords,emb_size,sparse=True)
        nn.init.xavier_uniform_(self.context_emb.weight)
    
    def forward(self,words_pos,context_positions,negative_sample=False):
        word_emb=self.word_emb(words_pos)
        context_emb=self.context_emb(context_positions)
        score=torch.matmul(word_emb,context_emb.transpose(dim0=1,dim1=0))
        if negative_sample:#负样本就是父分数
            score=-1*score
        obj=-1*torch.sum(F.logsigmoid(score))
        return obj

K=3#负样本的样本数
N=2#窗口大小
EMB_SIZE=128
embeddings_location="embeddings.txt"
labels_location="labels.txt"
w2i=defaultdict(lambda:len(w2i))
word_counts=defaultdict(int)#记录每个词出现的次数
S=w2i["<s>"]
UNK=w2i["<unk>"]

def read_dataset(filename):
    with open(filename,"r") as f:
        for line in f:
            line=line.strip().split(" ")
            for word in line:
                word_counts[w2i[word]]+= 1
            yield[w2i[x] for x in line]
train=list(read_dataset("Demo/DataSets/train.txt"))
w2i=defaultdict(lambda :UNK,w2i)
dev=list(read_dataset("Demo/DataSets/valid.txt"))
i2w={v:k for k,v in w2i.items()}
nwords=len(w2i)

#归一化
counts=np.array([list(x) for x in word_counts.items()])[:,1]**.75
normalizing_constant=sum(counts)
word_probabilities=np.zeros(nwords)
for word_id in word_counts:
    word_probabilities[word_id]=word_counts[word_id]**.75/normalizing_constant

with open(labels_location,'w') as f:
    for i in range(nwords):
        f.write(i2w[i]+"\n")

model =WordEmbSkip(nwords,EMB_SIZE)
optimizer=torch.optim.SGD(model.parameters(),lr=0.1)
def sent_loss(sent):
    all_neg_words = np.random.choice(nwords, size=2*N*K*len(sent), replace=True, p=word_probabilities)
    losses=[]
    for i,word in enumerate(sent):
        pos_words = [sent[x] if x >= 0 else S for x in range(i-N,i)] + [sent[x] if x < len(sent) else S for x in range(i+1,i+N+1)]
        pos_words_tensor = torch.tensor(pos_words)
        neg_words = all_neg_words[i*K*2*N:(i+1)*K*2*N]
        neg_words_tensor = torch.tensor(neg_words)
        target_word_tensor = torch.tensor([word])
        pos_loss = model(target_word_tensor, pos_words_tensor)
        neg_loss = model(target_word_tensor, neg_words_tensor, negative_sample=True)

        losses.append(pos_loss + neg_loss)

    return torch.stack(losses).sum()
for epoch in range(1):
    random.shuffle(train)
    train_words,train_loss=0,0.0
    model.train()
    for sent_id,sent in enumerate(train):
        my_loss=sent_loss(sent)
        optimizer.zero_grad()
        my_loss.backward()
        optimizer.step()
        train_loss+=my_loss.item()
        train_words+=len(sent)
        if(sent_id%2000==0):
            print("epoch=%r,allloss=%.4f,loss=%.4f"%(epoch,train_loss,train_loss/train_words))
    print("epoch=%r,allloss=%.4f,loss=%.4f"%(epoch,train_loss,train_loss/train_words))
    model.eval()
    dev_words,dev_loss=0,0.0
    for sent_id,sent in enumerate(dev):
        my_loss=sent_loss(sent)
        dev_loss+=my_loss.item()
        dev_words+=len(sent)
        if(sent_id%2000==0):
            print("epoch=%r,allloss=%.4f,loss=%.4f"%(epoch,dev_loss,dev_loss/dev_words))
    if(sent_id%2000==0):
            print("epoch=%r,allloss=%.4f,loss=%.4f"%(epoch,dev_loss,dev_loss/dev_words))
            
with open(embeddings_location,"w") as f:
    W_np=model.word_emb.weight.data.numpy()
    for i in range(nwords):
        ith_embeddings='\t'.join(map(str,W_np[i]))
        f.write(ith_embeddings+'\n')


epoch=0,allloss=166.3806,loss=11.0920
epoch=0,allloss=402754.0525,loss=9.6811
epoch=0,allloss=765802.5440,loss=9.1182
epoch=0,allloss=1114675.9577,loss=8.8611
epoch=0,allloss=1463813.5307,loss=8.6827
epoch=0,allloss=1807101.8429,loss=8.5718
epoch=0,allloss=2144485.2823,loss=8.4792
epoch=0,allloss=2480586.0877,loss=8.4105
epoch=0,allloss=2816309.2116,loss=8.3553
epoch=0,allloss=3158292.2935,loss=8.3367
epoch=0,allloss=3500108.7700,loss=8.3041
epoch=0,allloss=3838008.4798,loss=8.2757
epoch=0,allloss=4175995.7636,loss=8.2553
epoch=0,allloss=4516629.6000,loss=8.2361
epoch=0,allloss=4850940.0141,loss=8.2210
epoch=0,allloss=5190263.7328,loss=8.2089
epoch=0,allloss=5529253.1796,loss=8.2002
epoch=0,allloss=5870579.7614,loss=8.1937
epoch=0,allloss=6217268.5927,loss=8.1878
epoch=0,allloss=6565514.4718,loss=8.1890
epoch=0,allloss=6905893.7228,loss=8.1850
epoch=0,allloss=7248909.3811,loss=8.1806
epoch=0,allloss=7260258.0207,loss=8.1804
epoch=0,allloss=121.2186,loss=8.6585
epoch=0,allloss=343847.71