Skip to content

Commit

Permalink
Remove previous state model variable, track by hand in StreamingState…
Browse files Browse the repository at this point in the history
… instead
  • Loading branch information
reuben committed Jun 6, 2019
1 parent 751c7a8 commit 2dfce92
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 196 deletions.
84 changes: 31 additions & 53 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,12 +574,8 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
# no state management since n_step is expected to be dynamic too (see below)
previous_state = previous_state_c = previous_state_h = None
else:
if tflite:
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
else:
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')

previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)

Expand All @@ -605,7 +601,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
logits = tf.squeeze(logits, [1])

# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits)
logits = tf.nn.softmax(logits, name='logits')

if batch_size <= 0:
if tflite:
Expand All @@ -618,51 +614,31 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
'input_lengths': seq_length,
},
{
'outputs': tf.identity(logits, name='logits'),
'outputs': logits,
},
layers
)

new_state_c, new_state_h = layers['rnn_output_state']
if tflite:
logits = tf.identity(logits, name='logits')
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')

inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}

if FLAGS.use_seq_length:
inputs.update({'input_lengths': seq_length})

outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
else:
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
initialize_c = tf.assign(previous_state_c, zero_state)
initialize_h = tf.assign(previous_state_h, zero_state)
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
logits = tf.identity(logits, name='logits')

inputs = {
'input': input_tensor,
'input_lengths': seq_length,
'input_samples': input_samples,
}
outputs = {
'outputs': logits,
'initialize_state': initialize_state,
'mfccs': mfccs,
}
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')

inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}

if FLAGS.use_seq_length:
inputs.update({'input_lengths': seq_length})

outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}

return inputs, outputs, layers

Expand All @@ -682,10 +658,12 @@ def export():
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = ",".join(output_names_tensors + output_names_ops)

if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
else:
mapping = None
if FLAGS.export_tflite:
# Create a saver using variables from the above newly created graph
# Training graph uses LSTMFusedCell, but the TFLite inference graph uses
# a static RNN with a normal cell, so we need to rewrite the names to
# match the training weights when restoring.
def fixup(name):
if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
Expand All @@ -710,7 +688,7 @@ def fixup(name):
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)

def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
return freeze_graph.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=saver.as_saver_def(),
Expand All @@ -724,7 +702,7 @@ def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklis
initializer_nodes='')

if not FLAGS.export_tflite:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
frozen_graph = do_graph_freeze(output_node_names=output_names)
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())

# Add a no-op node to the graph with metadata information to be loaded by the native client
Expand All @@ -740,7 +718,7 @@ def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklis
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
frozen_graph = do_graph_freeze(output_node_names=output_names)
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))

converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
Expand Down
12 changes: 1 addition & 11 deletions native_client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -114,34 +114,24 @@ tf_cc_shared_object(
### => Trying to be more fine-grained
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
### CPU only build, libdeepspeech.so file size reduced by ~50%
"//tensorflow/core/kernels:dense_update_ops", # Assign
"//tensorflow/core/kernels:constant_op", # Const
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst
"//tensorflow/core/kernels:constant_op", # Placeholder
"//tensorflow/core/kernels:identity_op", # Identity
"//tensorflow/core/kernels:softmax_op", # Softmax
"//tensorflow/core/kernels:transpose_op", # Transpose
"//tensorflow/core/kernels:reshape_op", # Reshape
"//tensorflow/core/kernels:shape_ops", # Shape
"//tensorflow/core/kernels:concat_op", # ConcatV2
"//tensorflow/core/kernels:split_op", # Split
"//tensorflow/core/kernels:variable_ops", # VariableV2
"//tensorflow/core/kernels:relu_op", # Relu
"//tensorflow/core/kernels:bias_op", # BiasAdd
"//tensorflow/core/kernels:math", # Range, MatMul
"//tensorflow/core/kernels:control_flow_ops", # Enter
"//tensorflow/core/kernels:tile_ops", # Tile
"//tensorflow/core/kernels:gather_op", # Gather
"//tensorflow/core/kernels:mfcc_op", # Mfcc
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
"//tensorflow/core/kernels:strided_slice_op", # StridedSlice
"//tensorflow/core/kernels:slice_op", # Slice, needed by StridedSlice
"//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM
"//tensorflow/core/kernels:random_ops", # RandomGammaGrad
"//tensorflow/core/kernels:pack_op", # Pack
"//tensorflow/core/kernels:gather_nd_op", # GatherNd
#### Needed by production model produced without "--use_seq_length False"
#"//tensorflow/core/kernels:logging_ops", # Assert
#"//tensorflow/core/kernels:reverse_sequence_op", # ReverseSequence
],
}) + if_cuda([
"//tensorflow/core:core",
Expand Down
18 changes: 12 additions & 6 deletions native_client/deepspeech.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ struct StreamingState {
vector<float> audio_buffer_;
vector<float> mfcc_buffer_;
vector<float> batch_buffer_;
vector<float> previous_state_c_;
vector<float> previous_state_h_;

ModelState* model_;
DecoderState* decoder_state_;

Expand Down Expand Up @@ -237,7 +240,13 @@ void
StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
{
vector<float> logits;
model_->infer(buf.data(), n_steps, logits);
model_->infer(buf,
n_steps,
previous_state_c_,
previous_state_h_,
logits,
previous_state_c_,
previous_state_h_);

const int cutoff_top_n = 40;
const double cutoff_prob = 1.0;
Expand Down Expand Up @@ -330,11 +339,6 @@ DS_SetupStream(ModelState* aCtx,
{
*retval = nullptr;

int err = aCtx->initialize_state();
if (err != DS_ERR_OK) {
return err;
}

std::unique_ptr<StreamingState> ctx(new StreamingState());
if (!ctx) {
std::cerr << "Could not allocate streaming state." << std::endl;
Expand All @@ -352,6 +356,8 @@ DS_SetupStream(ModelState* aCtx,
ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_);
ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f);
ctx->batch_buffer_.reserve(aCtx->n_steps_ * aCtx->mfcc_feats_per_timestep_);
ctx->previous_state_c_.reserve(aCtx->state_size_);
ctx->previous_state_h_.reserve(aCtx->state_size_);
ctx->model_ = aCtx;

DecoderState *params = decoder_init(*aCtx->alphabet_, num_classes, aCtx->scorer_);
Expand Down
1 change: 1 addition & 0 deletions native_client/modelstate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ ModelState::ModelState()
, sample_rate_(DEFAULT_SAMPLE_RATE)
, audio_win_len_(DEFAULT_WINDOW_LENGTH)
, audio_win_step_(DEFAULT_WINDOW_STEP)
, state_size_(-1)
{
}

Expand Down
11 changes: 8 additions & 3 deletions native_client/modelstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct ModelState {
unsigned int sample_rate_;
unsigned int audio_win_len_;
unsigned int audio_win_step_;
unsigned int state_size_;

ModelState();
virtual ~ModelState();
Expand All @@ -38,8 +39,6 @@ struct ModelState {
const char* alphabet_path,
unsigned int beam_width);

virtual int initialize_state() = 0;

virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;

/**
Expand All @@ -52,7 +51,13 @@ struct ModelState {
*
* @param[out] output_logits Where to store computed logits.
*/
virtual void infer(const float* mfcc, unsigned int n_frames, std::vector<float>& logits_output) = 0;
virtual void infer(const std::vector<float>& mfcc,
unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
std::vector<float>& logits_output,
std::vector<float>& state_c_output,
std::vector<float>& state_h_output) = 0;

/**
* @brief Perform decoding of the logits, using basic CTC decoder or
Expand Down
Loading

0 comments on commit 2dfce92

Please sign in to comment.