Skip to content

Commit

Permalink
wip tflite
Browse files Browse the repository at this point in the history
  • Loading branch information
reuben committed Aug 30, 2018
1 parent 9fffafd commit 57da516
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions DeepSpeech.py
Expand Up @@ -435,16 +435,22 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
# Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM.

# Forward direction cell:
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(n_cell_dim, reuse=reuse)
layers['fw_cell'] = fw_cell
# fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(n_cell_dim, reuse=reuse)
# layers['fw_cell'] = fw_cell

fw_cell = tf.nn.rnn_cell.LSTMCell(n_cell_dim, reuse=reuse)

# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [n_steps, batch_size, n_hidden_3])
layer_3 = tf.unstack(layer_3, n_steps)

This comment has been minimized.

Copy link
@lissyx

lissyx Oct 22, 2018

Collaborator

this still makes TF Lite choke :(


# We parametrize the RNN implementation as the training and inference graph
# need to do different things here.
output, output_state = fw_cell(inputs=layer_3, dtype=tf.float32, sequence_length=seq_length, initial_state=previous_state)
# output, output_state = fw_cell(inputs=layer_3, dtype=tf.float32, sequence_length=seq_length, initial_state=previous_state)

output, output_state = tf.nn.static_rnn(fw_cell, layer_3, previous_state, tf.float32, seq_length)
output = tf.concat(output, 0)

# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
Expand Down Expand Up @@ -1759,8 +1765,10 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, n_input + 2*n_input*n_context], name='input_node')
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')

previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None)
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, n_cell_dim], initializer=None)
# previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None)
# previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, n_cell_dim], initializer=None)
previous_state_c = tf.placeholder(tf.float32, [batch_size, n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, n_cell_dim], name='previous_state_h')
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)

logits, layers = BiRNN(batch_x=input_tensor,
Expand All @@ -1773,15 +1781,9 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
new_state_c, new_state_h = layers['rnn_output_state']

# Initial zero state
zero_state = tf.zeros([batch_size, n_cell_dim], tf.float32)

initialize_c = tf.assign(previous_state_c, zero_state)
initialize_h = tf.assign(previous_state_h, zero_state)

initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
zero_state = tf.zeros([batch_size, n_cell_dim], tf.float32, name='zero_state')

with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
logits = tf.identity(logits, name='logits')
logits = tf.identity(logits, name='logits')

return (
{
Expand All @@ -1790,7 +1792,6 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
},
{
'outputs': logits,
'initialize_state': initialize_state,
}
)

Expand All @@ -1808,7 +1809,12 @@ def export():
inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps)

# Create a saver using variables from the above newly created graph
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
def fixup(name):
if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'lstm_block_wrapper/')
return name

mapping = {fixup(v.op.name): v for v in tf.global_variables()}
saver = tf.train.Saver(mapping)

# Restore variables from training checkpoint
Expand All @@ -1822,6 +1828,10 @@ def export():
try:
output_graph_path = os.path.join(FLAGS.export_dir, 'output_graph.pb')

with tf.gfile.FastGFile(output_graph_path + 'txt', 'w') as fout:
from google.protobuf import text_format
fout.write(text_format.MessageToString(session.graph_def))

if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)

Expand All @@ -1830,13 +1840,12 @@ def export():
input_graph_def=session.graph_def,
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path,
output_node_names='logits,initialize_state',
output_node_names='logits',
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_graph_path,
clear_devices=False,
initializer_nodes='',
variable_names_blacklist='previous_state_c,previous_state_h')
initializer_nodes='')

log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e:
Expand Down

0 comments on commit 57da516

Please sign in to comment.