Skip to content

Commit

Permalink
Merge pull request #714 from mv1388/amp_scaler_update_grad_scaling
Browse files Browse the repository at this point in the history
AMP scaler update when grad accumulating
  • Loading branch information
mv1388 committed Jul 31, 2022
2 parents a48333e + 634610c commit a875979
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion aitoolbox/torchtrain/train_loop/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,10 @@ def _train(self, num_epochs, num_iterations, callbacks=None, grad_accumulation=1
# Optimizer zero grad
self._optimizer_zero_grad(optimizer_idx)

self.amp_scaler.update()
# Execute AMP scaler update only when optimizer is stepped and grads are zeroed out
# https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-accumulation
if (self.iteration + 1) % self.grad_accumulation == 0:
self.amp_scaler.update()

self.callbacks_handler.execute_batch_end()

Expand Down
Binary file modified dist/aitoolbox-1.6.0-py3-none-any.whl
Binary file not shown.
Binary file modified dist/aitoolbox-1.6.0.tar.gz
Binary file not shown.

0 comments on commit a875979

Please sign in to comment.