Skip to content

Commit

Permalink
Merge pull request #136 from hirofumi0810/transformer
Browse files Browse the repository at this point in the history
Fix bug in mixed precision training
  • Loading branch information
hirofumi0810 committed Sep 10, 2020
2 parents ff412a2 + e1e768e commit 66a8865
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 57 deletions.
61 changes: 31 additions & 30 deletions neural_sp/bin/asr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,17 @@ def main():
logger.info('Overwrite %s' % n)

# Set optimizer
resume_epoch = 0
if args.resume:
resume_epoch = int(args.resume.split('-')[-1])
optimizer = set_optimizer(model, 'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer,
args.lr, args.weight_decay)
else:
resume_epoch = 0
optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay)

# Wrap optimizer by learning rate scheduler
is_transformer = 'former' in args.enc_type or 'former' in args.dec_type
optimizer = LRScheduler(optimizer, args.lr,
scheduler = LRScheduler(optimizer, args.lr,
decay_type=args.lr_decay_type,
decay_start_epoch=args.lr_decay_start_epoch,
decay_rate=args.lr_decay_rate,
Expand All @@ -196,11 +196,11 @@ def main():

if args.resume:
# Restore the last saved model
load_checkpoint(args.resume, model, optimizer)
load_checkpoint(args.resume, model, scheduler)

# Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
if resume_epoch == args.convert_to_sgd_epoch:
optimizer.convert_to_sgd(model, args.lr, args.weight_decay,
scheduler.convert_to_sgd(model, args.lr, args.weight_decay,
decay_type='always', decay_rate=0.5)

# Load the teacher ASR model
Expand Down Expand Up @@ -235,10 +235,10 @@ def main():
benchmark=not is_transformer and args.cudnn_benchmark)
model.cuda()

# Mix precision training setting
# Mixed precision training setting
if use_apex:
from apex import amp
model, optimizer.optimizer = amp.initialize(model, optimizer.optimizer,
model, scheduler.optimizer = amp.initialize(model, scheduler.optimizer,
opt_level=args.train_dtype)
from neural_sp.models.seq2seq.decoders.ctc import CTC
amp.register_float_function(CTC, "loss_fn")
Expand Down Expand Up @@ -288,7 +288,7 @@ def main():
start_time_epoch = time.time()
start_time_step = time.time()
accum_n_steps = 0
n_steps = optimizer.n_steps * accum_grad_n_steps
n_steps = scheduler.n_steps * accum_grad_n_steps
epoch_detail_prev = 0
for ep in range(resume_epoch, args.n_epochs):
pbar_epoch = tqdm(total=len(train_set))
Expand All @@ -303,31 +303,32 @@ def main():

# Change mini-batch depending on task
if accum_n_steps == 1:
loss_train = 0 # moving average over gradient accumulation
loss_train = 0 # average over gradient accumulation
for task in tasks:
loss, observation = model(batch_train, task=task,
teacher=teacher, teacher_lm=teacher_lm)
loss = loss / accum_n_steps
reporter.add(observation)
if use_apex:
with amp.scale_loss(loss, optimizer.optimizer) as scaled_loss:
with amp.scale_loss(loss, scheduler.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
loss.detach() # Trancate the graph
loss_train = (loss_train * (accum_n_steps - 1) + loss.item()) / accum_n_steps
if accum_n_steps >= accum_grad_n_steps or is_new_epoch:
if args.clip_grad_norm > 0:
total_norm = torch.nn.utils.clip_grad_norm_(
model.module.parameters(), args.clip_grad_norm)
reporter.add_tensorboard_scalar('total_norm', total_norm)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
scheduler.zero_grad()
accum_n_steps = 0
# NOTE: parameters are forcibly updated at the end of every epoch
loss_train += loss.item()
del loss

pbar_epoch.update(len(batch_train['utt_ids']))
reporter.add_tensorboard_scalar('learning_rate', optimizer.lr)
reporter.add_tensorboard_scalar('learning_rate', scheduler.lr)
# NOTE: loss/acc/ppl are already added in the model
reporter.step()
n_steps += 1
Expand All @@ -352,9 +353,9 @@ def main():
xlen = max(len(x) for x in batch_train['ys'])
ylen = max(len(y) for y in batch_train['ys_sub1'])
logger.info("step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.7f/bs:%d/xlen:%d/ylen:%d (%.2f min)" %
(n_steps, optimizer.n_epochs + train_set.epoch_detail,
(n_steps, scheduler.n_epochs + train_set.epoch_detail,
loss_train, loss_dev,
optimizer.lr, len(batch_train['utt_ids']),
scheduler.lr, len(batch_train['utt_ids']),
xlen, ylen, duration_step / 60))
start_time_step = time.time()

Expand All @@ -371,7 +372,7 @@ def main():
evaluate([model.module], dev_set, recog_params, args,
int(train_set.epoch_detail * 10) / 10, logger)
# Save the model
optimizer.save_checkpoint(
scheduler.save_checkpoint(
model, save_path, remove_old=False, amp=amp,
epoch_detail=train_set.epoch_detail)
epoch_detail_prev = train_set.epoch_detail
Expand All @@ -382,47 +383,47 @@ def main():
# Save checkpoint and evaluate model per epoch
duration_epoch = time.time() - start_time_epoch
logger.info('========== EPOCH:%d (%.2f min) ==========' %
(optimizer.n_epochs + 1, duration_epoch / 60))
(scheduler.n_epochs + 1, duration_epoch / 60))

if optimizer.n_epochs + 1 < args.eval_start_epoch:
optimizer.epoch() # lr decay
if scheduler.n_epochs + 1 < args.eval_start_epoch:
scheduler.epoch() # lr decay
reporter.epoch() # plot

# Save the model
optimizer.save_checkpoint(
scheduler.save_checkpoint(
model, save_path, remove_old=not is_transformer and args.remove_old_checkpoints, amp=amp)
else:
start_time_eval = time.time()
# dev
metric_dev = evaluate([model.module], dev_set, recog_params, args,
optimizer.n_epochs + 1, logger)
optimizer.epoch(metric_dev) # lr decay
scheduler.n_epochs + 1, logger)
scheduler.epoch(metric_dev) # lr decay
reporter.epoch(metric_dev, name=args.metric) # plot

if optimizer.is_topk or is_transformer:
if scheduler.is_topk or is_transformer:
# Save the model
optimizer.save_checkpoint(
scheduler.save_checkpoint(
model, save_path, remove_old=not is_transformer and args.remove_old_checkpoints, amp=amp)

# test
if optimizer.is_topk:
if scheduler.is_topk:
for eval_set in eval_sets:
evaluate([model.module], eval_set, recog_params, args,
optimizer.n_epochs, logger)
scheduler.n_epochs, logger)

duration_eval = time.time() - start_time_eval
logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

# Early stopping
if optimizer.is_early_stop:
if scheduler.is_early_stop:
break

# Convert to fine-tuning stage
if optimizer.n_epochs == args.convert_to_sgd_epoch:
optimizer.convert_to_sgd(model, args.lr, args.weight_decay,
if scheduler.n_epochs == args.convert_to_sgd_epoch:
scheduler.convert_to_sgd(model, args.lr, args.weight_decay,
decay_type='always', decay_rate=0.5)

if optimizer.n_epochs >= args.n_epochs:
if scheduler.n_epochs >= args.n_epochs:
break
# if args.ss_prob > 0:
# model.module.scheduled_sampling_trigger()
Expand Down
55 changes: 28 additions & 27 deletions neural_sp/bin/lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,17 @@ def main():
logger.info(model)

# Set optimizer
resume_epoch = 0
if args.resume:
resume_epoch = int(args.resume.split('-')[-1])
optimizer = set_optimizer(model, 'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer,
args.lr, args.weight_decay)
else:
resume_epoch = 0
optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay)

# Wrap optimizer by learning rate scheduler
is_transformer = args.lm_type in ['transformer', 'transformer_xl']
optimizer = LRScheduler(optimizer, args.lr,
scheduler = LRScheduler(optimizer, args.lr,
decay_type=args.lr_decay_type,
decay_start_epoch=args.lr_decay_start_epoch,
decay_rate=args.lr_decay_rate,
Expand All @@ -160,11 +160,11 @@ def main():

if args.resume:
# Restore the last saved model
load_checkpoint(args.resume, model, optimizer)
load_checkpoint(args.resume, model, scheduler)

# Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
if resume_epoch == args.convert_to_sgd_epoch:
optimizer.convert_to_sgd(model, args.lr, args.weight_decay,
scheduler.convert_to_sgd(model, args.lr, args.weight_decay,
decay_type='always', decay_rate=0.5)

# GPU setting
Expand All @@ -178,7 +178,7 @@ def main():
# Mix precision training setting
if use_apex:
from apex import amp
model, optimizer.optimizer = amp.initialize(model, optimizer.optimizer,
model, scheduler.optimizer = amp.initialize(model, scheduler.optimizer,
opt_level=args.train_dtype)
amp.init()
if args.resume:
Expand All @@ -201,7 +201,7 @@ def main():
start_time_epoch = time.time()
start_time_step = time.time()
accum_n_steps = 0
n_steps = optimizer.n_steps * accum_grad_n_steps
n_steps = scheduler.n_steps * accum_grad_n_steps
for ep in range(resume_epoch, args.n_epochs):
pbar_epoch = tqdm(total=len(train_set))

Expand All @@ -212,28 +212,29 @@ def main():
if accum_n_steps == 1:
loss_train = 0 # moving average over gradient accumulation
loss, hidden, observation = model(ys_train, state=hidden)
loss = loss / accum_n_steps
reporter.add(observation)
if use_apex:
with amp.scale_loss(loss, optimizer.optimizer) as scaled_loss:
with amp.scale_loss(loss, scheduler.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
loss.detach() # Trancate the graph
loss_train = (loss_train * (accum_n_steps - 1) + loss.item()) / accum_n_steps
if accum_n_steps >= accum_grad_n_steps or is_new_epoch:
if args.clip_grad_norm > 0:
total_norm = torch.nn.utils.clip_grad_norm_(
model.module.parameters(), args.clip_grad_norm)
reporter.add_tensorboard_scalar('total_norm', total_norm)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
scheduler.zero_grad()
accum_n_steps = 0
# NOTE: parameters are forcibly updated at the end of every epoch
loss_train += loss.item()
del loss
hidden = model.module.repackage_state(hidden)

pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1))
reporter.add_tensorboard_scalar('learning_rate', optimizer.lr)
reporter.add_tensorboard_scalar('learning_rate', scheduler.lr)
# NOTE: loss/acc/ppl are already added in the model
reporter.step()
n_steps += 1
Expand All @@ -250,9 +251,9 @@ def main():

duration_step = time.time() - start_time_step
logger.info("step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)" %
(n_steps, optimizer.n_epochs + train_set.epoch_detail,
(n_steps, scheduler.n_epochs + train_set.epoch_detail,
loss_train, loss_dev,
optimizer.lr, ys_train.shape[0], duration_step / 60))
scheduler.lr, ys_train.shape[0], duration_step / 60))
start_time_step = time.time()

# Save fugures of loss and accuracy
Expand All @@ -266,14 +267,14 @@ def main():
# Save checkpoint and evaluate model per epoch
duration_epoch = time.time() - start_time_epoch
logger.info('========== EPOCH:%d (%.2f min) ==========' %
(optimizer.n_epochs + 1, duration_epoch / 60))
(scheduler.n_epochs + 1, duration_epoch / 60))

if optimizer.n_epochs + 1 < args.eval_start_epoch:
optimizer.epoch() # lr decay
if scheduler.n_epochs + 1 < args.eval_start_epoch:
scheduler.epoch() # lr decay
reporter.epoch() # plot

# Save the model
optimizer.save_checkpoint(
scheduler.save_checkpoint(
model, save_path, remove_old=not is_transformer, amp=amp)
else:
start_time_eval = time.time()
Expand All @@ -282,14 +283,14 @@ def main():
ppl_dev, _ = eval_ppl([model.module], dev_set,
batch_size=1, bptt=args.bptt)
model.module.reset_length(args.bptt)
optimizer.epoch(ppl_dev) # lr decay
scheduler.epoch(ppl_dev) # lr decay
reporter.epoch(ppl_dev, name='perplexity') # plot
logger.info('PPL (%s, ep:%d): %.2f' %
(dev_set.set, optimizer.n_epochs, ppl_dev))
(dev_set.set, scheduler.n_epochs, ppl_dev))

if optimizer.is_topk or is_transformer:
if scheduler.is_topk or is_transformer:
# Save the model
optimizer.save_checkpoint(
scheduler.save_checkpoint(
model, save_path, remove_old=not is_transformer, amp=amp)

# test
Expand All @@ -300,25 +301,25 @@ def main():
batch_size=1, bptt=args.bptt)
model.module.reset_length(args.bptt)
logger.info('PPL (%s, ep:%d): %.2f' %
(eval_set.set, optimizer.n_epochs, ppl_test))
(eval_set.set, scheduler.n_epochs, ppl_test))
ppl_test_avg += ppl_test
if len(eval_sets) > 0:
logger.info('PPL (avg., ep:%d): %.2f' %
(optimizer.n_epochs, ppl_test_avg / len(eval_sets)))
(scheduler.n_epochs, ppl_test_avg / len(eval_sets)))

duration_eval = time.time() - start_time_eval
logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

# Early stopping
if optimizer.is_early_stop:
if scheduler.is_early_stop:
break

# Convert to fine-tuning stage
if optimizer.n_epochs == args.convert_to_sgd_epoch:
optimizer.convert_to_sgd(model, args.lr, args.weight_decay,
if scheduler.n_epochs == args.convert_to_sgd_epoch:
scheduler.convert_to_sgd(model, args.lr, args.weight_decay,
decay_type='always', decay_rate=0.5)

if optimizer.n_epochs >= args.n_epochs:
if scheduler.n_epochs >= args.n_epochs:
break

start_time_step = time.time()
Expand Down

0 comments on commit 66a8865

Please sign in to comment.