![title](img/pointer_net.png)

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
PARAMS = {
    'max_len': 15,
    'embed_dims': 15,
    'rnn_size': 50,
    'clip_norm': 5.0,
    'batch_size': 128,
    'n_epochs': 100,
}

In [3]:
def read_data(path):
    with open(path, 'r', encoding='utf-8') as f:
        return f.read()

    
def build_map(data):
    specials = ['<PAD>', '<GO>',  '<EOS>', '<UNK>']
    chars = list(set([char for line in data.split('\n') for char in line]))
    chars = sorted(chars)
    idx2char = {idx: char for idx, char in enumerate(specials + chars)}
    char2idx = {char: idx for idx, char in idx2char.items()}
    return idx2char, char2idx


def preprocess_data():
    source = read_data('../temp/letters_source.txt')
    target = read_data('../temp/letters_target.txt')

    PARAMS['src_idx2char'], PARAMS['src_char2idx'] = build_map(source)
    
    src_indices, tgt_indices = [], []
    src_seq_lens, tgt_seq_lens = [], []
    
    for src_line, tgt_line in zip(source.split('\n'), target.split('\n')):
        src_idx = [PARAMS['src_char2idx'].get(c, 3) for c in src_line] + [2]
        src_seq_lens.append(len(src_idx))
        src_idx = src_idx + [0] * (PARAMS['max_len']-len(src_idx))
        
        tgt_idx = [PARAMS['src_char2idx'].get(c, 3) for c in tgt_line] + [2]
        tgt_seq_lens.append(len(tgt_idx))
        tgt_idx = tgt_idx + [0] * (PARAMS['max_len']-len(tgt_idx))
        tgt_idx = [src_idx.index(t) for t in tgt_idx]
        
        src_indices.append(src_idx)
        tgt_indices.append(tgt_idx)
    
    return (np.array(src_indices),
            np.array(tgt_indices),
            np.array(src_seq_lens),
            np.array(tgt_seq_lens))

In [4]:
def clip_grads(loss):
    variables = tf.trainable_variables()
    grads = tf.gradients(loss, variables)
    clipped_grads, _ = tf.clip_by_global_norm(grads, PARAMS['clip_norm'])
    return zip(clipped_grads, variables)


def rnn_cell():
    return tf.nn.rnn_cell.GRUCell(PARAMS['rnn_size'],
                                  kernel_initializer=tf.orthogonal_initializer())


def point(idx, batch_sz, enc_inp):
    idx = tf.expand_dims(idx, 1)
    b = tf.range(batch_sz)
    b = tf.expand_dims(b, 1)
    idx = tf.concat((tf.to_int64(b), idx), 1)
    g = tf.gather_nd(enc_inp, idx)
    return g


def attention(query, keys, masks, w1, w2, v):
    query = tf.expand_dims(query, 1)
    align = v * tf.tanh(w1(query) + w2(keys))
    align = tf.reduce_sum(align, [2])
    align *= masks
    return align


def forward(features):
    inputs = features['src_idx']
    enc_seq_len = features['src_seq_lens']
    batch_sz = tf.shape(inputs)[0]
    masks = tf.to_float(tf.sign(inputs))
    
    with tf.variable_scope('Encoder'):
        embedding = tf.get_variable('lookup_table',
                                    [len(PARAMS['src_char2idx']), PARAMS['embed_dims']])
        enc_inp = tf.nn.embedding_lookup(embedding, inputs)
        enc_rnn_out, enc_rnn_state = tf.nn.dynamic_rnn(rnn_cell(),
                                                       enc_inp,
                                                       enc_seq_len,
                                                       dtype=tf.float32)
        
    with tf.variable_scope('Decoder'):
        outputs = []
        
        dec_cell = rnn_cell()
        w1 = tf.layers.Dense(PARAMS['rnn_size'], use_bias=False)
        w2 = tf.layers.Dense(PARAMS['rnn_size'], use_bias=False)
        v = tf.get_variable('v', [PARAMS['rnn_size']])
        
        state = enc_rnn_state
        starts = tf.fill([batch_sz], PARAMS['src_char2idx']['<GO>'])
        inp = tf.nn.embedding_lookup(embedding, starts)
        
        for _ in range(PARAMS['max_len']):
            _, state = dec_cell(inp, state)
            output = attention(state, enc_rnn_out, masks, w1, w2, v)
            outputs.append(output)
            idx = tf.argmax(output, -1)
            inp = point(idx, batch_sz, enc_inp)
    
    outputs = tf.stack(outputs, 1)
    return outputs

In [5]:
def model_fn(features, labels, mode, params):
    logits = forward(features)
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions=tf.argmax(logits, -1))
        
    if mode == tf.estimator.ModeKeys.TRAIN:
        loss_op = tf.contrib.seq2seq.sequence_loss(
            logits = logits,
            targets = labels['tgt_idx'],
            weights = tf.sequence_mask(labels['tgt_seq_lens'], PARAMS['max_len'], dtype=tf.float32))
        train_op = tf.train.AdamOptimizer().apply_gradients(
            clip_grads(loss_op),
            global_step = tf.train.get_global_step())
        
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss_op, train_op=train_op)

In [6]:
def infe_inps(str_li):
    max_len = max([len(s) for s in str_li])
    x_inps, x_seq_lens = [], []
    for s in str_li:
        x = [PARAMS['src_char2idx'].get(c, 3) for c in s] + [2]
        x_inps.append(x)
        x_seq_lens.append(len(x))
    return {'src_idx': tf.keras.preprocessing.sequence.pad_sequences(x_inps, PARAMS['max_len'],
                                                                     padding='post'),
            'src_seq_lens': np.array(x_seq_lens)}


def demo(xs, preds):
    for x, pred in zip(xs, preds):
        print('\nIN: {}'.format(x))
        x = np.array([PARAMS['src_char2idx'].get(c, 3) for c in x] + [2])
        pred = x[pred[:len(x)]]
        print('OUT: {}'.format(' '.join([PARAMS['src_idx2char'][i] for i in pred])))
    

def main():
    src_idx, tgt_idx, src_seq_lens, tgt_seq_lens = preprocess_data()
    
    test_strs = ['apple', 'common', 'zhedong']
    
    estimator = tf.estimator.Estimator(model_fn)
    
    estimator.train(tf.estimator.inputs.numpy_input_fn(
        x = {'src_idx': src_idx, 'src_seq_lens': src_seq_lens},
        y = {'tgt_idx': tgt_idx, 'tgt_seq_lens': tgt_seq_lens},
        batch_size = PARAMS['batch_size'],
        num_epochs = PARAMS['n_epochs'],
        shuffle = True))
    
    preds = list(estimator.predict(tf.estimator.inputs.numpy_input_fn(
        x = infe_inps(test_strs),
        shuffle = False)))
    
    demo(test_strs, preds)


if __name__ == '__main__':
    main()

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp31qmitb7', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x10f7fb080>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 1 into /var/folders/sx/

INFO:tensorflow:loss = 0.028184542, step = 6901 (4.488 sec)
INFO:tensorflow:global_step/sec: 22.2698
INFO:tensorflow:loss = 0.016577683, step = 7001 (4.491 sec)
INFO:tensorflow:global_step/sec: 22.6684
INFO:tensorflow:loss = 0.03540818, step = 7101 (4.411 sec)
INFO:tensorflow:global_step/sec: 21.9794
INFO:tensorflow:loss = 0.01682376, step = 7201 (4.550 sec)
INFO:tensorflow:global_step/sec: 22.8593
INFO:tensorflow:loss = 0.018445082, step = 7301 (4.375 sec)
INFO:tensorflow:global_step/sec: 22.6628
INFO:tensorflow:loss = 0.015392962, step = 7401 (4.412 sec)
INFO:tensorflow:global_step/sec: 22.1001
INFO:tensorflow:loss = 0.013443581, step = 7501 (4.525 sec)
INFO:tensorflow:global_step/sec: 22.7805
INFO:tensorflow:loss = 0.013857702, step = 7601 (4.390 sec)
INFO:tensorflow:global_step/sec: 22.4124
INFO:tensorflow:loss = 0.041313823, step = 7701 (4.462 sec)
INFO:tensorflow:global_step/sec: 23.8463
INFO:tensorflow:loss = 0.01792423, step = 7801 (4.194 sec)
INFO:tensorflow:Saving checkpoints