In [1]:
import torch
import torch.nn as nn
import torch.utils.data as Data
import json
import collections
import re
import time
import random
import math
import sys
import numpy as np
import math
 
sys.setrecursionlimit(5000) 

device = 'cuda' if torch.cuda.is_available() else'cpu'

In [2]:
class My_data(Data.Dataset):
    def __init__(self,max_len=50,min_en_count=0,min_cn_count=0):
        self.max_len = max_len
        self.min_en_count = min_en_count
        self.min_cn_count = min_cn_count
        self.counter = None
        self.cn_itos = ['<SOS>','<EOS>','<UNK>']
        self.cn_stoi = {'<SOS>':0,'<EOS>':1,'<UNK>':2}
        self.cn_data = []
        self.en_itos = ['<SOS>','<EOS>','<UNK>']
        self.en_stoi = {'<SOS>':0,'<EOS>':1,'<UNK>':2}
        self.en_data = []
        self.num_cn_vocab = 3
        self.num_en_vocab = 3
        self.len = 0
        
    def get_raw_data_cn_en(self,file,js=False, divide=1, choose=1):
        all_cn = []
        all_en = []
        if js == False:
            with open(file) as f:
                lines = f.readlines()
                data_len = len(lines)
                k = int(data_len//divide)
                start = max((choose-1)*k,0)
                end = min(choose*k,data_len)
                for line in lines[start:end]:
                    cn,en = line.strip().split('\t')
                    en = en.lower()
                    en = re.sub(r"([,'.!?])", r" \1", en)
                    en = re.sub(r"[^a-zA-Z\u4e00-\u9fa5.,'!?]+", r" ", en)
#                     cn = re.sub(r"([.!?])", r" \1", cn)
                    cn = re.sub(r"[^a-zA-Z\u4e00-\u9fa5.,'‘’“”!?，。？！]+", r" ", cn)
                    en=en.split()
                    all_cn.append(cn)
                    all_en.append(en)
        if js == True:
            with open(file) as f:
                lines = f.readlines()
                data_len = len(lines)
                k = int(data_len//divide)
                start = max((choose-1)*k,0)
                end = min(choose*k,data_len)
                for line in lines[start:end]:
                    st = json.loads(line)
                    cn = st['chinese']
                    en = st['english'].lower()
                    en = re.sub(r"([,'.!?])", r" \1", en)
                    en = re.sub(r"[^a-zA-Z\u4e00-\u9fa5.,'!?]+", r" ", en)
#                     cn = re.sub(r"([.!?])", r" \1", cn)
                    cn = re.sub(r"[^a-zA-Z\u4e00-\u9fa5.,'‘’“”!?，。？！]+", r" ", cn)
                    en=en.split()
                    all_cn.append(cn)
                    all_en.append(en)
        return all_cn,all_en
    
    def get_cn_en_stoi_itos(self,raw_cn,raw_en):
        cnCounter = collections.Counter([tk for line in raw_cn for tk in line])
        enCounter = collections.Counter([tk for line in raw_en for tk in line])
        cnCounter = dict(filter(lambda x: x[1]>=self.min_cn_count,cnCounter.most_common()))
        enCounter = dict(filter(lambda x: x[1]>=self.min_en_count,enCounter.most_common()))
        
        for tk,_ in cnCounter.items():
            self.cn_itos.append(tk)
        for tk,_ in enCounter.items():
            self.en_itos.append(tk)
        
        self.cn_stoi = {tk:idx for idx,tk in enumerate(self.cn_itos)}
        self.en_stoi = {tk:idx for idx,tk in enumerate(self.en_itos)}
        
        self.num_cn_vocab = len(self.cn_stoi)
        self.num_en_vocab = len(self.en_stoi)
        
    def get_from_vocab(self,cn_file,en_file):
        self.cn_itos = []
        self.en_itos = []
        
        with open(cn_file) as f:
            for line in f.readlines():
                self.cn_itos.append(line.replace('\n',''))
        with open(en_file) as f:
            for line in f.readlines():
                self.en_itos.append(line.replace('\n',''))
        
        self.cn_stoi = {tk: idx for idx,tk in enumerate(self.cn_itos)}
        self.en_stoi = {tk: idx for idx,tk in enumerate(self.en_itos)}
        
        self.num_cn_vocab = len(self.cn_stoi)
        self.num_en_vocab = len(self.en_stoi)
                
        
    def get_data(self,raw_cn,raw_en):
        self.cn_data = []
        self.en_data = []
        k = [0]*len(raw_cn)
        for idx,line in enumerate(raw_cn):
            if len(line) > self.max_len:
                k[idx] = 1
                continue
            temp = []
            temp.append(0)
            for tk in line:
                if tk not in self.cn_itos:
                    tk = '<UNK>'
                temp.append(self.cn_stoi[tk])
            temp.append(1)
            self.cn_data.append(temp)
            
        for idx,line in enumerate(raw_en):
            if k[idx] == 1:
                continue
            temp = []
            temp.append(0)
            for tk in line:
                if tk not in self.en_itos:
                    tk = '<UNK>'
                temp.append(self.en_stoi[tk])
            temp.append(1)
            self.en_data.append(temp)
        self.len = len(self.cn_data)
        
    def append_data(self,raw_cn,raw_en):
        k = [0]*len(raw_cn)
        for idx,line in enumerate(raw_cn):
            if len(line) > self.max_len:
                k[idx] = 1
                continue
            temp = []
            temp.append(0)
            for tk in line:
                if tk not in self.cn_itos:
                    tk = '<UNK>'
                temp.append(self.cn_stoi[tk])
            temp.append(1)
            self.cn_data.append(temp)
            
        for idx,line in enumerate(raw_en):
            if k[idx] == 1:
                continue
            temp = []
            temp.append(0)
            for tk in line:
                if tk not in self.en_itos:
                    tk = '<UNK>'
                temp.append(self.en_stoi[tk])
            temp.append(1)
            self.en_data.append(temp)
        self.len = len(self.cn_data)
        
    def do_all(self,file,js=False):
        cn,en = self.get_raw_data_cn_en(file=file,js=js)
        self.get_cn_en_stoi_itos(cn,en)
        self.get_data(cn,en)
        
    
    def get_data_pair(self,idx):
        return self.cn_data[idx],self.en_data[idx]
    
    def __getitem__(self,idx):
        return self.get_data_pair(idx)
    
    def __len__(self):
        return self.len
        

In [3]:
class Encoder_2(nn.Module):
    def __init__(self,num_vocab,d_model=512,nhead=2,num_layers=2,dropout=0.1):
        super(Encoder_2,self).__init__()
        self.embedding = nn.Embedding(num_vocab,d_model)
        self.MH = nn.TransformerEncoderLayer(d_model=d_model,nhead=nhead,dropout=dropout)
        self.all_MH = nn.TransformerEncoder(self.MH,num_layers)
        self.position_embedding = self.get_position_embedding(d_model,100).to(device)
                                       
    def forward(self,myinput,mask=None):
        X = self.embedding(myinput)
        X += self.position_embedding[:X.shape[1]]
        X = X.permute(1,0,2)
        #X shaep(L, Batch_size, d_model)
        out = self.all_MH(X,src_key_padding_mask=mask)
        return out
    
    def get_position_embedding(self,d_model,max_len):
        table = torch.empty(max_len,d_model)
        for position in range(max_len):
            for i in range(d_model):
                table[position,i] = position/10000**(i/d_model)
        table[:,0::2] = torch.sin(table[:,0::2])
        table[:,1::2] = torch.sin(table[:,1::2])
        return table.float()
        

In [4]:
class Decoder_2(nn.Module):
    def __init__(self,num_vocab,d_model=512,nhead=2,num_layers=2,dropout=0.1):
        super(Decoder_2,self).__init__()
        self.num_vocab = num_vocab
        self.embedding = nn.Embedding(num_vocab, d_model)
        self.MH = nn.TransformerDecoderLayer(d_model=d_model,nhead=nhead,dropout=dropout)
        self.all_MH = nn.TransformerDecoder(self.MH,num_layers)
        self.dense = nn.Linear(d_model,num_vocab)
        self.position_embedding = self.get_position_embedding(d_model,100).to(device)
        
    def forward(self,target,memory,t_mask=None,m_mask=None,tgt_mask=None):
        X = self.embedding(target)
        X += self.position_embedding[:X.shape[1]]
        X = X.permute(1,0,2)
        Y = self.all_MH(X,memory,tgt_key_padding_mask=t_mask,memory_key_padding_mask=m_mask,tgt_mask=tgt_mask)
        out = self.dense(Y.permute(1,0,2))
        return out
    
    def get_position_embedding(self,d_model,max_len):
        table = torch.empty(max_len,d_model)
        for position in range(max_len):
            for i in range(d_model):
                table[position,i] = position/10000**(i/d_model)
        table[:,0::2] = torch.sin(table[:,0::2])
        table[:,1::2] = torch.sin(table[:,1::2])
        return table.float()

In [5]:
def get_decoder_mask(L):
    return torch.from_numpy(np.triu(np.ones(L),k=1)).bool()

In [6]:
trainfile = 'translation2019zh/translation2019zh_train.json'
trainfile2 = 'cn-eng.txt'


trainset = My_data(min_en_count=8,min_cn_count=0)
# 3333trainset.do_all(trainfile, js=True)
# raw_cn, raw_en = trainset.get_raw_data_cn_en(trainfile,js=True,divide=50,choose=1)
# raw_cn2, raw_en2 = trainset.get_raw_data_cn_en(trainfile2,js=False)
# raw_cn3 = raw_cn+raw_cn2
# raw_en3 = raw_en+raw_en2
# trainset.get_cn_en_stoi_itos(raw_cn3,raw_en3)


cn_file = 'cn_vocab.txt'
en_file = 'en_vocab.txt'
trainset.get_from_vocab(cn_file,en_file)

print(len(trainset.cn_stoi))
print(len(trainset.en_stoi))
print(len(trainset.cn_itos))
print(len(trainset.en_itos))
print(len(trainset))
print(trainset.cn_itos[7])
print(trainset.en_stoi['bother'])

6457
25769
6457
25769
0
一
5606


In [7]:

encoder = Encoder_2(trainset.num_cn_vocab, d_model=512, nhead=4, num_layers=4).to(device)
decoder = Decoder_2(trainset.num_en_vocab, d_model=512, nhead=4, num_layers=4).to(device)


In [8]:
encoder.load_state_dict(torch.load('TRANS07en.pth',map_location=device))
decoder.load_state_dict(torch.load('TRANS07de.pth',map_location=device))

<All keys matched successfully>

In [9]:
def Beam_search(decoder,memory,num_beam=10,max_len=200):
    beam = []
    
    EOS = 1
    PERIOD = 3
    QU = 31
    EX = 85
    
    with torch.no_grad():

        scores = torch.ones(num_beam,requires_grad=False)
        for i in range(num_beam):
            beam.append([0])

        k = beam[1]
        tgt_mask = get_decoder_mask(len(k)).to(device)
        ans = decoder(torch.tensor(k).long().view(1,-1).to(device),memory,tgt_mask=tgt_mask).to('cpu')
        ans = torch.nn.functional.softmax(ans.view(len(k),-1)[-1], dim=-1)
        v1, d1 = torch.topk(ans, k=num_beam, dim=0)
        for idx,d in enumerate(d1):
            beam[idx].append(d)
            scores[idx]=v1[idx].item()

        for i in range(1,max_len):
            values = []
            indices = []
            records = torch.ones(num_beam*num_beam)
            for j in range(num_beam):
                k = beam[j]
                tgt_mask = get_decoder_mask(len(k)).to(device)
                ans = decoder(torch.tensor(k).long().view(1,-1).to(device),memory,tgt_mask=tgt_mask).to('cpu')
                ans = torch.nn.functional.softmax(ans.view(len(k),-1)[-1], dim=-1)
                v1, d1 = torch.topk(ans, k=num_beam, dim=0)
                values.append(v1)
                indices.append(d1)
                for idx,v in enumerate(v1):
                    records[j*num_beam+idx] *= v
                    records[j*num_beam+idx] *= scores[j]
            v2, d2 = torch.topk(records,k=num_beam*num_beam,dim=-1)
            new_beam_last = []
            new_scores = []
            old_beam_indices = []
            for idx,d in enumerate(d2):
                beam_id = int(d/num_beam)
                old_beam_indices.append(beam_id)
    #             temp.append(indices[beam_id][d-beam_id*num_beam].item())
                new_beam_last.append(indices[beam_id][d-beam_id*num_beam].item())
                new_scores.append(v2[idx])
            counter = 0
            candidate_beam = []
            eos_set = []
            for m in range(num_beam):
                if beam[m][-1] == EOS :
                    eos_set.append(m)
                    list2 = beam[m][::]
                    candidate_beam.append(list2)
                    scores[counter] = scores[m]
                    counter += 1
                if beam[m][-1] == PERIOD: #or beam[m][-1] == QU or beam[m][-1] == EX:
                    eos_set.append(m)
                    beam[m].append(EOS)
                    list2 = beam[m][::]
                    candidate_beam.append(list2)
                    scores[counter] = scores[m]
                    counter += 1
            if counter == num_beam:
                break
            for m,old_id in enumerate(old_beam_indices):
                if counter >= num_beam:
                    break
                if old_id in set(eos_set):
                    continue
                list1 = beam[old_id][::]
                list1.append(new_beam_last[m])
                candidate_beam.append(list1)
                scores[counter] = new_scores[m]
                counter += 1
            beam = candidate_beam
#                 if beam[old][-1] != EOS :
#                     beam[m].append(new_beam_last[m])
#                     scores[m] = new_scores[m]
#                     counter += 1
                
        
    possibility,final_idx = torch.topk(scores, k=num_beam, dim=-1)
    
    return [beam[idx]for idx in final_idx],possibility
          

In [10]:
def polish(line):
    line = line.strip().capitalize()
    line = line.replace(' tom ',' Tom ')
    line = line.replace(' lincoln ',' Lincoln ')
    line = line.replace(' mary ',' Mary ')
    line = line.replace(' anna ',' Anna ')
    line = line.replace(' alice ',' Alice ')
    line = line.replace(' emma ',' Emma ')
    line = line.replace(' helen ',' Helen ')
    line = line.replace(' eva ',' Eva ')
    line = line.replace(' lisa ',' Lisa ')
    line = line.replace(' james ',' James ')
    line = line.replace(' john ',' John ')
    line = line.replace(' robert ',' Robert ')
    line = line.replace(' linda ',' Linda ')
    line = line.replace(' william ',' William ')
    line = line.replace(' david ',' David ')
    line = line.replace(' richard ',' Richard ')
    line = line.replace(' jason ',' Jason ')
    line = line.replace(' jose ',' Jose ')
    line = line.replace(' paul ',' Paul ')
    line = line.replace(' maria ',' Maria ')
    line = line.replace(' jerry ',' Jerry ')
    line = line.replace(' jack ',' Jack ')
    line = line.replace(' louis ',' Louis ')
    line = line.replace(' joe ',' Joe ')
    line = line.replace(' justin ',' Justin ')
    line = line.replace(' mike ',' Mike ')
    line = line.replace(' henry ',' Henry ')
    line = line.replace(' benjamin ',' Benjamin ')
    line = line.replace(' betty ',' Betty ')
    line = line.replace(' smith ',' Smith ')
    line = line.replace(' steve ',' Steve ')
    line = line.replace(' susan ',' Susan ')
    line = line.replace(' jane ',' Jane ')
    line = line.replace(' sally ',' sally ')
    line = line.replace(' julie ',' Julie ')
    line = line.replace(' mrs ',' Mrs ')
    line = line.replace(' mr ',' Mr ')
    line = line.replace(' ms ',' Ms ')
    line = line.replace(' miss ',' Miss ')
    line = line.replace(' wi fi ',' Wi-Fi ')
    line = line.replace(' u. s. ',' U.S. ')
    line = line.replace(' america ',' America ')
    line = line.replace(' canada ',' Canada ')
    line = line.replace(' chinese ',' Chinese ')
    line = line.replace(' china ',' China ')
    line = line.replace(' japanese ',' Japanese ')
    line = line.replace(' japan ',' Japan ')
    line = line.replace(' france ',' France ')
    line = line.replace(' french ',' French ')
    line = line.replace(' u. k. ',' U.K. ')
    line = line.replace(' english ',' English ')
    line = line.replace(' british ',' British ')
    line = line.replace(' england ',' England ')
    line = line.replace(' britain ',' Britain ')
    line = line.replace(' ireland ',' Ireland ')
    line = line.replace(' scotland ',' Scotland ')
    line = line.replace(' russia ',' Russia ')
    line = line.replace(' egypt ',' Egypt ')
    line = line.replace(' greece ',' Greece ')
    line = line.replace(' germany ',' Germany ')
    line = line.replace(' german ',' German ')
    line = line.replace(' finland ',' Finland ')
    line = line.replace(' sweden ',' Sweden ')
    line = line.replace(' norway ',' Norway ')
    line = line.replace(' iceland ',' Iceland ')
    line = line.replace(' denmark ',' Denmark ')
    line = line.replace(' poland ',' Poland ')
    line = line.replace(' austria ',' Austria ')
    line = line.replace(' switzerland ',' Switzerland ')
    line = line.replace(' monaco ',' Monaco ')
    line = line.replace(' italy ',' Italy ')
    line = line.replace(' korea ',' Korea ')
    line = line.replace(' singapore ',' Singapore ')
    line = line.replace(' indonesia ',' Indonesia ')
    line = line.replace(' iran ',' Iran ')
    line = line.replace(' mexico ',' Mexico ')
    line = line.replace(' greenland ',' Greenland ')
    line = line.replace(' australia ',' Australia ')
    line = line.replace(' australian ',' Australian ')
    line = line.replace(' brazil ',' Brazil ')
    line = line.replace(' asia ',' Asia ')
    line = line.replace(' africa ',' Africa ')
    line = line.replace(' antarctica ',' Antarctica ')
    line = line.replace(' europe ',' Europe ')
    line = line.replace(' oceania ',' Oceania ')
    line = line.replace(" m ","'m ")
    line = line.replace(" t ","'t ")
    line = line.replace(" s ","'s ")
    line = line.replace(" d ","'d ")
    line = line.replace(" ve ","'ve ")
    line = line.replace(" ll ","'ll ")
    line = line.replace(" re ","'re ")
    line = line.replace(" 'm ","'m ")
    line = line.replace(" 't ","'t ")
    line = line.replace(" 's ","'s ")
    line = line.replace(" 'd ","'d ")
    line = line.replace(" 've ","'ve ")
    line = line.replace(" 'll ","'ll ")
    line = line.replace(" 're ","'re ")
    line = line.replace(' u ',' U ')
    line = line.replace(' i ',' I ')
    line = line.replace(' z ',' Z ')
    line = line.replace(' c ',' C ')
    line = line.replace(' b ',' B ')
    line = line.replace(' f ',' F ')
    line = line.replace(' g ',' G ')
    line = line.replace(' h ',' H ')
    line = line.replace(' j ',' J ')
    line = line.replace(' k ',' K ')
    line = line.replace(' l ',' L ')
    line = line.replace(' o ',' O ')
    line = line.replace(' p ',' P ')
    line = line.replace(' q ',' Q ')
    line = line.replace(' r ',' R ')
    line = line.replace(' w ',' W ')
    line = line.replace(' x ',' X ')
    line = line.replace(' y ',' Y ')
    line = line.replace(' ?','?')
    line = line.replace(' .','.')
    line = line.replace(' ,',',')
    line = line.replace(' !','!')
        
    return line
    

In [11]:
def evaluate(st,max_len = 50):
    encoder.eval()
    decoder.eval()
#     t=[trainset.cn_stoi[tk] for tk in st if tk in trainset.cn_itos]
    t=[]
    for tk in st:
        if tk not in trainset.cn_itos:
            tk = '<UNK>'
        t.append(trainset.cn_stoi[tk])
        
        
    if t[0]!=0 :
        t = [0]+t+[1]
    
    out = [0]
    i = 1
    t = torch.tensor(t).long().view(1,-1).to(device)
    memory = encoder(t)
    
    for i in range(max_len):
        tgt_mask = get_decoder_mask(i+1).to(device)
        ans = decoder(torch.tensor(out).long().view(1,-1).to(device),memory,tgt_mask=tgt_mask).to('cpu')
        q = torch.argmax(ans.view(-1,ans.shape[-1])[-1],dim=-1)
        out.append(int(q))
        if q == 1 :
            break

    out = [trainset.en_itos[idx] for idx in out]
    return ' '.join(out)
        

In [12]:
def beam_search_evaluate(st,max_len = 50,long_st=False,ret=False,beam_width=10):
    encoder.eval()
    decoder.eval()
#     t=[trainset.cn_stoi[tk] for tk in st if tk in trainset.cn_itos]
    t = []
    for tk in st:
        if tk not in trainset.cn_itos:
            tk = '<UNK>'
        t.append(trainset.cn_stoi[tk])
        
    if t[0]!=0 :
        t = [0]+t+[1]
    
    out = [0]
    i = 1
    t = torch.tensor(t).long().view(1,-1).to(device)
    memory = encoder(t)
    
    outs,scores = Beam_search(decoder,memory,num_beam=beam_width,max_len=50)
    
    if ret == True:
        out = outs[0]
        out = out[1:-1]
        out = [trainset.en_itos[idx] for idx in out]
        return polish(' '.join(out))
        
    
    if long_st == False:
#         print('概率：\n',scores)

        outs = [[trainset.en_itos[idx] for idx in out]for out in outs]
        for out in outs:
            out = ' '.join(out[1:-1])
            print(polish(out))
            break

        return outs
    
    else:
        out = outs[0]
        out = [trainset.en_itos[idx] for idx in out]

        return out[1:-1],scores[0]

In [13]:
def long_st_beam_search(st,split1=False):
    
    head = 0
    end = 0
    temp = []
    if st[-1] != '。' and st[-1] != '！' and st[-1] != '？':
        st = st+'。'
    for tk in st:
        end += 1
        if tk == '！' or tk == '。' or tk =='？' or tk =='；':
            temp.append(st[head:end])
            head = end
    
    st = temp
    temp = []
    for min_st1 in st:
        min_st2 = min_st1.replace('，','')
#         if min_st2 != min_st1:
#             out1, score1 = beam_search_evaluate(min_st1,long_st=True)
#             out2, score2 = beam_search_evaluate(min_st2,long_st=True)
#             if len(out1)>len(out2):
#                 temp.append(out1)
#             else:
#                 temp.append(out2)
#         else:
        out1, score1 = beam_search_evaluate(min_st1,long_st=True)
        temp.append(out1)
    
    if split1 == False:
        out = ''
        for idx,k in enumerate(temp):
            k = ' '.join(k)
            k = polish(k)
            if idx == 0:
                out = k
            else :
                out = out + ' ' + k
        return out
    
    if split1 == True:
        out = []
        for idx,k in enumerate(temp):
            k = ' '.join(k)
            k = polish(k)
            out.append(k)
        return out
    

In [22]:
beam_search_evaluate('今天是周一而不是周二。')
# beam_search_evaluate('树新的蜜蜂。')
print()

Today is monday and not tuesday.



In [27]:
k = long_st_beam_search('只要事先将大的模型分解成一些小的模块,把每一个小模块都做好了,大模型就差不多了。当然每一个小的模型有时也会有问题,需要重新做,但此时仅仅需要修改这一部分,其他的部分并不受影响,此过程大概需要一天半的时间',split1=True)
for w in k:
    print(w)

So long as you break the big model in advance into smaller modules, do every small module well and the model is about the same.
Of course, every small model will sometimes have problems needing to be <unk>, but at this point only need to modify this section and the rest is not affected.


So long as you break the big model in advance into smaller modules, do every small module well and the model is about the same.
Of course, every small model will sometimes have problems needing to be <unk>, but at this point only need to modify this section and the rest is not affected.


In [56]:
# vocab03 = []
# with open('/Users/feisen/PycharmProjects/NLP03/vocab.txt') as f:
#     for line in f.readlines():
#         line = line.strip()
#         if line in trainset.en_stoi:
#             vocab03.append(trainset.en_stoi[line])
#         else:
#             vocab03.append(-1)

In [57]:
# wordvec = torch.zeros(len(vocab03),512)
# for idx,i in enumerate(vocab03):
#     if idx >= 0:
#         wordvec[idx] = decoder.embedding.weight.data[idx]

In [58]:
# wordvec = np.array(wordvec)
# with open('wordvec.txt','w') as f:
#     for line in wordvec:
#         for w in line:
#             f.write(str(w))
#             f.write(' ')
#         f.write('\n')
#     f.close()

In [68]:
# com300 = []
# with open('combined.csv') as f:
#     for line in f.readlines()[1:]:
#         line=line.strip().split(',')
#         com300.append(line)

In [69]:
# scores1 = []
# scores2 = []
# for tk1, tk2, score in com300:
#     if tk1.lower() in trainset.en_stoi and tk2.lower() in trainset.en_stoi:
#         v1 = decoder.embedding.weight.data[trainset.en_stoi[tk1.lower()]]
#         v2 = decoder.embedding.weight.data[trainset.en_stoi[tk2.lower()]]
#         s = torch.sum(v1*v2)/torch.sum(torch.sqrt(v1**2*v2**2))
#         scores1.append(s)
#         scores2.append(score)
    

In [70]:
# array = np.array(scores1,dtype=float)
# data2 = np.array(scores2,dtype=float)

# xy = np.sum(array*data2)/len(array)
# x = np.sum(array)/len(array)
# y = np.sum(data2)/len(array)
# x2 = np.sum(array*array)/len(array)
# y2 = np.sum(data2*data2)/len(array)

# cov = xy-x*y
# varX = x2-x*x
# varY = y2-y*y
# Correlation_coefficient = cov/np.sqrt(varX*varY)
# print(Correlation_coefficient)

0.1852576552300105


IndexError: index 9861 is out of bounds for dimension 0 with size 6457

In [71]:
# decoder.embedding.weight.data.shape

torch.Size([25769, 512])