-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug/feature: Fixing mixed precision training #290
base: main
Are you sure you want to change the base?
Changes from 63 commits
e8b084c
dc795ff
895c65f
6049d40
7ccb836
5e539c1
35401da
4ccca8d
3c7f8f3
eebaad7
d362686
6e5bb46
e2291b9
a45c9a4
65c4195
ba438aa
88f002c
e9d857d
0803b86
ca4ad1b
a3674e4
f0bc382
59cf8ab
73a4692
ab00953
903fbfe
e289d64
2959c30
362bd2d
2d64832
05e77b6
be99c30
c121dfb
b4da3ef
aa6f0a4
e227f0d
e1f309d
ae08009
ce364e1
5dc798a
9876caf
b547ee1
d78525d
a82daf4
1cf64e3
bc20875
691aeb5
dce3cd4
486de2a
9e7b65a
5d3cd51
acba9cf
78e3a52
03f4352
3ed2c29
4df3741
123b4d4
72e8216
e206b7a
07185be
7c7c3ff
8a9e034
a5ff578
da188f7
49ebf48
58fcb15
1d648a2
4efdffb
27cbdba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,8 +5,9 @@ | |
|
||
import torch | ||
from loguru import logger | ||
from torch import Tensor, device as Device, dtype as DType, nn | ||
from torch import Tensor, device as Device, dtype as DType, float16, float32, nn | ||
from torch.autograd import backward | ||
from torch.cuda.automatic_mixed_precision import GradScaler, autocast | ||
from torch.optim import Optimizer | ||
from torch.optim.lr_scheduler import ( | ||
CosineAnnealingLR, | ||
|
@@ -105,6 +106,9 @@ def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfigT) -> fl.Module: | |
if config.requires_grad is not None: | ||
model.requires_grad_(requires_grad=config.requires_grad) | ||
learnable_parameters = [param for param in model.parameters() if param.requires_grad] | ||
if self.config.training.automatic_mixed_precision: | ||
for learnable_parameter in learnable_parameters: | ||
learnable_parameter.to(dtype=float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is because with AMP some layers/ops behave better (range) with float32, right? Doesn't this deserve a comment? |
||
self.models[name] = ModelItem( | ||
name=name, config=config, model=model, learnable_parameters=learnable_parameters | ||
) | ||
|
@@ -183,6 +187,12 @@ def dtype(self) -> DType: | |
logger.info(f"Using dtype: {dtype}") | ||
return dtype | ||
|
||
@cached_property | ||
def scaler(self) -> GradScaler | None: | ||
if self.dtype != float16 or not self.config.training.automatic_mixed_precision: | ||
return None | ||
return GradScaler() | ||
|
||
@property | ||
def learnable_parameters(self) -> list[nn.Parameter]: | ||
"""Returns a list of learnable parameters in all models""" | ||
|
@@ -341,15 +351,28 @@ def compute_loss(self, batch: Batch) -> Tensor: | |
def compute_evaluation(self) -> None: | ||
pass | ||
|
||
def backward_step(self, scaled_loss: Tensor) -> None: | ||
if self.scaler is None: | ||
backward(tensors=scaled_loss) | ||
return | ||
self.scaler.scale(scaled_loss).backward() # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: as discussed offline, a follow up PR will be needed to properly unscale the gradients before gradient clipping (which happens in between the backward step and optimizer step, as per https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients) But also, what about gradient accumulation? https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients |
||
|
||
def optimizer_step(self) -> None: | ||
if self.scaler is None: | ||
self.optimizer.step() | ||
return | ||
self.scaler.step(self.optimizer) # type: ignore | ||
self.scaler.update() # type: ignore | ||
|
||
def backward(self) -> None: | ||
"""Backward pass on the loss.""" | ||
self._call_callbacks(event_name="on_backward_begin") | ||
scaled_loss = self.loss / self.clock.num_step_per_iteration | ||
backward(tensors=scaled_loss) | ||
self.backward_step(scaled_loss) | ||
self._call_callbacks(event_name="on_backward_end") | ||
if self.clock.is_optimizer_step: | ||
self._call_callbacks(event_name="on_optimizer_step_begin") | ||
self.optimizer.step() | ||
self.optimizer_step() | ||
self.optimizer.zero_grad() | ||
self._call_callbacks(event_name="on_optimizer_step_end") | ||
if self.clock.is_lr_scheduler_step: | ||
|
@@ -362,7 +385,8 @@ def backward(self) -> None: | |
def step(self, batch: Batch) -> None: | ||
"""Perform a single training step.""" | ||
self._call_callbacks(event_name="on_compute_loss_begin") | ||
loss = self.compute_loss(batch=batch) | ||
with autocast(dtype=self.dtype, enabled=self.config.training.automatic_mixed_precision): | ||
loss = self.compute_loss(batch=batch) | ||
self.loss = loss | ||
self._call_callbacks(event_name="on_compute_loss_end") | ||
self.backward() | ||
|
@@ -403,7 +427,8 @@ def evaluate(self) -> None: | |
"""Evaluate the model.""" | ||
self.set_models_to_mode(mode="eval") | ||
self._call_callbacks(event_name="on_evaluate_begin") | ||
self.compute_evaluation() | ||
with autocast(dtype=self.dtype, enabled=self.config.training.automatic_mixed_precision): | ||
self.compute_evaluation() | ||
self._call_callbacks(event_name="on_evaluate_end") | ||
self.set_models_to_mode(mode="train") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since AMP is on by default (opt-out), and given that this
dtype
config is used jointly given theautocast(dtype=self.dtype, enabled=self.config.training.automatic_mixed_precision)
call, any reason not to pick a sane default here? Namely: eitherfloat16
orbfloat16
. Might also deserve a comment.(accelerate even offers to "Choose from ‘no’,‘fp16’,‘bf16 or ‘fp8’.")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could set it to False by default, and even though float32 + amp doesn't do much, I don't see this as a big issue; I would rather keep the API straightforward.