Skip to content

Commit

Permalink
[scripts] Make nnet3 logging output look more like chain logging outp…
Browse files Browse the repository at this point in the history
…ut (#2338)
  • Loading branch information
Ore-an authored and danpovey committed Apr 6, 2018
1 parent 1a1e265 commit 60862b0
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 10 deletions.
10 changes: 0 additions & 10 deletions egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def train_one_iteration(dir, iter, srand, egs_dir,

# Set off jobs doing some diagnostics, in the background.
# Use the egs dir from the previous iteration for the diagnostics
logger.info("Training neural net (pass {0})".format(iter))

# check if different iterations use the same random seed
if os.path.exists('{0}/srand'.format(dir)):
Expand Down Expand Up @@ -257,15 +256,6 @@ def train_one_iteration(dir, iter, srand, egs_dir,
cur_minibatch_size_str = common_train_lib.halve_minibatch_size_str(minibatch_size_str)
cur_max_param_change = float(max_param_change) / math.sqrt(2)

shrink_info_str = ''
if shrinkage_value != 1.0:
shrink_info_str = ' and shrink value is {0}'.format(shrinkage_value)

logger.info("On iteration {0}, learning rate is {1}"
"{shrink_info}.".format(
iter, learning_rate,
shrink_info=shrink_info_str))

train_new_models(dir=dir, iter=iter, srand=srand, num_jobs=num_jobs,
num_archives_processed=num_archives_processed,
num_archives=num_archives,
Expand Down
13 changes: 13 additions & 0 deletions egs/wsj/s5/steps/nnet3/train_dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,19 @@ def train(args, run_opts):
"shrink-value={1}".format(args.proportional_shrink,
shrinkage_value))

percent = num_archives_processed * 100.0 / num_archives_to_process
epoch = (num_archives_processed * args.num_epochs
/ num_archives_to_process)
shrink_info_str = ''
if shrinkage_value != 1.0:
shrink_info_str = 'shrink: {0:0.5f}'.format(shrinkage_value)
logger.info("Iter: {0}/{1} "
"Epoch: {2:0.2f}/{3:0.1f} ({4:0.1f}% complete) "
"lr: {5:0.6f} {6}".format(iter, num_iters - 1,
epoch, args.num_epochs,
percent,
lrate, shrink_info_str))

train_lib.common.train_one_iteration(
dir=args.dir,
iter=iter,
Expand Down
13 changes: 13 additions & 0 deletions egs/wsj/s5/steps/nnet3/train_raw_dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,19 @@ def train(args, run_opts):
"shrink-value={1}".format(args.proportional_shrink,
shrinkage_value))

percent = num_archives_processed * 100.0 / num_archives_to_process
epoch = (num_archives_processed * args.num_epochs
/ num_archives_to_process)
shrink_info_str = ''
if shrinkage_value != 1.0:
shrink_info_str = 'shrink: {0:0.5f}'.format(shrinkage_value)
logger.info("Iter: {0}/{1} "
"Epoch: {2:0.2f}/{3:0.1f} ({4:0.1f}% complete) "
"lr: {5:0.6f} {6}".format(iter, num_iters - 1,
epoch, args.num_epochs,
percent,
lrate, shrink_info_str))

train_lib.common.train_one_iteration(
dir=args.dir,
iter=iter,
Expand Down
13 changes: 13 additions & 0 deletions egs/wsj/s5/steps/nnet3/train_raw_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,19 @@ def train(args, run_opts):
get_raw_nnet_from_am=False)
else shrinkage_value)

percent = num_archives_processed * 100.0 / num_archives_to_process
epoch = (num_archives_processed * args.num_epochs
/ num_archives_to_process)
shrink_info_str = ''
if shrinkage_value != 1.0:
shrink_info_str = 'shrink: {0:0.5f}'.format(shrinkage_value)
logger.info("Iter: {0}/{1} "
"Epoch: {2:0.2f}/{3:0.1f} ({4:0.1f}% complete) "
"lr: {5:0.6f} {6}".format(iter, num_iters - 1,
epoch, args.num_epochs,
percent,
lrate, shrink_info_str))

train_lib.common.train_one_iteration(
dir=args.dir,
iter=iter,
Expand Down
13 changes: 13 additions & 0 deletions egs/wsj/s5/steps/nnet3/train_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,19 @@ def train(args, run_opts):
iter, model_file,
args.shrink_saturation_threshold) else 1.0)

percent = num_archives_processed * 100.0 / num_archives_to_process
epoch = (num_archives_processed * args.num_epochs
/ num_archives_to_process)
shrink_info_str = ''
if shrinkage_value != 1.0:
shrink_info_str = 'shrink: {0:0.5f}'.format(shrinkage_value)
logger.info("Iter: {0}/{1} "
"Epoch: {2:0.2f}/{3:0.1f} ({4:0.1f}% complete) "
"lr: {5:0.6f} {6}".format(iter, num_iters - 1,
epoch, args.num_epochs,
percent,
lrate, shrink_info_str))

train_lib.common.train_one_iteration(
dir=args.dir,
iter=iter,
Expand Down

0 comments on commit 60862b0

Please sign in to comment.