In [1]:
# https://smerity.com/articles/2016/google_nmt_arch.html

In [2]:
# -*- coding: utf-8 -*-
import math

import numpy as np
import tensorflow as tf

  from ._conv import register_converters as _register_converters


In [3]:
TOKEN_PAD="<p>"
TOKEN_START="<s>"
TOKEN_END="</s>"

extra_tokens = [TOKEN_PAD, TOKEN_START, TOKEN_END]
pad_token = extra_tokens.index(TOKEN_PAD)
start_token = extra_tokens.index(TOKEN_START)
end_token = extra_tokens.index(TOKEN_END)

src_vocab_size = 200
tgt_vocab_size = 200

embedding_size = 128
hidden_units = 128
layer_num = 8

beam_width = 3
batch_size = 12

encoder_residual_start_idx = 3
decoder_residual_start_idx = 2

is_train_mode = True
use_beamsearch_decode = False

# Placeholder
'''
inputs: [batch_size, max_time_steps]
inputs_length: [batch_size]
'''
encoder_inputs = tf.placeholder(dtype=tf.int32, shape=(batch_size, None), name='encoder_inputs')
encoder_inputs_length = tf.placeholder(dtype=tf.int32, shape=(batch_size,), name='encoder_inputs_length')

decoder_inputs = tf.placeholder(dtype=tf.int32, shape=(batch_size, None), name='decoder_inputs')
decoder_inputs_length = tf.placeholder(dtype=tf.int32, shape=(batch_size,), name='decoder_inputs_length')
decoder_inputs_length_train = decoder_inputs_length + 1

decoder_start_token = tf.ones(shape=[batch_size, 1], dtype=tf.int32) * start_token
decoder_end_token = tf.ones(shape=[batch_size, 1], dtype=tf.int32) * end_token  
decoder_targets = tf.concat([decoder_inputs, decoder_end_token], axis=1)

keep_probability = tf.placeholder(dtype=tf.float32, shape=[], name='keep_probability')

# Embedding
with tf.variable_scope('embedding_layer'):
    
    sqrt3 = math.sqrt(3)
    initializer = tf.random_uniform_initializer(-sqrt3, sqrt3, dtype=tf.float32)
    
    embedding_encoder = tf.get_variable(name="embedding_encoder",
                                        shape=[src_vocab_size, embedding_size], 
                                        dtype=tf.float32,
                                        initializer=initializer,
                                        trainable=True)
    encoder_embeddding_inputs = tf.nn.embedding_lookup(params=embedding_encoder,
                                                       ids=encoder_inputs)
    
    embedding_decoder = tf.get_variable(name="embedding_decoder",
                                        shape=[tgt_vocab_size, embedding_size], 
                                        dtype=tf.float32,
                                        initializer=initializer,
                                        trainable=True)
    decoder_embeddding_inputs = tf.nn.embedding_lookup(params=embedding_decoder,
                                                       ids=decoder_inputs)



def print_tuple_state(tuple_state):
    print('len(layer_num): ', len(tuple_state))
    for state in tuple_state:
        if len(state) > 1:
            print('c:', state[0].get_shape())
            print('h:', state[1].get_shape())  
        else:
            print(state.get_shape())
    print('\n')
    return tuple_state
    
def build_single_cell(hidden_units, keep_probability, use_residual=True):
    cell = tf.contrib.rnn.BasicLSTMCell(hidden_units)
    cell = tf.contrib.rnn.DropoutWrapper(cell, 
                                         dtype=tf.float32,
                                         output_keep_prob=keep_probability)
    if use_residual:
        cell = tf.contrib.rnn.ResidualWrapper(cell)
    return cell

def attn_decoder_input_fn(inputs, attention):
    # Essential when use_residual=True
    #print('attn_decoder_input_fn - inputs:', inputs)
    #print('attn_decoder_input_fn - attention:', attention)

    attn_decoder_input = tf.concat([inputs, attention], -1)
    attn_decoder_input = tf.layers.dense(attn_decoder_input, hidden_units, name='attn_decoder_input')
    return attn_decoder_input
    
# Projection Layer
with tf.variable_scope('projection_layer'):
    input_layer = tf.layers.Dense(hidden_units, dtype=tf.float32, name='input_projection')
    output_layer = tf.layers.Dense(tgt_vocab_size, dtype=tf.float32, name='output_projection')

# Encoder
with tf.variable_scope('encoder'):
    
    encoder_embeddding_inputs = input_layer(encoder_embeddding_inputs)
        
    forward_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_units)
    backward_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_units)
    
    bi_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn(forward_cell, 
                                                            backward_cell, 
                                                            encoder_embeddding_inputs,
                                                            sequence_length=encoder_inputs_length,
                                                            dtype=tf.float32,
                                                            time_major=False)
    bi_encoder_outputs = tf.concat(bi_outputs, -1)
    
    encoder_cell_list = [build_single_cell(hidden_units, keep_probability, use_residual=False)]
    encoder_cell_list.extend([build_single_cell(hidden_units, keep_probability) for i in range(layer_num - encoder_residual_start_idx)])
    
    encoder_cell = tf.contrib.rnn.MultiRNNCell(encoder_cell_list)
    
    encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(cell=encoder_cell,
                                                       inputs=bi_encoder_outputs,
                                                       sequence_length=encoder_inputs_length,
                                                       dtype=tf.float32,
                                                       time_major=False)


# Decoder
with tf.variable_scope('decoder'):
    
    # train
    if is_train_mode:
        decoder_cell_list = [build_single_cell(hidden_units, keep_probability, use_residual=True) for i in range(decoder_residual_start_idx)]
        decoder_cell_list.extend([build_single_cell(hidden_units, keep_probability) for i in range(layer_num - decoder_residual_start_idx)])
        attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=hidden_units,
                                                            memory=encoder_outputs, 
                                                            memory_sequence_length=encoder_inputs_length)

        
        for i in range(len(decoder_cell_list)):
            decoder_cell_list[i] = tf.contrib.seq2seq.AttentionWrapper(
                cell=decoder_cell_list[i],
                attention_mechanism=attention_mechanism,
                attention_layer_size=hidden_units,
                cell_input_fn=attn_decoder_input_fn,
                alignment_history=False,
                name='Attention_Wrapper')
        
        decoder_cell = tf.contrib.rnn.MultiRNNCell(decoder_cell_list)
        
        decoder_initial_state = decoder_cell.zero_state(batch_size=batch_size, dtype=tf.float32)
        decoder_initial_state = tuple(decoder_initial_state) 

        print('decoder_cell:', decoder_cell)
        print('decoder_initial_state(AttentionWrapperState, cell_state, c or h)')
        print('decoder_initial_state length(layer_num): ', len(decoder_initial_state))
        print('decoder_initial_state[0][0][0].get_shape(): ', decoder_initial_state[0][0][0].get_shape())
        print('decoder_initial_state[0][0][1].get_shape(): ', decoder_initial_state[0][0][1].get_shape())

        decoder_embeddding_inputs = input_layer(decoder_embeddding_inputs)
        training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_embeddding_inputs,
                                                            sequence_length=decoder_inputs_length,
                                                            time_major=False,
                                                            name='training_helper')
        
        training_decoder = tf.contrib.seq2seq.BasicDecoder(cell=decoder_cell,
                                                           helper=training_helper,
                                                           initial_state=decoder_initial_state,
                                                           output_layer=output_layer)
        
        max_decoder_length = tf.reduce_max(decoder_inputs_length_train)
        decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=training_decoder,
                                                              output_time_major=False,
                                                              impute_finished=True,
                                                              maximum_iterations=max_decoder_length)


        crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=decoder_targets, 
                                                                  logits=decoder_outputs.rnn_output)
        masks = tf.sequence_mask(lengths=decoder_inputs_length_train, 
                                 maxlen=max_decoder_length, 
                                 dtype=tf.float32, 
                                 name='masks')
        loss = (tf.reduce_sum(crossent * masks) / batch_size)
        tf.summary.scalar('loss', loss)
        
        
    # inference
    else:
        
        if use_beamsearch_decode:
            encoder_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width) 
            encoder_final_state = tf.contrib.framework.nest.map_structure(
                lambda state: tf.contrib.seq2seq.tile_batch(state, beam_width), 
                encoder_final_state)
            encoder_inputs_length = tf.contrib.seq2seq.tile_batch(encoder_inputs_length, multiplier=beam_width)
            #print_tuple_state(encoder_final_state)

            
        decoder_cell_list = [build_single_cell(hidden_units, keep_probability, use_residual=True) for i in range(decoder_residual_start_idx)]
        decoder_cell_list.extend([build_single_cell(hidden_units, keep_probability) for i in range(layer_num - decoder_residual_start_idx)])
        
        attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units=hidden_units,
                                                                memory=encoder_outputs, 
                                                                memory_sequence_length=encoder_inputs_length)
            
        decoder_initial_state = []
        for i in range(len(decoder_cell_list)):
            decoder_cell_list[i] = tf.contrib.seq2seq.AttentionWrapper(
                cell=decoder_cell_list[i],
                attention_mechanism=attention_mechanism,
                attention_layer_size=hidden_units,
                cell_input_fn=attn_decoder_input_fn,
                alignment_history=False,
                name='Attention_Wrapper')
                
                
        decoder_cell = tf.contrib.rnn.MultiRNNCell(decoder_cell_list)
        
        if not use_beamsearch_decode:
            beam_batch_size = batch_size
        else:
            beam_batch_size = batch_size * beam_width
        
        
        decoder_initial_state = decoder_cell.zero_state(batch_size=beam_batch_size, dtype=tf.float32)
        decoder_initial_state = tuple(decoder_initial_state) 

        print('decoder_cell:', decoder_cell)
        print('decoder_initial_state(AttentionWrapperState, cell_state, c or h)')
        print('decoder_initial_state length(layer_num): ', len(decoder_initial_state))
        print('decoder_initial_state[0][0][0].get_shape(): ', decoder_initial_state[0][0][0].get_shape())
        print('decoder_initial_state[0][0][1].get_shape(): ', decoder_initial_state[0][0][1].get_shape())
        
        
        start_tokens = tf.fill([batch_size], start_token)
        if not use_beamsearch_decode:
            decoding_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(start_tokens=start_tokens,
                                                                       end_token=end_token,
                                                                       embedding=embedding_decoder)
            
            inference_decoder = tf.contrib.seq2seq.BasicDecoder(cell=decoder_cell,
                                                               helper=decoding_helper,
                                                               initial_state=decoder_initial_state,
                                                               output_layer=output_layer)
        else:
            inference_decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell=decoder_cell,
                                                                     embedding=embedding_decoder,
                                                                     start_tokens=start_tokens,
                                                                     end_token=end_token,
                                                                     initial_state=decoder_initial_state,
                                                                     beam_width=beam_width,
                                                                     output_layer=output_layer,)

        max_decoder_length = tf.reduce_max(decoder_inputs_length_train)
        decoder_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=inference_decoder,
                                                                  output_time_major=False,
                                                                  maximum_iterations=max_decoder_length)



        if not use_beamsearch_decode:
            decoder_pred_decode = tf.expand_dims(decoder_outputs.sample_id, -1)
        else:
            decoder_pred_decode = decoder_outputs.predicted_ids


decoder_cell: <tensorflow.python.ops.rnn_cell_impl.MultiRNNCell object at 0x123bdfa20>
decoder_initial_state(AttentionWrapperState, cell_state, c or h)
decoder_initial_state length(layer_num):  8
decoder_initial_state[0][0][0].get_shape():  (12, 128)
decoder_initial_state[0][0][1].get_shape():  (12, 128)


In [None]:
'''
GNMT Attention Cell로는 Bhadanau style을 사용함
attention_layer_size = None => attention layer 사용안함, Bhadanau style
output_attention = False => output으로 attention이 아닌 cell output을 사용함, Bhadanau style
'''
attention_cell = tf.contrib.seq2seq.AttentionWrapper(
        attention_cell,
        attention_mechanism,
        attention_layer_size=None,  # don't use attention layer.
        output_attention=False,
        alignment_history=alignment_history,
        name="attention")

    if attention_architecture == "gnmt":
      cell = GNMTAttentionMultiCell(
          attention_cell, cell_list)
    elif attention_architecture == "gnmt_v2":
      cell = GNMTAttentionMultiCell(
          attention_cell, cell_list, use_new_attention=True)

    
    
# https://github.com/tensorflow/nmt/blob/master/nmt/gnmt_model.py
# 가져다가 쓰면 Decoder쪽에  GNMT Attention MultiCell로 사용 
# GNMT에서는 맨 아래의 셀만 AttentionWrapper로 감싸주고 나머지는 Cell은 state만
# 복사해서 input과 concat
class GNMTAttentionMultiCell(tf.nn.rnn_cell.MultiRNNCell):
  """A MultiCell with GNMT attention style."""

  def __init__(self, attention_cell, cells, use_new_attention=False):
    """Creates a GNMTAttentionMultiCell.
    Args:
      attention_cell: An instance of AttentionWrapper.
      cells: A list of RNNCell wrapped with AttentionInputWrapper.
      use_new_attention: Whether to use the attention generated from current
        step bottom layer's output. Default is False.
    """
    cells = [attention_cell] + cells
    self.use_new_attention = use_new_attention
    super(GNMTAttentionMultiCell, self).__init__(cells, state_is_tuple=True)

  def __call__(self, inputs, state, scope=None):
    """Run the cell with bottom layer's attention copied to all upper layers."""
    if not nest.is_sequence(state):
      raise ValueError(
          "Expected state to be a tuple of length %d, but received: %s"
          % (len(self.state_size), state))

    with tf.variable_scope(scope or "multi_rnn_cell"):
      new_states = []

      with tf.variable_scope("cell_0_attention"):
        attention_cell = self._cells[0] #AttentionWrapper Cell
        attention_state = state[0] # AttentionWrapper Cell의 이전(t-1) attention_state
        
        #현재(t) Input과 이전(t-1) attention_state 를 넣으면 
        #cur_inp, new_attention_state 가 나오는데 
        #여기서 cur_inp은 cell아웃풋이면서 attention 값이고, 
        #마찬가지로 new_attention_state안에는 next_cell_state(Cell state)과 attention이 둘 다 있음
        cur_inp, new_attention_state = attention_cell(inputs, attention_state) 
        new_states.append(new_attention_state)

    # AttentionWrapper Cell위에 쌓여있는 cell들에게는
    # input값으로 cur_input과 첫번째레이어(AttentionWrapper)의 attenion값이 같이 들어감
      for i in range(1, len(self._cells)):
        with tf.variable_scope("cell_%d" % i):

          cell = self._cells[i]
          cur_state = state[i]
            
        
          if self.use_new_attention:
            cur_inp = tf.concat([cur_inp, new_attention_state.attention], -1)
          else:
            cur_inp = tf.concat([cur_inp, attention_state.attention], -1)

          cur_inp, new_state = cell(cur_inp, cur_state)
          new_states.append(new_state)

    return cur_inp, tuple(new_states)
