In [1]:
from __future__ import print_function, division
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

num_epochs = 100
total_series_length = 50000
truncated_backprop_length = 15
state_size = 4
num_classes = 2
echo_step = 3
batch_size = 5
num_batches = total_series_length//batch_size//truncated_backprop_length

def generateData():
    x = np.array(np.random.choice(2, total_series_length, p=[0.5, 0.5]))
    y = np.roll(x, echo_step)
    y[0:echo_step] = 0

    x = x.reshape((batch_size, -1))  
    y = y.reshape((batch_size, -1))

    return (x, y)

batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])

init_state = tf.placeholder(tf.float32, [batch_size, state_size])

W = tf.Variable(np.random.rand(state_size+1, state_size), dtype=tf.float32)
b = tf.Variable(np.zeros((1,state_size)), dtype=tf.float32)

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)

# Unpack columns
inputs_series = tf.unstack(batchX_placeholder, axis=1)
labels_series = tf.unstack(batchY_placeholder, axis=1)

# Forward pass
current_state = init_state
states_series = []
for current_input in inputs_series:
    current_input = tf.reshape(current_input, [batch_size, 1])
    input_and_state_concatenated = tf.concat([current_input, current_state],1)  # Increasing number of columns

    next_state = tf.tanh(tf.matmul(input_and_state_concatenated, W) + b)  # Broadcasted addition
    states_series.append(next_state)
    current_state = next_state

logits_series = [tf.matmul(state, W2) + b2 for state in states_series] #Broadcasted addition
predictions_series = [tf.nn.softmax(logits) for logits in logits_series]

losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) for logits, labels in zip(logits_series,labels_series)]
total_loss = tf.reduce_mean(losses)

train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)

def plot(loss_list, predictions_series, batchX, batchY):
    plt.subplot(2, 3, 1)
    plt.cla()
    plt.plot(loss_list)

    for batch_series_idx in range(5):
        one_hot_output_series = np.array(predictions_series)[:, batch_series_idx, :]
        single_output_series = np.array([(1 if out[0] < 0.5 else 0) for out in one_hot_output_series])

        plt.subplot(2, 3, batch_series_idx + 2)
        plt.cla()
        plt.axis([0, truncated_backprop_length, 0, 2])
        left_offset = range(truncated_backprop_length)
        plt.bar(left_offset, batchX[batch_series_idx, :], width=1, color="blue")
        plt.bar(left_offset, batchY[batch_series_idx, :] * 0.5, width=1, color="red")
        plt.bar(left_offset, single_output_series * 0.3, width=1, color="green")

    plt.draw()
    plt.pause(0.0001)


with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()
        _current_state = np.zeros((batch_size, state_size))

        print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder:batchX,
                    batchY_placeholder:batchY,
                    init_state:_current_state
                })

            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Loss", _total_loss)
                plot(loss_list, _predictions_series, batchX, batchY)

plt.ioff()
plt.show()

Instructions for updating:
Use `tf.global_variables_initializer` instead.


<matplotlib.figure.Figure at 0x2f581cb2be0>

New data, epoch 0
Step 0 Loss 0.710251
Step 100 Loss 0.704808
Step 200 Loss 0.695629
Step 300 Loss 0.694527
Step 400 Loss 0.691322
Step 500 Loss 0.502575
Step 600 Loss 0.0857413
New data, epoch 1
Step 0 Loss 0.208692
Step 100 Loss 0.0157477
Step 200 Loss 0.0414585
Step 300 Loss 0.00704834
Step 400 Loss 0.00426842
Step 500 Loss 0.00373628
Step 600 Loss 0.00370912
New data, epoch 2
Step 0 Loss 0.225374
Step 100 Loss 0.00299744
Step 200 Loss 0.00209781
Step 300 Loss 0.00254153
Step 400 Loss 0.00211675
Step 500 Loss 0.00164919
Step 600 Loss 0.00129451
New data, epoch 3
Step 0 Loss 0.187772
Step 100 Loss 0.00170089
Step 200 Loss 0.00161829
Step 300 Loss 0.00130566
Step 400 Loss 0.00145548
Step 500 Loss 0.00101379
Step 600 Loss 0.0010898
New data, epoch 4
Step 0 Loss 0.196488
Step 100 Loss 0.00089545
Step 200 Loss 0.000975538
Step 300 Loss 0.00300007
Step 400 Loss 0.000966896
Step 500 Loss 0.00110357
Step 600 Loss 0.00117069
New data, epoch 5
Step 0 Loss 0.291039
Step 100 Loss 0.00100692
Ste

Step 200 Loss 8.73922e-05
Step 300 Loss 8.27371e-05
Step 400 Loss 6.98498e-05
Step 500 Loss 0.000135144
Step 600 Loss 7.61056e-05
New data, epoch 43
Step 0 Loss 0.277445
Step 100 Loss 9.29355e-05
Step 200 Loss 8.88546e-05
Step 300 Loss 9.14383e-05
Step 400 Loss 0.000103485
Step 500 Loss 7.13092e-05
Step 600 Loss 7.77212e-05
New data, epoch 44
Step 0 Loss 0.343808
Step 100 Loss 7.55546e-05
Step 200 Loss 0.000123602
Step 300 Loss 8.99594e-05
Step 400 Loss 0.000100689
Step 500 Loss 7.33083e-05
Step 600 Loss 8.41318e-05
New data, epoch 45
Step 0 Loss 0.293065
Step 100 Loss 0.000155479
Step 200 Loss 0.000139428
Step 300 Loss 0.000116907
Step 400 Loss 0.000117299
Step 500 Loss 0.000128284
Step 600 Loss 0.000112349
New data, epoch 46
Step 0 Loss 0.229069
Step 100 Loss 0.000144691
Step 200 Loss 0.000145038
Step 300 Loss 0.000157077
Step 400 Loss 9.71064e-05
Step 500 Loss 0.000132467
Step 600 Loss 0.000120946
New data, epoch 47
Step 0 Loss 0.297689
Step 100 Loss 9.33578e-05
Step 200 Loss 7.4760

Step 200 Loss 6.0954e-05
Step 300 Loss 4.24333e-05
Step 400 Loss 5.12711e-05
Step 500 Loss 5.00495e-05
Step 600 Loss 8.02527e-05
New data, epoch 85
Step 0 Loss 0.190613
Step 100 Loss 5.6002e-05
Step 200 Loss 7.92371e-05
Step 300 Loss 8.47421e-05
Step 400 Loss 6.02309e-05
Step 500 Loss 6.38687e-05
Step 600 Loss 4.27447e-05
New data, epoch 86
Step 0 Loss 0.454235
Step 100 Loss 8.49203e-05
Step 200 Loss 6.27888e-05
Step 300 Loss 6.66744e-05
Step 400 Loss 7.74815e-05
Step 500 Loss 6.121e-05
Step 600 Loss 6.65923e-05
New data, epoch 87
Step 0 Loss 0.19067
Step 100 Loss 5.59917e-05
Step 200 Loss 6.14952e-05
Step 300 Loss 5.85194e-05
Step 400 Loss 7.10454e-05
Step 500 Loss 6.59102e-05
Step 600 Loss 7.03498e-05
New data, epoch 88
Step 0 Loss 0.128535
Step 100 Loss 7.1153e-05
Step 200 Loss 6.34734e-05
Step 300 Loss 6.04718e-05
Step 400 Loss 5.50855e-05
Step 500 Loss 5.53733e-05
Step 600 Loss 5.38586e-05
New data, epoch 89
Step 0 Loss 0.238254
Step 100 Loss 6.59162e-05
Step 200 Loss 8.15278e-05


<matplotlib.figure.Figure at 0x2f5fe09f160>