In [1]:
from pointer_net import PointerNetwork
import sys
import numpy as np
if int(sys.version[0]) == 2:
    from io import open


def read_data(path):
    with open(path, 'r', encoding='utf-8') as f:
        return f.read()
# end function


def build_map(data):
    specials = ['<GO>',  '<EOS>', '<PAD>', '<UNK>']
    chars = list(set([char for line in data.split('\n') for char in line]))
    chars = sorted(chars)
    idx2char = {idx: char for idx, char in enumerate(specials+chars)}
    char2idx = {char: idx for idx, char in idx2char.items()}
    return idx2char, char2idx
# end function


def preprocess_data(max_len):
    X_data = read_data('temp/letters_source.txt')
    Y_data = read_data('temp/letters_target.txt')

    X_idx2char, X_char2idx = build_map(X_data)
    print("==> Word Index Built")

    x_unk = X_char2idx['<UNK>']
    x_eos = X_char2idx['<EOS>']
    x_pad = X_char2idx['<PAD>']

    X_indices = []
    X_seq_len = []
    Y_indices = []
    Y_seq_len = []

    for x_line, y_line in zip(X_data.split('\n'), Y_data.split('\n')):
        x_chars = [X_char2idx.get(char, x_unk) for char in x_line]
        _x_chars = x_chars + [x_eos] + [x_pad]* (max_len-1-len(x_chars))
        
        y_chars = [X_char2idx.get(char, x_unk) for char in y_line]
        _y_chars = y_chars + [x_eos] + [x_pad]* (max_len-1-len(y_chars))
        target = [_x_chars.index(y) for y in _y_chars] # we are predicting the positions

        X_indices.append(_x_chars)
        Y_indices.append(target)
        X_seq_len.append(len(x_chars)+1)
        Y_seq_len.append(len(y_chars)+1)

    X_indices = np.array(X_indices)
    Y_indices = np.array(Y_indices)
    X_seq_len = np.array(X_seq_len)
    Y_seq_len = np.array(Y_seq_len)
    print("==> Sequence Padded")

    return X_indices, X_seq_len, Y_indices, Y_seq_len, X_char2idx, X_idx2char
# end function


def train_test_split(X_indices, X_seq_len, Y_indices, Y_seq_len, BATCH_SIZE):
    X_train = X_indices[BATCH_SIZE:]
    X_train_len = X_seq_len[BATCH_SIZE:]
    Y_train = Y_indices[BATCH_SIZE:]
    Y_train_len = Y_seq_len[BATCH_SIZE:]

    X_test = X_indices[:BATCH_SIZE]
    X_test_len = X_seq_len[:BATCH_SIZE]
    Y_test = Y_indices[:BATCH_SIZE]
    Y_test_len = Y_seq_len[:BATCH_SIZE]

    return (X_train, X_train_len, Y_train, Y_train_len), (X_test, X_test_len, Y_test, Y_test_len)
# end function


def main():
    BATCH_SIZE = 128
    MAX_LEN = 15
    X_indices, X_seq_len, Y_indices, Y_seq_len, X_char2idx, X_idx2char = preprocess_data(MAX_LEN)
    
    (X_train, X_train_len, Y_train, Y_train_len), (X_test, X_test_len, Y_test, Y_test_len) \
        = train_test_split(X_indices, X_seq_len, Y_indices, Y_seq_len, BATCH_SIZE)
    
    model = PointerNetwork(
        max_len = MAX_LEN,
        rnn_size = 50,
        X_word2idx = X_char2idx,
        embedding_dim = 15)
    
    model.fit(X_train, X_train_len, Y_train, Y_train_len,
        val_data=(X_test, X_test_len, Y_test, Y_test_len), batch_size=BATCH_SIZE, n_epoch=300)
    model.infer('common', X_idx2char)
    model.infer('apple', X_idx2char)
    model.infer('zhedong', X_idx2char)
# end main


if __name__ == '__main__':
    main()


==> Word Index Built
==> Sequence Padded
Epoch 1/300 | Batch 0/77 | train_loss: 2.702 | test_loss: 2.699
Epoch 1/300 | Batch 50/77 | train_loss: 2.312 | test_loss: 2.289
Epoch 2/300 | Batch 0/77 | train_loss: 2.047 | test_loss: 2.019
Epoch 2/300 | Batch 50/77 | train_loss: 1.798 | test_loss: 1.709
Epoch 3/300 | Batch 0/77 | train_loss: 1.645 | test_loss: 1.589
Epoch 3/300 | Batch 50/77 | train_loss: 1.398 | test_loss: 1.292
Epoch 4/300 | Batch 0/77 | train_loss: 1.124 | test_loss: 1.088
Epoch 4/300 | Batch 50/77 | train_loss: 0.877 | test_loss: 0.820
Epoch 5/300 | Batch 0/77 | train_loss: 0.761 | test_loss: 0.753
Epoch 5/300 | Batch 50/77 | train_loss: 0.687 | test_loss: 0.651
Epoch 6/300 | Batch 0/77 | train_loss: 0.629 | test_loss: 0.617
Epoch 6/300 | Batch 50/77 | train_loss: 0.592 | test_loss: 0.566
Epoch 7/300 | Batch 0/77 | train_loss: 0.548 | test_loss: 0.539
Epoch 7/300 | Batch 50/77 | train_loss: 0.519 | test_loss: 0.504
Epoch 8/300 | Batch 0/77 | train_loss: 0.489 | test_loss

Epoch 63/300 | Batch 50/77 | train_loss: 0.138 | test_loss: 0.156
Epoch 64/300 | Batch 0/77 | train_loss: 0.147 | test_loss: 0.152
Epoch 64/300 | Batch 50/77 | train_loss: 0.137 | test_loss: 0.154
Epoch 65/300 | Batch 0/77 | train_loss: 0.146 | test_loss: 0.153
Epoch 65/300 | Batch 50/77 | train_loss: 0.137 | test_loss: 0.157
Epoch 66/300 | Batch 0/77 | train_loss: 0.145 | test_loss: 0.152
Epoch 66/300 | Batch 50/77 | train_loss: 0.135 | test_loss: 0.151
Epoch 67/300 | Batch 0/77 | train_loss: 0.143 | test_loss: 0.148
Epoch 67/300 | Batch 50/77 | train_loss: 0.135 | test_loss: 0.154
Epoch 68/300 | Batch 0/77 | train_loss: 0.142 | test_loss: 0.150
Epoch 68/300 | Batch 50/77 | train_loss: 0.133 | test_loss: 0.152
Epoch 69/300 | Batch 0/77 | train_loss: 0.141 | test_loss: 0.148
Epoch 69/300 | Batch 50/77 | train_loss: 0.132 | test_loss: 0.151
Epoch 70/300 | Batch 0/77 | train_loss: 0.140 | test_loss: 0.147
Epoch 70/300 | Batch 50/77 | train_loss: 0.131 | test_loss: 0.148
Epoch 71/300 | Ba

Epoch 126/300 | Batch 0/77 | train_loss: 0.100 | test_loss: 0.112
Epoch 126/300 | Batch 50/77 | train_loss: 0.097 | test_loss: 0.115
Epoch 127/300 | Batch 0/77 | train_loss: 0.098 | test_loss: 0.116
Epoch 127/300 | Batch 50/77 | train_loss: 0.095 | test_loss: 0.115
Epoch 128/300 | Batch 0/77 | train_loss: 0.098 | test_loss: 0.110
Epoch 128/300 | Batch 50/77 | train_loss: 0.094 | test_loss: 0.114
Epoch 129/300 | Batch 0/77 | train_loss: 0.097 | test_loss: 0.113
Epoch 129/300 | Batch 50/77 | train_loss: 0.102 | test_loss: 0.113
Epoch 130/300 | Batch 0/77 | train_loss: 0.097 | test_loss: 0.113
Epoch 130/300 | Batch 50/77 | train_loss: 0.095 | test_loss: 0.114
Epoch 131/300 | Batch 0/77 | train_loss: 0.096 | test_loss: 0.110
Epoch 131/300 | Batch 50/77 | train_loss: 0.094 | test_loss: 0.112
Epoch 132/300 | Batch 0/77 | train_loss: 0.100 | test_loss: 0.113
Epoch 132/300 | Batch 50/77 | train_loss: 0.093 | test_loss: 0.114
Epoch 133/300 | Batch 0/77 | train_loss: 0.095 | test_loss: 0.116
Epo

Epoch 188/300 | Batch 0/77 | train_loss: 0.071 | test_loss: 0.102
Epoch 188/300 | Batch 50/77 | train_loss: 0.112 | test_loss: 0.097
Epoch 189/300 | Batch 0/77 | train_loss: 0.078 | test_loss: 0.107
Epoch 189/300 | Batch 50/77 | train_loss: 0.082 | test_loss: 0.107
Epoch 190/300 | Batch 0/77 | train_loss: 0.078 | test_loss: 0.114
Epoch 190/300 | Batch 50/77 | train_loss: 0.080 | test_loss: 0.105
Epoch 191/300 | Batch 0/77 | train_loss: 0.073 | test_loss: 0.103
Epoch 191/300 | Batch 50/77 | train_loss: 0.076 | test_loss: 0.118
Epoch 192/300 | Batch 0/77 | train_loss: 0.074 | test_loss: 0.104
Epoch 192/300 | Batch 50/77 | train_loss: 0.071 | test_loss: 0.128
Epoch 193/300 | Batch 0/77 | train_loss: 0.070 | test_loss: 0.114
Epoch 193/300 | Batch 50/77 | train_loss: 0.070 | test_loss: 0.114
Epoch 194/300 | Batch 0/77 | train_loss: 0.069 | test_loss: 0.114
Epoch 194/300 | Batch 50/77 | train_loss: 0.069 | test_loss: 0.115
Epoch 195/300 | Batch 0/77 | train_loss: 0.068 | test_loss: 0.110
Epo

Epoch 250/300 | Batch 0/77 | train_loss: 0.056 | test_loss: 0.110
Epoch 250/300 | Batch 50/77 | train_loss: 0.055 | test_loss: 0.102
Epoch 251/300 | Batch 0/77 | train_loss: 0.058 | test_loss: 0.118
Epoch 251/300 | Batch 50/77 | train_loss: 0.057 | test_loss: 0.102
Epoch 252/300 | Batch 0/77 | train_loss: 0.060 | test_loss: 0.139
Epoch 252/300 | Batch 50/77 | train_loss: 0.058 | test_loss: 0.104
Epoch 253/300 | Batch 0/77 | train_loss: 0.059 | test_loss: 0.152
Epoch 253/300 | Batch 50/77 | train_loss: 0.068 | test_loss: 0.110
Epoch 254/300 | Batch 0/77 | train_loss: 0.086 | test_loss: 0.133
Epoch 254/300 | Batch 50/77 | train_loss: 0.055 | test_loss: 0.124
Epoch 255/300 | Batch 0/77 | train_loss: 0.072 | test_loss: 0.187
Epoch 255/300 | Batch 50/77 | train_loss: 0.060 | test_loss: 0.102
Epoch 256/300 | Batch 0/77 | train_loss: 0.065 | test_loss: 0.117
Epoch 256/300 | Batch 50/77 | train_loss: 0.057 | test_loss: 0.103
Epoch 257/300 | Batch 0/77 | train_loss: 0.058 | test_loss: 0.116
Epo