Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[TUTORIAL] Use FixedBucketSampler in BERT tutorial for better perform…
Browse files Browse the repository at this point in the history
…ance (#506)

* Use FixedBucketSampler for better performance

* Update gradient clipping code
  • Loading branch information
Ishitori authored and eric-haibin-lin committed Jan 8, 2019
1 parent 5577a37 commit f24ecfd
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions docs/examples/sentence_embedding/bert.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,10 @@ skip validation steps.
```{.python .input}
batch_size = 32
lr = 5e-6
bert_dataloader = mx.gluon.data.DataLoader(data_train, batch_size=batch_size,
shuffle=True, last_batch='rollover')
train_sampler = nlp.data.FixedBucketSampler(lengths=[int(item[1]) for item in data_train],
batch_size=batch_size,
shuffle=True)
bert_dataloader = mx.gluon.data.DataLoader(data_train, batch_sampler=train_sampler)
trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': lr, 'epsilon': 1e-9})
Expand Down Expand Up @@ -213,11 +215,10 @@ for epoch_id in range(num_epochs):
ls.backward()
# gradient clipping
grads = [p.grad(c) for p in params for c in [ctx]]
gluon.utils.clip_global_norm(grads, grad_clip)
# parameter update
trainer.step(1)
trainer.allreduce_grads()
nlp.utils.clip_grad_global_norm(params, 1)
trainer.update(1)
step_loss += ls.asscalar()
metric.update([label], [out])
if (batch_id + 1) % (log_interval) == 0:
Expand Down

0 comments on commit f24ecfd

Please sign in to comment.