In [1]:
import torch
import re

In [2]:
class EncoderGru(torch.nn.Module):
    def __init__(self,word_size,hidden_size):
        super(EncoderGru,self).__init__()
        self.hidden_size=hidden_size
        self.embedding_layer=torch.nn.Embedding(word_size,hidden_size)
        self.gru=torch.nn.GRU(hidden_size,hidden_size)
    
    def forward(self,input_vector,hidden):
        embedded=self.embedding_layer(input_vector)
        embedded=embedded.unsqueeze(1)
        out,hid=self.gru(embedded,hidden)
        return out,hid
    
    def init_hidden(self):
        return torch.zeros(1,1,self.hidden_size)

In [41]:
class AttenDecoder(torch.nn.Module):
    def __init__(self,word_szie,hidden_size):
        super(AttenDecoder,self).__init__()
        self.hidden_size=hidden_size
        self.embedding_layer=torch.nn.Embedding(word_szie,hidden_size)
        self.atten_layer=torch.nn.Linear(hidden_size*2,50)
        self.atten_combine_layer=torch.nn.Linear(hidden_size*2,hidden_size)
        self.gru=torch.nn.GRU(hidden_size,hidden_size)
        self.last_layer=torch.nn.Linear(hidden_size,word_szie)
        
    def forward(self,input_vector,hidden,encoder_output):
        embeded=self.embedding_layer(input_vector)
        contacted=torch.cat((embeded,hidden[0]),dim=1)
        atten=self.atten_layer(contacted)
        atten_apply=torch.mm(atten,encoder_output.view(-1,4))
        
        atten_in=torch.cat((embeded,atten_apply),dim=1)
        
        gru_in=self.atten_combine_layer(atten_in)
        gru_in=gru_in.unsqueeze(0)
        
        out,hid=self.gru(gru_in,hidden)
        
        out=self.last_layer(out[0])
        return out,hid

In [45]:
ecncoder=EncoderGru(english.n_word,4)
atten_decorder=AttenDecoder(chinese.n_word,4)
encoder_optimizer=torch.optim.SGD(ecncoder.parameters(),lr=0.001)
decoder_optimizer=torch.optim.SGD(atten_decorder.parameters(),lr=0.001)
loss_f=torch.nn.CrossEntropyLoss()

In [51]:
inputs= sentence2tensor(pairs[5000][0],english)
outs= sentence2tensor(pairs[5000][1],chinese)
print(inputs)
encoder_outs=torch.zeros((50,1,1,4),dtype=torch.float)
for i in range(inputs.shape[0]):
    encoder_outs[i],hidden=ecncoder.forward(inputs[i],ecncoder.init_hidden())

loss=0
for i in range(outs.shape[0]):
    out,de_hid=atten_decorder.forward(torch.tensor([0],dtype=torch.long),hidden,encoder_outs)
    loss+=loss_f(out,outs[i])
print(loss)

tensor([[ 187],
        [  62],
        [  41],
        [  15],
        [2192],
        [  16]])
torch.Size([1, 4]) torch.Size([1, 4])
torch.Size([1, 4]) torch.Size([1, 4])
torch.Size([1, 4]) torch.Size([1, 4])
torch.Size([1, 4]) torch.Size([1, 4])
torch.Size([1, 4]) torch.Size([1, 4])
torch.Size([1, 4]) torch.Size([1, 4])
torch.Size([1, 4]) torch.Size([1, 4])
torch.Size([1, 4]) torch.Size([1, 4])
tensor(64.3966, grad_fn=<AddBackward0>)


In [4]:
for i in range(10):
    for p in pairs:
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        inputs= sentence2tensor(p[0])
        target= sentence2tensor(p[1])

In [6]:
def read_pairs(path='F:/Github/machine_learn_record/pytorch/data/cmn.txt'):
    file=open(path,encoding='utf-8')
    content=file.read()
    pairs=[]
    for p in content.split('\n'):
        temp=p.split('\t')
        pairs.append(temp)
        
    #调整英语中的符号
    for p in pairs:
        es=p[0].lower().strip()
        es= re.sub(r"([.!?])", r" \1", es)
        es = re.sub(r"[^a-zA-Z.!?]+", r" ", es)
        p[0]=es
    #删除最后一行
    pairs.pop()
    return pairs

In [7]:
pairs=read_pairs()

In [8]:
class Lang:
    def __init__(self,name):
        self.name=name
        self.index2word={}
        self.word2index={0: "SOS", 1: "EOS"}
        self.word2count={}
        self.n_word=2
    
    def add_sentence(self,sentence):
        if self.name=='en':
            for w in sentence.split(' '):
                self.add_word(w)
        else:
            for w in sentence:
                self.add_word(w)
                
    def add_word(self,word):
        if word not in self.word2index:
            self.index2word[self.n_word]=word
            self.word2index[word]=self.n_word
            self.word2count[word]=1
            self.n_word+=1
        else:
            self.word2count[word]+=1

In [13]:
chinese=Lang('cn')
english=Lang('en')
for p in pairs:
    chinese.add_sentence(p[1])
    english.add_sentence(p[0])

In [18]:
def sentence2tensor(sentence,lang):
    idxs=[]
    if lang.name=='en':
        for w in sentence.split(' '):
            idxs.append(lang.word2index[w])
    else:
        for w in sentence:
            idxs.append(lang.word2index[w])
    tnr=torch.tensor(idxs,dtype=torch.long)
    return tnr.view(-1,1)

In [34]:
pairs[500]

['please leave .', '請你離開。']

In [35]:
sentence2tensor('please leave .',english)

tensor([158, 100,   3])

In [36]:
sentence2tensor('請你離開。',chinese)

tensor([536,   4, 537,  65,   3])