In [94]:
import torch
import re
import random

In [2]:
class EncoderGru(torch.nn.Module):
    def __init__(self,word_size,hidden_size):
        super(EncoderGru,self).__init__()
        self.hidden_size=hidden_size
        self.embedding_layer=torch.nn.Embedding(word_size,hidden_size)
        self.gru=torch.nn.GRU(hidden_size,hidden_size)
    
    def forward(self,input_vector,hidden):
        embedded=self.embedding_layer(input_vector)
        embedded=embedded.unsqueeze(1)
        out,hid=self.gru(embedded,hidden)
        return out,hid
    
    def init_hidden(self):
        return torch.zeros(1,1,self.hidden_size)

In [129]:
class AttenDecoder(torch.nn.Module):
    def __init__(self,word_szie,hidden_size):
        super(AttenDecoder,self).__init__()
        self.hidden_size=hidden_size
        self.embedding_layer=torch.nn.Embedding(word_szie,hidden_size)
        self.atten_layer=torch.nn.Linear(hidden_size*2,50)
        self.atten_combine_layer=torch.nn.Linear(hidden_size*2,hidden_size)
        self.gru=torch.nn.GRU(hidden_size,hidden_size)
        self.last_layer=torch.nn.Linear(hidden_size,word_szie)
        
    def forward(self,input_vector,hidden,encoder_output):
        embeded=self.embedding_layer(input_vector)
        contacted=torch.cat((embeded,hidden[0]),dim=1)
        atten=self.atten_layer(contacted)
        atten_apply=torch.mm(atten,encoder_output.view(-1,256))
        
        atten_in=torch.cat((embeded,atten_apply),dim=1)
        
        gru_in=self.atten_combine_layer(atten_in)
        gru_in=gru_in.unsqueeze(0)
        
        out,hid=self.gru(gru_in,hidden)
        
        out=self.last_layer(out[0])
        return out,hid

In [131]:
def read_pairs(path='F:/Github/machine_learn_record/pytorch/data/cmn.txt'):
    file=open(path,encoding='utf-8')
    content=file.read()
    pairs=[]
    for p in content.split('\n'):
        temp=p.split('\t')
        pairs.append(temp)
        
    #调整英语中的符号
    for p in pairs:
        es=p[0].lower().strip()
        es= re.sub(r"([.!?])", r" \1", es)
        es = re.sub(r"[^a-zA-Z.!?]+", r" ", es)
        p[0]=es
    #删除最后一行
    pairs.pop()
    return pairs

In [132]:
pairs=read_pairs()

In [133]:
class Lang:
    def __init__(self,name):
        self.name=name
        self.index2word={}
        self.word2index={0: "SOS", 1: "EOS"}
        self.word2count={}
        self.n_word=2
    
    def add_sentence(self,sentence):
        if self.name=='en':
            for w in sentence.split(' '):
                self.add_word(w)
        else:
            for w in sentence:
                self.add_word(w)
                
    def add_word(self,word):
        if word not in self.word2index:
            self.index2word[self.n_word]=word
            self.word2index[word]=self.n_word
            self.word2count[word]=1
            self.n_word+=1
        else:
            self.word2count[word]+=1

In [134]:
chinese=Lang('cn')
english=Lang('en')
for p in pairs:
    chinese.add_sentence(p[1])
    english.add_sentence(p[0])

In [162]:
def sentence2tensor(sentence,lang,device=torch.device("cuda")):
    idxs=[]
    if lang.name=='en':
        for w in sentence.split(' '):
            idxs.append(lang.word2index[w])
    else:
        for w in sentence:
            idxs.append(lang.word2index[w])
    idxs.append(1)
    tnr=torch.tensor(idxs,dtype=torch.long)
    return tnr.view(-1,1).to(device)

In [243]:
def train(encoder,decoder,encoder_optimizer,decoder_optimizer,loss_f,inputs,outs):
    encoder_outs=torch.zeros((50,1,1,256),dtype=torch.float).to(device)
    for i in range(inputs.shape[0]):
        encoder_outs[i],hidden=encoder.forward(inputs[i],encoder.init_hidden().to(device))

    loss=0
    
    decoder_in=torch.tensor([0],dtype=torch.long).to(device)
    for i in range(outs.shape[0]):
        out,hidden=atten_decorder.forward(decoder_in,hidden,encoder_outs)
        topv, topi = out.data.topk(1)
        decoder_in=topi.squeeze(0).detach() if random.random()>0.5 else outs[i]
        loss+=loss_f(out,outs[i])
    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()
    return loss.item()/outs.shape[0]

In [244]:
def model_train(encoder,atten_decorder):
    encoder_optimizer=torch.optim.SGD(ecncoder.parameters(),lr=0.001)
    decoder_optimizer=torch.optim.SGD(atten_decorder.parameters(),lr=0.001)
    loss_f=torch.nn.CrossEntropyLoss()
    datas=[random.choice(pairs) for i in range(15000)]
    loss1000=0
    for i,p in enumerate(datas):
        inputs= sentence2tensor(p[0],english)
        outs= sentence2tensor(p[1],chinese)
        loss=train(encoder,atten_decorder,encoder_optimizer,decoder_optimizer,loss_f,inputs,outs)
        loss1000+=loss
        print(loss1000)
        if i%1000==999:
            print('{} epoc avg loss is {}'.format(i,loss1000/1000))
            loss1000=0

In [245]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder=EncoderGru(english.n_word,256)
atten_decorder=AttenDecoder(chinese.n_word,256)
encoder.to(device)
atten_decorder.to(device)

AttenDecoder(
  (embedding_layer): Embedding(3439, 256)
  (atten_layer): Linear(in_features=512, out_features=50, bias=True)
  (atten_combine_layer): Linear(in_features=512, out_features=256, bias=True)
  (gru): GRU(256, 256)
  (last_layer): Linear(in_features=256, out_features=3439, bias=True)
)

In [251]:
model_train(encoder,atten_decorder)

999 epoc avg loss is 2714.9952819870914
1999 epoc avg loss is 2892.429995227221
2999 epoc avg loss is 3074.048898804599
3999 epoc avg loss is 3255.610831805999
4999 epoc avg loss is 3465.2557185724004
5999 epoc avg loss is 3646.1713126738805
6999 epoc avg loss is 3787.6926430907843
7999 epoc avg loss is 3983.5497674811104
8999 epoc avg loss is 4100.21869844775
9999 epoc avg loss is 4326.269555655717
10999 epoc avg loss is 4564.293657520129
11999 epoc avg loss is 4758.90573353196
12999 epoc avg loss is 4941.175668666164
13999 epoc avg loss is 5126.419606660099
14999 epoc avg loss is 5312.339017648663


In [261]:
def predict(p,ecncoder,atten_decorder):
    inputs= sentence2tensor(p[0],english)
    outs= sentence2tensor(p[1],chinese)
    encoder_outs=torch.zeros((50,1,1,256),dtype=torch.float).to(device)
    for i in range(inputs.shape[0]):
        encoder_outs[i],hidden=encoder.forward(inputs[i],encoder.init_hidden().to(device))

    loss=0
    decoder_in=torch.tensor([0],dtype=torch.long).to(device)
    for i in range(outs.shape[0]):
        out,hidden=atten_decorder.forward(decoder_in,hidden,encoder_outs)
        topv, topi = out.data.topk(1)
        decoder_in=topi.squeeze(0).detach()
        print(chinese.index2word[topi[0].item()])
    print(p[0],p[1])

predict(pairs[56],ecncoder,atten_decorder)


成
好
上
土
skip it . 不管它。


In [242]:
chinese.index2word[1070]

'收'