In [89]:
import torch
from torch import nn 
from torch.optim import Adam

In [51]:
class RNN(nn.Module):

    def __init__(self,dim,output_size):
        super().__init__()
        self.dim = dim
        self.Wx = nn.Linear(dim,dim)
        self.Wh = nn.Linear(dim,dim)
        self.sigmoid = nn.Sigmoid()
        self.Wy = nn.Linear(dim,output_size)
        
    def hiddend_state(self):
        return torch.zeros(1, self.dim)
        
    def forward(self,x,h):
        h = self.sigmoid(self.Wx(x) + self.Wh(h))
        logits = self.Wy(h)
        return h,logits

In [53]:
import csv

In [180]:
path = 'WMT-Chinese-to-English-Machine-Translation-newstest/damo_mt_testsets_zh2en_news_wmt18.csv'

In [181]:
import pandas as pd

In [182]:
df = pd.read_csv(path)

In [184]:
chs = df['0'].values
ens = df['1'].values

In [187]:
chs

array(['声明补充说，沃伦的同事都深感震惊，并且希望他能够投案自首。', '不光改变硬件，软件也要跟上',
       '“这不是我们习以为常的、可以称之为典型的谋杀事件。”', ...,
       '在被唐纳德·特朗普总统任命之前，普鲁特是俄克拉何马州总检察长，长期以来反对环境法规的严格化。',
       '在功能手机时代，手机的基本功能就是打电话、发短信、简单的备忘录，各种手机在功能上差距是不大的。',
       '11月份全国热点城市房价趋稳，其中京沪深止涨，未来部分城市房价水平将继续回落。'], dtype=object)

In [188]:
ens

array(["The statement added that Warren's colleagues were shocked and want him to turn himself in.",
       'We should not only change the hardware, but the software must also keep up.',
       '"\'This isn\'t the type of murder that we\'ve become used to and can call typical.\'',
       ...,
       "Pruitt, who was Oklahoma's state attorney general prior to his appointment by President Donald Trump, has long served as a reliable opponent of stricter environmental regulations.",
       'In a feature phone era, the basic functions of a mobile phone would be making a call, sending short text messages, and simple memos. The difference in the functions among mobile phones is quite small.',
       'Residential property prices in popular cities nationwide stabilized in November. In particular, prices stopped rising in Beijing, Shanghai and Shenzhen. Going forward, the level of residential property prices in some cities will continue to decline.'],
      dtype=object)

In [71]:
class Tokenizer:
    
    def __init__(self,sentences):
        self.chars = set()
        for sentence in sentences:
            self.chars.update(set(sentence))
        self.chars = list(self.chars)
        self.bos = '<bos>'
        self.eos = '<eos>'
        self.chars.append(self.bos)
        self.chars.append(self.eos)
        self.encode = {c:i for i,c in enumerate(self.chars)}
        self.decode = {i:c for i,c in enumerate(self.chars)}
        self.vocab_size = len(self.encode)
    
    def to_ids(self,text):
        return [self.encode[t] for t in text]
    
    def to_tokens(self,ids):
        return [self.decode[_id] for _id in ids]

In [None]:
ch_tokenizer = Tokenizer(chs)

In [None]:
en_tokenizer = Tokenizer(ens)

In [None]:
ch_tokenizer.to_ids('声明补充说')

In [None]:
en_tokenizer.to_tokens(en_tokenizer.to_ids('The statement'))

In [194]:
class Tokenizer:
    
    def __init__(self,sentences):
        self.chars = set()
        for sentence in sentences:
            self.chars.update(set(sentence))
        self.chars = list(self.chars)
        self.bos = '<bos>'
        self.eos = '<eos>'
        self.chars.append(self.bos)
        self.chars.append(self.eos)
        self.encode = {c:i for i,c in enumerate(self.chars)}
        self.decode = {i:c for i,c in enumerate(self.chars)}
        self.bos_index = self.encode[self.bos]
        self.eos_index = self.encode[self.eos]
        self.vocab_size = len(self.decode)
    
    def to_ids(self,text):
        return [self.encode[t] for t in text]
    
    def to_tokens(self,ids):
        return [self.decode[_id] for _id in ids]

In [195]:
ch_tokenizer = Tokenizer(chs)

In [196]:
en_tokenizer = Tokenizer(ens)

In [197]:
ch_tokenizer.to_ids('声明补充说')

[1909, 396, 1710, 139, 2877]

In [198]:
en_tokenizer.to_tokens(en_tokenizer.to_ids('The statement'))

['T', 'h', 'e', ' ', 's', 't', 'a', 't', 'e', 'm', 'e', 'n', 't']

In [135]:
class Seq2Seq(nn.Module):

    def __init__(self,dim,input_size,output_size):
        super().__init__()
        self.encoder = RNN(dim,input_size)   
        self.decoder = RNN(dim,output_size)   
        self.input_embedding = nn.Embedding(input_size,dim)
        self.output_embedding = nn.Embedding(output_size,dim)
        self.criterion = nn.CrossEntropyLoss()
    
    def _forward(self,model,embedding,x,h):
        h_seq = []
        logits_seq = []
        for xi in x:
            xi = embedding(xi)
            h,logits = model(xi,h)
            h_seq.append(h)
            logits_seq.append(logits)
        h_seq = torch.cat(h_seq)
        logits_seq = torch.cat(logits_seq)
        return h_seq,logits_seq
    
    def encode(self,x):
        h_seq,logits_seq = self._forward(self.encoder,self.input_embedding,x,self.encoder.hiddend_state())
        return h_seq[-1].unsqueeze(0),logits_seq[-1]
    
    def decode(self,x,h):
        h_seq,logits_seq = self._forward(self.decoder,self.output_embedding,x,h)
        return h_seq,logits_seq
    
    def forward(self,x,y):
        h,logits = self.encode(x)
        y_input = y[:-1]
        y_target = y[1:]
        h_seq,logits_seq = self.decode(y_input,h)
        loss = self.criterion(logits_seq,y_target)
        return loss

In [189]:
import random

In [199]:
def get_random_tokens(chs,ens,length=64):
    indexs = list(range(len(chs)))
    random.shuffle(indexs)
    indexs = indexs[:64]
    ch_tokens = [torch.LongTensor(ch_tokenizer.to_ids(chs[i])) for i in indexs]
    en_tokens = [torch.LongTensor([en_tokenizer.bos_index]+en_tokenizer.to_ids(ens[i])+[en_tokenizer.eos_index]) for i in indexs]
    return ch_tokens,en_tokens

In [203]:
def get_random_x_y(tokens,seq_len):
    i = random.randint(0,len(tokens)-seq_len-2)
    seq = tokens[i:i+seq_len+1]
    x = seq[:-1]
    y = seq[1:]
    return x,y 

In [204]:
model = Seq2Seq(dim,ch_tokenizer.vocab_size,en_tokenizer.vocab_size)
optimizer = Adam(model.parameters(),lr=3e-4)

In [213]:
for epoch in range(100000):
    count = 0
    loss = 0
    ch_tokens,en_tokens = get_random_tokens(chs,ens,length=64)
    for xi,yi in zip(ch_tokens,en_tokens):
        lossi = model(xi,yi)
        loss += lossi
        count += 1
    loss = loss / count
    if epoch % 100 == 0:
        print(loss)
    optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1)
    optimizer.step()

tensor(1.9531, grad_fn=<DivBackward0>)
tensor(2.0128, grad_fn=<DivBackward0>)
tensor(1.9718, grad_fn=<DivBackward0>)
tensor(2.0313, grad_fn=<DivBackward0>)
tensor(2.0086, grad_fn=<DivBackward0>)
tensor(1.9663, grad_fn=<DivBackward0>)
tensor(1.9844, grad_fn=<DivBackward0>)
tensor(1.9719, grad_fn=<DivBackward0>)
tensor(1.9349, grad_fn=<DivBackward0>)
tensor(1.9897, grad_fn=<DivBackward0>)
tensor(1.9703, grad_fn=<DivBackward0>)
tensor(1.9948, grad_fn=<DivBackward0>)
tensor(1.9552, grad_fn=<DivBackward0>)
tensor(1.9794, grad_fn=<DivBackward0>)
tensor(1.9674, grad_fn=<DivBackward0>)
tensor(1.9696, grad_fn=<DivBackward0>)
tensor(1.9698, grad_fn=<DivBackward0>)
tensor(1.9706, grad_fn=<DivBackward0>)
tensor(1.9608, grad_fn=<DivBackward0>)
tensor(1.9649, grad_fn=<DivBackward0>)
tensor(1.9818, grad_fn=<DivBackward0>)
tensor(1.9339, grad_fn=<DivBackward0>)
tensor(1.9488, grad_fn=<DivBackward0>)
tensor(1.9628, grad_fn=<DivBackward0>)
tensor(1.9528, grad_fn=<DivBackward0>)
tensor(1.9384, grad_fn=<D

KeyboardInterrupt: 

In [211]:
def predict(model,x):
    x = torch.LongTensor(ch_tokenizer.to_ids(x))
    h,logits = model.encode(x)
    start = torch.LongTensor([en_tokenizer.bos_index])
    for _ in range(100):
        h_seq,logits_seq = model.decode(start,h)
        logits = logits_seq[-1]
        indexs = logits.argmax(dim=-1)
        if indexs.item() == en_tokenizer.eos_index:
            print('end')
            break
        indexs = indexs.unsqueeze(0)
        start = torch.cat((start,indexs),dim=-1)
    print(''.join(en_tokenizer.to_tokens(xi.item() for xi in start)))
    

In [212]:
predict(model,'手机的基本功能就是打电话')

<bos>and the starting the starting the starting the starting the starting the starting the starting the s


In [148]:
''.join(id2char[xi.item()] for xi in x)

'ou should consider getting a taste of research as an undergradua'

In [149]:
''.join(id2char[xi.item()] for xi in logits.argmax(dim=-1))

'nrwhould bomsider iot ing t Phnt rsf teaearch i  a dander ram rt'