In [1]:
import tensorflow as tf
import numpy as np
import datetime, time

In [6]:
def genObs(length, full_length, dict_size=3):
    base = np.random.choice(range(dict_size), size=length)
    pad_length = full_length - length
    return np.concatenate([1 + base, np.zeros(pad_length)])

def getTarget(x):
    s = np.argmin(x)
    return np.concatenate([np.array(list(sorted(x[:s]))), x[s:]])

def genSample(num, length=20):
    x = [genObs(np.random.randint(length), full_length=length) for _ in range(num)]
    y = [getTarget(y) for y in x]
    return np.array(x), np.array(y)

def randomBatch(tensorTuple, batchSize=64):
    ids = np.random.choice(range(tensorTuple[0].shape[0]), batchSize)
    return (x[ids,] for x in tensorTuple)

def shuffleBatches(tensorTuple, batchSize=64):
    if type(tensorTuple) is list or type(tensorTuple) is tuple: 
        ids = list(range(tensorTuple[0].shape[0]))
        np.random.shuffle(ids)
        for i in range(0,len(ids),batchSize):
            lst = min(len(ids), i + batchSize)
            yield (np.array(x[ids[i:lst],]) for x in tensorTuple)
    else:
        ids = list(range(tensorTuple.shape[0]))
        np.random.shuffle(ids)
        for i in range(0,len(ids),batchSize):
            lst = min(len(ids), i + batchSize)
            yield np.array(tensorTuple[ids[i:lst],])

In [7]:
train_x, train_y = genSample(100000, length=6)
valid_x, valid_y = genSample(10000, length=6)
valid_x[:10], valid_y[:10]

(array([[ 3.,  0.,  0.,  0.,  0.,  0.],
        [ 2.,  2.,  3.,  2.,  1.,  0.],
        [ 2.,  2.,  3.,  0.,  0.,  0.],
        [ 1.,  0.,  0.,  0.,  0.,  0.],
        [ 2.,  3.,  3.,  2.,  0.,  0.],
        [ 2.,  3.,  3.,  3.,  3.,  0.],
        [ 1.,  1.,  3.,  1.,  2.,  0.],
        [ 2.,  1.,  0.,  0.,  0.,  0.],
        [ 3.,  1.,  1.,  0.,  0.,  0.],
        [ 1.,  0.,  0.,  0.,  0.,  0.]]),
 array([[ 3.,  0.,  0.,  0.,  0.,  0.],
        [ 1.,  2.,  2.,  2.,  3.,  0.],
        [ 2.,  2.,  3.,  0.,  0.,  0.],
        [ 1.,  0.,  0.,  0.,  0.,  0.],
        [ 2.,  2.,  3.,  3.,  0.,  0.],
        [ 2.,  3.,  3.,  3.,  3.,  0.],
        [ 1.,  1.,  1.,  2.,  3.,  0.],
        [ 1.,  2.,  0.,  0.,  0.,  0.],
        [ 1.,  1.,  3.,  0.,  0.,  0.],
        [ 1.,  0.,  0.,  0.,  0.,  0.]]))

In [13]:
IN_LEN = 6
OUT_LEN = 6
DICT_SIZE = 3
ENCODER_RNN_SIZE = [10]
DECODER_RNN_SIZE = [4]
FC_STATE = 4
FC_INPUT = 4

EncoderCell = lambda n: tf.nn.rnn_cell.GRUCell(num_units=n, activation=tf.nn.elu)
DecoderCell = lambda n: tf.nn.rnn_cell.GRUCell(num_units=n, activation=tf.nn.elu)

tf.reset_default_graph()


tfi_x = tf.placeholder(shape=(None, IN_LEN), dtype=tf.int32)
tfi_y = tf.placeholder(shape=(None, OUT_LEN), dtype=tf.int32)

tfX = tf.one_hot(tfi_x, DICT_SIZE + 1, dtype=tf.float32)
tfY = tf.one_hot(tfi_y, DICT_SIZE + 1, dtype=tf.float32)

with tf.name_scope(name='ENCODER'):
    rnnEncoderCell = tf.nn.rnn_cell.MultiRNNCell([EncoderCell(s) for s in ENCODER_RNN_SIZE], state_is_tuple=True)
    _, tfEncodedState0 = tf.nn.dynamic_rnn(rnnEncoderCell, inputs=tfX, dtype=tf.float32, time_major=False, scope='ENCODER')
    tfEncodedState = tfEncodedState0[-1] #get latest layer in RNN

tfFC_State = tf.layers.dense(tfEncodedState, FC_STATE, activation=tf.nn.elu)
tfFC_Input0 = tf.layers.dense(tfEncodedState, FC_INPUT, activation=tf.nn.elu)
tfFC_Input1 = tf.reshape(tf.tile(tfFC_Input0, [1, OUT_LEN]), shape=(tf.shape(tfFC_Input0)[0], FC_INPUT, OUT_LEN))
tfFC_Input = tf.transpose(tfFC_Input1, [0,2,1])


with tf.name_scope(name='DECODER'):    
    rnnDecoderCell = tf.nn.rnn_cell.MultiRNNCell([DecoderCell(s) for s in DECODER_RNN_SIZE], state_is_tuple=True)
    tfDecodedSeq, _ = tf.nn.dynamic_rnn(rnnDecoderCell, inputs=tfFC_Input, initial_state=(tfFC_State,), time_major=False, scope='DECODER')

tfOutputSeqLogits = tf.layers.dense(tfDecodedSeq, DICT_SIZE + 1) #no activation

tfOutputSeq = tf.argmax(tfOutputSeqLogits, dimension=2)
tfOutputHS = tfFC_State
tfOutputHI = tfFC_Input0

tfHLoss0 = tf.nn.softmax_cross_entropy_with_logits(labels=tfY, logits=tfOutputSeqLogits)
tfLoss = tf.reduce_mean(tfHLoss0)

tfTrain = tf.train.AdamOptimizer(1e-3).minimize(tfLoss)

In [14]:
dt_now = datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")
#tffw = tf.summary.FileWriter('D:/Jupyter/Logs/12RNN07-{0}'.format(dt_now), tf.get_default_graph())

batch_size = 1000
num_steps  = 10
num_epochs = 100
checkpoints = 500

fmtstr = 'Epoch {0} ({1:1.3} sec): \t\tVL:{2:1.3f}'
valid_batch = {tfi_x: valid_x, tfi_y:valid_y}
with tf.Session() as tfs:
    tfs.run(tf.global_variables_initializer())
    for i in range(num_epochs):
        te0 = time.perf_counter()
        for (mini_x, mini_y) in shuffleBatches((train_x,train_y), batchSize=batch_size):
            train_batch = {tfi_x:mini_x, tfi_y:mini_y}
            
            #l0 = tfLoss.eval(feed_dict=train_batch)
            #t0 = time.perf_counter()
            for j in range(num_steps):
                tfTrain.run(feed_dict=train_batch)
            #t1 = time.perf_counter()
            l1 = tfLoss.eval(feed_dict=train_batch)
    
        te1 = time.perf_counter()
        [lv, valid_r] = tfs.run([tfLoss, tfOutputSeq], feed_dict=valid_batch)
            #tffw.add_summary(summary, i)
            #if i%checkpoints == 0 and i > 0:
            #    p = tfsSaver.save(tfs, 'D:/Jupyter/mltest/Models-12RNN07/model-{0:04d}.ckpt'.format(i))
            #    print('Model saved at checkpoint: {0}'.format(p))
                             
        print(fmtstr.format(i,te1-te0,lv))

Epoch 0 (16.7 sec): 		VL:0.068
Epoch 1 (16.4 sec): 		VL:0.013
Epoch 2 (16.3 sec): 		VL:0.004
Epoch 3 (16.6 sec): 		VL:0.002
Epoch 4 (16.1 sec): 		VL:0.001


KeyboardInterrupt: 

In [15]:
valid_y[:10], valid_r[:10]

(array([[ 3.,  0.,  0.,  0.,  0.,  0.],
        [ 1.,  2.,  2.,  2.,  3.,  0.],
        [ 2.,  2.,  3.,  0.,  0.,  0.],
        [ 1.,  0.,  0.,  0.,  0.,  0.],
        [ 2.,  2.,  3.,  3.,  0.,  0.],
        [ 2.,  3.,  3.,  3.,  3.,  0.],
        [ 1.,  1.,  1.,  2.,  3.,  0.],
        [ 1.,  2.,  0.,  0.,  0.,  0.],
        [ 1.,  1.,  3.,  0.,  0.,  0.],
        [ 1.,  0.,  0.,  0.,  0.,  0.]]), array([[3, 0, 0, 0, 0, 0],
        [1, 2, 2, 2, 3, 0],
        [2, 2, 3, 0, 0, 0],
        [1, 0, 0, 0, 0, 0],
        [2, 2, 3, 3, 0, 0],
        [2, 3, 3, 3, 3, 0],
        [1, 1, 1, 2, 3, 0],
        [1, 2, 0, 0, 0, 0],
        [1, 1, 3, 0, 0, 0],
        [1, 0, 0, 0, 0, 0]], dtype=int64))