Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
reuben committed Aug 3, 2018
1 parent 75091a8 commit 51cac10
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 173 deletions.
10 changes: 0 additions & 10 deletions DeepSpeech.py
Expand Up @@ -131,7 +131,6 @@ def create_flags():
# Initialization

tf.app.flags.DEFINE_integer ('random_seed', 4567, 'default random seed that is used to initialize variables')
tf.app.flags.DEFINE_float ('default_stddev', 0.046875, 'default standard deviation to use when initialising weights and biases')

# Early Stopping

Expand Down Expand Up @@ -163,9 +162,6 @@ def create_flags():

tf.app.flags.DEFINE_string ('initialize_from_frozen_model', '', 'path to frozen model to initialize from. This behaves like a checkpoint, loading the weights from the frozen model and starting training with those weights. The optimizer parameters aren\'t restored, so remember to adjust the learning rate.')

for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']:
tf.app.flags.DEFINE_float('%s_stddev' % var, None, 'standard deviation to use when initialising %s' % var)

FLAGS = tf.app.flags.FLAGS

def initialize_globals():
Expand Down Expand Up @@ -285,12 +281,6 @@ def initialize_globals():
global n_hidden_6
n_hidden_6 = n_character

# Assign default values for standard deviation
for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']:
val = getattr(FLAGS, '%s_stddev' % var)
if val is None:
setattr(FLAGS, '%s_stddev' % var, FLAGS.default_stddev)

# Queues that are used to gracefully stop parameter servers.
# Each queue stands for one ps. A finishing worker sends a token to each queue before joining/quitting.
# Each ps will dequeue as many tokens as there are workers before joining/quitting.
Expand Down
2 changes: 1 addition & 1 deletion bin/ops_in_graph.py
Expand Up @@ -6,6 +6,6 @@

with tf.gfile.FastGFile(sys.argv[1], 'rb') as fin:
graph_def = tf.GraphDef()
graph_def.MergeFromString(fin.read())
graph_def.ParseFromString(fin.read())

print('\n'.join(sorted(set(n.op for n in graph_def.node))))
30 changes: 16 additions & 14 deletions evaluate.py
Expand Up @@ -84,26 +84,28 @@ def preprocess(dataset_files, batch_size, hdf5_dest_path=None):
features, features_len, transcript, transcript_len = zip(*out_data)

with tables.open_file(hdf5_dest_path, 'w') as file:
features_dset = file.create_vlarray(file.root, 'features',
tables.Float32Atom(shape=()), filters=tables.Filters(complevel=1))
features_dset = file.create_vlarray(file.root,
'features',
tables.Float32Atom(),
filters=tables.Filters(complevel=1))
# VLArray atoms need to be 1D, so flatten feature array
for f in features:
features_dset.append(np.reshape(f, -1))

features_len_dset = file.create_array(
file.root, 'features_len', features_len)
features_len_dset = file.create_array(file.root,
'features_len',
features_len)

transcript_dset = file.create_vlarray(
file.root,
'transcript',
tables.Int32Atom(),
filters=tables.Filters(
complevel=1))
transcript_dset = file.create_vlarray(file.root,
'transcript',
tables.Int32Atom(),
filters=tables.Filters(complevel=1))
for t in transcript:
transcript_dset.append(t)

transcript_len_dset = file.create_array(
file.root, 'transcript_len', transcript_len)
transcript_len_dset = file.create_array(file.root,
'transcript_len',
transcript_len)

return pandas.DataFrame(data=out_data, columns=COLUMNS)

Expand Down Expand Up @@ -159,8 +161,8 @@ def calculate_report(labels, decodings, distances, losses):
# Order the remaining items by their loss (lowest loss on top)
samples.sort(key=lambda s: s.loss)

# Then order by WER (lowest WER on top)
samples.sort(key=lambda s: s.wer)
# Then order by WER (highest WER on top)
samples.sort(key=lambda s: s.wer, reverse=True)

return samples_wer, samples

Expand Down
22 changes: 6 additions & 16 deletions native_client/args.h
Expand Up @@ -6,20 +6,15 @@

#include "deepspeech.h"

bool has_model = false;
char* model;
char* model = NULL;

bool has_alphabet = false;
char* alphabet;
char* alphabet = NULL;

bool has_lm = false;
char* lm;
char* lm = NULL;

bool has_trie = false;
char* trie;
char* trie = NULL;

bool has_audio = false;
char* audio;
char* audio = NULL;

bool show_times = false;

Expand Down Expand Up @@ -70,27 +65,22 @@ bool ProcessArgs(int argc, char** argv)
{
case 'm':
model = optarg;
has_model = true;
break;

case 'a':
alphabet = optarg;
has_alphabet = true;
break;

case 'l':
lm = optarg;
has_lm = true;
break;

case 'r':
trie = optarg;
has_trie = true;
break;

case 'w':
audio = optarg;
has_audio = true;
break;

case 't':
Expand All @@ -114,7 +104,7 @@ bool ProcessArgs(int argc, char** argv)
return false;
}

if (!has_model || !has_alphabet || !has_audio || strlen(alphabet) == 0 || strlen(audio) == 0) {
if (!model || !alphabet || !audio) {
PrintHelp(argv[0]);
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion native_client/beam_search.h
Expand Up @@ -119,7 +119,7 @@ class KenLMBeamScorer : public tensorflow::ctc::BaseBeamScorer<KenLMBeamState> {
// score to this beam's score.
state->score += lm_weight_ * state->delta_score;
if (state->num_words > 0) {
float normalized_score = state->score /= (float)state->num_words;
float normalized_score = state->score / (float)state->num_words;
state->delta_score = normalized_score - state->score;
}
}
Expand Down
24 changes: 14 additions & 10 deletions native_client/client.cc
Expand Up @@ -32,7 +32,7 @@ typedef struct {
} ds_result;

ds_result
LocalDsSTT(ModelState* aCtx, short* aBuffer, size_t aBufferSize,
LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize,
int aSampleRate)
{
ds_result res = {0};
Expand Down Expand Up @@ -170,7 +170,7 @@ ProcessFile(ModelState* context, const char* path, bool show_times)
// We take half of buffer_size because buffer is a char* while
// LocalDsSTT() expected a short*
ds_result result = LocalDsSTT(context,
(short*)audio.buffer,
(const short*)audio.buffer,
audio.buffer_size / 2,
audio.sample_rate);
free(audio.buffer);
Expand All @@ -197,17 +197,21 @@ main(int argc, char **argv)
ModelState* ctx;
int status = DS_CreateModel(model, N_CEP, N_CONTEXT, alphabet, BEAM_WIDTH, &ctx);
if (status != 0) {
fprintf(stderr, "Could not create model.\n");
return 1;
}

if (has_lm && has_trie) {
DS_EnableDecoderWithLM(
ctx,
alphabet,
lm,
trie,
LM_WEIGHT,
VALID_WORD_COUNT_WEIGHT);
if (lm && trie) {
int status = DS_EnableDecoderWithLM(ctx,
alphabet,
lm,
trie,
LM_WEIGHT,
VALID_WORD_COUNT_WEIGHT);
if (status != 0) {
fprintf(stderr, "Could not enable CTC decoder with LM.\n");
return 1;
}
}

// Initialise SOX
Expand Down

0 comments on commit 51cac10

Please sign in to comment.