# https://medium.com/@erikhallstrm/using-the-tensorflow-lstm-api-3-7-5f2b97ca6b73

# Using the LSTM API in TensorFlow (3/7)
In the previous post we modified our to code to use the TensorFlow native RNN API.
Now we will go about to build a modification of a RNN that called a “Recurrent Neural Network with Long short-term memory” or RNN-LSTM.
This architecture was pioneered by Jürgen Schmidhuber among others.
One problem with the RNN when using long time-dependencies (`truncated_backprop_length` is large) is the “vanishing gradient problem”: http://neuralnetworksanddeeplearning.com/chap5.html
One way to counter this is using a state that is “protected” and “selective”.
The RNN-LSTM remembers, forgets and chooses what to pass on and output depending on the current state and input.

Since this primarily is a practical tutorial I won’t go into more detail about the theory, I recommend reading this article again, continue with the “Modern RNN architectures”: https://arxiv.org/pdf/1506.00019.pdf
After you have done that read and look at the figures on this page.
Notice that the last mentioned resource are using vector concatenation in their calculations.

In the previous article we didn’t have to allocate the internal weight matrix and bias, that was done by TensorFlow automatically “under the hood”.
A LSTM RNN has many more “moving parts”, but by using the native API it will also be very simple.

# Different state
A LSTM have a “cell state” and a “hidden state”, to account for this you need to remove `_current_state` on line 79 in the previous script and replace it with this:
```
_current_cell_state = np.zeros((batch_size, state_size))
_current_hidden_state = np.zeros((batch_size, state_size))
```

TensorFlow uses a data structure called `LSTMStateTuple` internally for its LSTMs, where the first element in the tuple is the cell state, and the second is the hidden state.
So you need to change line 28 where the `init_state` is placeholders are declared to these lines:
```
cell_state = tf.placeholder(tf.float32, [batch_size, state_size])
hidden_state = tf.placeholder(tf.float32, [batch_size, state_size])
init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)
```
Changing the forward pass is now straight forward, you just change the function call to create a LSTM and supply the initial state-tuple on line 38–39.
```
cell = tf.nn.rnn_cell.BasicLSTMCell(state_size, state_is_tuple=True)
states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, init_state)
```
<u>The `states_series` will be a list of hidden states as tensors, and `current_state` will be a LSTMStateTuple which shows both the hidden- and the cell state on the last time-step as shown below:</u>

![Outputs of the previous states and the last LSTMStateTuple](pics/1_74HCrPbTjstECrQXZzJtDQ.png)

So the `current_state` returns the cell- and hidden state in a tuple.
They should be separated after calculation and supplied to the placeholders in the run-function on line 90.
```
_total_loss, _train_step, _current_state, _predictions_series = sess.run(
    [total_loss, train_step, current_state, predictions_series],
    feed_dict={
        batchX_placeholder: batchX,
        batchY_placeholder: batchY,
        cell_state: _current_cell_state,
        hidden_state: _current_hidden_state

    })

_current_cell_state, _current_hidden_state = _current_state
```

In [1]:
from __future__ import print_function, division
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

num_epochs = 5#100
total_series_length = 50000
truncated_backprop_length = 15
state_size = 4
num_classes = 2
echo_step = 3
batch_size = 5
num_batches = total_series_length//batch_size//truncated_backprop_length

def generateData():
    x = np.array(np.random.choice(2, total_series_length, p=[0.5, 0.5]))
    y = np.roll(x, echo_step)
    y[0:echo_step] = 0

    x = x.reshape((batch_size, -1))  # The first index changing slowest, subseries as rows
    y = y.reshape((batch_size, -1))

    return (x, y)

batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])

cell_state = tf.placeholder(tf.float32, [batch_size, state_size])
hidden_state = tf.placeholder(tf.float32, [batch_size, state_size])
init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)

print("cell_state: ", cell_state)
print("hidden_state: ", hidden_state)
print("init_state: ", init_state)

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)

# Unpack columns
inputs_series = tf.split(batchX_placeholder, truncated_backprop_length, 1)
labels_series = tf.unstack(batchY_placeholder, axis=1)

# Forward passes
cell = tf.nn.rnn_cell.BasicLSTMCell(state_size, state_is_tuple=True)
states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, init_state)

print("states_series: ", states_series)
print("current_state: ", current_state)

logits_series = [tf.matmul(state, W2) + b2 for state in states_series] #Broadcasted addition
predictions_series = [tf.nn.softmax(logits) for logits in logits_series]

losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) for logits, labels in zip(logits_series,labels_series)]
total_loss = tf.reduce_mean(losses)

train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)

def plot(loss_list, predictions_series, batchX, batchY):
    plt.subplot(2, 3, 1)
    plt.cla()
    plt.plot(loss_list)

    for batch_series_idx in range(5):
        one_hot_output_series = np.array(predictions_series)[:, batch_series_idx, :]
        single_output_series = np.array([(1 if out[0] < 0.5 else 0) for out in one_hot_output_series])

        plt.subplot(2, 3, batch_series_idx + 2)
        plt.cla()
        plt.axis([0, truncated_backprop_length, 0, 2])
        left_offset = range(truncated_backprop_length)
        plt.bar(left_offset, batchX[batch_series_idx, :], width=1, color="blue")
        plt.bar(left_offset, batchY[batch_series_idx, :] * 0.5, width=1, color="red")
        plt.bar(left_offset, single_output_series * 0.3, width=1, color="green")

    plt.draw()
    plt.pause(0.0001)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()
        _current_cell_state = np.zeros((batch_size, state_size))
        _current_hidden_state = np.zeros((batch_size, state_size))

        print("New data, epoch", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder: batchX,
                    batchY_placeholder: batchY,
                    cell_state: _current_cell_state,
                    hidden_state: _current_hidden_state

                })

            _current_cell_state, _current_hidden_state = _current_state

            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Batch loss", _total_loss)
                plot(loss_list, _predictions_series, batchX, batchY)

plt.ioff()
plt.show()

cell_state:  Tensor("Placeholder_2:0", shape=(5, 4), dtype=float32)
hidden_state:  Tensor("Placeholder_3:0", shape=(5, 4), dtype=float32)
init_state:  LSTMStateTuple(c=<tf.Tensor 'Placeholder_2:0' shape=(5, 4) dtype=float32>, h=<tf.Tensor 'Placeholder_3:0' shape=(5, 4) dtype=float32>)
states_series:  [<tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_2:0' shape=(5, 4) dtype=float32>, <tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_5:0' shape=(5, 4) dtype=float32>, <tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_8:0' shape=(5, 4) dtype=float32>, <tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_11:0' shape=(5, 4) dtype=float32>, <tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_14:0' shape=(5, 4) dtype=float32>, <tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_17:0' shape=(5, 4) dtype=float32>, <tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_20:0' shape=(5, 4) dtype=float32>, <tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_23:0' shape=(5, 4) dtype=float32>, <tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_26:0' shape=(5, 4) dtype=float32>, <tf.Tensor 'rnn/rnn/basic

<matplotlib.figure.Figure at 0x1eb78189710>

New data, epoch 0
Step 0 Batch loss 0.6847398


<matplotlib.figure.Figure at 0x1eb781e2da0>

Step 100 Batch loss 0.6903279


<matplotlib.figure.Figure at 0x1ebeb8003c8>

Step 200 Batch loss 0.5376254


<matplotlib.figure.Figure at 0x1ebecbdbef0>

Step 300 Batch loss 0.3745834


<matplotlib.figure.Figure at 0x1ebefe13860>

Step 400 Batch loss 0.26797488


<matplotlib.figure.Figure at 0x1ebf1134668>

Step 500 Batch loss 0.2107292


<matplotlib.figure.Figure at 0x1ebf1174dd8>

Step 600 Batch loss 0.015902335


<matplotlib.figure.Figure at 0x1ebea596400>

New data, epoch 1
Step 0 Batch loss 0.6586295


<matplotlib.figure.Figure at 0x1ec01a90da0>

Step 100 Batch loss 0.006981566


<matplotlib.figure.Figure at 0x1ebeb755630>

Step 200 Batch loss 0.004477961


<matplotlib.figure.Figure at 0x1ebeb77d438>

Step 300 Batch loss 0.0035909899


<matplotlib.figure.Figure at 0x1ebfa2519b0>

Step 400 Batch loss 0.0024970921


<matplotlib.figure.Figure at 0x1ec0422fb70>

Step 500 Batch loss 0.0022026335


<matplotlib.figure.Figure at 0x1ebeb8b0c50>

Step 600 Batch loss 0.0024576653


<matplotlib.figure.Figure at 0x1ebf1150dd8>

New data, epoch 2
Step 0 Batch loss 0.8107532


<matplotlib.figure.Figure at 0x1ec01b04160>

Step 100 Batch loss 0.00191311


<matplotlib.figure.Figure at 0x1ebeb686e80>

Step 200 Batch loss 0.0014569368


<matplotlib.figure.Figure at 0x1ebfa240080>

Step 300 Batch loss 0.0012985464


<matplotlib.figure.Figure at 0x1ebeb8fc7f0>

Step 400 Batch loss 0.0015577151


<matplotlib.figure.Figure at 0x1ebeb8a22b0>

Step 500 Batch loss 0.0010756091


<matplotlib.figure.Figure at 0x1ec041d7438>

Step 600 Batch loss 0.0009408433


<matplotlib.figure.Figure at 0x1ebf11195c0>

New data, epoch 3
Step 0 Batch loss 0.7937352


<matplotlib.figure.Figure at 0x1ebfe904550>

Step 100 Batch loss 0.0010361391


<matplotlib.figure.Figure at 0x1ebefe34e80>

Step 200 Batch loss 0.001043925


<matplotlib.figure.Figure at 0x1ebfe9214e0>

Step 300 Batch loss 0.000841864


<matplotlib.figure.Figure at 0x1ebecbd2eb8>

Step 400 Batch loss 0.00083522894


<matplotlib.figure.Figure at 0x1ebfa2f1438>

Step 500 Batch loss 0.0007012048


<matplotlib.figure.Figure at 0x1ebf11816d8>

Step 600 Batch loss 0.0006273109


<matplotlib.figure.Figure at 0x1ebf115f048>

New data, epoch 4
Step 0 Batch loss 0.62024146


<matplotlib.figure.Figure at 0x1ebfa240198>

Step 100 Batch loss 0.000759213


<matplotlib.figure.Figure at 0x1ec01b66908>

Step 200 Batch loss 0.0007472509


<matplotlib.figure.Figure at 0x1ec01b06e80>

Step 300 Batch loss 0.00067188643


<matplotlib.figure.Figure at 0x1ec01a1bf28>

Step 400 Batch loss 0.00056284555


<matplotlib.figure.Figure at 0x1ebeb5dc128>

Step 500 Batch loss 0.00048423384


<matplotlib.figure.Figure at 0x1ebeb7f2438>

Step 600 Batch loss 0.0005423832


<matplotlib.figure.Figure at 0x1ec01bea470>