Skip to content

Commit

Permalink
move eval logging into function to unify validation and test, add fin…
Browse files Browse the repository at this point in the history
…al checkpoint

Fix #6
  • Loading branch information
8enmann committed May 2, 2019
1 parent e19426f commit adf2f62
Showing 1 changed file with 30 additions and 40 deletions.
70 changes: 30 additions & 40 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,14 @@ def weights_init(m):
###############################################################################


def evaluate(eval_iter):
def evaluate(eval_iter, split, train_step=-1):
global best_val_loss
eval_start_time = time.time()
# Turn on evaluation mode which disables dropout.
model.eval()

# Have to unwrap twice: DDP & FP16
model_to_reset = model.module.module if args.fp16 else args.module
model_to_reset = model.module.module if args.fp16 else model.module
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
if args.mem_len == 0:
Expand Down Expand Up @@ -454,7 +456,26 @@ def evaluate(eval_iter):
model_to_reset.reset_length(args.tgt_len, args.ext_len, args.mem_len)
model.train()

return total_loss / total_len
mean_loss = total_loss / total_len
logger.info('-' * 100)
log_str = (f'| Eval {train_step // args.eval_interval:3d} at step {train_step:>8d} | ' +
f'time: {time.time() - eval_start_time:5.2f}s ' +
f'| {split} loss {mean_loss:5.2f}')
if args.dataset in ['enwik8', 'text8']:
log_str += f' | bpc {mean_loss / math.log(2):9.5f}'
else:
log_str += f' | {split} ppl {math.exp(mean_loss):9.3f}'
logger.info(log_str)
logger.info('-' * 100)
log_tb(f'loss/{split}_loss', mean_loss)
log_tb(f'loss/{split}_ppl', math.exp(mean_loss))

if split == 'val' and (not best_val_loss or mean_loss < best_val_loss):
if not args.debug:
logger.info('Saving checkpoint for new best loss')
util.dist_save_checkpoint(model, optimizer, args.logdir, suffix='best')

best_val_loss = mean_loss


def train():
Expand All @@ -467,9 +488,8 @@ def train():
log_tb('sizes/seq_size', args.tgt_len)

mems = tuple()
train_iter = tr_iter
log_start_time = time.time()
for batch, (data, target, seq_len) in enumerate(train_iter):
for batch, (data, target, seq_len) in enumerate(tr_iter):
# TODO(y): batch is dimension 1, why?

assert seq_len == data.shape[0]
Expand Down Expand Up @@ -566,29 +586,7 @@ def train():
last_log_step = train_step

if train_step % args.eval_interval == 0:
eval_start_time = time.time()
val_loss = evaluate(va_iter)
if not best_val_loss or val_loss < best_val_loss:
if not args.debug:
logger.info('Saving checkpoint for new best loss')
util.dist_save_checkpoint(model, optimizer, args.logdir, suffix='best')

best_val_loss = val_loss

logger.info('-' * 100)
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
'| valid loss {:5.2f}'.format(train_step // args.eval_interval, train_step, (time.time() -
eval_start_time),
val_loss)
if args.dataset in ['enwik8', 'text8']:
log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
else:
log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
logger.info(log_str)
logger.info('-' * 100)
log_tb('loss/val_loss', val_loss)
log_tb('loss/val_ppl', math.exp(val_loss))

evaluate(va_iter, 'val', train_step)

# TODO: instead of stopping training, transition to constant small LR forever
if global_token_count >= args.max_tokens:
Expand Down Expand Up @@ -717,6 +715,9 @@ def main():
except StopIteration:
pass

# Eval one more time.
evaluate(va_iter, 'val', train_step=-1)

# Load the best saved model.
logger.info("Loading best checkpoint")
model_file = os.path.join(args.work_dir, 'model-best.pt')
Expand All @@ -732,18 +733,7 @@ def main():
logger.warn('no model file, using current model for loss')

# Run on test data.
test_loss = evaluate(te_iter)
logger.info('=' * 100)
if args.dataset in ['enwik8', 'text8']:
logger.info('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format(
test_loss, test_loss / math.log(2)))
else:
logger.info('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
test_loss, math.exp(test_loss)))
log_tb('loss/test_loss', test_loss)
log_tb('loss/test_ppl', math.exp(test_loss))

logger.info('=' * 100)
evaluate(te_iter, 'test', -1)


if __name__ == '__main__':
Expand Down

0 comments on commit adf2f62

Please sign in to comment.