In [22]:
import torch
import sys
sys.path.append('../..')
# from nn_with_transformers.function import optim, nn, random
from torch import nn, optim
import random
from generator.function.Seq2SeqBasic import EncoderRNN, DecoderRNN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SOS_token = 0
EOS_token = 1

In [23]:
def sentence2tensor(lang, sentence):
    indexes = [lang(word) for word in sentence.split()]
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)

def pair2tensor(pair):
    input_tensor = sentence2tensor(lan1, pair[0])
    target_tensor = sentence2tensor(lan2, pair[1])
    return (input_tensor, target_tensor)

def translate(s):
    t = [lan1(i) for i in s.split()]
    t.append(EOS_token)
    f = encoder.sample(t)
    s = decoder.sample(f)
    r = [lan2.idx2word[i] for i in s]
    return ' '.join(r)


In [24]:
class Vocabulary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {0: "<SOS>", 1: "<EOS>", -1: "<unk>"}
        self.idx = 2 # Count SOS and EOS

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def add_sentence(self, sentence):
        for word in sentence.split():
            self.add_word(word)

    def __call__(self, word):
        if not word in self.word2idx:
            return -1
        return self.word2idx[word]

    def __len__(self):
        return self.idx

In [25]:

lan1 = Vocabulary()
lan2 = Vocabulary()

data = [['你 很 聪明 。', 'you are very wise .'],
        ['我们 一起 打 游戏 。', 'let us play game together .'],
        ['你 太 刻薄 了 。', 'you are so mean .'],
        ['你 完全 正确 。', 'you are perfectly right .'],
        ['我 坚决 反对 妥协 。', 'i am strongly opposed to a compromise .'],
        ['他们 正在 看 电影 。', 'they are watching a movie .'],
        ['他 正在 看着 你 。', 'he is looking at you .'],
        ['我 怀疑 他 是否 会 来', 'i am doubtful whether he will come .']]


data = [['你 很 聪 明 。', 'you are very wise .'],
        ['我 们 一 起 打 游 戏 。', 'let us play game together .'],
        ['你 太 刻 薄 了 。', 'you are so mean .'],
        ['你 完 全 正 确 。', 'you are perfectly right .'],
        ['我 坚 决 反 对 妥 协 。', 'i am strongly opposed to a compromise .'],
        ['他 们 正 在 看 电 影 。', 'they are watching a movie .'],
        ['他 正 在 看 着 你 。', 'he is looking at you .'],
        ['我 怀 疑 他 是 否 会 来', 'i am doubtful whether he will come .']]

for i,j in data:
    lan1.add_sentence(i)
    lan2.add_sentence(j)
learning_rate = 0.001
hidden_size = 256

encoder = EncoderRNN(len(lan1),hidden_size).to(device)
decoder = DecoderRNN(hidden_size,len(lan2)).to(device)
params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = optim.Adam(params, lr=learning_rate)

loss = 0
criterion = nn.NLLLoss()
turns = 200
print_every = 20
print_loss_total = 0
training_pairs = [pair2tensor(random.choice(data)) for pair in range(turns)]

for turn in range(turns):
    optimizer.zero_grad()
    loss = 0
    x,y = training_pairs[turn]
    input_length = x.size(0)
    target_length = y.size(0)
    h = encoder.initHidden()
    for i in range(input_length):
        h = encoder(x[i],h)
    decoder_input = torch.LongTensor([SOS_token]).to(device)
    for i in range(target_length):
        decoder_output,h = decoder(decoder_input,h)
        topv, topi = decoder_output.topk(1)
        decoder_input = topi.squeeze().detach()
        loss += criterion(decoder_output, y[i])
        if decoder_input.item() == EOS_token:break   
    print_loss_total += loss.item()/target_length
    if (turn+1) % print_every == 0 :
        print("loss:{loss:,.4f}".format(loss=print_loss_total/print_every))
        print_loss_total = 0    
    loss.backward()
    optimizer.step()

for pr in data:
    print('>>',pr[0])
    print('==',pr[1])
    print('result:',translate(pr[0]))
    print()

loss:2.7567
loss:2.0884
loss:1.6270
loss:1.0138
loss:0.6485
loss:0.4325
loss:0.2443
loss:0.2057
loss:0.0727
loss:0.0530
>> 你 很 聪 明 。
== you are very wise .
result: <SOS> you are very wise . <EOS>

>> 我 们 一 起 打 游 戏 。
== let us play game together .
result: <SOS> let us play game together . <EOS>

>> 你 太 刻 薄 了 。
== you are so mean .
result: <SOS> you are so mean . <EOS>

>> 你 完 全 正 确 。
== you are perfectly right .
result: <SOS> you are perfectly right . <EOS>

>> 我 坚 决 反对 妥协 。
== i am strongly opposed to a compromise .
result: <SOS> i am strongly opposed to a compromise . <EOS>

>> 他 们 正 在 看 电 影 。
== they are watching a movie .
result: <SOS> they are watching a movie . <EOS>

>> 他 正 在 看 着 你 。
== he is looking at you .
result: <SOS> he is looking at you . <EOS>

>> 我 怀 疑 他 是 否 会 来
== i am doubtful whether he will come .
result: <SOS> i am doubtful whether he will come . <EOS>

