![title](img/seq2seq.png)

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
PARAMS = {
    'embed_dims': 15,
    'rnn_size': 50,
    'num_layers': 1,
    'beam_width': 5,
    'clip_norm': 5.0,
    'batch_size': 128,
    'n_epochs': 60,
}

In [3]:
def read_data(path):
    with open(path, 'r', encoding='utf-8') as f:
        return f.read()

    
def build_map(data):
    specials = ['<PAD>', '<GO>',  '<EOS>', '<UNK>']
    chars = list(set([char for line in data.split('\n') for char in line]))
    idx2char = {idx: char for idx, char in enumerate(specials + chars)}
    char2idx = {char: idx for idx, char in idx2char.items()}
    return idx2char, char2idx


def preprocess_data():
    source = read_data('../temp/letters_source.txt')
    target = read_data('../temp/letters_target.txt')

    PARAMS['src_idx2char'], PARAMS['src_char2idx'] = build_map(source)
    PARAMS['tgt_idx2char'], PARAMS['tgt_char2idx'] = build_map(target)

    src_idx = [[PARAMS['src_char2idx'].get(char, 3) for char in line] for line in source.split('\n')]
    tgt_idx = [[PARAMS['tgt_char2idx'].get(char, 3) for char in line]+[2] for line in target.split('\n')]

    return src_idx, tgt_idx

In [4]:
def pad_sent_batch(sent_batch):
    max_sent_len = max([len(sent) for sent in sent_batch])
    padded_seqs = [(sent + [0]*(max_sent_len - len(sent))) for sent in sent_batch]
    return padded_seqs


def next_train_batch(src_idx, tgt_idx):
    for i in range(0, len(src_idx), PARAMS['batch_size']):
        padded_src = pad_sent_batch(src_idx[i: i+PARAMS['batch_size']])
        padded_tgt = pad_sent_batch(tgt_idx[i: i+PARAMS['batch_size']])
        yield padded_src, padded_tgt

        
def train_input_fn(src_idx, tgt_idx):
    dataset = tf.data.Dataset.from_generator(
        lambda: next_train_batch(src_idx, tgt_idx),
        (tf.int32, tf.int32),
        (tf.TensorShape([None, None]), tf.TensorShape([None, None])))
    dataset = dataset.repeat(PARAMS['n_epochs'])
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()

In [5]:
def clip_grads(loss):
    variables = tf.trainable_variables()
    grads = tf.gradients(loss, variables)
    clipped_grads, _ = tf.clip_by_global_norm(grads, PARAMS['clip_norm'])
    return zip(clipped_grads, variables)


def rnn_cell():
    def cell_fn():
        cell = tf.nn.rnn_cell.GRUCell(PARAMS['rnn_size'],
                                      kernel_initializer=tf.orthogonal_initializer())
        return cell
    return tf.nn.rnn_cell.MultiRNNCell([cell_fn() for _ in range(PARAMS['num_layers'])])


def dec_cell(enc_out, enc_seq_len):
    attention = tf.contrib.seq2seq.BahdanauAttention(
        num_units = PARAMS['rnn_size'],
        memory = enc_out,
        memory_sequence_length = enc_seq_len)
    
    return tf.contrib.seq2seq.AttentionWrapper(
        cell = rnn_cell(),
        attention_mechanism = attention,
        attention_layer_size = PARAMS['rnn_size'])


def dec_input(labels):
    x = tf.fill([tf.shape(labels)[0], 1], PARAMS['tgt_char2idx']['<GO>'])
    x = tf.to_int32(x)
    return tf.concat([x, labels[:, :-1]], 1)


def forward(inputs, labels, reuse, is_training):
    enc_seq_len = tf.count_nonzero(inputs, 1, dtype=tf.int32)
    dec_seq_len = tf.count_nonzero(labels, 1, dtype=tf.int32)
    batch_sz = tf.shape(inputs)[0]
    
    with tf.variable_scope('Encoder', reuse=reuse):
        embedding = tf.get_variable('lookup_table',
                                    [len(PARAMS['src_char2idx']), PARAMS['embed_dims']])
        x = tf.nn.embedding_lookup(embedding, inputs)
        enc_out, enc_state = tf.nn.dynamic_rnn(rnn_cell(), x, enc_seq_len, dtype=tf.float32)
        
    with tf.variable_scope('Decoder', reuse=reuse):
        output_proj = tf.layers.Dense(len(PARAMS['tgt_char2idx']))
        
        enc_out_t = tf.contrib.seq2seq.tile_batch(enc_out, PARAMS['beam_width'])
        enc_state_t = tf.contrib.seq2seq.tile_batch(enc_state, PARAMS['beam_width'])
        enc_seq_len_t = tf.contrib.seq2seq.tile_batch(enc_seq_len, PARAMS['beam_width'])
        
        _e_o = tf.cond(tf.constant(is_training), lambda: enc_out, lambda: enc_out_t)
        _e_s_l = tf.cond(tf.constant(is_training), lambda: enc_seq_len, lambda: enc_seq_len_t)
        
        cell = dec_cell(_e_o, _e_s_l)
        
        if is_training:
            init_state = cell.zero_state(batch_sz, tf.float32).clone(
                cell_state=enc_state)
            
            helper = tf.contrib.seq2seq.TrainingHelper(
                inputs = tf.nn.embedding_lookup(embedding, dec_input(labels)),
                sequence_length = dec_seq_len)
            decoder = tf.contrib.seq2seq.BasicDecoder(
                cell = cell,
                helper = helper,
                initial_state = init_state,
                output_layer = output_proj)
            decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = decoder,
                maximum_iterations = tf.reduce_max(dec_seq_len))
            return decoder_output.rnn_output
        else:
            init_state = cell.zero_state(batch_sz*PARAMS['beam_width'], tf.float32).clone(
                cell_state=enc_state_t)
            
            decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                cell = cell,
                embedding = embedding,
                start_tokens = tf.tile(tf.constant([PARAMS['tgt_char2idx']['<GO>']], tf.int32),
                                       [batch_sz]),
                end_token = PARAMS['tgt_char2idx']['<EOS>'],
                initial_state = init_state,
                beam_width = PARAMS['beam_width'],
                output_layer = output_proj)
            decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = decoder)
            return decoder_output.predicted_ids[:, :, 0]

In [6]:
def model_fn(features, labels, mode, params):
    if labels is None:
        labels = tf.placeholder(tf.int32, [None, None])
    
    logits = forward(features, labels, reuse=False, is_training=True)
    
    pred_ids = forward(features, labels, reuse=True, is_training=False)
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions=pred_ids)
        
    if mode == tf.estimator.ModeKeys.TRAIN:
        loss_op = tf.contrib.seq2seq.sequence_loss(logits = logits,
                                                   targets = labels,
                                                   weights = tf.to_float(tf.sign(labels)))
        train_op = tf.train.AdamOptimizer().apply_gradients(
            clip_grads(loss_op),
            global_step = tf.train.get_global_step())
        
        return tf.estimator.EstimatorSpec(
            mode=mode, loss=loss_op, train_op=train_op)

In [7]:
def infe_inps(str_li):
    max_len = max([len(s) for s in str_li])
    xs = [[PARAMS['src_char2idx'].get(c, 3) for c in s] for s in str_li]
    return tf.keras.preprocessing.sequence.pad_sequences(xs, max_len, padding='post')


def demo(xs, preds):
    for x, pred in zip(xs, preds):
        print('\nIN: {}'.format(x))
        print('OUT: {}'.format(' '.join([PARAMS['tgt_idx2char'][i] for i in pred])))
    

def main():
    src_idx, tgt_idx = preprocess_data()
    
    estimator = tf.estimator.Estimator(model_fn)
    
    estimator.train(lambda: train_input_fn(src_idx, tgt_idx))
    
    xs = ['apple', 'common', 'zhedong']
    
    preds = list(estimator.predict(tf.estimator.inputs.numpy_input_fn(
        x = infe_inps(xs),
        shuffle = False)))
    
    demo(xs, preds)


if __name__ == '__main__':
    main()

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpeiu7hwdc', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x11d90ce10>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 1 into /var/folders/sx/