In [1]:
# jesseLiu2000
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tkinter import _flatten


class Seq2Seq(nn.Module):
    def __init__(self):
        super(Seq2Seq, self).__init__()
        self.embedding = nn.Embedding(n_seq, embedding_size)
        self.encoder = nn.RNN(input_size=embedding_size, hidden_size=hidden_size, num_layers=6, dropout=0.1)
        self.decoder = nn.RNN(input_size=embedding_size, hidden_size=hidden_size, num_layers=6, dropout=0.1)
        self.fc1 = nn.Linear(hidden_size, n_seq)

    def forward(self, encoder_input, hidden_state, decoder_input):
        encoder_embedding = self.embedding(encoder_input) # [batch_size, seq_length, embedding_size]
        encoder_embedding = encoder_embedding.permute(1, 0, 2) # [seq_length, batch_size, embedding_size]
        decoder_embedding = self.embedding(decoder_input)
        decoder_embedding = decoder_embedding.permute(1, 0, 2)

        _, encoder_output = self.encoder(encoder_embedding, hidden_state) # [num_layers * num_directions, batch_size, hidden_size]
        decoder_output, _ = self.decoder(decoder_embedding, encoder_output) # [max_seq_length, batch_size, num_directions * hidden_size]

        output = self.fc1(decoder_output) # [max_seq_length, batch_size, n_class]
        return output



if __name__ == '__main__':
    max_length = 7

    alphabet = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
    word2index = {w: i for i, w in enumerate(alphabet)}
    index2word = {i: w for i, w in enumerate(alphabet)}
    file_lst = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'],
                ['high', 'low']]

    # [batch_size, seg_length]
    en_input = [[word2index[word] for word in (sent[0]+("P"*(max_length-len(sent[0]))))] for sent in file_lst]
    de_input = [[word2index[word] for word in (('S'+sent[1])+("P"*(max_length-len('S'+sent[1]))))] for sent in file_lst]
    targets = [[word2index[word] for word in (sent[1]+("P"*(max_length-len(sent[1]+'E')))+'E')] for sent in file_lst]
    en_input = torch.LongTensor(en_input)
    de_input = torch.LongTensor(de_input)
    targets = torch.LongTensor(targets)
    state = None

    n_seq = len(word2index)
    batch_size = len(file_lst)
    embedding_size = 64
    hidden_size = 128

    lr = 1e-3
    epoch = 5000

    model = Seq2Seq()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    #Train
    for ep in range(epoch):
        optimizer.zero_grad()

        output = model(en_input, state, de_input)
        output = output.permute(1, 0, 2) # [batch_size, max_seq_length, n_class]
        # print(output.size())
        # print(targets.size())
        loss = 0
        for i in range(len(targets)):
            loss += criterion(output[i], targets[i])

        if (ep + 1) % 1000 == 0:
            print('Epoch:', '%04d' % (ep + 1), 'cost =', '{:.6f}'.format(loss))

        loss.backward()
        optimizer.step()

    #Test
    test_lst = ['men', 'black']
    state = None
    de_input = torch.LongTensor([[0, 2, 2, 2, 2, 2, 2], [0, 2, 2, 2, 2, 2, 2]])
    en_input = [[word2index[word] for word in (sent+(max_length-len(sent))*'P')] for sent in test_lst]
    en_input = torch.LongTensor(en_input)
    output = model(en_input, state, de_input)
    predict = output.data.max(2, keepdim=True)[1]
    # map_pre = [alphabet[i] for i in predict]
    predict = predict.permute(1, 0, 2)
    for i in range(len(predict)):
        pre_str = [index2word[int(lt)] for word in predict[i] for lt in word ]
        # pre_str = _flatten(pre_str)
        pre_str = ''.join(pre_str)
        print(pre_str.replace('E','').replace('P', ''))






Epoch: 1000 cost = 0.002546
Epoch: 2000 cost = 0.000730
Epoch: 3000 cost = 0.000323
Epoch: 4000 cost = 0.000165
Epoch: 5000 cost = 0.000091
woww
whitw
