[View in Colaboratory](https://colab.research.google.com/github/muik/notebooks/blob/master/tensorflow/use_cudnn_lstm_on_cpu.ipynb)

# Tensorflow CuDNN LSTM example

Use Tensorflow CuDNN LSTM trained model on CPU by CudnnCompatibleLSTMCell

In [0]:
import tensorflow as tf
print(tf.__version__)

from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn
from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
from tensorflow.python.ops import rnn_cell

MODEL_SAVE_PATH = './model'

num_layers = 2 # rnn layers
num_dirs = 2 # bidirectional
num_units = 3 # rnn hidden
batch_size = 5
time_len = 12
input_size = 4 # input dim

def build_graph(is_training=False):
  lstm = cudnn_rnn.CudnnLSTM(num_layers=num_layers, num_units=num_units,
                             direction=cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION)
  inputs = tf.random_uniform((batch_size, time_len, input_size), seed=1)
  inputs = tf.transpose(inputs, [1, 0, 2]) # [time_len, batch_size, input_size]
  
  initial_state = tf.random_uniform((batch_size, num_units), seed=2)
  initial_state = tf.expand_dims(initial_state, 0)
  c = tf.concat([initial_state for _ in range(num_layers*num_dirs)], 0)
  h = tf.zeros([num_layers * num_dirs, batch_size, num_units])
  initial_state = (h, c)
  
  outputs, (output_h, output_c) = lstm(inputs, initial_state=initial_state,
                                       training=is_training)
  
  # [time_len, batch_size, num_dirs * num_units] > [batch_size, time_len, num_dirs * num_units]
  outputs = tf.transpose(outputs, [1, 0, 2])
  
  last_h_state = tf.concat([output_h[-2], output_h[-1]], 1)
  last_c_state = tf.concat([output_c[-2], output_c[-1]], 1)
  return outputs, last_c_state, last_h_state

def build_compat_graph():
  base_cell = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell
  single_cell = lambda: base_cell(num_units)
  cells_fw = [single_cell() for _ in range(num_layers)]
  cells_bw = [single_cell() for _ in range(num_layers)]
  inputs = tf.random_uniform((batch_size, time_len, input_size), seed=1)
  
  initial_state = tf.random_uniform((batch_size, num_units), seed=2)
  c = initial_state
  h = tf.zeros([batch_size, num_units])
  state_tuple = rnn_cell.LSTMStateTuple(c, h)
  initial_states_fw = initial_states_bw = [state_tuple] * num_layers
  
  (outputs, output_state_fw,
   output_state_bw) = contrib_rnn.stack_bidirectional_dynamic_rnn(
      cells_fw, cells_bw, inputs, dtype=tf.float32,
      initial_states_fw=initial_states_fw, initial_states_bw=initial_states_bw,
      time_major=False, scope='cudnn_lstm/stack_bidirectional_rnn')
  last_h_state = tf.concat([output_state_fw[-1].h, output_state_bw[-1].h], 1)
  last_c_state = tf.concat([output_state_fw[-1].c, output_state_bw[-1].c], 1)
  return outputs, last_c_state, last_h_state

# Cudnn training
with tf.Graph().as_default() as graph:
  ops = build_graph(is_training=True)
  saver = tf.train.Saver()
  
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    outputs, c_state, h_state = sess.run(ops)
    print(outputs[0][0], c_state[0], h_state[0])
    saver.save(sess, MODEL_SAVE_PATH)

# Cudnn restore & inference
with tf.Graph().as_default() as graph:
  ops = build_graph(is_training=False)
  saver = tf.train.Saver()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, MODEL_SAVE_PATH)
    outputs, c_state, h_state = sess.run(ops)
    print(outputs[0][0], c_state[0], h_state[0])

# CudnnCompatible restore & inference
with tf.Graph().as_default() as graph:
  ops = build_compat_graph()
  saver = tf.train.Saver()
  
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, MODEL_SAVE_PATH)
    outputs, c_state, h_state = sess.run(ops)
    print(outputs[0][0], c_state[0], h_state[0])

1.7.0
(array([ 0.16085729,  0.12783323,  0.07822829, -0.00185133, -0.05733256,
       -0.04276047], dtype=float32), array([-0.0851168 ,  0.12208229,  0.00152305, -0.00415992, -0.12166481,
       -0.08929548], dtype=float32), array([-0.04304137,  0.0662911 ,  0.00084107, -0.00185133, -0.05733256,
       -0.04276047], dtype=float32))
INFO:tensorflow:Restoring parameters from ./model
(array([ 0.16085729,  0.12783323,  0.07822829, -0.00185133, -0.05733256,
       -0.04276047], dtype=float32), array([-0.0851168 ,  0.12208229,  0.00152305, -0.00415992, -0.12166481,
       -0.08929548], dtype=float32), array([-0.04304137,  0.0662911 ,  0.00084107, -0.00185133, -0.05733256,
       -0.04276047], dtype=float32))
INFO:tensorflow:Restoring parameters from ./model
(array([ 0.16085729,  0.12783323,  0.07822829, -0.00185133, -0.05733256,
       -0.04276047], dtype=float32), array([-0.0851168 ,  0.12208231,  0.00152306, -0.00415991, -0.12166481,
       -0.08929548], dtype=float32), array([-0.04304137,