# Understanding RNN structure
- Distinguished from feedforward nets, RNNs are structures that can well handle data with "sequential" format by preserving previous "state" 
- Thus, grasping concepts of **"sequences"** and (hidden) **"states"** in RNNs is crucial

<br>
<img src="http://karpathy.github.io/assets/rnn/charseq.jpeg" style="width: 500px"/>

In [8]:
import numpy as np
from keras.models import Model, Sequential
from keras.layers import *

In [9]:
import keras
keras.__version__

'3.0.5'

## 1. SimpleRNN 

Input shape of SimpleRNN should be 3D tensor => (batch_size, timesteps, input_dim)
- **batch_size**: ommitted when creating RNN instance (== None). Usually designated when fitting model.
- **timesteps**: number of input sequence per batch
- **input_dim**: dimensionality of input sequence

In [10]:
# for instance, consider below array
x = np.array([[
             [1,    # => input_dim 1
              2,    # => input_dim 2 
              3],   # => input_dim 3     # => timestep 1                            
             [4, 5, 6]                   # => timestep 2
             ],                                  # => batch 1
             [[7, 8, 9], [10, 11, 12]],          # => batch 2
             [[13, 14, 15], [16, 17, 18]]        # => batch 3
             ])

In [4]:
[[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]],[[13,14,15],[16,17,18]]]

[[[1, 2, 3], [4, 5, 6]],
 [[7, 8, 9], [10, 11, 12]],
 [[13, 14, 15], [16, 17, 18]]]

In [None]:
[[1, 2, 3], [4, 5, 6]]

In [1]:
[1,2,3] -> 1d, 2d, 3d ->->->1d

[1, 2, 3]

In [5]:
import numpy as np
np.array([1,2,3,4,5]).shape

(5,)

In [6]:
np.array([[1,2,3,4,5]]).shape

(1, 5)

In [7]:
np.array([[[1,2,3,4,5]]]).shape

(1, 1, 5)

In [11]:
print('(Batch size, timesteps, input_dim) = ',x.shape)

(Batch size, timesteps, input_dim) =  (3, 2, 3)


In [6]:
# rnn = SimpleRNN(50)(Input(shape = (10,))) => error
# rnn = SimpleRNN(50)(Input(shape = (10, 30, 40))) => error
rnn = SimpleRNN(50)(Input(shape = (10, 30)))

**return_state** = **return_sequences** = **False** ====> output_shape = **(batch_size = None, num_units)**

In [15]:
rnn = SimpleRNN(50)(Input(shape = (10, 30)))
print(rnn.shape)

(None, 50)


**return_sequences = True** ====> output_shape = **(batch_size, timesteps, num_units)**

In [16]:
rnn = SimpleRNN(50, return_sequences = True)(Input(shape = (10, 30)))
print(rnn.shape)

(None, 10, 50)


return_state = True ===> outputs list of tensor: **[output, state]**
- if return_sequences == False     =>>    output_shape = (batch_size, num_units)
- if return_sequences == True      =>>    output_shape = (batch_size, timesteps, num_units)

In [18]:
rnn = SimpleRNN(50, return_sequences = False, return_state = True)(Input(shape = (10, 30)))
print(rnn[0].shape)         # shape of output
print(rnn[1].shape)         # shape of last state

(None, 50)
(None, 50)


In [20]:
rnn = SimpleRNN(50, return_sequences = True, return_state = True)(Input(shape = (10, 30)))
print(rnn[0].shape)         # shape of output
print(rnn[1].shape)         # shape of last state


(None, 10, 50)
(None, 50)


Current output and state can be unpacked as below

In [26]:
output, state = SimpleRNN(50, return_sequences = True, return_state = True)(Input(shape = (10, 30)))

In [22]:
print(output.shape)
print(state.shape)

(None, 10, 50)
(None, 50)


## 2. LSTM
- Outputs of LSTM are quite similar to those of RNNs, but there exist subtle differences
- If you compare two diagrams below, there is one more type of "state" that is preserved to next module

<br>
<img src="http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-SimpleRNN.png" style="width: 500px"/>

<center> Standard RNN </center>

<br>
<img src="http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-chain.png" style="width: 500px"/>

<center> LSTM </center>

In addition to "hidden state (ht)" in RNN, there exist "cell state (Ct)" in LSTM structure

<br>
<img src="http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-focus-o.png" style="width: 500px"/>

<center> Hidden State </center>

<br>
<img src="http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-focus-C.png" style="width: 500px"/>

<center> Cell State </center>

In [28]:
lstm = LSTM(50)(Input(shape = (10, 30)))

In [29]:
print(lstm.shape)

(None, 50)


In [30]:
lstm = LSTM(50, return_sequences = False, return_state = True)(Input(shape = (10, 30)))
print(lstm[0].shape)         # shape of output
print(lstm[1].shape)         # shape of hidden state
print(lstm[2].shape)         # shape of cell state

(None, 50)
(None, 50)
(None, 50)


In [31]:
lstm = LSTM(50, return_sequences = True, return_state = True)(Input(shape = (10, 30)))
print(lstm[0].shape)         # shape of output
print(lstm[1].shape)         # shape of hidden state
print(lstm[2].shape)         # shape of cell state

(None, 10, 50)
(None, 50)
(None, 50)


In [50]:
output, hidden_state, cell_state = LSTM(50, return_sequences = False, return_state = True)(Input(shape = (10, 30)))

In [48]:
print(output.shape)
print(hidden_state.shape)
print(cell_state.shape)

(None, 50)
(None, 50)
(None, 50)


In [None]:
output, hidden_state, cell_state = LSTM(50, return_sequences = False, return_state = False)(Input(shape = (10, 30)))

In [None]:
print(output.shape)
print(hidden_state.shape) ->  None
print(cell_state.shape)  ->  None

## 3. GRU
- GRU, Popular variant of LSTM, does not have cell state
- Hence, it has only hidden state, as simple RNN

In [41]:
gru = GRU(50)(Input(shape = (10, 30)))

In [42]:
print(gru.shape)

(None, 50)


In [43]:
gru = GRU(50, return_sequences = False, return_state = True)(Input(shape = (10, 30)))
print(gru[0].shape)         # shape of output
print(gru[1].shape)         # shape of hidden state

(None, 50)
(None, 50)


In [44]:
gru = GRU(50, return_sequences = True, return_state = True)(Input(shape = (10, 30)))
print(gru[0].shape)         # shape of output
print(gru[1].shape)         # shape of hidden state

(None, 10, 50)
(None, 50)


In [45]:
output, hidden_state = GRU(50, return_sequences = True, return_state = True)(Input(shape = (10, 30)))

In [46]:
print(output.shape)
print(hidden_state.shape)

(None, 10, 50)
(None, 50)
