Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug/feature: Fixing mixed precision training #290

Open
wants to merge 69 commits into
base: main
Choose a base branch
from
Open
Changes from 8 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
e8b084c
Adding grad scaler
isamu-isozaki Feb 16, 2024
dc795ff
Fixed double backward
isamu-isozaki Feb 16, 2024
895c65f
zero grad
isamu-isozaki Feb 16, 2024
6049d40
Merge branch 'grad_scaler' of https://github.com/isamu-isozaki/refine…
isamu-isozaki Feb 16, 2024
7ccb836
Fixed accordign to review
isamu-isozaki Feb 21, 2024
5e539c1
Fixed styles
isamu-isozaki Feb 21, 2024
35401da
Cleaned up code
isamu-isozaki Feb 21, 2024
4ccca8d
Removed diffs
isamu-isozaki Feb 21, 2024
3c7f8f3
Disable grad scaler for bf16
isamu-isozaki Feb 21, 2024
eebaad7
Fixed mixed precision training
isamu-isozaki Feb 21, 2024
d362686
Added option for None as dtype for mixed precision training
isamu-isozaki Feb 21, 2024
6e5bb46
Remove dtype
isamu-isozaki Feb 22, 2024
e2291b9
Testing with no api change
isamu-isozaki Feb 22, 2024
a45c9a4
Fixed typo
isamu-isozaki Feb 23, 2024
65c4195
Amp config
isamu-isozaki Feb 23, 2024
ba438aa
Style fixes
isamu-isozaki Feb 23, 2024
88f002c
Attempt file fix
isamu-isozaki Feb 23, 2024
e9d857d
Remove diff
isamu-isozaki Feb 23, 2024
0803b86
Remove diffs
isamu-isozaki Feb 23, 2024
ca4ad1b
Fixed always true condition
isamu-isozaki Feb 23, 2024
a3674e4
Removed not implemented error
isamu-isozaki Feb 23, 2024
f0bc382
Remove comment
isamu-isozaki Feb 23, 2024
59cf8ab
Remove overflow grad check
isamu-isozaki Feb 23, 2024
73a4692
Removed grad scaler for non-amp training
isamu-isozaki Feb 26, 2024
ab00953
Remove comments
isamu-isozaki Feb 26, 2024
903fbfe
Cleaner amp
isamu-isozaki Feb 26, 2024
e289d64
Fix default
isamu-isozaki Feb 26, 2024
2959c30
More explicit name
isamu-isozaki Feb 26, 2024
362bd2d
Remove accelerate
isamu-isozaki Feb 26, 2024
2d64832
Linting
isamu-isozaki Feb 26, 2024
05e77b6
Adding grad scaler
isamu-isozaki Feb 16, 2024
be99c30
zero grad
isamu-isozaki Feb 16, 2024
c121dfb
Fixed double backward
isamu-isozaki Feb 16, 2024
b4da3ef
Fixed accordign to review
isamu-isozaki Feb 21, 2024
aa6f0a4
Fixed styles
isamu-isozaki Feb 21, 2024
e227f0d
Cleaned up code
isamu-isozaki Feb 21, 2024
e1f309d
Removed diffs
isamu-isozaki Feb 21, 2024
ae08009
Disable grad scaler for bf16
isamu-isozaki Feb 21, 2024
ce364e1
Fixed mixed precision training
isamu-isozaki Feb 21, 2024
5dc798a
Added option for None as dtype for mixed precision training
isamu-isozaki Feb 21, 2024
9876caf
Remove dtype
isamu-isozaki Feb 22, 2024
b547ee1
Testing with no api change
isamu-isozaki Feb 22, 2024
d78525d
Fixed typo
isamu-isozaki Feb 23, 2024
a82daf4
Amp config
isamu-isozaki Feb 23, 2024
1cf64e3
Style fixes
isamu-isozaki Feb 23, 2024
bc20875
Attempt file fix
isamu-isozaki Feb 23, 2024
691aeb5
Remove diff
isamu-isozaki Feb 23, 2024
dce3cd4
Remove diffs
isamu-isozaki Feb 23, 2024
486de2a
Fixed always true condition
isamu-isozaki Feb 23, 2024
9e7b65a
Removed not implemented error
isamu-isozaki Feb 23, 2024
5d3cd51
Remove comment
isamu-isozaki Feb 23, 2024
acba9cf
Remove overflow grad check
isamu-isozaki Feb 23, 2024
78e3a52
Removed grad scaler for non-amp training
isamu-isozaki Feb 26, 2024
03f4352
Remove comments
isamu-isozaki Feb 26, 2024
3ed2c29
Cleaner amp
isamu-isozaki Feb 26, 2024
4df3741
Fix default
isamu-isozaki Feb 26, 2024
123b4d4
More explicit name
isamu-isozaki Feb 26, 2024
72e8216
Remove accelerate
isamu-isozaki Feb 26, 2024
e206b7a
Linting
isamu-isozaki Feb 26, 2024
07185be
Merge branch 'grad_scaler' of https://github.com/isamu-isozaki/refine…
isamu-isozaki Feb 26, 2024
7c7c3ff
Merge branch 'main' into grad_scaler
isamu-isozaki Feb 26, 2024
8a9e034
Merge branch 'grad_scaler' of https://github.com/isamu-isozaki/refine…
isamu-isozaki Feb 26, 2024
a5ff578
Remove diffs
isamu-isozaki Feb 26, 2024
da188f7
Did fixes
isamu-isozaki Mar 1, 2024
49ebf48
Fixed import
isamu-isozaki Mar 8, 2024
58fcb15
Resolve merge conflicts
isamu-isozaki Apr 1, 2024
1d648a2
Delete double optimizer step
isamu-isozaki Apr 1, 2024
4efdffb
Fixed dtype
isamu-isozaki Apr 29, 2024
27cbdba
Merge branch 'main' into grad_scaler
isamu-isozaki Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions src/refiners/training_utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from loguru import logger
from torch import Tensor, device as Device, dtype as DType, nn
from torch.autograd import backward
from torch.cuda.amp import GradScaler
from torch.optim import Optimizer
from torch.optim.lr_scheduler import (
CosineAnnealingLR,
Expand Down Expand Up @@ -183,6 +184,12 @@ def dtype(self) -> DType:
logger.info(f"Using dtype: {dtype}")
return dtype

@cached_property
def scaler(self) -> GradScaler | None:
if self.config.training.dtype == "float32":
return None
return GradScaler()

@property
def learnable_parameters(self) -> list[nn.Parameter]:
"""Returns a list of learnable parameters in all models"""
Expand Down Expand Up @@ -341,15 +348,34 @@ def compute_loss(self, batch: Batch) -> Tensor:
def compute_evaluation(self) -> None:
pass

def backward_step(self, scaled_loss: Tensor) -> None:
if self.scaler is None:
backward(tensors=scaled_loss)
return
self.scaler.scale(scaled_loss).backward() # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: as discussed offline, a follow up PR will be needed to properly unscale the gradients before gradient clipping (which happens in between the backward step and optimizer step, as per https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients)

But also, what about gradient accumulation? https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients


# logic from accelerator
def optimizer_step(self) -> None:
if self.scaler is None:
self.optimizer.step()
return
scale_before = self.scaler.get_scale() # type: ignore
self.scaler.step(self.optimizer) # type: ignore
self.scaler.update() # type: ignore
scale_after = self.scaler.get_scale() # type: ignore
# If we reduced the loss scale, it means the optimizer step was skipped because of gradient overflow.
if scale_after < scale_before:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what this custom logic from accelerate does exactly, let's remove it for now (I mean scale_before, scale_after)

logger.info("Overflow in optimizer caused optimizer to skip")

def backward(self) -> None:
"""Backward pass on the loss."""
self._call_callbacks(event_name="on_backward_begin")
scaled_loss = self.loss / self.clock.num_step_per_iteration
backward(tensors=scaled_loss)
self.backward_step(scaled_loss)
self._call_callbacks(event_name="on_backward_end")
if self.clock.is_optimizer_step:
self._call_callbacks(event_name="on_optimizer_step_begin")
self.optimizer.step()
self.optimizer_step()
self.optimizer.zero_grad()
self._call_callbacks(event_name="on_optimizer_step_end")
if self.clock.is_lr_scheduler_step:
Expand Down