In [55]:
import torch
import torch.nn as nn
import torch.functional as F
from torch.autograd import Variable
import torch.optim as optim
import numpy as np

In [59]:
chars = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
char_dic = {w:i for i,w in enumerate(chars)}
pairs = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]

#parameters
n_step = 5
n_hidden = 128
n_class = len(char_dic)
batch_size = len(pairs)


def make_batch(pairs):
    input_batch = []
    output_batch = []
    target_batch = []
    for pair in pairs:
        for i in range(2):
            pair[i] = pair[i]+'P'*(n_step-len(pair[i]))
        input = [char_dic[n] for n in pair[0]]
        output = [char_dic[n] for n in ('S'+pair[1])]
        target = [char_dic[n] for n in (pair[1]+'E')]
        input_batch.append(np.eye(n_class)[input])
        output_batch.append(np.eye(n_class)[output])
        target_batch.append(target)
        
    return Variable(torch.Tensor(input_batch)),Variable(torch.Tensor(output_batch)),Variable(torch.Tensor(target_batch).type(torch.LongTensor))

class Seq2Seq(nn.Module):
    def __init__(self):
        super(Seq2Seq,self).__init__()
        self.enc_cell = nn.RNN(input_size=n_class,hidden_size=n_hidden,dropout=0.5)
        self.dec_cell = nn.RNN(input_size=n_class,hidden_size=n_hidden,dropout=0.5)
        self.fc = nn.Linear(n_hidden,n_class)
    
    def forward(self,enc_input,enc_hidden,dec_input):
        enc_input = enc_input.transpose(0,1)#[batch_size,n_step,hidden_size]->[n_step,batch_size,hidden_size]
        dec_input = dec_input.transpose(0,1)#[n_step,batch_size,hidden_size]
        
        #enc_states : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        _,enc_states = self.enc_cell(enc_input,enc_hidden)
        #dec_output : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        dec_output,_ = self.dec_cell(dec_input,enc_states)
        
        model = self.fc(dec_output)
        return model  # model : [max_len+1(=6), batch_size, n_class]
    
def translate(word):
    input_batch,output_batch,_ = make_batch([[word,'P'*len(word)]])
    hidden = Variable(torch.zeros(1,1,n_hidden))
    output = model(input_batch,hidden,output_batch)
    predict = output.data.max(2,keepdim=True)[1]
    decoded = [chars[i] for i in predict]
    end = decoded.index('E')
    translated = ''.join(decoded[:end])
    return translated.replace('P', '')



In [57]:
model = Seq2Seq()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr = 0.01)

input_batch,output_batch,target_batch = make_batch(pairs)

for epoch in range(5000):
    optimizer.zero_grad()
    enc_hidden = Variable(torch.zeros(1,batch_size,n_hidden))
    output = model(input_batch,enc_hidden,output_batch)
    output = output.transpose(0,1)
    loss = 0
    for i in range(len(target_batch)):
        loss += criterion(output[i],target_batch[i])
    if (epoch+1)%1000 == 0:
        print('epoch:%d,loss:%f' % (epoch+1,loss))
    loss.backward()
    optimizer.step()

  "num_layers={}".format(dropout, num_layers))


epoch:100,loss:0.010268
epoch:200,loss:0.003431
epoch:300,loss:0.001938
epoch:400,loss:0.001276
epoch:500,loss:0.000915
epoch:600,loss:0.000694
epoch:700,loss:0.000547
epoch:800,loss:0.000443
epoch:900,loss:0.000367
epoch:1000,loss:0.000309
epoch:1100,loss:0.000265
epoch:1200,loss:0.000228
epoch:1300,loss:0.000199
epoch:1400,loss:0.000174
epoch:1500,loss:0.000154
epoch:1600,loss:0.000138
epoch:1700,loss:0.000125
epoch:1800,loss:0.000111
epoch:1900,loss:0.000100
epoch:2000,loss:0.000091
epoch:2100,loss:0.000083
epoch:2200,loss:0.000076
epoch:2300,loss:0.000070
epoch:2400,loss:0.000063
epoch:2500,loss:0.000058
epoch:2600,loss:0.000054
epoch:2700,loss:0.000050
epoch:2800,loss:0.000047
epoch:2900,loss:0.000044
epoch:3000,loss:0.000038
epoch:3100,loss:0.000036
epoch:3200,loss:0.000034
epoch:3300,loss:0.000032
epoch:3400,loss:0.000030
epoch:3500,loss:0.000028
epoch:3600,loss:0.000025
epoch:3700,loss:0.000024
epoch:3800,loss:0.000023
epoch:3900,loss:0.000021
epoch:4000,loss:0.000020
epoch:410

In [60]:
print('test')
print('man ->', translate('man'))
print('mans ->', translate('mans'))
print('king ->', translate('king'))
print('black ->', translate('black'))
print('upp ->', translate('upp'))

test
man -> women
mans -> women
king -> queen
black -> white
upp -> down
