# 一些简单的demo，显示tf里面rnn/gru和lstm cell的区别

In [1]:
import logging
import time
import numpy as np
import tensorflow as tf

In [23]:
batch_size = 16
dropout = 0.0
embedding_size = 64
hidden_size = 128
input_dropout = 0.0
learning_rate = 0.005
max_grad_norm = 5.0
model = 'rnn'
num_layers = 2
num_unrollings = 10
vocab_size = 26
is_training = True

In [24]:
sess = tf.InteractiveSession()

In [25]:
if model == 'rnn':
    cell_fn = tf.nn.rnn_cell.BasicRNNCell
elif model == 'lstm':
    cell_fn = tf.nn.rnn_cell.BasicLSTMCell
elif model == 'gru':
    cell_fn = tf.nn.rnn_cell.GRUCell

In [26]:
params = dict()
if model == 'lstm':
    params['forget_bias'] = 1.0  # 1.0 is default value
cell = cell_fn(hidden_size, **params)

cells = [cell]
for i in range(num_layers-1):
    higher_layer_cell = cell_fn(hidden_size, **params)
    cells.append(higher_layer_cell)
    
multi_cell = tf.nn.rnn_cell.MultiRNNCell(cells)

### tf built in cell有一个zero_state方法

In [27]:
multi_cell.zero_state(batch_size, tf.float32)

(<tf.Tensor 'MultiRNNCellZeroState_2/BasicRNNCellZeroState/zeros:0' shape=(16, 128) dtype=float32>,
 <tf.Tensor 'MultiRNNCellZeroState_2/BasicRNNCellZeroState_1/zeros:0' shape=(16, 128) dtype=float32>)

* zero\_state if the rnn\_cell is 'rnn' or 'gru'
```python
(<tf.Tensor 'MultiRNNCellZeroState/GRUCellZeroState/zeros:0' shape=(16, 128) dtype=float32>,
 <tf.Tensor 'MultiRNNCellZeroState/GRUCellZeroState_1/zeros:0' shape=(16, 128) dtype=float32>)
```

* zero\_state if the rnn\_cell is 'lstm'
```python
(LSTMStateTuple(c=<tf.Tensor 'MultiRNNCellZeroState/BasicLSTMCellZeroState/zeros:0' shape=(16, 128) dtype=float32>, h=<tf.Tensor 'MultiRNNCellZeroState_3/BasicLSTMCellZeroState/zeros_1:0' shape=(16, 128) dtype=float32>),
 LSTMStateTuple(c=<tf.Tensor 'MultiRNNCellZeroState/BasicLSTMCellZeroState_1/zeros:0' shape=(16, 128) dtype=float32>, h=<tf.Tensor 'MultiRNNCellZeroState_3/BasicLSTMCellZeroState_1/zeros_1:0' shape=(16, 128) dtype=float32>))
```

### 我们产生可以 feed 给init_state的 zero array

In [29]:

with tf.name_scope('initial_state'):
    zero_state = multi_cell.zero_state(batch_size, tf.float32)
    if model == 'rnn' or model == 'gru':
        initial_state = tuple(
            [tf.placeholder(tf.float32,
                            [batch_size, multi_cell.state_size[idx]],
                            'initial_state_'+str(idx+1))
             for idx in range(num_layers)])
    elif model == 'lstm':
        initial_state = tuple(
            [tf.nn.rnn_cell.LSTMStateTuple(
                tf.placeholder(tf.float32, [batch_size, multi_cell.state_size[idx][0]],
                               'initial_lstm_state_'+str(idx+1)),
                tf.placeholder(tf.float32, [batch_size, multi_cell.state_size[idx][1]],
                               'initial_lstm_state_'+str(idx+1)))
             for idx in range(num_layers)])

* state\_size if the rnn\_cell is 'rnn' or 'gru'
```python
(128, 128)
```
* state\_size if the rnn\_cell is 'lstm'
```python
(LSTMStateTuple(c=128, h=128), LSTMStateTuple(c=128, h=128))
```