In [2]:
import tensorflow as tf
import numpy as np
from PntrNetwork import PointerNetwork

In [3]:
def load_data(path):
    """
    Loads TSP dataset cited in PointerNetwork Paper

    Args:
        path: Path to dataset
    Returns:
        enc_input: Input sequence
        dec_input: Target sequence
        ans: Target sequence shifted by 1
    """
    with open(path, 'r') as f:
        data = f.readlines()
    enc_input = []
    dec_input = []
    ans = []
    for line in data:
        i, o = line.strip().split(' output ')
        enc_input.append(list(map(float, i.split(' '))))
        dec_input.append(list(map(float, o.split(' '))))
        ans.append(list(map(float, o.split(' '))))
    enc_input = np.array(enc_input).squeeze().reshape(len(enc_input), 5, 2).astype('float32')
    dec_input = np.subtract(np.array(dec_input).squeeze().reshape(len(dec_input), 6, 1), 1)
    ans = np.subtract(np.array(ans).squeeze().reshape(len(ans), 6, 1), 1)
    size = dec_input.shape[1]
    return enc_input, dec_input[:,0:size-1,:], ans[:,1:,:]

def decode_seq(model, enc_input):
    """
    Decode LSTM output sequence

    Args:
        model: PointerNetwork model
        enc_input: Encoder input sequence
    Returns:
        decoded_seq: Decoded sequence
    """
    dec_input = np.zeros((enc_input.shape[0], seq_len, 1))
    for i in range(enc_input.shape[1] - 1):
        pred = model.predict([enc_input, dec_input])
        dec_input[:,i+1,:] = tf.argmax(pred[:,i,:], axis=-1).numpy()[:,np.newaxis]
    return dec_input

In [4]:
enc_input, dec_input, output = load_data('../data/tsp5.txt')
test_input, _, targets = load_data('../data/tsp5_test.txt')
BUFFER_SIZE = enc_input.shape[0]
BATCH_SIZE = 64
units = 256
seq_len = 5

In [5]:
mod = PointerNetwork(seq_len, units, BATCH_SIZE)
mod.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy())

In [6]:
mod.fit([enc_input, dec_input], output, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
 2336/31250 [=>............................] - ETA: 11:35 - loss: 0.2068

In [76]:
decode_seq(mod, test_input)

array([[[3.],
        [4.],
        [2.],
        [3.],
        [0.]],

       [[3.],
        [4.],
        [2.],
        [3.],
        [0.]],

       [[3.],
        [4.],
        [2.],
        [3.],
        [0.]]])

In [47]:
mod.predict([enc_input[0:2,:,:], np.zeros((2, seq_len, 1))])

array([[[4.0110268e-04, 2.5860542e-01, 2.4755125e-01, 2.6400182e-01,
         2.2944036e-01],
        [1.6250326e-09, 4.4321314e-06, 4.2065501e-02, 3.4119797e-01,
         6.1673212e-01],
        [7.9299003e-09, 1.5056122e-07, 1.6678314e-05, 7.1853591e-04,
         9.9926466e-01],
        [1.4585952e-04, 1.3692783e-03, 2.8396944e-02, 1.2854522e-01,
         8.4154260e-01],
        [3.6734942e-02, 9.2481176e-04, 4.1595466e-02, 2.5929287e-02,
         8.9481550e-01]],

       [[3.9562522e-04, 2.5597182e-01, 2.4870425e-01, 2.6555920e-01,
         2.2936915e-01],
        [1.6105764e-09, 4.6196101e-06, 4.3197222e-02, 3.4161666e-01,
         6.1518145e-01],
        [8.0101668e-09, 1.5500275e-07, 1.7419703e-05, 7.4412505e-04,
         9.9923825e-01],
        [1.4678562e-04, 1.3891507e-03, 2.9671118e-02, 1.3042437e-01,
         8.3836859e-01],
        [3.6434706e-02, 9.4254210e-04, 4.3243770e-02, 2.6303645e-02,
         8.9307529e-01]]], dtype=float32)

In [None]:
pred1 = mod([enc_input[0:2,:,:], dec_input[0:2,:,:]], training = False)

In [None]:
pred1[:,1,:]