# Lab 12-7 sequence to sequence with attention
### simple neural machine translation training 
* sequence to sequence
* variable input sequence length
* variable output sequence length
* Luong attention
  
### Reference
* [Sequence to Sequence Learning with Neural Networks](https://arxiv.org/abs/1409.3215)
* [Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025)

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.sequence import pad_sequences
from pprint import pprint

s2s = tf.contrib.seq2seq

### Prepairing dataset

In [2]:
sources = [['I', 'feel', 'hungry'],
     ['tensorflow', 'is', 'very', 'difficult'],
     ['tensorflow', 'is', 'a', 'framework', 'for', 'deep', 'learning'],
     ['tensorflow', 'is', 'very', 'fast', 'changing']]
targets = [['나는', '배가', '고프다'],
           ['텐서플로우는', '매우', '어렵다'],
           ['텐서플로우는', '딥러닝을', '위한', '프레임워크이다'],
           ['텐서플로우는', '매우', '빠르게', '변화한다']]

In [3]:
# vocabulary for sources
s_vocab = list(set(sum(sources, [])))
s_vocab.sort()
s_vocab = ['<pad>'] + s_vocab
source2idx = {word : idx for idx, word in enumerate(s_vocab)}
idx2source = {idx : word for idx, word in enumerate(s_vocab)}

pprint(source2idx)

{'<pad>': 0,
 'I': 1,
 'a': 2,
 'changing': 3,
 'deep': 4,
 'difficult': 5,
 'fast': 6,
 'feel': 7,
 'for': 8,
 'framework': 9,
 'hungry': 10,
 'is': 11,
 'learning': 12,
 'tensorflow': 13,
 'very': 14}


In [4]:
# vocabulary for targets
t_vocab = list(set(sum(targets, [])))
t_vocab.sort()
t_vocab = ['<pad>', '<bos>', '<eos>'] + t_vocab
target2idx = {word : idx for idx, word in enumerate(t_vocab)}
idx2target = {idx : word for idx, word in enumerate(t_vocab)}

pprint(target2idx)

{'<bos>': 1,
 '<eos>': 2,
 '<pad>': 0,
 '고프다': 3,
 '나는': 4,
 '딥러닝을': 5,
 '매우': 6,
 '배가': 7,
 '변화한다': 8,
 '빠르게': 9,
 '어렵다': 10,
 '위한': 11,
 '텐서플로우는': 12,
 '프레임워크이다': 13}


### Preprocessing dataset

In [5]:
def preprocess(sequences, max_len, dic, mode = 'source'):
    assert mode in ['source', 'target'], 'source와 target 중에 선택해주세요.'
    
    if mode == 'source':
        # preprocessing for source (encoder)
        s_input = list(map(lambda sentence : [dic.get(token) for token in sentence], sequences))
        s_len = list(map(lambda sentence : len(sentence), s_input))
        s_input = pad_sequences(sequences = s_input, maxlen = max_len, padding = 'post', truncating = 'post')
        return s_len, s_input
    
    elif mode == 'target':
        # preprocessing for target (decoder)
        # input
        t_input = list(map(lambda sentence : ['<bos>'] + sentence, sequences))
        t_input = list(map(lambda sentence : [dic.get(token) for token in sentence], t_input))
        t_len = list(map(lambda sentence : len(sentence), t_input))
        t_input = pad_sequences(sequences = t_input, maxlen = max_len, padding = 'post', truncating = 'post')
        
        # output
        t_output = list(map(lambda sentence : sentence + ['<eos>'], sequences))
        t_output = list(map(lambda sentence : [dic.get(token) for token in sentence], t_output))
        t_output = pad_sequences(sequences = t_output, maxlen = max_len, padding = 'post', truncating = 'post')
        
        return t_len, t_input, t_output

In [6]:
# preprocessing for source
s_max_len = 10
s_len, s_input = preprocess(sequences = sources,
                            max_len = s_max_len, dic = source2idx, mode = 'source')
print(s_len, s_input)

[3, 4, 7, 5] [[ 1  7 10  0  0  0  0  0  0  0]
 [13 11 14  5  0  0  0  0  0  0]
 [13 11  2  9  8  4 12  0  0  0]
 [13 11 14  6  3  0  0  0  0  0]]


In [7]:
# preprocessing for target
t_max_len = 12
t_len, t_input, t_output = preprocess(sequences = targets,
                                      max_len = t_max_len, dic = target2idx, mode = 'target')
print(t_len, t_input, t_output)

[4, 4, 5, 5] [[ 1  4  7  3  0  0  0  0  0  0  0  0]
 [ 1 12  6 10  0  0  0  0  0  0  0  0]
 [ 1 12  5 11 13  0  0  0  0  0  0  0]
 [ 1 12  6  9  8  0  0  0  0  0  0  0]] [[ 4  7  3  2  0  0  0  0  0  0  0  0]
 [12  6 10  2  0  0  0  0  0  0  0  0]
 [12  5 11 13  2  0  0  0  0  0  0  0]
 [12  6  9  8  2  0  0  0  0  0  0  0]]


### Creating graph

In [8]:
# hyper-parameters
epochs = 100
batch_size = 2
lr = .3

# input
data = tf.data.Dataset.from_tensor_slices((s_len, s_input, t_len, t_input, t_output))
data = data.shuffle(buffer_size = 10)
data = data.batch(batch_size = batch_size)
iterator = data.make_initializable_iterator()
s_mb_len, s_mb_input, t_mb_len, t_mb_input, t_mb_output = iterator.get_next()

In [9]:
## encoder
# hyper-parameters for encoder, decoder (sequence to sequence), one-hot encoding
n_of_classes = len(target2idx)
enc_hidden_dim = 10
dec_hidden_dim = 5

# one-hot encoding
s_embedding = tf.eye(num_rows = len(source2idx))
s_embedding = tf.get_variable(name = 'source_embedding', initializer = s_embedding, trainable = False)

# embedding layer
s_mb_batch = tf.nn.embedding_lookup(params = s_embedding, ids = s_mb_input)

# encoder (lstm_cell)
enc_cell = tf.nn.rnn_cell.LSTMCell(num_units = enc_hidden_dim, dtype = tf.float32)
enc_outputs, _ = tf.nn.dynamic_rnn(cell = enc_cell, inputs = s_mb_batch, sequence_length = s_mb_len,
                                 dtype = tf.float32)

In [10]:
## decoder
t_embedding = tf.eye(num_rows = len(target2idx))
t_embedding = tf.get_variable(name = 'target_embedding', initializer = t_embedding, trainable = False)

# embedding layer
t_mb_batch = tf.nn.embedding_lookup(params = t_embedding, ids = t_mb_input)

batch_size = tf.reduce_sum(tf.ones_like(tensor = s_mb_len, dtype = tf.int32))
tr_tokens = tf.tile(input = [t_max_len], multiples = [batch_size])
trans_tokens = tf.tile(input = [target2idx.get('<bos>')], multiples = [batch_size])

# decoder (lstm_cell) with Luong attention
dec_cell = tf.nn.rnn_cell.LSTMCell(num_units = dec_hidden_dim, dtype = tf.float32)
luong_attn = s2s.LuongAttention(num_units = dec_hidden_dim, memory = enc_outputs, 
                                memory_sequence_length = s_mb_len)
dec_attn_cell = s2s.AttentionWrapper(cell = dec_cell, attention_mechanism = luong_attn)
dec_attn_init_state = dec_attn_cell.zero_state(batch_size = batch_size, dtype = tf.float32)

In [11]:
# 추후에 keras.layers.Dense로 교체될 듯, 아직 tf.contrib.seq2seq package가 tf.keras.layers를 지원하지 않음
output_layer = tf.layers.Dense(units = n_of_classes) 

# decoder for training
tr_helper = s2s.TrainingHelper(inputs = t_mb_batch, sequence_length = tr_tokens)
tr_decoder = s2s.BasicDecoder(cell = dec_attn_cell, helper = tr_helper, initial_state = dec_attn_init_state,
                              output_layer = output_layer)
tr_outputs, _, _ = s2s.dynamic_decode(decoder = tr_decoder, impute_finished = True,
                                      maximum_iterations = t_max_len)

# decoder for translation
trans_helper = s2s.GreedyEmbeddingHelper(embedding = t_embedding,
                                         start_tokens = trans_tokens, end_token = target2idx.get('<eos>'))
trans_decoder = s2s.BasicDecoder(cell = dec_attn_cell, helper = trans_helper, initial_state = dec_attn_init_state,
                                 output_layer = output_layer)
trans_outputs, _, _ = s2s.dynamic_decode(decoder = trans_decoder, impute_finished = True,
                                         maximum_iterations = t_max_len * 2)

In [12]:
## loss
masking = tf.sequence_mask(lengths = t_mb_len,
                           maxlen = t_max_len, dtype = tf.float32)
loss = s2s.sequence_loss(logits = tr_outputs.rnn_output, 
                         targets = t_mb_output,
                         weights = masking)

In [13]:
## training
opt = tf.train.AdamOptimizer(learning_rate = lr)
training_op = opt.minimize(loss = loss)

### Training

In [14]:
sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
sess = tf.Session(config = sess_config)
sess.run(tf.global_variables_initializer())

In [15]:
tr_loss_hist = []

for epoch in range(epochs):
    sess.run(iterator.initializer)
    avg_tr_loss = 0
    tr_step = 0
    
    try:
        while True:
            _, tr_loss = sess.run(fetches = [training_op, loss])
            avg_tr_loss += tr_loss
            tr_step += 1
            
    except tf.errors.OutOfRangeError:
        pass
    
    avg_tr_loss /= tr_step
    tr_loss_hist.append(avg_tr_loss)
    
    if (epoch + 1) % 10 == 0:
        print('epoch : {:3}, tr_loss : {:.3f}'.format(epoch + 1, avg_tr_loss))             

epoch :  10, tr_loss : 1.122
epoch :  20, tr_loss : 0.416
epoch :  30, tr_loss : 0.317
epoch :  40, tr_loss : 0.265
epoch :  50, tr_loss : 0.322
epoch :  60, tr_loss : 0.210
epoch :  70, tr_loss : 0.114
epoch :  80, tr_loss : 0.117
epoch :  90, tr_loss : 0.082
epoch : 100, tr_loss : 0.006


### Accuracy

In [16]:
t_output_hat = sess.run(trans_outputs.sample_id,
                        feed_dict = {s_mb_len : s_len,
                                     s_mb_input : s_input})

In [17]:
list(map(lambda sentence : [idx2target.get(token) for token in sentence], t_output_hat))

[['나는', '배가', '고프다', '<eos>', '<pad>'],
 ['텐서플로우는', '매우', '어렵다', '<eos>', '<pad>'],
 ['텐서플로우는', '딥러닝을', '위한', '프레임워크이다', '<eos>'],
 ['텐서플로우는', '매우', '빠르게', '변화한다', '<eos>']]