Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug/feature: Fixing mixed precision training #290

Open
wants to merge 69 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
e8b084c
Adding grad scaler
isamu-isozaki Feb 16, 2024
dc795ff
Fixed double backward
isamu-isozaki Feb 16, 2024
895c65f
zero grad
isamu-isozaki Feb 16, 2024
6049d40
Merge branch 'grad_scaler' of https://github.com/isamu-isozaki/refine…
isamu-isozaki Feb 16, 2024
7ccb836
Fixed accordign to review
isamu-isozaki Feb 21, 2024
5e539c1
Fixed styles
isamu-isozaki Feb 21, 2024
35401da
Cleaned up code
isamu-isozaki Feb 21, 2024
4ccca8d
Removed diffs
isamu-isozaki Feb 21, 2024
3c7f8f3
Disable grad scaler for bf16
isamu-isozaki Feb 21, 2024
eebaad7
Fixed mixed precision training
isamu-isozaki Feb 21, 2024
d362686
Added option for None as dtype for mixed precision training
isamu-isozaki Feb 21, 2024
6e5bb46
Remove dtype
isamu-isozaki Feb 22, 2024
e2291b9
Testing with no api change
isamu-isozaki Feb 22, 2024
a45c9a4
Fixed typo
isamu-isozaki Feb 23, 2024
65c4195
Amp config
isamu-isozaki Feb 23, 2024
ba438aa
Style fixes
isamu-isozaki Feb 23, 2024
88f002c
Attempt file fix
isamu-isozaki Feb 23, 2024
e9d857d
Remove diff
isamu-isozaki Feb 23, 2024
0803b86
Remove diffs
isamu-isozaki Feb 23, 2024
ca4ad1b
Fixed always true condition
isamu-isozaki Feb 23, 2024
a3674e4
Removed not implemented error
isamu-isozaki Feb 23, 2024
f0bc382
Remove comment
isamu-isozaki Feb 23, 2024
59cf8ab
Remove overflow grad check
isamu-isozaki Feb 23, 2024
73a4692
Removed grad scaler for non-amp training
isamu-isozaki Feb 26, 2024
ab00953
Remove comments
isamu-isozaki Feb 26, 2024
903fbfe
Cleaner amp
isamu-isozaki Feb 26, 2024
e289d64
Fix default
isamu-isozaki Feb 26, 2024
2959c30
More explicit name
isamu-isozaki Feb 26, 2024
362bd2d
Remove accelerate
isamu-isozaki Feb 26, 2024
2d64832
Linting
isamu-isozaki Feb 26, 2024
05e77b6
Adding grad scaler
isamu-isozaki Feb 16, 2024
be99c30
zero grad
isamu-isozaki Feb 16, 2024
c121dfb
Fixed double backward
isamu-isozaki Feb 16, 2024
b4da3ef
Fixed accordign to review
isamu-isozaki Feb 21, 2024
aa6f0a4
Fixed styles
isamu-isozaki Feb 21, 2024
e227f0d
Cleaned up code
isamu-isozaki Feb 21, 2024
e1f309d
Removed diffs
isamu-isozaki Feb 21, 2024
ae08009
Disable grad scaler for bf16
isamu-isozaki Feb 21, 2024
ce364e1
Fixed mixed precision training
isamu-isozaki Feb 21, 2024
5dc798a
Added option for None as dtype for mixed precision training
isamu-isozaki Feb 21, 2024
9876caf
Remove dtype
isamu-isozaki Feb 22, 2024
b547ee1
Testing with no api change
isamu-isozaki Feb 22, 2024
d78525d
Fixed typo
isamu-isozaki Feb 23, 2024
a82daf4
Amp config
isamu-isozaki Feb 23, 2024
1cf64e3
Style fixes
isamu-isozaki Feb 23, 2024
bc20875
Attempt file fix
isamu-isozaki Feb 23, 2024
691aeb5
Remove diff
isamu-isozaki Feb 23, 2024
dce3cd4
Remove diffs
isamu-isozaki Feb 23, 2024
486de2a
Fixed always true condition
isamu-isozaki Feb 23, 2024
9e7b65a
Removed not implemented error
isamu-isozaki Feb 23, 2024
5d3cd51
Remove comment
isamu-isozaki Feb 23, 2024
acba9cf
Remove overflow grad check
isamu-isozaki Feb 23, 2024
78e3a52
Removed grad scaler for non-amp training
isamu-isozaki Feb 26, 2024
03f4352
Remove comments
isamu-isozaki Feb 26, 2024
3ed2c29
Cleaner amp
isamu-isozaki Feb 26, 2024
4df3741
Fix default
isamu-isozaki Feb 26, 2024
123b4d4
More explicit name
isamu-isozaki Feb 26, 2024
72e8216
Remove accelerate
isamu-isozaki Feb 26, 2024
e206b7a
Linting
isamu-isozaki Feb 26, 2024
07185be
Merge branch 'grad_scaler' of https://github.com/isamu-isozaki/refine…
isamu-isozaki Feb 26, 2024
7c7c3ff
Merge branch 'main' into grad_scaler
isamu-isozaki Feb 26, 2024
8a9e034
Merge branch 'grad_scaler' of https://github.com/isamu-isozaki/refine…
isamu-isozaki Feb 26, 2024
a5ff578
Remove diffs
isamu-isozaki Feb 26, 2024
da188f7
Did fixes
isamu-isozaki Mar 1, 2024
49ebf48
Fixed import
isamu-isozaki Mar 8, 2024
58fcb15
Resolve merge conflicts
isamu-isozaki Apr 1, 2024
1d648a2
Delete double optimizer step
isamu-isozaki Apr 1, 2024
4efdffb
Fixed dtype
isamu-isozaki Apr 29, 2024
27cbdba
Merge branch 'main' into grad_scaler
isamu-isozaki Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/refiners/training_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

class TrainingConfig(BaseModel):
device: str = "cpu"
automatic_mixed_precision: bool = (
True # Enables automatic mixed precision which allows float32 gradients while working with lower precision.
)
dtype: str = "float32"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since AMP is on by default (opt-out), and given that this dtype config is used jointly given the autocast(dtype=self.dtype, enabled=self.config.training.automatic_mixed_precision) call, any reason not to pick a sane default here? Namely: either float16 or bfloat16. Might also deserve a comment.

(accelerate even offers to "Choose from ‘no’,‘fp16’,‘bf16 or ‘fp8’.")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could set it to False by default, and even though float32 + amp doesn't do much, I don't see this as a big issue; I would rather keep the API straightforward.

duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
seed: int = 0
Expand Down
35 changes: 30 additions & 5 deletions src/refiners/training_utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import torch
from loguru import logger
from torch import Tensor, device as Device, dtype as DType, nn
from torch import Tensor, device as Device, dtype as DType, float16, float32, nn
from torch.autograd import backward
from torch.cuda.automatic_mixed_precision import GradScaler, autocast
from torch.optim import Optimizer
from torch.optim.lr_scheduler import (
CosineAnnealingLR,
Expand Down Expand Up @@ -105,6 +106,9 @@ def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfigT) -> fl.Module:
if config.requires_grad is not None:
model.requires_grad_(requires_grad=config.requires_grad)
learnable_parameters = [param for param in model.parameters() if param.requires_grad]
if self.config.training.automatic_mixed_precision:
for learnable_parameter in learnable_parameters:
learnable_parameter.to(dtype=float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is because with AMP some layers/ops behave better (range) with float32, right? Doesn't this deserve a comment?

self.models[name] = ModelItem(
name=name, config=config, model=model, learnable_parameters=learnable_parameters
)
Expand Down Expand Up @@ -183,6 +187,12 @@ def dtype(self) -> DType:
logger.info(f"Using dtype: {dtype}")
return dtype

@cached_property
def scaler(self) -> GradScaler | None:
if self.dtype != float16 or not self.config.training.automatic_mixed_precision:
return None
return GradScaler()

@property
def learnable_parameters(self) -> list[nn.Parameter]:
"""Returns a list of learnable parameters in all models"""
Expand Down Expand Up @@ -341,15 +351,28 @@ def compute_loss(self, batch: Batch) -> Tensor:
def compute_evaluation(self) -> None:
pass

def backward_step(self, scaled_loss: Tensor) -> None:
if self.scaler is None:
backward(tensors=scaled_loss)
return
self.scaler.scale(scaled_loss).backward() # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: as discussed offline, a follow up PR will be needed to properly unscale the gradients before gradient clipping (which happens in between the backward step and optimizer step, as per https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients)

But also, what about gradient accumulation? https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients


def optimizer_step(self) -> None:
if self.scaler is None:
self.optimizer.step()
return
self.scaler.step(self.optimizer) # type: ignore
self.scaler.update() # type: ignore

def backward(self) -> None:
"""Backward pass on the loss."""
self._call_callbacks(event_name="on_backward_begin")
scaled_loss = self.loss / self.clock.num_step_per_iteration
backward(tensors=scaled_loss)
self.backward_step(scaled_loss)
self._call_callbacks(event_name="on_backward_end")
if self.clock.is_optimizer_step:
self._call_callbacks(event_name="on_optimizer_step_begin")
self.optimizer.step()
self.optimizer_step()
self.optimizer.zero_grad()
self._call_callbacks(event_name="on_optimizer_step_end")
if self.clock.is_lr_scheduler_step:
Expand All @@ -362,7 +385,8 @@ def backward(self) -> None:
def step(self, batch: Batch) -> None:
"""Perform a single training step."""
self._call_callbacks(event_name="on_compute_loss_begin")
loss = self.compute_loss(batch=batch)
with autocast(dtype=self.dtype, enabled=self.config.training.automatic_mixed_precision):
loss = self.compute_loss(batch=batch)
self.loss = loss
self._call_callbacks(event_name="on_compute_loss_end")
self.backward()
Expand Down Expand Up @@ -403,7 +427,8 @@ def evaluate(self) -> None:
"""Evaluate the model."""
self.set_models_to_mode(mode="eval")
self._call_callbacks(event_name="on_evaluate_begin")
self.compute_evaluation()
with autocast(dtype=self.dtype, enabled=self.config.training.automatic_mixed_precision):
self.compute_evaluation()
self._call_callbacks(event_name="on_evaluate_end")
self.set_models_to_mode(mode="train")

Expand Down