In [1]:
#Leaving comments to explain what is occuring
#imports of Libraries: tensorflow, numpy: python math functions
#matplot, os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"

# Hyper Parameters
TIME_STEP = 10       # rnn time step
INPUT_SIZE = 1      # rnn input size
CELL_SIZE = 32      # rnn cell size
LR = 0.02           # learning rate

# show data
steps = np.linspace(0, np.pi*2, 100, dtype=np.float32)
x_np = np.sin(steps); y_np = np.cos(steps)    # float32 for converting torch FloatTensor
plt.plot(steps, y_np, 'r-', label='target (cos)'); plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best'); plt.show()

# tensorflow placeholders
tf_x = tf.placeholder(tf.float32, [None, TIME_STEP, INPUT_SIZE])        # shape(batch, 5, 1)
tf_y = tf.placeholder(tf.float32, [None, TIME_STEP, INPUT_SIZE])          # input y

# RNN
rnn_cell = tf.contrib.rnn.BasicRNNCell(num_units=CELL_SIZE)
init_s = rnn_cell.zero_state(batch_size=1, dtype=tf.float32)    # very first hidden state
outputs, final_s = tf.nn.dynamic_rnn(
    rnn_cell,                   # cell you have chosen
    tf_x,                       # input
    initial_state=init_s,       # the initial hidden state
    time_major=False,           # False: (batch, time step, input); True: (time step, batch, input)
)
outs2D = tf.reshape(outputs, [-1, CELL_SIZE])                       # reshape 3D output to 2D for fully connected layer
net_outs2D = tf.layers.dense(outs2D, INPUT_SIZE)
outs = tf.reshape(net_outs2D, [-1, TIME_STEP, INPUT_SIZE])          # reshape back to 3D

loss = tf.losses.mean_squared_error(labels=tf_y, predictions=outs)  # compute cost
train_op = tf.train.AdamOptimizer(LR).minimize(loss)

sess = tf.Session()
sess.run(tf.global_variables_initializer())     # initialize var in graph

plt.figure(1, figsize=(12, 5)); plt.ion()       # continuously plot

for step in range(60):
    start, end = step * np.pi, (step+1)*np.pi   # time range
    # use sin predicts cos
    steps = np.linspace(start, end, TIME_STEP)
    x = np.sin(steps)[np.newaxis, :, np.newaxis]    # shape (batch, time_step, input_size)
    y = np.cos(steps)[np.newaxis, :, np.newaxis]
    if 'final_s_' not in globals():                 # first state, no any hidden state
        feed_dict = {tf_x: x, tf_y: y}
    else:                                           # has hidden state, so pass it to rnn
        feed_dict = {tf_x: x, tf_y: y, init_s: final_s_}
    _, pred_, final_s_ = sess.run([train_op, outs, final_s], feed_dict)     # train

    # plotting
    plt.plot(steps, y.flatten(), 'r-'); plt.plot(steps, pred_.flatten(), 'b-')
    plt.ylim((-1.2, 1.2)); plt.draw(); plt.pause(0.05)

plt.ioff(); plt.show()

<matplotlib.figure.Figure at 0x2254a61be10>

<matplotlib.figure.Figure at 0x225522b5128>

<matplotlib.figure.Figure at 0x22553527320>

<matplotlib.figure.Figure at 0x225537dd0f0>

<matplotlib.figure.Figure at 0x225539946d8>

<matplotlib.figure.Figure at 0x22553b25550>

<matplotlib.figure.Figure at 0x22553ca8320>

<matplotlib.figure.Figure at 0x22553e45550>

<matplotlib.figure.Figure at 0x22554004eb8>

<matplotlib.figure.Figure at 0x22553981668>

<matplotlib.figure.Figure at 0x225539b4d30>

<matplotlib.figure.Figure at 0x22553e573c8>

<matplotlib.figure.Figure at 0x225547c3278>

<matplotlib.figure.Figure at 0x225547f72b0>

<matplotlib.figure.Figure at 0x2255433bd68>

<matplotlib.figure.Figure at 0x22554b47940>

<matplotlib.figure.Figure at 0x22554e2c860>

<matplotlib.figure.Figure at 0x22554ccdac8>

<matplotlib.figure.Figure at 0x22554804f60>

<matplotlib.figure.Figure at 0x22554e4f588>

<matplotlib.figure.Figure at 0x225534df5f8>

<matplotlib.figure.Figure at 0x22554b1e2b0>

<matplotlib.figure.Figure at 0x225537b9588>

<matplotlib.figure.Figure at 0x22554cbc320>

<matplotlib.figure.Figure at 0x22554651668>

<matplotlib.figure.Figure at 0x22553afcc88>

<matplotlib.figure.Figure at 0x2255399ce48>

<matplotlib.figure.Figure at 0x225543495c0>

<matplotlib.figure.Figure at 0x2255400f3c8>

<matplotlib.figure.Figure at 0x22553ca0eb8>

<matplotlib.figure.Figure at 0x22553fde5c0>

<matplotlib.figure.Figure at 0x2255417cf98>

<matplotlib.figure.Figure at 0x2255400fcf8>

<matplotlib.figure.Figure at 0x2255369b2e8>

<matplotlib.figure.Figure at 0x22553750438>

<matplotlib.figure.Figure at 0x225549e71d0>

<matplotlib.figure.Figure at 0x22554a52080>

<matplotlib.figure.Figure at 0x225536eae48>

<matplotlib.figure.Figure at 0x22555780278>

<matplotlib.figure.Figure at 0x2255591ff60>

<matplotlib.figure.Figure at 0x22555a8d320>

<matplotlib.figure.Figure at 0x22555c2b518>

<matplotlib.figure.Figure at 0x22555acc160>

<matplotlib.figure.Figure at 0x225557564a8>

<matplotlib.figure.Figure at 0x22552325b70>

<matplotlib.figure.Figure at 0x22555a8d358>

<matplotlib.figure.Figure at 0x22555c2b320>

<matplotlib.figure.Figure at 0x22553523588>

<matplotlib.figure.Figure at 0x22555daff60>

<matplotlib.figure.Figure at 0x22554a86400>

<matplotlib.figure.Figure at 0x22554a18c88>

<matplotlib.figure.Figure at 0x225537d2e48>

<matplotlib.figure.Figure at 0x22553b04630>

<matplotlib.figure.Figure at 0x22554a86c88>

<matplotlib.figure.Figure at 0x22554a1ae80>

<matplotlib.figure.Figure at 0x225537606d8>

<matplotlib.figure.Figure at 0x22553c93588>

<matplotlib.figure.Figure at 0x225536a68d0>

<matplotlib.figure.Figure at 0x22554171668>

<matplotlib.figure.Figure at 0x22554640c88>

<matplotlib.figure.Figure at 0x22554b5fe48>