In [None]:
import torch
import re
import random
import matplotlib.pyplot as plt
import numpy as np

In [None]:
class EncoderGru(torch.nn.Module):
    def __init__(self,word_size,hidden_size):
        super(EncoderGru,self).__init__()
        self.hidden_size=hidden_size
        self.embedding_layer=torch.nn.Embedding(word_size,hidden_size)
        self.gru=torch.nn.GRU(hidden_size,hidden_size)
    
    def forward(self,input_vector,hidden):
        embedded=self.embedding_layer(input_vector)
        embedded=embedded.unsqueeze(1)
        out,hid=self.gru(embedded,hidden)
        return out,hid
    
    def init_hidden(self):
        return torch.zeros(1,1,self.hidden_size)

In [None]:
class AttenDecoder(torch.nn.Module):
    def __init__(self,word_szie,hidden_size):
        super(AttenDecoder,self).__init__()
        self.hidden_size=hidden_size
        self.embedding_layer=torch.nn.Embedding(word_szie,hidden_size)
        self.atten_layer=torch.nn.Linear(hidden_size*2,50)
        self.atten_combine_layer=torch.nn.Linear(hidden_size*2,hidden_size)
        self.gru=torch.nn.GRU(hidden_size,hidden_size)
        self.last_layer=torch.nn.Linear(hidden_size,word_szie)
        
    def forward(self,input_vector,hidden,encoder_output):
        embeded=self.embedding_layer(input_vector)
        contacted=torch.cat((embeded,hidden[0]),dim=1)
        atten=self.atten_layer(contacted)
        atten=torch.nn.functional.softmax(atten,dim=1)
        atten_apply=torch.mm(atten,encoder_output.view(-1,256))
        
        atten_in=torch.cat((embeded,atten_apply),dim=1)
        
        gru_in=self.atten_combine_layer(atten_in)
        atten=torch.nn.functional.relu(atten)
        gru_in=gru_in.unsqueeze(0)
        
        out,hid=self.gru(gru_in,hidden)
        
        out=self.last_layer(out[0])
        return out,hid,atten

In [None]:
def read_pairs(path='F:/Github/machine_learn_record/pytorch/data/cmn.txt'):
    file=open(path,encoding='utf-8')
    content=file.read()
    pairs=[]
    for p in content.split('\n'):
        temp=p.split('\t')
        pairs.append(temp)
        
    #调整英语中的符号
    for p in pairs:
        es=p[0].lower().strip()
        es= re.sub(r"([.!?])", r" \1", es)
        es = re.sub(r"[^a-zA-Z.!?]+", r" ", es)
        p[0]=es
    #删除最后一行
    pairs.pop()
    return pairs

In [None]:
pairs=read_pairs()
pairs[1]

In [None]:
class Lang:
    def __init__(self,name):
        self.name=name
        self.index2word={}
        self.word2index={0: "SOS", 1: "EOS"}
        self.word2count={}
        self.n_word=2
    
    def add_sentence(self,sentence):
        if self.name=='en':
            for w in sentence.split(' '):
                self.add_word(w)
        else:
            for w in sentence:
                self.add_word(w)
                
    def add_word(self,word):
        if word not in self.word2index:
            self.index2word[self.n_word]=word
            self.word2index[word]=self.n_word
            self.word2count[word]=1
            self.n_word+=1
        else:
            self.word2count[word]+=1

In [None]:
chinese=Lang('cn')
english=Lang('en')
for p in pairs:
    chinese.add_sentence(p[1])
    english.add_sentence(p[0])

In [None]:
def sentence2tensor(sentence,lang,device=torch.device("cuda")):
    idxs=[]
    if lang.name=='en':
        for w in sentence.split(' '):
            idxs.append(lang.word2index[w])
    else:
        for w in sentence:
            idxs.append(lang.word2index[w])
    idxs.append(1)
    tnr=torch.tensor(idxs,dtype=torch.long)
    return tnr.view(-1,1).to(device)

In [None]:
def train(encoder,decoder,encoder_optimizer,decoder_optimizer,loss_f,inputs,outs):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    encoder_outs=torch.zeros((50,1,1,256),dtype=torch.float).to(device)
    for i in range(inputs.shape[0]):
        encoder_outs[i],hidden=encoder.forward(inputs[i],encoder.init_hidden().to(device))

    loss=0
    
    decoder_in=torch.tensor([0],dtype=torch.long).to(device)
    for i in range(outs.shape[0]):
        out,hidden,atten=atten_decorder.forward(decoder_in,hidden,encoder_outs)
        topv, topi = out.data.topk(1)
        decoder_in=topi.squeeze(0).detach() if random.random()>0.5 else outs[i]
        temp=loss_f(out,outs[i])
        loss+=temp
    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()
    return loss.item()/outs.shape[0]

In [None]:
def model_train(encoder,atten_decorder):
    encoder_optimizer=torch.optim.Adam(encoder.parameters(),lr=0.0001)
    decoder_optimizer=torch.optim.Adam(atten_decorder.parameters(),lr=0.0001)
    loss_f=torch.nn.CrossEntropyLoss()
    datas=[random.choice(pairs) for i in range(15000)]
    loss1000=0
    for i,p in enumerate(datas):
        inputs= sentence2tensor(p[0],english)
        outs= sentence2tensor(p[1],chinese)
        loss=train(encoder,atten_decorder,encoder_optimizer,decoder_optimizer,loss_f,inputs,outs)
        loss1000+=loss
        if i%100==0:
            print('{} epoc avg loss is {}'.format(i,loss1000/100))
            loss1000=0

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder=EncoderGru(english.n_word,256)
atten_decorder=AttenDecoder(chinese.n_word,256)
encoder.to(device)
atten_decorder.to(device)

In [None]:
for i in range(100):
    model_train(encoder,atten_decorder)

In [None]:
def predict(p,ecncoder,atten_decorder):
    inputs= sentence2tensor(p[0],english)
    outs= sentence2tensor(p[1],chinese)
    encoder_outs=torch.zeros((50,1,1,256),dtype=torch.float).to(device)
    for i in range(inputs.shape[0]):
        encoder_outs[i],hidden=encoder.forward(inputs[i],encoder.init_hidden().to(device))

    decoder_in=torch.tensor([0],dtype=torch.long).to(device)
    rs=[]
    attens=[]
    for i in range(50):
        out,hidden,atten=atten_decorder.forward(decoder_in,hidden,encoder_outs)
        topv, topi = out.topk(1)
        decoder_in=topi.squeeze(0).detach()
        if topi[0].item()==1:
            break
        attens.append(atten.cpu().detach().numpy()[0])
        rs.append(chinese.index2word[topi[0].item()])
    return rs,np.asarray(attens)

In [None]:
for i in range(50):
    a=random.randint(0,10000)
    rs,atten=predict(pairs[a],encoder,atten_decorder)
    showAttention(pairs[a][0],rs,atten[:,:len(pairs[a][0].split(' '))])
    print("input: {}\ntarget: {}\npredict: {}".format(pairs[a][0],pairs[a][1],''.join(rs)))

In [None]:
def showAttention(input_sentence, output_words, attentions):
    # Set up figure with colorbar
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions, cmap='bone')
    fig.colorbar(cax)

    # Set up axes
    ax.set_xticklabels([''] + input_sentence.split(' ') +
                       ['<EOS>'], rotation=90)
    ax.set_yticklabels([''] + output_words)


    plt.show()

In [None]:
def translate(sentence,ecncoder,atten_decorder):
    inputs= sentence2tensor(sentence,english)
    encoder_outs=torch.zeros((50,1,1,256),dtype=torch.float).to(device)
    for i in range(inputs.shape[0]):
        encoder_outs[i],hidden=encoder.forward(inputs[i],encoder.init_hidden().to(device))

    decoder_in=torch.tensor([0],dtype=torch.long).to(device)
    rs=[]
    for i in range(50):
        out,hidden,atten=atten_decorder.forward(decoder_in,hidden,encoder_outs)
        topv, topi = out.data.topk(1)
        decoder_in=topi.squeeze(0).detach()
        if topi[0].item()==1:
            break
        rs.append(chinese.index2word[topi[0].item()])
    print("in------>",sentence,"predict--->","".join(rs))

In [337]:
translate('tom dead .',encoder,atten_decorder)

in------> tom dead . predict---> 汤姆死了。
