-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug/feature: Fixing mixed precision training #290
Open
isamu-isozaki
wants to merge
69
commits into
finegrain-ai:main
Choose a base branch
from
isamu-isozaki:grad_scaler
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 8 commits
Commits
Show all changes
69 commits
Select commit
Hold shift + click to select a range
e8b084c
Adding grad scaler
isamu-isozaki dc795ff
Fixed double backward
isamu-isozaki 895c65f
zero grad
isamu-isozaki 6049d40
Merge branch 'grad_scaler' of https://github.com/isamu-isozaki/refine…
isamu-isozaki 7ccb836
Fixed accordign to review
isamu-isozaki 5e539c1
Fixed styles
isamu-isozaki 35401da
Cleaned up code
isamu-isozaki 4ccca8d
Removed diffs
isamu-isozaki 3c7f8f3
Disable grad scaler for bf16
isamu-isozaki eebaad7
Fixed mixed precision training
isamu-isozaki d362686
Added option for None as dtype for mixed precision training
isamu-isozaki 6e5bb46
Remove dtype
isamu-isozaki e2291b9
Testing with no api change
isamu-isozaki a45c9a4
Fixed typo
isamu-isozaki 65c4195
Amp config
isamu-isozaki ba438aa
Style fixes
isamu-isozaki 88f002c
Attempt file fix
isamu-isozaki e9d857d
Remove diff
isamu-isozaki 0803b86
Remove diffs
isamu-isozaki ca4ad1b
Fixed always true condition
isamu-isozaki a3674e4
Removed not implemented error
isamu-isozaki f0bc382
Remove comment
isamu-isozaki 59cf8ab
Remove overflow grad check
isamu-isozaki 73a4692
Removed grad scaler for non-amp training
isamu-isozaki ab00953
Remove comments
isamu-isozaki 903fbfe
Cleaner amp
isamu-isozaki e289d64
Fix default
isamu-isozaki 2959c30
More explicit name
isamu-isozaki 362bd2d
Remove accelerate
isamu-isozaki 2d64832
Linting
isamu-isozaki 05e77b6
Adding grad scaler
isamu-isozaki be99c30
zero grad
isamu-isozaki c121dfb
Fixed double backward
isamu-isozaki b4da3ef
Fixed accordign to review
isamu-isozaki aa6f0a4
Fixed styles
isamu-isozaki e227f0d
Cleaned up code
isamu-isozaki e1f309d
Removed diffs
isamu-isozaki ae08009
Disable grad scaler for bf16
isamu-isozaki ce364e1
Fixed mixed precision training
isamu-isozaki 5dc798a
Added option for None as dtype for mixed precision training
isamu-isozaki 9876caf
Remove dtype
isamu-isozaki b547ee1
Testing with no api change
isamu-isozaki d78525d
Fixed typo
isamu-isozaki a82daf4
Amp config
isamu-isozaki 1cf64e3
Style fixes
isamu-isozaki bc20875
Attempt file fix
isamu-isozaki 691aeb5
Remove diff
isamu-isozaki dce3cd4
Remove diffs
isamu-isozaki 486de2a
Fixed always true condition
isamu-isozaki 9e7b65a
Removed not implemented error
isamu-isozaki 5d3cd51
Remove comment
isamu-isozaki acba9cf
Remove overflow grad check
isamu-isozaki 78e3a52
Removed grad scaler for non-amp training
isamu-isozaki 03f4352
Remove comments
isamu-isozaki 3ed2c29
Cleaner amp
isamu-isozaki 4df3741
Fix default
isamu-isozaki 123b4d4
More explicit name
isamu-isozaki 72e8216
Remove accelerate
isamu-isozaki e206b7a
Linting
isamu-isozaki 07185be
Merge branch 'grad_scaler' of https://github.com/isamu-isozaki/refine…
isamu-isozaki 7c7c3ff
Merge branch 'main' into grad_scaler
isamu-isozaki 8a9e034
Merge branch 'grad_scaler' of https://github.com/isamu-isozaki/refine…
isamu-isozaki a5ff578
Remove diffs
isamu-isozaki da188f7
Did fixes
isamu-isozaki 49ebf48
Fixed import
isamu-isozaki 58fcb15
Resolve merge conflicts
isamu-isozaki 1d648a2
Delete double optimizer step
isamu-isozaki 4efdffb
Fixed dtype
isamu-isozaki 27cbdba
Merge branch 'main' into grad_scaler
isamu-isozaki File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
from loguru import logger | ||
from torch import Tensor, device as Device, dtype as DType, nn | ||
from torch.autograd import backward | ||
from torch.cuda.amp import GradScaler | ||
from torch.optim import Optimizer | ||
from torch.optim.lr_scheduler import ( | ||
CosineAnnealingLR, | ||
|
@@ -183,6 +184,12 @@ def dtype(self) -> DType: | |
logger.info(f"Using dtype: {dtype}") | ||
return dtype | ||
|
||
@cached_property | ||
def scaler(self) -> GradScaler | None: | ||
if self.config.training.dtype == "float32": | ||
return None | ||
return GradScaler() | ||
|
||
@property | ||
def learnable_parameters(self) -> list[nn.Parameter]: | ||
"""Returns a list of learnable parameters in all models""" | ||
|
@@ -341,15 +348,34 @@ def compute_loss(self, batch: Batch) -> Tensor: | |
def compute_evaluation(self) -> None: | ||
pass | ||
|
||
def backward_step(self, scaled_loss: Tensor) -> None: | ||
if self.scaler is None: | ||
backward(tensors=scaled_loss) | ||
return | ||
self.scaler.scale(scaled_loss).backward() # type: ignore | ||
|
||
# logic from accelerator | ||
def optimizer_step(self) -> None: | ||
if self.scaler is None: | ||
self.optimizer.step() | ||
return | ||
scale_before = self.scaler.get_scale() # type: ignore | ||
self.scaler.step(self.optimizer) # type: ignore | ||
self.scaler.update() # type: ignore | ||
scale_after = self.scaler.get_scale() # type: ignore | ||
# If we reduced the loss scale, it means the optimizer step was skipped because of gradient overflow. | ||
if scale_after < scale_before: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what this custom logic from accelerate does exactly, let's remove it for now (I mean scale_before, scale_after) |
||
logger.info("Overflow in optimizer caused optimizer to skip") | ||
|
||
def backward(self) -> None: | ||
"""Backward pass on the loss.""" | ||
self._call_callbacks(event_name="on_backward_begin") | ||
scaled_loss = self.loss / self.clock.num_step_per_iteration | ||
backward(tensors=scaled_loss) | ||
self.backward_step(scaled_loss) | ||
self._call_callbacks(event_name="on_backward_end") | ||
if self.clock.is_optimizer_step: | ||
self._call_callbacks(event_name="on_optimizer_step_begin") | ||
self.optimizer.step() | ||
self.optimizer_step() | ||
self.optimizer.zero_grad() | ||
self._call_callbacks(event_name="on_optimizer_step_end") | ||
if self.clock.is_lr_scheduler_step: | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: as discussed offline, a follow up PR will be needed to properly unscale the gradients before gradient clipping (which happens in between the backward step and optimizer step, as per https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients)
But also, what about gradient accumulation? https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients