#构建训练数据、自己构建类（继承torch.utils.data）,
#搭建pytorch模型训练词向量（不太用管这是什么东西，主要是看pytorch处理的流程，怎么使用的），
#最后训练过程中保存模型，在inference中运用保存的模型进行测试。

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from torch.nn.parameter import Parameter

from collections import Counter #计数器，传入一个
import numpy as np
import pandas as pd
import random
import math
import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity
USE_CUDA=torch.cuda.is_available()


random.seed(123)
np.random.seed(123)
torch.manual_seed(123)
if USE_CUDA:
    torch.manual_seed(123)
K=100
C=3
NUM_EPOCHES=2
BATCH_SIZE=128
MAX_VOCAB_SIZE=30000
LR=0.2
EMBEDDING_SIZE=100
LOG_FILE='/home/control/Desktop/word-embedding.log'
def word_tokenize(text):
    return text.split()


In [2]:
with open('/home/control/Desktop/text8/text8.train.txt','r') as f:
    text=f.read()
text=[w for w in word_tokenize(text.lower())]
#print(word_tokenize(text.lower())[:10])
# print(text[:10])
# print(len(text))
vocab=dict(Counter(text).most_common(MAX_VOCAB_SIZE-1))#之所以减去1，是因为还要添加一个最不常出现的单词
#代码里重新定义了dict，比如 dict= {...}，这时调用的是代码里定义的dict而不是python内置类型,所以会dict不可调用的错误
#counter返回的是一个list，里面是一个元祖，放的是每一单词和其次数
vocab['<unk>']=len(text)-np.sum(list(vocab.values()))

In [3]:
idx_to_word=[word for word in vocab.keys()]#注意，vocab.keys(),vocab.values(),vocab.items()生成的都不是list
#准确的说不能直接当做list来用
word_to_idx={word:i for i,word in enumerate(idx_to_word)}
word_counts=np.array([count for count in vocab.values()],dtype=np.float32)
word_freqs=word_counts/np.sum(word_counts)
word_freqs=word_freqs**(3./4.)
word_freqs=word_freqs/np.sum(word_freqs)
VOCAB_SIZE=len(idx_to_word)
#print(VOCAB_SIZE)
# VOCAB_SIZE1=len(vocab)
print(VOCAB_SIZE)

30000


In [4]:
#自己定义一个dataset，其实官网都有教程的告诉你这个dataset应该怎么写！！
class WordEmbeddingDataset(tud.Dataset):
    def __init__(self,text,word_to_idx,idx_to_word,word_freqs,word_counts):
        super(WordEmbeddingDataset,self).__init__()
        self.text_encode=[word_to_idx.get(t,VOCAB_SIZE-1) for t in text]
        self.text_encode=torch.Tensor(self.text_encode).long()
        self.word_to_idx=word_to_idx
        self.idx_to_word=idx_to_word
        self.word_freqs=torch.Tensor(word_freqs)
        self.word_counts=torch.Tensor(word_counts)
    def __len__(self):#在定义数据集的时候，这个函数必须写，且返回整个数据集的长度
        return len(self.text_encode)
    def __getitem__(self,idx):#在定义数据集的时候,这个函数必须写，根据idx返回第idx个sample。
        center_word=self.text_encode[idx]
        pos_indices=list(range(idx-C,idx))+list(range(idx+1,idx+C+1))
        pos_indices=[i%len(self.text_encode) for i in pos_indices]#取余数才能保证索引不超出范围
        pos_words=self.text_encode[pos_indices]
        neg_words=torch.multinomial(self.word_freqs,K*pos_words.shape[0],True)
        #对input的每一行做n_samples次取值，输出的张量是每一次取值时input张量对应行的下标。
        return center_word,pos_words,neg_words
    

In [5]:
dataset=WordEmbeddingDataset(text,word_to_idx,idx_to_word,word_freqs,word_counts)
dataloader=tud.DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)

In [6]:
# from torch.autograd import Variable
# word_index={'hello':0,'world':1}
# embeds=nn.Embedding(2,5)
# #nn.Embedding:一个保存了固定字典和大小的简单查找表。这个模块常用来保存词嵌入和用下标检索它们。
# #模块的输入是一个下标的列表，输出是对应的词嵌入。
# #这是一个矩阵类，里面初始化了一个随机矩阵，矩阵的长是字典的大小，宽是用来表示字典中每个元素的属性向量，
# #向量的维度根据你想要表示的元素的复杂度而定。类实例化之后可以根据字典中元素的下标来查找元素对应的向量。
# hello_index=Variable(torch.LongTensor(word_index['hello']))#必须是LongTensor
# hello_embed=embeds(hello_index)
# print(hello_embed)

In [6]:
#定义模型(这里的这个模型，不是怎么明白，这个没关系，只看看就行)

class EmbeddingModel(nn.Module):
    def __init__(self,vocab_size,embed_size):
        super(EmbeddingModel,self).__init__()
        self.vocab_size=vocab_size
        self.embed_size=embed_size
        
        initrange=0.5/self.embed_size
        self.out_embed=nn.Embedding(self.vocab_size,self.embed_size,sparse=False)
        self.out_embed.weight.data.uniform_(-initrange,initrange)
        
        self.in_embed=nn.Embedding(self.vocab_size,self.embed_size,sparse=False)
        self.in_embed.weight.data.uniform_(-initrange,initrange)
        
    def forward(self,input_labels,pos_labels,neg_labels):#注意这一段，是原论文中的一个损失函数对应着，
    #论文：Distributed Representations of Words and Phrases and their Compositionality,论文中公式(4)
        
        batch_size=input_labels.size(0)
        
        input_embedding=self.in_embed(input_labels) #B*embed_size
        
        pos_embeding=self.out_embed(pos_labels)#B*(2*C)*embed_size
        neg_embeding=self.out_embed(neg_labels)#B*(2*C*K)*embed_size
        #rint(input_embedding.shape,pos_embeding.shape,neg_embeding.shape)
        #找错误的时候，一步步的一点点的找，不要飘忽不定，这里的input_embedding维度不对,则往上找就是in_embed有错误，
        #因为input_labels只是一个输入参数，
        log_pos=torch.bmm(pos_embeding,input_embedding.unsqueeze(2)).squeeze()#B*(2*C)
        log_neg=torch.bmm(neg_embeding,-input_embedding.unsqueeze(2)).squeeze()#B*(2*C*K)
        #块矩阵相乘，10*3*4与10*4*5相乘，得到的就是10*3*5
        
        log_pos=F.logsigmoid(log_pos).sum(1)#logsigmoid函数，先对各个元素应用sigmoid，在求log
        log_neg=F.logsigmoid(log_neg).sum(1)
        
        loss=log_pos+log_neg
        
        return -loss
    def input_embeddings(self):
        return self.in_embed.weight.data.cpu().numpy()

In [7]:
model=EmbeddingModel(VOCAB_SIZE,EMBEDDING_SIZE)
if USE_CUDA:
    model=model.cuda()

In [8]:
#模型评估函数
def evaluate(filename,embedding_weights):
    if filename.endswith('.csv'):
        data=pd.read_csv(filename,sep=',')
    else:
        data=pd.read_csv(filename,sep='\t')
    human_similarity=[]
    model_similarity=[]
    
    for i in data.iloc[:,0:2].index:
        word1,word2=data.iloc[i,0],data.iloc[i,1]
        if word1 not in word_to_index or word2 not in word_to_index:
            continue
        else:
            word1_idx,word2_idx=word_to_index[word1],word_to_index[word2]
            word1_embed,word2_embed=embedding_weights[[word1_idx]],embedding_weights[[word2_idx]]
            model_similarity.append(float(sklearn.metrics.pairwise.cosine_similarity(word1_embed,word2_embed)))
            #计算余弦相似度
            human_similarity.append(float(data.iloc[i,2]))
    return scipy.stats.spearmanr(human_similarity,model_similarity)#返回相似性
def find_nearest(word):#余弦夹角越小，表示相似性越大啊
    index=word_to_idx[word]
    embedding=embedding_weights[index]
    cos_dis=np.array([ scipy.spatial.distance.cosine(e,embedding) for e in embedding_weights])
    return [idx_to_word[i] for i in cos_dis.argsort()[:10]]

In [10]:
# a=np.array([3,4,6,2,1,8])
# a.argsort()#按照最小排序顺序输出索引啊

In [11]:
# data=[['房子',40,1000],['别墅',30,2000],['小黑屋',20,3000]]
# data=pd.DataFrame(data)#默认的行、列索引是数字0,1,2，，等
# #data=pd.DataFrame(data,index=['lc','lb','bb'],columns=['home','age','income'])#指定了行、列索引
# print(data)
# print(data.iloc[:,1])#iloc索引是位置索引
# print(data.iloc[:,1].index)
# for i in data.iloc[:,1].index:
#     print(i)
# #print(data.loc[:,'home'])#loc是name索引

In [9]:
torch.backends.cudnn.benchmark = True#加快训练速度

In [10]:
#下面进入最激动的模型训练和模型保存和测试的时候了

optimizer=torch.optim.SGD(model.parameters(),LR)
for epoch in range(NUM_EPOCHES):
    for i,(input_labels,pos_labels,neg_labels) in enumerate(dataloader):
        input_labels=input_labels.long()
        pos_labels=pos_labels.long()
        neg_labels=neg_labels.long()
        
        if USE_CUDA:
            input_labels=input_labels.cuda()
            pos_labels=pos_labels.cuda()
            neg_labels=neg_labels.cuda()
            
        optimizer.zero_grad()
        loss=model(input_labels,pos_labels,neg_labels).mean()
        loss.backward()
        optimizer.step()#记住了，这里是优化器的step，而不是loss
        
        if i % 100==0:
            with open(LOG_FILE, 'a') as f:#a表示追加写。
                f.write('epoch:{},iter:{},loss:{}\n'.format(epoch,i,loss.item()))
                print('epoch:{},iter:{},loss:{}'.format(epoch,i,loss.item()))
        if i%2000==0:
            embedding_weights=model.input_embeddings()
            sim_sim=evaluate('/home/control/Desktop/text8/text8.test.txt',embedding_weights)
            with open(LOG_FILE,'a') as f:
                f.write('epoch:{},iteration:{},sim_sim:{},nearest to monster:{}'.format(epoch,i,sim_sim,
                    find_nearest('monster')))
            print('epoch:{},iteration:{},sim_sim:{},nearest to monster:{}'.format(epoch,i,sim_sim,
                    find_nearest('monster')))
    #每一个epoch执行完了之后保存一下模型。
    embedding_weights=model.input_embeddings()
    np.save('/home/control/Desktop/text8/embedding-{}'.format(EMBEDDING_SIZE),embedding_weights)#np.save(要保存的文件，保存的对象)
    torch.save(model.state_dict(),'/home/control/Desktop/text8/embedding-{}.pth'.format(EMBEDDING_SIZE))#保存训练好的模型的参数
                

epoch:0,iter:0,loss:420.04730224609375
epoch:0,iteration:0,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'infer', 'tame', 'keeshond', 'businessmen', 'norte', 'armenia', 'famine', 'perceive', 'andrey']
epoch:0,iter:100,loss:271.48870849609375
epoch:0,iter:200,loss:196.11825561523438
epoch:0,iter:300,loss:177.7938690185547
epoch:0,iter:400,loss:168.9546661376953
epoch:0,iter:500,loss:133.968017578125
epoch:0,iter:600,loss:110.83262634277344
epoch:0,iter:700,loss:114.25141906738281
epoch:0,iter:800,loss:94.58683776855469
epoch:0,iter:900,loss:106.93927764892578
epoch:0,iter:1000,loss:107.48808288574219
epoch:0,iter:1100,loss:91.71240234375
epoch:0,iter:1200,loss:86.25588989257812
epoch:0,iter:1300,loss:80.93287658691406
epoch:0,iter:1400,loss:76.32593536376953
epoch:0,iter:1500,loss:85.531005859375
epoch:0,iter:1600,loss:71.97907257080078
epoch:0,iter:1700,loss:56.93315124511719
epoch:0,iter:1800,loss:71.38660430908203
epoch:0,iter:1900,loss:74.979606

epoch:0,iter:15800,loss:32.81930923461914
epoch:0,iter:15900,loss:32.54316711425781
epoch:0,iter:16000,loss:34.55683517456055
epoch:0,iteration:16000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'officer', 'dream', 'beautiful', 'partner', 'hills', 'approved', 'orange', 'describing', 'hindu']
epoch:0,iter:16100,loss:32.77202224731445
epoch:0,iter:16200,loss:32.04798126220703
epoch:0,iter:16300,loss:32.68056869506836
epoch:0,iter:16400,loss:32.67103576660156
epoch:0,iter:16500,loss:32.825469970703125
epoch:0,iter:16600,loss:34.178955078125
epoch:0,iter:16700,loss:32.573726654052734
epoch:0,iter:16800,loss:32.49825668334961
epoch:0,iter:16900,loss:32.475242614746094
epoch:0,iter:17000,loss:33.910316467285156
epoch:0,iter:17100,loss:32.66007995605469
epoch:0,iter:17200,loss:32.190040588378906
epoch:0,iter:17300,loss:33.121788024902344
epoch:0,iter:17400,loss:32.450740814208984
epoch:0,iter:17500,loss:33.01435852050781
epoch:0,iter:17600,loss:33.337169

epoch:0,iter:31400,loss:31.15221405029297
epoch:0,iter:31500,loss:31.00267219543457
epoch:0,iter:31600,loss:31.745643615722656
epoch:0,iter:31700,loss:31.52621841430664
epoch:0,iter:31800,loss:31.4133358001709
epoch:0,iter:31900,loss:31.513397216796875
epoch:0,iter:32000,loss:31.801189422607422
epoch:0,iteration:32000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'mouse', 'keyboard', 'window', 'shield', 'leg', 'cult', 'bird', 'organisation', 'blade']
epoch:0,iter:32100,loss:31.664464950561523
epoch:0,iter:32200,loss:31.817832946777344
epoch:0,iter:32300,loss:31.888084411621094
epoch:0,iter:32400,loss:31.80487060546875
epoch:0,iter:32500,loss:31.036502838134766
epoch:0,iter:32600,loss:31.27777099609375
epoch:0,iter:32700,loss:31.962032318115234
epoch:0,iter:32800,loss:31.977624893188477
epoch:0,iter:32900,loss:31.871780395507812
epoch:0,iter:33000,loss:31.811279296875
epoch:0,iter:33100,loss:31.86081314086914
epoch:0,iter:33200,loss:31.0777969360351

epoch:0,iter:47100,loss:31.21398162841797
epoch:0,iter:47200,loss:31.05522918701172
epoch:0,iter:47300,loss:30.954774856567383
epoch:0,iter:47400,loss:31.528831481933594
epoch:0,iter:47500,loss:31.416275024414062
epoch:0,iter:47600,loss:31.03464126586914
epoch:0,iter:47700,loss:30.772079467773438
epoch:0,iter:47800,loss:30.994869232177734
epoch:0,iter:47900,loss:31.391660690307617
epoch:0,iter:48000,loss:31.167869567871094
epoch:0,iteration:48000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'mouse', 'crystal', 'bird', 'camera', 'leg', 'window', 'cult', 'pen', 'missile']
epoch:0,iter:48100,loss:31.539546966552734
epoch:0,iter:48200,loss:31.077733993530273
epoch:0,iter:48300,loss:31.33805274963379
epoch:0,iter:48400,loss:31.264358520507812
epoch:0,iter:48500,loss:32.01964569091797
epoch:0,iter:48600,loss:31.244997024536133
epoch:0,iter:48700,loss:30.897560119628906
epoch:0,iter:48800,loss:31.130638122558594
epoch:0,iter:48900,loss:31.43749237060547


epoch:0,iter:62800,loss:31.41981315612793
epoch:0,iter:62900,loss:31.28317642211914
epoch:0,iter:63000,loss:31.491233825683594
epoch:0,iter:63100,loss:31.33405303955078
epoch:0,iter:63200,loss:30.894474029541016
epoch:0,iter:63300,loss:30.871509552001953
epoch:0,iter:63400,loss:30.716106414794922
epoch:0,iter:63500,loss:30.666919708251953
epoch:0,iter:63600,loss:30.66132354736328
epoch:0,iter:63700,loss:31.278310775756836
epoch:0,iter:63800,loss:31.039413452148438
epoch:0,iter:63900,loss:31.236129760742188
epoch:0,iter:64000,loss:30.836265563964844
epoch:0,iteration:64000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'blade', 'shield', 'mouse', 'bird', 'window', 'pen', 'plate', 'camera', 'boat']
epoch:0,iter:64100,loss:30.712352752685547
epoch:0,iter:64200,loss:30.60208511352539
epoch:0,iter:64300,loss:30.90082550048828
epoch:0,iter:64400,loss:30.496692657470703
epoch:0,iter:64500,loss:30.73851776123047
epoch:0,iter:64600,loss:31.186283111572266
ep

epoch:0,iter:78500,loss:30.404129028320312
epoch:0,iter:78600,loss:30.998384475708008
epoch:0,iter:78700,loss:30.726593017578125
epoch:0,iter:78800,loss:30.743160247802734
epoch:0,iter:78900,loss:31.036296844482422
epoch:0,iter:79000,loss:31.33098602294922
epoch:0,iter:79100,loss:30.930673599243164
epoch:0,iter:79200,loss:30.96022605895996
epoch:0,iter:79300,loss:31.013507843017578
epoch:0,iter:79400,loss:30.260482788085938
epoch:0,iter:79500,loss:30.801044464111328
epoch:0,iter:79600,loss:30.945575714111328
epoch:0,iter:79700,loss:30.883453369140625
epoch:0,iter:79800,loss:31.054367065429688
epoch:0,iter:79900,loss:30.97827911376953
epoch:0,iter:80000,loss:30.768596649169922
epoch:0,iteration:80000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'blade', 'leg', 'pen', 'crystal', 'bird', 'shield', 'mouse', 'tank', 'shell']
epoch:0,iter:80100,loss:31.17371368408203
epoch:0,iter:80200,loss:31.51957130432129
epoch:0,iter:80300,loss:31.276025772094727
ep

epoch:0,iter:94300,loss:30.762001037597656
epoch:0,iter:94400,loss:30.67934799194336
epoch:0,iter:94500,loss:30.151138305664062
epoch:0,iter:94600,loss:30.74493408203125
epoch:0,iter:94700,loss:30.583518981933594
epoch:0,iter:94800,loss:30.479812622070312
epoch:0,iter:94900,loss:30.607688903808594
epoch:0,iter:95000,loss:30.69506072998047
epoch:0,iter:95100,loss:30.94977569580078
epoch:0,iter:95200,loss:30.524341583251953
epoch:0,iter:95300,loss:30.83545684814453
epoch:0,iter:95400,loss:30.858306884765625
epoch:0,iter:95500,loss:31.113323211669922
epoch:0,iter:95600,loss:30.801067352294922
epoch:0,iter:95700,loss:30.862842559814453
epoch:0,iter:95800,loss:30.723215103149414
epoch:0,iter:95900,loss:30.887653350830078
epoch:0,iter:96000,loss:30.60649871826172
epoch:0,iteration:96000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'boat', 'blade', 'pen', 'bird', 'mouse', 'camera', 'garden', 'leg', 'mine']
epoch:0,iter:96100,loss:30.62636947631836
epoch:

epoch:0,iter:110100,loss:30.60517692565918
epoch:0,iter:110200,loss:30.412818908691406
epoch:0,iter:110300,loss:30.511381149291992
epoch:0,iter:110400,loss:30.71621322631836
epoch:0,iter:110500,loss:31.07813262939453
epoch:0,iter:110600,loss:30.651123046875
epoch:0,iter:110700,loss:30.83493423461914
epoch:0,iter:110800,loss:30.716888427734375
epoch:0,iter:110900,loss:30.854145050048828
epoch:0,iter:111000,loss:30.730955123901367
epoch:0,iter:111100,loss:30.297290802001953
epoch:0,iter:111200,loss:30.840457916259766
epoch:0,iter:111300,loss:30.90220069885254
epoch:0,iter:111400,loss:31.162641525268555
epoch:0,iter:111500,loss:31.124467849731445
epoch:0,iter:111600,loss:30.719667434692383
epoch:0,iter:111700,loss:31.009632110595703
epoch:0,iter:111800,loss:31.112876892089844
epoch:0,iter:111900,loss:30.85529136657715
epoch:0,iter:112000,loss:31.082490921020508
epoch:0,iteration:112000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'blade', 'ghost', 'h

epoch:1,iter:6300,loss:30.86039924621582
epoch:1,iter:6400,loss:30.587650299072266
epoch:1,iter:6500,loss:30.612316131591797
epoch:1,iter:6600,loss:30.020017623901367
epoch:1,iter:6700,loss:30.4842586517334
epoch:1,iter:6800,loss:31.481605529785156
epoch:1,iter:6900,loss:30.90500831604004
epoch:1,iter:7000,loss:30.082054138183594
epoch:1,iter:7100,loss:29.981239318847656
epoch:1,iter:7200,loss:30.844589233398438
epoch:1,iter:7300,loss:30.531570434570312
epoch:1,iter:7400,loss:30.465015411376953
epoch:1,iter:7500,loss:30.8810977935791
epoch:1,iter:7600,loss:30.26763153076172
epoch:1,iter:7700,loss:30.7535343170166
epoch:1,iter:7800,loss:30.792987823486328
epoch:1,iter:7900,loss:31.012805938720703
epoch:1,iter:8000,loss:30.266571044921875
epoch:1,iteration:8000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'blade', 'ghost', 'boat', 'horn', 'mine', 'shield', 'bird', 'pen', 'angel']
epoch:1,iter:8100,loss:30.576637268066406
epoch:1,iter:8200,loss:30.58

epoch:1,iter:22200,loss:30.715560913085938
epoch:1,iter:22300,loss:30.881378173828125
epoch:1,iter:22400,loss:30.926464080810547
epoch:1,iter:22500,loss:30.650985717773438
epoch:1,iter:22600,loss:30.56364631652832
epoch:1,iter:22700,loss:30.636043548583984
epoch:1,iter:22800,loss:30.75058364868164
epoch:1,iter:22900,loss:30.490116119384766
epoch:1,iter:23000,loss:30.223796844482422
epoch:1,iter:23100,loss:30.669612884521484
epoch:1,iter:23200,loss:30.271526336669922
epoch:1,iter:23300,loss:30.277957916259766
epoch:1,iter:23400,loss:30.611858367919922
epoch:1,iter:23500,loss:30.95414924621582
epoch:1,iter:23600,loss:31.140771865844727
epoch:1,iter:23700,loss:30.269983291625977
epoch:1,iter:23800,loss:30.389888763427734
epoch:1,iter:23900,loss:30.95796775817871
epoch:1,iter:24000,loss:30.411577224731445
epoch:1,iteration:24000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'boat', 'mine', 'ghost', 'blade', 'shield', 'giant', 'hammer', 'horn', 'bird']


epoch:1,iter:38100,loss:30.37805938720703
epoch:1,iter:38200,loss:30.66231346130371
epoch:1,iter:38300,loss:30.168249130249023
epoch:1,iter:38400,loss:29.730915069580078
epoch:1,iter:38500,loss:30.7469482421875
epoch:1,iter:38600,loss:30.59576416015625
epoch:1,iter:38700,loss:30.538475036621094
epoch:1,iter:38800,loss:30.95449447631836
epoch:1,iter:38900,loss:30.740726470947266
epoch:1,iter:39000,loss:30.414318084716797
epoch:1,iter:39100,loss:30.918659210205078
epoch:1,iter:39200,loss:30.540706634521484
epoch:1,iter:39300,loss:30.659164428710938
epoch:1,iter:39400,loss:30.363147735595703
epoch:1,iter:39500,loss:30.91689109802246
epoch:1,iter:39600,loss:30.279651641845703
epoch:1,iter:39700,loss:30.31270408630371
epoch:1,iter:39800,loss:30.772098541259766
epoch:1,iter:39900,loss:29.830541610717773
epoch:1,iter:40000,loss:30.459339141845703
epoch:1,iteration:40000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'ghost', 'giant', 'shield', 'blade', 'pe

epoch:1,iter:54100,loss:30.627330780029297
epoch:1,iter:54200,loss:30.488882064819336
epoch:1,iter:54300,loss:30.425880432128906
epoch:1,iter:54400,loss:30.78284454345703
epoch:1,iter:54500,loss:30.329559326171875
epoch:1,iter:54600,loss:30.860849380493164
epoch:1,iter:54700,loss:30.55350112915039
epoch:1,iter:54800,loss:30.777734756469727
epoch:1,iter:54900,loss:30.703516006469727
epoch:1,iter:55000,loss:30.330577850341797
epoch:1,iter:55100,loss:30.374889373779297
epoch:1,iter:55200,loss:29.96446418762207
epoch:1,iter:55300,loss:29.94273567199707
epoch:1,iter:55400,loss:30.2362117767334
epoch:1,iter:55500,loss:30.50101089477539
epoch:1,iter:55600,loss:30.59259796142578
epoch:1,iter:55700,loss:30.709766387939453
epoch:1,iter:55800,loss:30.34781265258789
epoch:1,iter:55900,loss:30.3350772857666
epoch:1,iter:56000,loss:30.268383026123047
epoch:1,iteration:56000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'blade', 'mine', 'giant', 'hammer', 'ghost'

epoch:1,iter:70100,loss:30.21051025390625
epoch:1,iter:70200,loss:30.78095054626465
epoch:1,iter:70300,loss:30.19822120666504
epoch:1,iter:70400,loss:30.906967163085938
epoch:1,iter:70500,loss:30.31114959716797
epoch:1,iter:70600,loss:30.097675323486328
epoch:1,iter:70700,loss:30.972497940063477
epoch:1,iter:70800,loss:30.554183959960938
epoch:1,iter:70900,loss:30.54020881652832
epoch:1,iter:71000,loss:30.452972412109375
epoch:1,iter:71100,loss:30.46377182006836
epoch:1,iter:71200,loss:30.745302200317383
epoch:1,iter:71300,loss:30.498790740966797
epoch:1,iter:71400,loss:29.945402145385742
epoch:1,iter:71500,loss:30.28778839111328
epoch:1,iter:71600,loss:30.428443908691406
epoch:1,iter:71700,loss:30.507991790771484
epoch:1,iter:71800,loss:30.50565528869629
epoch:1,iter:71900,loss:30.77753448486328
epoch:1,iter:72000,loss:30.131874084472656
epoch:1,iteration:72000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'giant', 'bird', 'hammer', 'mine', 'blade

epoch:1,iter:86100,loss:30.6229248046875
epoch:1,iter:86200,loss:30.710458755493164
epoch:1,iter:86300,loss:30.389734268188477
epoch:1,iter:86400,loss:30.98797607421875
epoch:1,iter:86500,loss:30.3839168548584
epoch:1,iter:86600,loss:30.267261505126953
epoch:1,iter:86700,loss:30.65376853942871
epoch:1,iter:86800,loss:30.566932678222656
epoch:1,iter:86900,loss:30.37026023864746
epoch:1,iter:87000,loss:30.631908416748047
epoch:1,iter:87100,loss:29.769317626953125
epoch:1,iter:87200,loss:30.201290130615234
epoch:1,iter:87300,loss:30.305492401123047
epoch:1,iter:87400,loss:30.619487762451172
epoch:1,iter:87500,loss:30.138002395629883
epoch:1,iter:87600,loss:30.10323715209961
epoch:1,iter:87700,loss:30.794397354125977
epoch:1,iter:87800,loss:30.435169219970703
epoch:1,iter:87900,loss:30.677669525146484
epoch:1,iter:88000,loss:30.243457794189453
epoch:1,iteration:88000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'giant', 'hammer', 'triangle', 'bird', '

epoch:1,iter:102100,loss:29.949745178222656
epoch:1,iter:102200,loss:30.77664566040039
epoch:1,iter:102300,loss:30.27798843383789
epoch:1,iter:102400,loss:30.170265197753906
epoch:1,iter:102500,loss:30.557680130004883
epoch:1,iter:102600,loss:30.270912170410156
epoch:1,iter:102700,loss:30.334346771240234
epoch:1,iter:102800,loss:30.660491943359375
epoch:1,iter:102900,loss:30.001434326171875
epoch:1,iter:103000,loss:30.813701629638672
epoch:1,iter:103100,loss:30.70857048034668
epoch:1,iter:103200,loss:30.299846649169922
epoch:1,iter:103300,loss:30.81991958618164
epoch:1,iter:103400,loss:30.20003890991211
epoch:1,iter:103500,loss:30.276321411132812
epoch:1,iter:103600,loss:30.40438461303711
epoch:1,iter:103700,loss:30.186279296875
epoch:1,iter:103800,loss:30.304706573486328
epoch:1,iter:103900,loss:30.768939971923828
epoch:1,iter:104000,loss:30.399606704711914
epoch:1,iteration:104000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'giant', 'hammer', '

epoch:1,iter:117800,loss:29.770294189453125
epoch:1,iter:117900,loss:30.855987548828125
epoch:1,iter:118000,loss:29.936290740966797
epoch:1,iteration:118000,sim_sim:SpearmanrResult(correlation=nan, pvalue=nan),nearest to monster:['monster', 'giant', 'hammer', 'demon', 'clown', 'robot', 'rod', 'warrior', 'melody', 'tiger']
epoch:1,iter:118100,loss:29.871692657470703
epoch:1,iter:118200,loss:30.288227081298828
epoch:1,iter:118300,loss:30.43087387084961
epoch:1,iter:118400,loss:30.696880340576172
epoch:1,iter:118500,loss:30.628774642944336
epoch:1,iter:118600,loss:29.997852325439453
epoch:1,iter:118700,loss:30.153135299682617
epoch:1,iter:118800,loss:30.581192016601562
epoch:1,iter:118900,loss:30.2862548828125
epoch:1,iter:119000,loss:30.362186431884766
epoch:1,iter:119100,loss:30.263721466064453
epoch:1,iter:119200,loss:30.002391815185547
epoch:1,iter:119300,loss:29.851043701171875
epoch:1,iter:119400,loss:30.102832794189453
epoch:1,iter:119500,loss:30.521350860595703


In [None]:
#注意：加载模型的时候，格式如下
#model.load_state_dict(torch.load(path))