In [None]:
import torch
from data import load_data_train
from attention import TransformerEncoder, TransformerDecoder
from seq2seq import EncoderDecoder, train_seq2seq, predict_seq2seq, bleu

def try_gpu(i=0):
    """
    Return gpu(i) if exists, otherwise return cpu()
    """
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

num_hiddens, num_blks, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, try_gpu()
ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
key_size, query_size, value_size = 32, 32, 32
norm_shape = [32]

train_iter, en_vocab, cn_vocab = load_data_train(batch_size, num_steps)

# train en to cn

encoder = TransformerEncoder(
    len(en_vocab),
    num_hiddens, 
    ffn_num_hiddens, 
    num_heads,
    num_blks, 
    dropout
)

decoder = TransformerDecoder(
    len(cn_vocab), 
    num_hiddens, 
    ffn_num_hiddens, 
    num_heads,
    num_blks, 
    dropout
)

transformer = EncoderDecoder(
    encoder, decoder
)

train_seq2seq(
    transformer, 
    train_iter, 
    lr, 
    num_epochs, 
    cn_vocab, 
    device
)

In [None]:
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
chis = ['我 走 了 。', '我 迷失 了 。', '他 很 冷 静 。', '我 回家 了 。']
for eng, chi in zip(engs, chis):
    translation = predict_seq2seq(
        eng, 
        transformer, 
        en_vocab, 
        cn_vocab, 
        num_steps, 
        device
    )
    print(f'{eng} => {translation}, expect {chi}, bleu {bleu(translation, [chi], k=2) * 100:.1f}')