In [1]:
# code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor
import torch
import numpy as np
import torch.nn as nn
import torch.utils.data as Data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# ?: Symbol that will fill in blank sequence if current batch data size is short than n_step

In [2]:
letter = [c for c in 'SE?abcdefghijklmnopqrstuvwxyz']
letter2idx = {n: i for i, n in enumerate(letter)}

seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low'],['big','small']]

# Seq2Seq Parameter
n_step = max([max(len(i), len(j)) for i, j in seq_data]) # max_len(=5)
n_hidden = 128
n_class = len(letter2idx) # classfication problem
batch_size = 3

In [7]:
letter2idx

{'S': 0,
 'E': 1,
 '?': 2,
 'a': 3,
 'b': 4,
 'c': 5,
 'd': 6,
 'e': 7,
 'f': 8,
 'g': 9,
 'h': 10,
 'i': 11,
 'j': 12,
 'k': 13,
 'l': 14,
 'm': 15,
 'n': 16,
 'o': 17,
 'p': 18,
 'q': 19,
 'r': 20,
 's': 21,
 't': 22,
 'u': 23,
 'v': 24,
 'w': 25,
 'x': 26,
 'y': 27,
 'z': 28}

In [3]:
def make_data(seq_data):
    enc_input_all, dec_input_all, dec_output_all = [], [], []

    for seq in seq_data:
        for i in range(2):
            seq[i] = seq[i] + '?' * (n_step - len(seq[i])) # 'man??', 'women'

        enc_input = [letter2idx[n] for n in (seq[0]+'E')] # ['m', 'a', 'n', '?', '?', 'E']
        dec_input = [letter2idx[n] for n in ('S' + seq[1])] # ['S', 'w', 'o', 'm', 'e', 'n']
        dec_output = [letter2idx[n] for n in (seq[1] + 'E')] # ['w', 'o', 'm', 'e', 'n', 'E']

        enc_input_all.append(enc_input)
        dec_input_all.append(dec_input)
        dec_output_all.append(dec_output) # not one-hot

    # make tensor
    return torch.LongTensor(enc_input_all), torch.LongTensor(dec_input_all), torch.LongTensor(dec_output_all)

'''
enc_input_all: [6, n_step+1 (because of 'E'), n_class]
dec_input_all: [6, n_step+1 (because of 'S'), n_class]
dec_output_all: [6, n_step+1 (because of 'E')]
'''
enc_input_all, dec_input_all, dec_output_all = make_data(seq_data)

In [12]:
dec_input_all.shape

torch.Size([7, 6])

In [4]:
class TranslateDataSet(Data.Dataset):
    def __init__(self, enc_input_all, dec_input_all, dec_output_all):
        self.enc_input_all = enc_input_all
        self.dec_input_all = dec_input_all
        self.dec_output_all = dec_output_all
    
    def __len__(self): # return dataset size
        return len(self.enc_input_all)
    
    def __getitem__(self, idx):
        return self.enc_input_all[idx], self.dec_input_all[idx], self.dec_output_all[idx]

loader = Data.DataLoader(TranslateDataSet(enc_input_all, dec_input_all, dec_output_all), batch_size, True)

In [5]:
enc_input_all.shape

torch.Size([7, 6])

In [6]:
class Seq2SeqEncoder(nn.Module):
    def __init__(self,vocab_size,embed_size,num_hiddens):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.rnn = nn.GRU(embed_size,num_hiddens)
        
    def forward(self,X):
        X = self.embedding(X)
        X = X.permute(1,0,2)
        # output的形状:(num_steps,batch_size,num_hiddens)
        # state的形状:(num_layers,batch_size,num_hiddens)
        output,state = self.rnn(X)
        return output,state

In [7]:
class Seq2SeqDecoder(nn.Module):
    def __init__(self,vocab_size,embed_size,num_hiddens):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.rnn = nn.GRU(embed_size+num_hiddens,num_hiddens)
        self.dense = nn.Linear(num_hiddens,vocab_size)
        
    def init_state(self,enc_outputs):
        return enc_outputs[1]
    
    def forward(self,X,state):
        X = self.embedding(X).permute(1,0,2)
        context = state[-1].repeat(X.shape[0],1,1)
        X_and_context = torch.cat((X,context),dim=2)
        output,state = self.rnn(X_and_context,state)
        output = self.dense(output).permute(1,0,2)
        # output的形状:(batch_size,num_steps,vocab_size)
        # state的形状:(num_layers,batch_size,num_hiddens)
        return output,state

In [8]:
class EncoderDecoder(nn.Module):
    """编码器-解码器架构的基类"""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args) # 这里的outputs输出的是一个元组，1代表encoder的state
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)

In [10]:
encoder = Seq2SeqEncoder(n_class,5,n_hidden)
decoder = Seq2SeqDecoder(n_class,5,n_hidden)

In [11]:
model = EncoderDecoder(encoder,decoder).to(device)

In [12]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

In [13]:
def xavier_init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
    if type(m) == nn.GRU:
        for param in m._flat_weights_names:
            if "weight" in param:
                nn.init.xavier_uniform_(m._parameters[param])

In [14]:
model.apply(xavier_init_weights)
model.train()
for epoch in range(50000):
    for enc_input_batch, dec_input_batch, dec_output_batch in loader:
      # make hidden shape [num_layers * num_directions, batch_size, n_hidden]
        (enc_input_batch, dec_intput_batch, dec_output_batch) = (enc_input_batch.to(device), dec_input_batch.to(device), dec_output_batch.to(device))
      # enc_input_batch : [batch_size, n_step+1, n_class]
      # dec_intput_batch : [batch_size, n_step+1, n_class]
      # dec_output_batch : [batch_size, n_step+1], not one-hot
        pred,_ = model(enc_input_batch, dec_intput_batch)
      # pred : [n_step+1, batch_size, n_class]
        loss = 0
        for i in range(len(dec_output_batch)):
          # pred[i] : [n_step+1, n_class]
          # dec_output_batch[i] : [n_step+1]
            loss += criterion(pred[i], dec_output_batch[i])
    if (epoch + 1) % 1000 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Epoch: 1000 cost = 3.433763
Epoch: 2000 cost = 3.353740
Epoch: 3000 cost = 3.415884
Epoch: 4000 cost = 3.327003
Epoch: 5000 cost = 3.415658
Epoch: 6000 cost = 3.383410
Epoch: 7000 cost = 3.383352
Epoch: 8000 cost = 3.343417
Epoch: 9000 cost = 3.353258
Epoch: 10000 cost = 3.415230
Epoch: 11000 cost = 3.347921
Epoch: 12000 cost = 3.347865
Epoch: 13000 cost = 3.353030
Epoch: 14000 cost = 3.326344
Epoch: 15000 cost = 3.432687
Epoch: 16000 cost = 3.432620
Epoch: 17000 cost = 3.414704
Epoch: 18000 cost = 3.347489
Epoch: 19000 cost = 3.352634
Epoch: 20000 cost = 3.414444
Epoch: 21000 cost = 3.352488
Epoch: 22000 cost = 3.352410
Epoch: 23000 cost = 3.343051
Epoch: 24000 cost = 3.414085
Epoch: 25000 cost = 3.431927
Epoch: 26000 cost = 3.347023
Epoch: 27000 cost = 3.352031
Epoch: 28000 cost = 3.351951
Epoch: 29000 cost = 3.413639
Epoch: 30000 cost = 3.351785
Epoch: 31000 cost = 3.413450
Epoch: 32000 cost = 3.351614
Epoch: 33000 cost = 3.325371
Epoch: 34000 cost = 3.346564
Epoch: 35000 cost = 3.4

In [18]:
# Test
def translate(word):
    #model.eval()
    enc_input, dec_input, _ = make_data([[word, '?' * n_step]])
    enc_input, dec_input = enc_input.to(device), dec_input.to(device)
    enc_outputs = model.encoder(enc_input)
    dec_state = model.decoder.init_state(enc_outputs)
    dec_input = torch.LongTensor([[0]]).to(device)
    print(dec_input)
    res = ''
    for _ in range(6):
        Y,dec_state = model.decoder(dec_input,dec_state)
        dec_input = Y.argmax(2)
        res+= letter[int(dec_input[0][0])]
    return res

        
print('test')
print('man ->', translate('man'))
print('mans ->', translate('mans'))
print('king ->', translate('king'))
print('black ->', translate('black'))
print('up ->', translate('up'))

test
tensor([[0]], device='cuda:0')
man -> sqiofj
tensor([[0]], device='cuda:0')
mans -> uqiuzx
tensor([[0]], device='cuda:0')
king -> sqioff
tensor([[0]], device='cuda:0')
black -> uqiiof
tensor([[0]], device='cuda:0')
up -> sqiofx
