Skip to content

Commit

Permalink
[scripts] Documentation fixes. Thanks: Rongjin Li.
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey committed Nov 11, 2017
1 parent 05b2aed commit b952cf3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
20 changes: 10 additions & 10 deletions egs/wsj/s5/steps/libs/nnet3/report/log_parse.py
Expand Up @@ -358,16 +358,16 @@ def parse_prob_logs(exp_dir, key='accuracy', output="output"):
"nnet.*diagnostics.cc:[0-9]+. Overall ([a-zA-Z\-]+) for "
"'{output}'.*is ([0-9.\-e]+) .*per frame".format(output=output))

train_loss = {}
valid_loss = {}
train_objf = {}
valid_objf = {}

for line in train_prob_strings.split('\n'):
mat_obj = parse_regex.search(line)
if mat_obj is not None:
groups = mat_obj.groups()
if groups[1] == key:
train_loss[int(groups[0])] = groups[2]
if not train_loss:
train_objf[int(groups[0])] = groups[2]
if not train_objf:
raise KaldiLogParseException("Could not find any lines with {k} in "
" {l}".format(k=key, l=train_prob_files))

Expand All @@ -376,20 +376,20 @@ def parse_prob_logs(exp_dir, key='accuracy', output="output"):
if mat_obj is not None:
groups = mat_obj.groups()
if groups[1] == key:
valid_loss[int(groups[0])] = groups[2]
valid_objf[int(groups[0])] = groups[2]

if not valid_loss:
if not valid_objf:
raise KaldiLogParseException("Could not find any lines with {k} in "
" {l}".format(k=key, l=valid_prob_files))

iters = list(set(valid_loss.keys()).intersection(train_loss.keys()))
iters = list(set(valid_objf.keys()).intersection(train_objf.keys()))
if not iters:
raise KaldiLogParseException("Could not any common iterations with"
" key {k} in both {tl} and {vl}".format(
k=key, tl=train_prob_files, vl=valid_prob_files))
iters.sort()
return map(lambda x: (int(x), float(train_loss[x]),
float(valid_loss[x])), iters)
return map(lambda x: (int(x), float(train_objf[x]),
float(valid_objf[x])), iters)



Expand All @@ -402,7 +402,7 @@ def generate_acc_logprob_report(exp_dir, key="accuracy", output="output"):
times = []

report = []
report.append("%Iter\tduration\ttrain_loss\tvalid_loss\tdifference")
report.append("%Iter\tduration\ttrain_objective\tvalid_objective\tdifference")
try:
data = list(parse_prob_logs(exp_dir, key, output))
except:
Expand Down
26 changes: 17 additions & 9 deletions egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py
Expand Up @@ -44,16 +44,24 @@ def train_new_models(dir, iter, srand, num_jobs,
but we use the same script for consistency with FF-DNN code
Selected args:
frames_per_eg: The default value -1 implies chunk_level_training, which
is particularly applicable to RNN training. If it is > 0, then it
implies frame-level training, which is applicable for DNN training.
If it is > 0, then each parallel SGE job created, a different frame
numbered 0..frames_per_eg-1 is used.
frames_per_eg:
The frames_per_eg, in the context of (non-chain) nnet3 training,
is normally the number of output (supervised) frames in each training
example. However, the frames_per_eg argument to this function should
only be set to that number (greater than zero) if you intend to
train on a single frame of each example, on each minibatch. If you
provide this argument >0, then for each training job a different
frame from the dumped example is selected to train on, based on
the option --frame=n to nnet3-copy-egs.
If you leave frames_per_eg at its default value (-1), then the
entire sequence of frames is used for supervision. This is suitable
for RNN training, where it helps to amortize the cost of computing
the activations for the frames of context needed for the recurrence.
use_multitask_egs : True, if different examples used to train multiple
tasks or outputs, e.g.multilingual training.
multilingual egs can be generated using get_egs.sh and
steps/nnet3/multilingual/allocate_multilingual_examples.py,
those are the top-level scripts.
tasks or outputs, e.g.multilingual training. multilingual egs can
be generated using get_egs.sh and
steps/nnet3/multilingual/allocate_multilingual_examples.py, those
are the top-level scripts.
"""

chunk_level_training = False if frames_per_eg > 0 else True
Expand Down

0 comments on commit b952cf3

Please sign in to comment.